diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu index 7ab4851321..0d2500fac2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu @@ -118,6 +118,52 @@ __global__ void AddBiasTransposeQKV(const T* input, const T* biases, T* output) } } +template +__global__ void AddBiasTransposeQKV(const T* input, const T* biases, T* output, int v_head_size) { + // Input: BxSxMxNxH (Format 1) + // Output: MxBxNxSxH + // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size + + int n = threadIdx.y; // head_num_id + int s = blockIdx.x; // sequence_id + int b = blockIdx.y; // batch_id + int m = blockIdx.z; // matrix id (Q=0, K=1, V=2) + const int h = threadIdx.x; // head_element_id + + const int qk_head_size = blockDim.x; + const int num_heads = blockDim.y; + + const int sequence_length = gridDim.x; + const int batch_size = gridDim.y; + + const int qkv_head_sizes[3] = {qk_head_size, qk_head_size, v_head_size}; + + const int total_head_size = num_heads * (qkv_head_sizes[0] + qkv_head_sizes[1] + qkv_head_sizes[2]); + + int in_offset; + int out_offset; + int bias_offset; + in_offset = b * (total_head_size * sequence_length) + // B + s * (total_head_size) + // S + m * (qk_head_size * num_heads) + // M + n * qkv_head_sizes[m] + // N + h; // H + + out_offset = m * (num_heads * qk_head_size * sequence_length * batch_size) + // M + b * (num_heads * qkv_head_sizes[m] * sequence_length) + // B + n * (sequence_length * qkv_head_sizes[m]) + // N + s * (qkv_head_sizes[m]) + // S + h; // H + + bias_offset = m * (num_heads * qk_head_size)+ // QKV + n * (qkv_head_sizes[m]) + // N + h; // H + + if (h < qkv_head_sizes[m]) { + output[out_offset] = input[in_offset] + biases[bias_offset]; + } +} + template __global__ void AddBiasTransposeQKVLarge(const int head_size, const T* input, const T* biases, T* output) { int n = threadIdx.y; @@ -203,26 +249,30 @@ __global__ void AddBiasTransposeLarge(const int head_size, const T* input, const template void InvokeAddBiasTranspose( cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block, - const int batch_size, const int sequence_length, const int num_heads, const int head_size, - const T* input, const T* biases, T* output) { + const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, + const T* input, const T* biases, T* output, const int v_head_size) { const dim3 grid(sequence_length, batch_size, num_matrices); - if (head_size * num_heads <= max_threads_per_block) { - const dim3 block(head_size, num_heads, 1); + if (qk_head_size * num_heads <= max_threads_per_block) { + const dim3 block(qk_head_size, num_heads, 1); if (format == 2) { AddBiasTransposeTrt<<>>(input, biases, output); } else if (format == 1) { - AddBiasTransposeQKV<<>>(input, biases, output); + if ((v_head_size == -1) || (qk_head_size == v_head_size)) { + AddBiasTransposeQKV<<>>(input, biases, output); + } else { + AddBiasTransposeQKV<<>>(input, biases, output, v_head_size); + } } else { AddBiasTranspose<<>>(input, biases, output); } } else { const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); if (format == 2) { - AddBiasTransposeTrtLarge<<>>(head_size, input, biases, output); + AddBiasTransposeTrtLarge<<>>(qk_head_size, input, biases, output); } else if (format == 1) { - AddBiasTransposeQKVLarge<<>>(head_size, input, biases, output); + AddBiasTransposeQKVLarge<<>>(qk_head_size, input, biases, output); } else { - AddBiasTransposeLarge<<>>(head_size, input, biases, output); + AddBiasTransposeLarge<<>>(qk_head_size, input, biases, output); } } } @@ -230,53 +280,55 @@ void InvokeAddBiasTranspose( template <> void LaunchAddBiasTranspose( cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block, - const int batch_size, const int sequence_length, const int num_heads, const int head_size, + const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, const half* input, const half* biases, half* output, - bool enable_half4) { - if (enable_half4 && 0 == (head_size % 4)) { - const int H = head_size / 4; + bool enable_half4, const int v_head_size) { + if (enable_half4 && 0 == (qk_head_size % 4) && 0 == (v_head_size % 4)) { + const int H_q = qk_head_size / 4; + const int H_v = v_head_size / 4; const Half4* input2 = reinterpret_cast(input); const Half4* biases2 = reinterpret_cast(biases); Half4* output2 = reinterpret_cast(output); InvokeAddBiasTranspose(stream, num_matrices, format, max_threads_per_block, - batch_size, sequence_length, num_heads, H, input2, biases2, output2); - } else if (0 == (head_size & 1)) { - const int H = head_size / 2; + batch_size, sequence_length, num_heads, H_q, input2, biases2, output2, H_v); + } else if (0 == (qk_head_size & 1) && 0 == (v_head_size % 1)) { + const int H_q = qk_head_size / 2; + const int H_v = v_head_size / 2; const half2* input2 = reinterpret_cast(input); const half2* biases2 = reinterpret_cast(biases); half2* output2 = reinterpret_cast(output); InvokeAddBiasTranspose(stream, num_matrices, format, max_threads_per_block, - batch_size, sequence_length, num_heads, H, input2, biases2, output2); + batch_size, sequence_length, num_heads, H_q, input2, biases2, output2, H_v); } else { InvokeAddBiasTranspose(stream, num_matrices, format, max_threads_per_block, - batch_size, sequence_length, num_heads, head_size, input, biases, output); + batch_size, sequence_length, num_heads, qk_head_size, input, biases, output, v_head_size); } } template <> void LaunchAddBiasTranspose( cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block, - const int batch_size, const int sequence_length, const int num_heads, const int head_size, + const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, const float* input, const float* biases, float* output, - bool /*enable_half4*/) { - if (0 == (head_size % 4)) { - const int H = head_size / 4; + bool /*enable_half4*/, const int v_head_size) { + if (0 == (qk_head_size % 4)) { + const int H = qk_head_size / 4; const float4* input2 = reinterpret_cast(input); const float4* biases2 = reinterpret_cast(biases); float4* output2 = reinterpret_cast(output); InvokeAddBiasTranspose(stream, num_matrices, format, max_threads_per_block, - batch_size, sequence_length, num_heads, H, input2, biases2, output2); - } else if (0 == (head_size & 1)) { - const int H = head_size / 2; + batch_size, sequence_length, num_heads, H, input2, biases2, output2, v_head_size / 4); + } else if (0 == (qk_head_size & 1)) { + const int H = qk_head_size / 2; const float2* input2 = reinterpret_cast(input); const float2* biases2 = reinterpret_cast(biases); float2* output2 = reinterpret_cast(output); InvokeAddBiasTranspose(stream, num_matrices, format, max_threads_per_block, - batch_size, sequence_length, num_heads, H, input2, biases2, output2); + batch_size, sequence_length, num_heads, H, input2, biases2, output2, v_head_size / 2); } else { InvokeAddBiasTranspose(stream, num_matrices, format, max_threads_per_block, - batch_size, sequence_length, num_heads, head_size, input, biases, output); + batch_size, sequence_length, num_heads, qk_head_size, input, biases, output, v_head_size); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h index 427b8003eb..0917d09b62 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h @@ -24,8 +24,8 @@ namespace cuda { template void LaunchAddBiasTranspose( cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block, - const int batch_size, const int sequence_length, const int num_heads, const int head_size, - const T* input, const T* biases, T* output, bool enable_half4); + const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, + const T* input, const T* biases, T* output, bool enable_half4, const int v_head_size); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index a652ff41ba..85c801d8d9 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -82,18 +82,31 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // bias shape (3 * hidden_size) const auto& bias_shape = bias->Shape(); - int hidden_size = static_cast(bias_shape[0]) / 3; + int q_hidden_size; + int k_hidden_size; + int v_hidden_size; - int head_size = hidden_size / num_heads_; + + if (qkv_hidden_sizes_.size() == 0) { + q_hidden_size = static_cast(bias_shape[0]) / 3; + k_hidden_size = static_cast(bias_shape[0]) / 3; + v_hidden_size = static_cast(bias_shape[0]) / 3; + } else { + q_hidden_size = static_cast(qkv_hidden_sizes_[0]); + k_hidden_size = static_cast(qkv_hidden_sizes_[1]); + v_hidden_size = static_cast(qkv_hidden_sizes_[2]); + } + + const int qkv_head_size[3] = {q_hidden_size / num_heads_, k_hidden_size / num_heads_, v_hidden_size / num_heads_}; TensorShapeVector output_shape(3); output_shape[0] = shape[0]; output_shape[1] = shape[1]; - output_shape[2] = static_cast(hidden_size); + output_shape[2] = static_cast(v_hidden_size); Tensor* output = context->Output(0, output_shape); int past_sequence_length = 0; - Tensor* present = GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length); + Tensor* present = GetPresent(context, past, batch_size, qkv_head_size[1], sequence_length, past_sequence_length); // Check whether we can use fused kernel int sm = device_prop.major * 10 + device_prop.minor; @@ -103,12 +116,14 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { nullptr == present && nullptr == extra_add_qk && !is_unidirectional_ && - HasFusedFp16Kernel(sm, head_size, sequence_length)); + qkv_head_size[0] == qkv_head_size[1] && + qkv_head_size[1] == qkv_head_size[2] && + HasFusedFp16Kernel(sm, qkv_head_size[0], sequence_length)); MHARunner* fused_runner = nullptr; if (use_fused_runner) { if (nullptr == fused_fp16_runner_.get()) { - fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, head_size, sm)); + fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, qkv_head_size[0], sm)); } // In case some kernel not loaded due to shared memory limit, we need to double check here. if (fused_fp16_runner_->isValid(sequence_length)) { @@ -121,9 +136,9 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // Use GEMM for fully connection. int m = batch_size * sequence_length; - int n = 3 * hidden_size; + int n = (q_hidden_size + k_hidden_size + v_hidden_size); int k = input_hidden_size; - size_t gemm_buffer_size = static_cast(batch_size) * sequence_length * 3 * hidden_size * element_size; + size_t gemm_buffer_size = static_cast(batch_size) * sequence_length * n * element_size; auto gemm_buffer = GetScratchBuffer(gemm_buffer_size); typedef typename ToCudaType::MappedType CudaT; @@ -140,10 +155,11 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, - head_size, + qkv_head_size[0], sequence_length, past_sequence_length, - fused_runner); + fused_runner, + qkv_head_size[2]); auto work_space = GetScratchBuffer(workSpaceSize); ORT_RETURN_IF_ERROR(LaunchAttentionKernel( @@ -154,7 +170,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { batch_size, sequence_length, num_heads_, - head_size, + qkv_head_size[0], past_sequence_length, is_unidirectional_, reinterpret_cast(gemm_buffer.get()), @@ -166,7 +182,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { work_space.get(), output->MutableData(), nullptr == present ? nullptr : present->MutableData(), - fused_runner)); + fused_runner, + qkv_head_size[2])); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 2cfc8b6d0d..91b722fc6b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -64,11 +64,13 @@ size_t GetAttentionWorkspaceSize( size_t element_size, size_t batch_size, size_t num_heads, - size_t head_size, + size_t qk_head_size, size_t sequence_length, size_t past_sequence_length, - void* fused_runner) { - size_t q_size = element_size * batch_size * sequence_length * num_heads * head_size; + void* fused_runner, + size_t v_head_size) { + size_t q_size = element_size * batch_size * sequence_length * num_heads * qk_head_size; + size_t v_size = element_size * batch_size * sequence_length * num_heads * v_head_size; if (fused_runner != nullptr) { // Offsets without padding is B + 1. When we add padding, the size need to increase to 2B + 1. @@ -76,7 +78,7 @@ size_t GetAttentionWorkspaceSize( return 4 * q_size + reinterpret_cast(fused_runner)->getWorkspaceSize() + sequenceOffsetBytes; } - return 3 * q_size + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, + return (2 * q_size + v_size) + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, past_sequence_length + sequence_length); } @@ -88,7 +90,7 @@ Status QkvToContext( const int batch_size, const int sequence_length, const int num_heads, - const int head_size, + const int qk_head_size, const size_t element_size, const T* input, const T* bias, @@ -102,34 +104,40 @@ Status QkvToContext( const T* extra_add_qk, T* present, bool use_persistent_softmax, - MHARunner* fused_runner) { + MHARunner* fused_runner, + const int v_head_size) { + const int max_threads_per_block = prop.maxThreadsPerBlock; // input should be BxSx3xNxH => qkv: 3xBxNxSxH T* qkv = workspace; if (bias == nullptr) { - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 3, sequence_length, batch_size, head_size, num_heads, + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 3, sequence_length, batch_size, qk_head_size, num_heads, max_threads_per_block, false, input, qkv)); + } else { // For fused TRT attention, qkv need transpose to BxSxNx3xH const int format = (nullptr == fused_runner ? 1 : 2); const bool enable_half4 = true; LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, batch_size, - sequence_length, num_heads, head_size, + sequence_length, num_heads, qk_head_size, input, bias, qkv, - enable_half4); + enable_half4, v_head_size); CUDA_RETURN_IF_ERROR(cudaGetLastError()); } // Q, K, V has size BxNxSxH const int batches = batch_size * num_heads; - const int size_per_batch = sequence_length * head_size; - const int total_size = batches * size_per_batch; + const int size_per_batch_qk = sequence_length * qk_head_size; + const int size_per_batch_v = sequence_length * v_head_size; + const int total_size_qk = batches * size_per_batch_qk; + + T* scratch1; + scratch1 = qkv + (batches * sequence_length * (qk_head_size + qk_head_size + v_head_size)); - T* scratch1 = qkv + 3 * total_size; T* temp_output = scratch1; if (nullptr != fused_runner && bias != nullptr) { - int* sequence_offset = reinterpret_cast(qkv + 4 * total_size); + int* sequence_offset = reinterpret_cast(qkv + 4 * total_size_qk); LaunchTrtSequenceOffset(sequence_offset, mask_index, batch_size, stream); CUDA_RETURN_IF_ERROR(cudaGetLastError()); @@ -147,44 +155,45 @@ Status QkvToContext( T* scratch2 = scratch1 + (bytes / element_size); const T* q = qkv; - const T* k = q + total_size; - const T* v = k + total_size; + const T* k = q + (batches * sequence_length * qk_head_size); + const T* v = k + (batches * sequence_length * qk_head_size); cublasSetStream(cublas, stream); // Concat past (2xBxNxS'xH) to present (2xBxNxS*xH): // past_k (BxNxS'xH) + k (BxNxSxH) => present_k (BxNxS*xH) // past_v (BxNxS'xH) + v (BxNxSxH) => present_v (BxNxS*xH) - const int present_size_per_batch = all_sequence_length * head_size; + const int present_size_per_batch_k = all_sequence_length * qk_head_size; + const int present_size_per_batch_v = all_sequence_length * v_head_size; if (nullptr != present) { ORT_RETURN_IF_ERROR( - LaunchConcatPastToPresent(stream, all_sequence_length, sequence_length, batch_size, head_size, num_heads, + LaunchConcatPastToPresent(stream, all_sequence_length, sequence_length, batch_size, qk_head_size, num_heads, max_threads_per_block, past, k, present)); // update pointers to present_k and present_v. k = present; - v = present + batches * present_size_per_batch; + v = present + batches * present_size_per_batch_k; } // Raw attention mask could be 2D (BxS) or 3D (BxSxS*) or 4D(Bx1xMxM), where M is the max sequence length. bool use_raw_attention_mask = (nullptr != mask_index && mask_index_dims.size() >= 2); - // compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS* + // compute Q*K' (as K'*Q), scaled by 1/sqrt(H_qk) and store in scratch1: BxNxSxS* // Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS* - const float rsqrt_head_size = 1.f / sqrt(static_cast(head_size)); + const float rsqrt_head_size = 1.f / sqrt(static_cast(qk_head_size)); const int temp_matrix_size = sequence_length * all_sequence_length; float one = 1.0f; float zero = 0.f; - // For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation. + // For raw attention mask, the scalar 1/sqrt(H_qk) is moved to combine with softmax computation. float alpha = use_raw_attention_mask ? one : rsqrt_head_size; CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_T, CUBLAS_OP_N, - all_sequence_length, sequence_length, head_size, - &alpha, k, head_size, present_size_per_batch, - q, head_size, size_per_batch, - &zero, scratch1, all_sequence_length, temp_matrix_size, batches, prop)); + cublas, CUBLAS_OP_T, CUBLAS_OP_N, + all_sequence_length, sequence_length, qk_head_size, + &alpha, k, qk_head_size, present_size_per_batch_k, + q, qk_head_size, size_per_batch_qk, + &zero, scratch1, all_sequence_length, temp_matrix_size, batches, prop)); // apply softmax and store result P to scratch2: BxNxSxS* if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask @@ -214,13 +223,13 @@ Status QkvToContext( temp_output = qkv; CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, - head_size, sequence_length, all_sequence_length, - &one, v, head_size, present_size_per_batch, + v_head_size, sequence_length, all_sequence_length, + &one, v, v_head_size, present_size_per_batch_v, scratch2, all_sequence_length, temp_matrix_size, - &zero, temp_output, head_size, size_per_batch, batches, prop)); + &zero, temp_output, v_head_size, size_per_batch_v, batches, prop)); // temp_output is BxNxSxH, transpose to output BxSxNxH - return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, + return LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads, max_threads_per_block, false, temp_output, output); } @@ -232,7 +241,7 @@ Status LaunchAttentionKernel( int batch_size, int sequence_length, int num_heads, - int head_size, + const int qk_head_size, int past_sequence_length, bool is_unidirectional, const void* input, @@ -244,13 +253,14 @@ Status LaunchAttentionKernel( void* workspace, void* output, void* present, - void* fused_runner) { + void* fused_runner, + const int v_head_size) { // For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 could enable persistent softmax used in Torch. const TransformerOptions* options = TransformerOptions::GetInstance(); bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax(); if (element_size == 2) { - return QkvToContext(prop, cublas, stream, batch_size, sequence_length, num_heads, head_size, element_size, + return QkvToContext(prop, cublas, stream, batch_size, sequence_length, num_heads, qk_head_size, element_size, reinterpret_cast(input), reinterpret_cast(bias), reinterpret_cast(output), @@ -263,9 +273,10 @@ Status LaunchAttentionKernel( reinterpret_cast(extra_add_qk), reinterpret_cast(present), use_persistent_softmax, - reinterpret_cast(fused_runner)); + reinterpret_cast(fused_runner), + v_head_size); } else { - return QkvToContext(prop, cublas, stream, batch_size, sequence_length, num_heads, head_size, element_size, + return QkvToContext(prop, cublas, stream, batch_size, sequence_length, num_heads, qk_head_size, element_size, reinterpret_cast(input), reinterpret_cast(bias), reinterpret_cast(output), @@ -278,7 +289,8 @@ Status LaunchAttentionKernel( reinterpret_cast(extra_add_qk), reinterpret_cast(present), use_persistent_softmax, - nullptr); + nullptr, + v_head_size); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index 954a4ba734..304686ba0e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -21,10 +21,11 @@ size_t GetAttentionWorkspaceSize( size_t element_size, size_t batchsize, size_t num_heads, - size_t head_size, + size_t qk_head_size, size_t sequence_length, size_t past_sequence_length, - void* fused_runner); + void* fused_runner, + size_t v_head_size); Status LaunchAttentionKernel( const cudaDeviceProp& prop, // Device Properties @@ -34,7 +35,7 @@ Status LaunchAttentionKernel( int batch_size, // Batch size (B) int sequence_length, // Sequence length (S) int num_heads, // Number of attention heads (N) - int head_size, // Hidden layer size per head (H) + const int qk_head_size, // Hidden layer size per head for q and k (H_qk) int past_sequence_length, // Sequence length in past state bool is_unidirectional, // Whether there is unidirecitonal mask. const void* input, // Input tensor @@ -46,8 +47,8 @@ Status LaunchAttentionKernel( void* workspace, // Temporary buffer void* output, // Output tensor void* present, // Present state output - void* fused_runner // Fused multi-head attention -); + void* fused_runner, // Fused multi-head attention + const int v_head_size); // Hidden layer size per head for v (H_v) Status LaunchDecoderAttentionKernel( const cudaDeviceProp& prop, // Device Properties diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu index 5141fd117c..aec8d7fa3b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu @@ -868,6 +868,7 @@ Status LongformerQkvToContext( // The order of qkv space: // Q, K, V, Global_K, Global_V, Global_Q (format 0) // Q, K, V, Global_Q, Global_K, Global_V (format 1) + // Assume H_q == H_k == H_v if (format == 1 || max_num_global == 0 || nullptr == global_input) { if (bias == nullptr) { ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 3, sequence_length, batch_size, head_size, num_heads, @@ -876,7 +877,7 @@ Status LongformerQkvToContext( LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, batch_size, sequence_length, num_heads, head_size, input, bias, qkv, - use_half4); + use_half4, head_size); } if (max_num_global > 0 && nullptr != global_input) { @@ -887,20 +888,20 @@ Status LongformerQkvToContext( LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, batch_size, sequence_length, num_heads, head_size, global_input, global_bias, qkv + 3 * elements, - use_half4); + use_half4, head_size); } } } else { LaunchAddBiasTranspose(stream, 5, format, max_threads_per_block, batch_size, sequence_length, num_heads, head_size, input, bias, qkv, - use_half4); + use_half4, head_size); compact_global_q = (disable_compact_memory == false); LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, batch_size, compact_global_q ? max_num_global : sequence_length, num_heads, head_size, global_input + 2 * elements, global_bias, qkv + 5 * elements, - use_half4); + use_half4, head_size); } CUDA_RETURN_IF_ERROR(cudaGetLastError()); diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index aaa764cf2e..a89174fc35 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -119,7 +119,8 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { const auto& bias_shape = bias->Shape(); const int hidden_size = SafeInt(bias_shape.GetDims()[0]) / 3; - const int head_size = hidden_size / num_heads_; + // Note: Scenario where q_hidden_size == k_hidden_size != v_hidden_size is not supported in quantization + const int qkv_head_size[3] = {hidden_size / num_heads_, hidden_size / num_heads_, hidden_size / num_heads_}; TensorShapeVector output_shape(3); output_shape[0] = shape[0]; @@ -166,12 +167,12 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { n)); int past_sequence_length = 0; - Tensor* present_tensor = GetPresent(context, past_tensor, batch_size, head_size, + Tensor* present_tensor = GetPresent(context, past_tensor, batch_size, qkv_head_size[1], sequence_length, past_sequence_length); void* fused_runner = nullptr; // TODO(tianleiwu): use fused kernel to speed up - size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, head_size, - sequence_length, past_sequence_length, fused_runner); + size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, qkv_head_size[0], + sequence_length, past_sequence_length, fused_runner, qkv_head_size[2]); auto work_space = GetScratchBuffer(workSpaceSize); return LaunchAttentionKernel( @@ -182,7 +183,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { batch_size, sequence_length, num_heads_, - head_size, + qkv_head_size[0], past_sequence_length, is_unidirectional_, reinterpret_cast(gemm_buffer.get()), @@ -194,7 +195,8 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { work_space.get(), output->MutableData(), nullptr == present_tensor ? nullptr : present_tensor->MutableData(), - fused_runner); + fused_runner, + qkv_head_size[2]); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.cc b/onnxruntime/contrib_ops/rocm/bert/attention.cc index df4e2bc3bc..12e7d7a190 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention.cc +++ b/onnxruntime/contrib_ops/rocm/bert/attention.cc @@ -50,6 +50,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { int sequence_length = static_cast(shape[1]); int input_hidden_size = static_cast(shape[2]); + // Note: Scenario where q_hidden_size == k_hidden_size != v_hidden_size is not supported in ROCM EP // bias shape (3 * hidden_size) const auto& bias_shape = bias->Shape(); int hidden_size = static_cast(bias_shape[0]) / 3; diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index 27dc95f258..af68d6f01b 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -41,12 +41,13 @@ static void RunAttentionTest( bool only_enable_cuda = false, bool only_enable_cpu = false, std::vector qkv_sizes = {}, - const std::vector& extra_add_data = {}) { + const std::vector& extra_add_data = {}, + const bool disable_rocm = false) { input_hidden_size = (input_hidden_size == 0 ? hidden_size : input_hidden_size); // By default, no pruning. int min_cuda_architecture = use_float16 ? 530 : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture) && !is_weights_constant && !only_enable_cpu; - bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()) && !is_weights_constant && !only_enable_cpu; + bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()) && !is_weights_constant && !only_enable_cpu && !disable_rocm; bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16 && !only_enable_cuda; int head_size = hidden_size / number_of_heads; @@ -188,17 +189,18 @@ static void RunAttentionTest( bool only_enable_cuda = false, bool only_enable_cpu = false, const std::vector qkv_sizes = {}, - const std::vector& extra_add_data = {}) { + const std::vector& extra_add_data = {}, + const bool disable_rocm = false) { RunAttentionTest(input_data, weights_data, false, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, mask_index_type, input_hidden_size, max_sequence_length, - only_enable_cuda, only_enable_cpu, qkv_sizes, extra_add_data); + only_enable_cuda, only_enable_cpu, qkv_sizes, extra_add_data, disable_rocm); RunAttentionTest(input_data, weights_data, true, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, mask_index_type, input_hidden_size, max_sequence_length, - only_enable_cuda, only_enable_cpu, qkv_sizes, extra_add_data); + only_enable_cuda, only_enable_cpu, qkv_sizes, extra_add_data, disable_rocm); } TEST(AttentionTest, AttentionBatch1) { @@ -267,7 +269,11 @@ TEST(AttentionTest, AttentionBatch1WithQKVAttr1) { RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, false, false, false, 0, nullptr, nullptr, kMaskIndexEnd, 0, - 0, false, true, qkv_sizes); + 0, true, false, qkv_sizes, {}, true); + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, + false, false, false, 0, nullptr, nullptr, kMaskIndexEnd, 0, + 0, false, true, qkv_sizes, {}, true); } TEST(AttentionTest, AttentionBatch1WithQKVAttr2) { @@ -304,7 +310,11 @@ TEST(AttentionTest, AttentionBatch1WithQKVAttr2) { RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, false, false, false, 0, nullptr, nullptr, kMaskIndexEnd, 0, - 0, false, true, qkv_sizes); + 0, true, false, qkv_sizes, {}, true); + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, + false, false, false, 0, nullptr, nullptr, kMaskIndexEnd, 0, + 0, false, true, qkv_sizes, {}, true); } TEST(AttentionTest, AttentionBatch1ExtraAdd) {