From 94c69f55d480cb4a8dcbc161d29ef3acca9392a7 Mon Sep 17 00:00:00 2001 From: aciddelgado <139922440+aciddelgado@users.noreply.github.com> Date: Mon, 22 Apr 2024 19:57:05 -0700 Subject: [PATCH] GQA 4 CPU (#20299) ### Description Support GQA operator on CPU with FP32. ### Motivation and Context Right now, models generated for CPU and GPU must be different. GQA CPU allows these models to be the same. --- docs/ContribOperators.md | 5 +- docs/OperatorKernels.md | 1 + .../contrib_ops/cpu/bert/attention_base.h | 1 + .../contrib_ops/cpu/bert/attention_helper.h | 70 + .../contrib_ops/cpu/bert/attention_utils.cc | 246 +++ .../contrib_ops/cpu/bert/attention_utils.h | 61 + .../contrib_ops/cpu/bert/gqa_attention_base.h | 276 +++ .../cpu/bert/group_query_attention.cc | 192 ++ .../cpu/bert/group_query_attention.h | 21 + .../cpu/bert/group_query_attention_helper.h | 299 +++ .../cpu/bert/multihead_attention.cc | 215 +- .../contrib_ops/cpu/bert/rotary_embedding.cc | 129 +- .../contrib_ops/cpu/bert/rotary_embedding.h | 6 + .../cpu/bert/rotary_embedding_helper.h | 25 +- .../contrib_ops/cpu/cpu_contrib_kernels.cc | 2 + .../core/graph/contrib_ops/bert_defs.cc | 5 +- ..._flash_attn.py => test_flash_attn_cuda.py} | 0 .../test/python/transformers/test_gqa_cpu.py | 1884 +++++++++++++++++ 18 files changed, 3159 insertions(+), 279 deletions(-) create mode 100644 onnxruntime/contrib_ops/cpu/bert/attention_utils.cc create mode 100644 onnxruntime/contrib_ops/cpu/bert/attention_utils.h create mode 100644 onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h create mode 100644 onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc create mode 100644 onnxruntime/contrib_ops/cpu/bert/group_query_attention.h create mode 100644 onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h rename onnxruntime/test/python/transformers/{test_flash_attn.py => test_flash_attn_cuda.py} (100%) create mode 100644 onnxruntime/test/python/transformers/test_gqa_cpu.py diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 3d984a54c0..8d3d807514 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2455,6 +2455,9 @@ This version of the operator has been available since version 1 of the 'com.micr Group Query Self/Cross Attention. Supports different number of heads for q and kv. Only supports causal or local attention. + Supports rotary position embedding. + Supports k-v cache. + CPU EP supports fp32... CUDA EP supports fp16. #### Version @@ -2514,7 +2517,7 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(float16), tensor(bfloat16)
+
T : tensor(float16), tensor(bfloat16), tensor(float)
Constrain input and output to float tensors.
M : tensor(int32)
Constrain mask to int tensor.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index f54f264255..8fa67ee172 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -481,6 +481,7 @@ Do not modify directly.* |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h index a6782daa58..af902a713e 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h @@ -68,6 +68,7 @@ class AttentionBase { const Tensor* past_seq_len = nullptr) const; int num_heads_; // number of attention heads + int kv_num_heads_; // different for k and v for group query attention bool is_unidirectional_; // whether every token can only attend to previous tokens. std::vector qkv_hidden_sizes_; // Q, K, V hidden sizes parsed from the qkv_hidden_sizes attribute. bool require_same_hidden_size_; // whether the implementation supports different hidden sizes of Q/K/V. diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index f1a0ce994e..5d37213303 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -153,6 +153,49 @@ void PrepareMask(const int32_t* mask_index, } } +// Applies causal mask and past seqlens right pad to the mask_data buffer. +template +void PrepareMaskGQA(T* mask_data, + int batch_size, + int sequence_length, + int buffer_sequence_length, + int local_window_size, + const int32_t* seqlens_k) { + // mask_data has been filled with 0, and its shape is BxSxT + T* p_mask = mask_data; + // TODO: parallelize this + for (int b_i = 0; b_i < batch_size; b_i++) { + if (sequence_length > 1) { + // Apply causal/local mask for prompt case. + for (int s_i = 0; s_i < sequence_length; s_i++) { + for (int m_i = s_i + 1; m_i < buffer_sequence_length; m_i++) { + p_mask[s_i * buffer_sequence_length + m_i] = std::numeric_limits::lowest(); + } + // Apply local mask. + if (local_window_size > 0) { + for (int m_i = 0; m_i < s_i - local_window_size; m_i++) { + p_mask[s_i * buffer_sequence_length + m_i] = std::numeric_limits::lowest(); + } + } + } + } else if (sequence_length == 1) { + // Apply right padding to mask for token gen case. + int total_seqlen = seqlens_k[b_i] + 1; + for (int m_i = total_seqlen; m_i < buffer_sequence_length; m_i++) { + p_mask[m_i] = std::numeric_limits::lowest(); + } + // Apply local mask. + if (local_window_size > 0) { + for (int m_i = 0; m_i < total_seqlen - local_window_size - 1; m_i++) { + p_mask[m_i] = std::numeric_limits::lowest(); + } + } + } + ptrdiff_t mask_to_advance = SafeInt(sequence_length) * buffer_sequence_length; + p_mask += mask_to_advance; + } +} + // Concatenate a past state chunk PxH with input state chunk LxH into present state chunk TxH // Returns a pointer to the start of present state chunk. template @@ -175,5 +218,32 @@ T* ConcatStateChunk(const T* past, return start; } +// GQA version of ConcatStateChunk +template +T* ConcatStateChunkGQA(const T* past, + const T* chunk, + T* present, + size_t present_buff_chunk_length, + size_t past_buff_chunk_length, + size_t past_chunk_length, + size_t new_chunk_length, + bool is_prompt, + bool past_present_share_buffer, + std::ptrdiff_t i) { + T* start = present + i * present_buff_chunk_length; + + T* p = start; + if (!is_prompt) { + if (!past_present_share_buffer) { + const T* src_past = past + i * past_buff_chunk_length; + memcpy(p, src_past, past_chunk_length * sizeof(T)); + } + p += past_chunk_length; + } + + memcpy(p, chunk, new_chunk_length * sizeof(T)); + return start; +} + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc b/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc new file mode 100644 index 0000000000..7b84971585 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc @@ -0,0 +1,246 @@ +#include "attention_utils.h" +#include "core/common/common.h" +#include "core/framework/tensorprotoutils.h" +#include "core/framework/transpose_helper.h" +#include "core/providers/cpu/tensor/reshape_helper.h" +#include "core/providers/cpu/math/element_wise_ops.h" + +using onnxruntime::concurrency::ThreadPool; + +namespace onnxruntime { +namespace contrib { + +// Reshape Q/K/V from BxSxD to BxSxNxH +inline Status Reshape_BSD_to_BSNH(Tensor* qkv, + int batch_size, + int sequence_length, + int num_heads, + int head_size) { + qkv->Reshape(TensorShape({batch_size, sequence_length, num_heads, head_size})); + return Status::OK(); +} + +// Transpose Q/K/V from BxSxNxH to BxNxSxH +inline Status Transpose_BSNH_to_BNSH(const Tensor* qkv, + OrtValue& qkv_transposed, + concurrency::ThreadPool* tp) { + std::vector permutations({0, 2, 1, 3}); + gsl::span permutations_span{permutations}; + size_t from = 2, to = 1; + SingleAxisTranspose(permutations, *qkv, *qkv_transposed.GetMutable(), from, to, nullptr, tp); + return Status::OK(); +} + +// Add bias + transpose for each of Q/K/V +template +Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V data - query is BxSxD, key is BxLxD, value is BxLxD_v + const T* qkv_bias, // Input: QKV bias - bias is (D + D + D_v) + OrtValue& qkv_with_bias_transposed, // Output: Q/K/V data - query is BxNxSxH, key is BxNxLxH, value is BxNxLxH_v + int bias_offset, // bias offset to enter qkv_bias + int batch_size, // batch size + int sequence_length, // sequence_length for Q, kv_sequence_length for K/V + int num_heads, // num heads + int head_size, // head_size for Q/K, v_head_size for V + int hidden_size, // hidden_size for Q/K, v_hidden_size for V + OpKernelContext* context) { + // Note: the comments below will refer to Q's dimensions for simplicity + auto element_type = DataTypeImpl::GetType(); + constexpr size_t element_size = sizeof(T); + ProcessBroadcastSpanFuncs add_funcs{ + [](BroadcastHelper& per_iter_bh) { + per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array(); + }, + [](BroadcastHelper& per_iter_bh) { + per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1(); + }, + [](BroadcastHelper& per_iter_bh) { + per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1(); + }}; // For element-wise add + + // Allocate space for output of Q(BS, D) + bias(D) + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + std::vector old_dims({batch_size, sequence_length, hidden_size}); + gsl::span old_dims_span{old_dims}; + TensorShape qkv_with_bias_shape(old_dims_span); + OrtValue qkv_with_bias; + Tensor::InitOrtValue(element_type, qkv_with_bias_shape, allocator, qkv_with_bias); + + // Get Q's bias from combined bias + std::vector bias_dims({hidden_size}); + gsl::span bias_dims_span{bias_dims}; + TensorShape bias_shape(bias_dims_span); + OrtValue bias; + Tensor::InitOrtValue(element_type, bias_shape, allocator, bias); + memcpy(bias.GetMutable()->MutableData(), qkv_bias + bias_offset, hidden_size * element_size); + + // Compute Q(BS, D) + bias(D) as broadcasted element-wise add + { + InputBroadcaster input_broadcaster(*bias.GetMutable(), *qkv); + const InputBroadcaster& const_input_broadcaster = input_broadcaster; + Tensor& output_tensor = *qkv_with_bias.GetMutable(); + + size_t span_size = input_broadcaster.GetSpanSize(); + size_t output_size = static_cast(output_tensor.Shape().Size()); + void* user_data = nullptr; + + const int loop_len = static_cast(output_size / span_size); + double unit_cost = 1.0f; + const auto cost = TensorOpCost{static_cast(input_broadcaster.Input0ElementSize()) * span_size, + static_cast(output_tensor.DataType()->Size()) * span_size, + unit_cost * span_size}; + auto tp = context->GetOperatorThreadPool(); + ThreadPool::TryParallelFor(tp, loop_len, cost, + [span_size, &const_input_broadcaster, &output_tensor, &add_funcs, user_data](std::ptrdiff_t first_span, + std::ptrdiff_t last_span) { + InputBroadcaster segment_input_broadcaster(const_input_broadcaster); + segment_input_broadcaster.AdvanceBy(first_span * span_size); + + OutputBroadcaster segment_output_broadcaster(span_size, output_tensor, + first_span * span_size, last_span * span_size); + + BroadcastHelper segment_helper(segment_input_broadcaster, segment_output_broadcaster, user_data); + BroadcastLooper(segment_helper, add_funcs); + }); + } + + // Reshape Q from BxSxD to BxSxNxH + ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(qkv_with_bias.GetMutable(), batch_size, sequence_length, num_heads, head_size)); + + // Transpose Q from BxSxNxH to BxNxSxH + auto tp = context->GetOperatorThreadPool(); + ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable(), qkv_with_bias_transposed, tp)); + + return Status::OK(); +} + +// Add bias + reshape for each of Q/K/V +// This is used in decoder_with_past when the sequence length is 1 +template +Status AddBiasReshape(const Tensor* qkv, // Input: Q/K/V data - query is BxSxD, key is BxLxD, value is BxLxD_v + const T* qkv_bias, // Input: QKV bias - bias is (D + D + D_v) + OrtValue& qkv_with_bias, // Output: Q/K/V data - query is BxNxSxH, key is BxNxLxH, value is BxNxLxH_v + int bias_offset, // bias offset to enter qkv_bias + int batch_size, // batch size + int sequence_length, // sequence_length for Q, kv_sequence_length for K/V + int num_heads, // num heads + int head_size, // head_size for Q/K, v_head_size for V + int hidden_size, // hidden_size for Q/K, v_hidden_size for V + OpKernelContext* context) { + // Note: the comments below will refer to Q's dimensions for simplicity + auto element_type = DataTypeImpl::GetType(); + constexpr size_t element_size = sizeof(T); + ProcessBroadcastSpanFuncs add_funcs{ + [](BroadcastHelper& per_iter_bh) { + per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array(); + }, + [](BroadcastHelper& per_iter_bh) { + per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1(); + }, + [](BroadcastHelper& per_iter_bh) { + per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1(); + }}; // For element-wise add + + // Get Q's bias from combined bias + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + std::vector bias_dims({hidden_size}); + gsl::span bias_dims_span{bias_dims}; + TensorShape bias_shape(bias_dims_span); + OrtValue bias; + Tensor::InitOrtValue(element_type, bias_shape, allocator, bias); + auto num_bias_elements = SafeInt(hidden_size) * element_size; + memcpy(bias.GetMutable()->MutableData(), qkv_bias + bias_offset, num_bias_elements); + + // Compute Q(BS, D) + bias(D) as broadcasted element-wise add + { + InputBroadcaster input_broadcaster(*bias.GetMutable(), *qkv); + const InputBroadcaster& const_input_broadcaster = input_broadcaster; + Tensor& output_tensor = *qkv_with_bias.GetMutable(); + + size_t span_size = input_broadcaster.GetSpanSize(); + size_t output_size = static_cast(output_tensor.Shape().Size()); + void* user_data = nullptr; + + const int loop_len = static_cast(output_size / span_size); + double unit_cost = 1.0f; + const auto cost = TensorOpCost{static_cast(input_broadcaster.Input0ElementSize()) * span_size, + static_cast(output_tensor.DataType()->Size()) * span_size, + unit_cost * span_size}; + auto tp = context->GetOperatorThreadPool(); + ThreadPool::TryParallelFor(tp, loop_len, cost, + [span_size, &const_input_broadcaster, &output_tensor, &add_funcs, user_data](std::ptrdiff_t first_span, + std::ptrdiff_t last_span) { + InputBroadcaster segment_input_broadcaster(const_input_broadcaster); + segment_input_broadcaster.AdvanceBy(first_span * span_size); + + OutputBroadcaster segment_output_broadcaster(span_size, output_tensor, + first_span * span_size, last_span * span_size); + + BroadcastHelper segment_helper(segment_input_broadcaster, segment_output_broadcaster, user_data); + BroadcastLooper(segment_helper, add_funcs); + }); + } + + // Reshape Q from BxSxD to BxNxSxH + qkv_with_bias.GetMutable()->Reshape(TensorShape({batch_size, num_heads, sequence_length, head_size})); + + return Status::OK(); +} + +template +Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context, AllocatorPtr allocator, + int batch_size, int num_heads, int sequence_length, int head_size, + const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out) { + auto element_type = DataTypeImpl::GetType(); + std::vector new_dims({batch_size, num_heads, sequence_length, head_size}); + gsl::span new_dims_span{new_dims}; + TensorShape v_BNLH(new_dims_span); + Tensor::InitOrtValue(element_type, v_BNLH, allocator, out); + if (bias == nullptr) { + std::unique_ptr reshaped; + if (in->Shape().GetDims().size() == 3) { + reshaped = std::make_unique(in->DataType(), in->Shape(), const_cast(in->DataRaw()), in->Location()); + ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(reshaped.get(), batch_size, sequence_length, num_heads, head_size)); + } + ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH((reshaped == nullptr) ? in : reshaped.get(), out)); + } else { + const auto* qkv_bias = bias->Data(); + if (sequence_length == 1) { + ORT_RETURN_IF_ERROR(AddBiasReshape(in, qkv_bias, out, bias_offset, batch_size, sequence_length, num_heads, head_size, num_heads * head_size, context)); + } else { + ORT_RETURN_IF_ERROR(AddBiasTranspose(in, qkv_bias, out, bias_offset, batch_size, sequence_length, num_heads, head_size, num_heads * head_size, context)); + } + } + return Status::OK(); +}; + +template Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context, AllocatorPtr allocator, + int batch_size, int num_heads, int sequence_length, int head_size, + const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out); + +template +Status MaybeTransposeToBNSH(AllocatorPtr allocator, + int batch_size, int num_heads, int sequence_length, int head_size, + const Tensor* in, OrtValue& out) { + auto element_type = DataTypeImpl::GetType(); + std::vector new_dims({batch_size, num_heads, sequence_length, head_size}); + gsl::span new_dims_span{new_dims}; + TensorShape v_BNLH(new_dims_span); + Tensor::InitOrtValue(element_type, v_BNLH, allocator, out); + std::unique_ptr reshaped; + if (in->Shape().GetDims().size() == 3) { + reshaped = std::make_unique(in->DataType(), in->Shape(), const_cast(in->DataRaw()), in->Location()); + ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(reshaped.get(), batch_size, sequence_length, num_heads, head_size)); + } + ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH((reshaped == nullptr) ? in : reshaped.get(), out)); + + return Status::OK(); +}; + +template Status MaybeTransposeToBNSH(AllocatorPtr allocator, + int batch_size, int num_heads, int sequence_length, int head_size, + const Tensor* in, OrtValue& out); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_utils.h b/onnxruntime/contrib_ops/cpu/bert/attention_utils.h new file mode 100644 index 0000000000..d7fb0c496b --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/attention_utils.h @@ -0,0 +1,61 @@ +#pragma once +#include "core/common/common.h" +#include "core/framework/tensorprotoutils.h" +#include "core/framework/transpose_helper.h" +#include "core/providers/cpu/tensor/reshape_helper.h" +#include "core/providers/cpu/math/element_wise_ops.h" + +namespace onnxruntime { +namespace contrib { + +// Reshape Q/K/V from BxSxD to BxSxNxH +Status Reshape_BSD_to_BSNH(Tensor* qkv, + int batch_size, + int sequence_length, + int num_heads, + int head_size); + +// Transpose Q/K/V from BxSxNxH to BxNxSxH +Status Transpose_BSNH_to_BNSH(const Tensor* qkv, + OrtValue& qkv_transposed, + concurrency::ThreadPool* tp = nullptr); + +// Add bias + transpose for each of Q/K/V +template +Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V data - query is BxSxD, key is BxLxD, value is BxLxD_v + const T* qkv_bias, // Input: QKV bias - bias is (D + D + D_v) + OrtValue& qkv_with_bias_transposed, // Output: Q/K/V data - query is BxNxSxH, key is BxNxLxH, value is BxNxLxH_v + int bias_offset, // bias offset to enter qkv_bias + int batch_size, // batch size + int sequence_length, // sequence_length for Q, kv_sequence_length for K/V + int num_heads, // num heads + int head_size, // head_size for Q/K, v_head_size for V + int hidden_size, // hidden_size for Q/K, v_hidden_size for V + OpKernelContext* context); // OpKernelContext + +// Add bias + reshape for each of Q/K/V +// This is used in decoder_with_past when the sequence length is 1 +template +Status AddBiasReshape(const Tensor* qkv, // Input: Q/K/V data - query is BxSxD, key is BxLxD, value is BxLxD_v + const T* qkv_bias, // Input: QKV bias - bias is (D + D + D_v) + OrtValue& qkv_with_bias, // Output: Q/K/V data - query is BxNxSxH, key is BxNxLxH, value is BxNxLxH_v + int bias_offset, // bias offset to enter qkv_bias + int batch_size, // batch size + int sequence_length, // sequence_length for Q, kv_sequence_length for K/V + int num_heads, // num heads + int head_size, // head_size for Q/K, v_head_size for V + int hidden_size, // hidden_size for Q/K, v_hidden_size for V + OpKernelContext* context); // OpKernelContext + +template +Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context, AllocatorPtr allocator, + int batch_size, int num_heads, int sequence_length, int head_size, + const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out); + +template +Status MaybeTransposeToBNSH(AllocatorPtr allocator, + int batch_size, int num_heads, int sequence_length, int head_size, + const Tensor* in, OrtValue& out); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h new file mode 100644 index 0000000000..b0ebc50215 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -0,0 +1,276 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "attention_base.h" +#include "attention_helper.h" + +#include "core/common/common.h" +#include "contrib_ops/cpu/bert/attention_common.h" +#include "core/common/safeint.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace contrib { + +class GQAAttentionBase : public AttentionBase { + protected: + GQAAttentionBase(const OpKernelInfo& info, bool require_same_hidden_size) + : AttentionBase(info, require_same_hidden_size) {} + + int local_window_size_; + bool do_rotary_; + bool rotary_interleaved_; + + template + Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH + const T* K, // K data with shape BxN_kvxSxH + const T* V, // V data with shape BxN_kvxSxH + const Tensor* past_key, // past K input tensor (if not using past state) + const Tensor* past_value, // past V input tensor (if not using past state) + Tensor* output, // output tensor + Tensor* present_key, // present K output tensor (if separating present KV) + Tensor* present_value, // present V output tensor (if separating present KV) + const Tensor* seqlens_k, // past sequence lengths tensor + GroupQueryAttentionParameters& parameters, // attention parameters + AllocatorPtr allocator, // allocator for temporary tensors + OpKernelContext* context) const { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int head_size = parameters.head_size; + const int hidden_size = parameters.hidden_size; + const bool packed_qkv = parameters.is_packed_qkv; + + auto* tp = context->GetOperatorThreadPool(); + + int seqlen_past_kv_cache = 0; + if (past_key != nullptr && past_value != nullptr) { + seqlen_past_kv_cache = static_cast(past_key->Shape().GetDims()[2]); + } + int seqlen_present_kv_cache = static_cast(present_key->Shape().GetDims()[2]); + + // Compute the attention score. + size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof(T); + auto attention_probs = allocator->Alloc(bytes); + BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator)); + + void* mask_data = nullptr; + size_t mask_data_bytes = SafeInt(batch_size) * sequence_length * seqlen_present_kv_cache * sizeof(T); + mask_data = allocator->Alloc(mask_data_bytes); + memset(mask_data, 0, mask_data_bytes); + BufferUniquePtr mask_data_buffer(mask_data, BufferDeleter(allocator)); + + const T* past_key_data = past_key != nullptr ? past_key->Data() : nullptr; + T* present_key_data = present_key != nullptr ? present_key->MutableData() : nullptr; + const T* past_value_data = past_value != nullptr ? past_value->Data() : nullptr; + T* present_value_data = present_value != nullptr ? present_value->MutableData() : nullptr; + + bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data; + + const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; + ComputeAttentionProbs(static_cast(attention_probs), Q, k, + seqlens_k->Data(), static_cast(mask_data), + batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, + head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, tp); + + // Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) + auto out_tmp_data = + allocator->Alloc(SafeInt(batch_size) * num_heads_ * sequence_length * head_size * sizeof(T)); + BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(allocator)); + + const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; + ComputeVxAttentionScore(output->MutableData(), static_cast(out_tmp_data), static_cast(attention_probs), + v, seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, + seqlen_present_kv_cache, head_size, hidden_size, past_value_data, present_value_data, + past_present_share_buffer, packed_qkv, tp); + + return Status::OK(); + } + + private: + // Helper function to compute the attention probs. It does 2 things: + // attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T) + + // 1 x mask_data(B, N, S, T) + // attention_probs(B, N, S, T) = Softmax(attention_probs) + template + void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT + const T* Q, // Q data. Its size is BxNxSxH + const T* K, // k data. Its size is BxNxLxH + const int32_t* seqlens_k, // past sequence lengths tensor + T* mask_data, // buffer for mask data. + int batch_size, // batch size of self-attention + int sequence_length, // sequence length of self-attention (S) + int past_buffer_sequence_length, // sequence length of past state + int present_buffer_sequence_length, // sequence length of present state + int head_size, // head size of self-attention + const T* past_key, // past key only (if not using past state) + T* present_key, // present key only (if not using present state) + bool past_present_share_buffer, // whether present key and value share the same buffer + bool packed_qkv, // whether Q, K, V are packed + ThreadPool* tp) const { // thread pool + const bool is_prompt = sequence_length != 1; + const int packed_batch_stride = packed_qkv ? (num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : 0; + const int kv_num_heads_factor = num_heads_ / kv_num_heads_; + const size_t q_input_chunk_length = static_cast(sequence_length) * head_size; // S x H + const size_t kv_input_chunk_length = static_cast(sequence_length) * head_size; // L x H + const size_t past_buff_chunk_length = static_cast(past_buffer_sequence_length) * head_size; // L x H + const size_t present_buff_chunk_length = static_cast(present_buffer_sequence_length) * head_size; // T x H + + PrepareMaskGQA(mask_data, batch_size, sequence_length, present_buffer_sequence_length, local_window_size_, seqlens_k); + + const int loop_len = batch_size * num_heads_; + const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; + + // TODO: cost might differ for gqa because of right padding and total_sequence_length being sequence dependent + TensorOpCost unit_cost; + const size_t probs_matrix_bytes = SafeInt(sequence_length) * present_buffer_sequence_length * sizeof(T); + unit_cost.compute_cycles = static_cast(2 * sequence_length * head_size * present_buffer_sequence_length); + unit_cost.bytes_loaded = static_cast((sequence_length + present_buffer_sequence_length) * head_size * sizeof(T)); + unit_cost.bytes_stored = static_cast(probs_matrix_bytes); + + unit_cost.bytes_loaded += static_cast(probs_matrix_bytes); + unit_cost.bytes_stored += static_cast(probs_matrix_bytes); + + if (present_key) { + double bytes_to_copy_key = static_cast(sizeof(T) * present_buff_chunk_length); + unit_cost.bytes_loaded += bytes_to_copy_key; + unit_cost.bytes_stored += bytes_to_copy_key; + } + + ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t i = begin; i != end; ++i) { + const int batch_index = static_cast(i) / num_heads_; + const int head_index = static_cast(i) % num_heads_; + const int past_seqlen = sequence_length == 1 ? static_cast(seqlens_k[batch_index]) : past_buffer_sequence_length; + const size_t past_chunk_length = static_cast(past_seqlen) * head_size; + + const int output_offset = static_cast(i) * sequence_length * present_buffer_sequence_length; + const int mask_offset = batch_index * sequence_length * present_buffer_sequence_length; + T* output = attention_probs + output_offset; + + // Broadcast mask data: (Bx)SxT -> (BxNx)SxT + // TODO: mask after present_sequence_length + memcpy(output, + mask_data + mask_offset, + probs_matrix_bytes); + + const T* k; + if (packed_qkv) { + k = K + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor); + } else { + k = K + kv_input_chunk_length * (i / kv_num_heads_factor); + } + if (nullptr != present_key) { + k = ConcatStateChunkGQA(past_key, k, present_key, present_buff_chunk_length, past_buff_chunk_length, + past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, + i / kv_num_heads_factor); + } + + // TODO: CblasTrans stuff what do? + // Compute Q*K' + AttentionMask + // original transposed each iteration + // A: Q (B x N x) S x H (B x N x) S x H S x H + // B: K' (B x N x) T x H (B x N x) H x T H x T + // C: attention_probs (B x N x) S x T (B x N x) S x T S x T + const T* q; + if (packed_qkv) { + q = Q + packed_batch_stride * batch_index + q_input_chunk_length * head_index; + } else { + q = Q + q_input_chunk_length * i; + } + math::Gemm(CblasNoTrans, CblasTrans, sequence_length, present_buffer_sequence_length, head_size, alpha, + q, k, mask_data != nullptr ? 1.0f : 0.0f, output, nullptr); + } + }); + + // attention_probs(B, N, S, T) = Softmax(attention_probs) + const int N = batch_size * num_heads_ * sequence_length; + const int D = present_buffer_sequence_length; + ComputeAttentionSoftmaxInplace(attention_probs, N, D, tp); + } + + template + void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH + T* tmp_buffer, // buffer for temp use with size is BxNxSxH + const T* attention_probs, // Attention probs with size BxNxSxT + const T* V, // V value with size BxN_kvxSxH + const int32_t* seqlens_k, // past sequence lengths tensor + int batch_size, // batch size + int sequence_length, // sequence length + int past_buffer_sequence_length, // sequence length in past state + int present_buffer_sequence_length, // sequence length in past state + int head_size, // head size of Q, K, V + int hidden_size, // hidden size of Output + const T* past_value, // past value only (if not using past state) + T* present_value, // present value only (if not using present state) + bool past_present_share_buffer, // whether present key and value share the same buffer + bool packed_qkv, // whether Q, K, V are packed + ThreadPool* tp) const { + const bool is_prompt = sequence_length != 1; + const int packed_batch_stride = packed_qkv ? (num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : 0; + const int kv_num_heads_factor = num_heads_ / kv_num_heads_; + // TODO: what is with these being ptrdiff_t? + const ptrdiff_t q_input_chunk_length = SafeInt(sequence_length) * head_size; // S x H + const ptrdiff_t kv_input_chunk_length = SafeInt(sequence_length) * head_size; // L x H + const size_t past_buff_chunk_length = static_cast(past_buffer_sequence_length) * head_size; // L x H + const size_t present_buff_chunk_length = static_cast(present_buffer_sequence_length) * head_size; // T x H + + // The cost of Gemm + TensorOpCost unit_cost; + unit_cost.compute_cycles = static_cast(2 * sequence_length * head_size * present_buffer_sequence_length); + unit_cost.bytes_loaded = static_cast((sequence_length + head_size) * present_buffer_sequence_length * sizeof(T)); + unit_cost.bytes_stored = static_cast(sequence_length * head_size * sizeof(T)); + + if (present_value) { + double bytes_to_copy_value = static_cast(present_buff_chunk_length * sizeof(T)); + unit_cost.bytes_loaded += bytes_to_copy_value; + unit_cost.bytes_stored += bytes_to_copy_value; + } + + const size_t bytes_to_copy_trans = SafeInt(head_size) * sizeof(T); + double bytes_to_copy_trans_all = static_cast(sequence_length * bytes_to_copy_trans); + unit_cost.bytes_loaded += bytes_to_copy_trans_all; + unit_cost.bytes_stored += bytes_to_copy_trans_all; + + ThreadPool::TryParallelFor(tp, SafeInt(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t i = begin; i != end; ++i) { + const int batch_index = static_cast(i / num_heads_); + const int head_index = static_cast(i % num_heads_); + const int past_seqlen = sequence_length == 1 ? static_cast(seqlens_k[batch_index]) : past_buffer_sequence_length; + const size_t past_chunk_length = static_cast(past_seqlen) * head_size; + + const T* v; + if (packed_qkv) { + v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor); + } else { + v = V + kv_input_chunk_length * (i / kv_num_heads_factor); + } + if (nullptr != present_value) { + v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length, + past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, + i / kv_num_heads_factor); + } + + T* current_tmp_data = reinterpret_cast(tmp_buffer) + q_input_chunk_length * i; + ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * present_buffer_sequence_length * i; + math::MatMul(sequence_length, head_size, present_buffer_sequence_length, + attention_probs + attention_probs_offset, + v, current_tmp_data, nullptr); + + // Transpose: out(B, S, N, H_v) -> out_tmp(B, N, S, H_v) + T* src = current_tmp_data; + ptrdiff_t dest_offset = (SafeInt(batch_index) * sequence_length * num_heads_ + head_index) * head_size; + T* dest = output + dest_offset; + for (int j = 0; j < sequence_length; j++) { + memcpy(dest, src, bytes_to_copy_trans); + src += head_size; + dest += hidden_size; + } + } + }); + } +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc new file mode 100644 index 0000000000..8e6b202ca4 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -0,0 +1,192 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "group_query_attention.h" +#include "group_query_attention_helper.h" +#include "attention_utils.h" +#include "rotary_embedding.h" +#include "rotary_embedding_helper.h" + +#include "core/framework/tensorprotoutils.h" +#include "core/graph/onnx_protobuf.h" +#include "core/common/safeint.h" +#include "core/platform/threadpool.h" + +#include +#include + +using onnxruntime::concurrency::ThreadPool; + +namespace onnxruntime { +namespace contrib { + +// These ops are internal-only, so register outside of onnx +ONNX_OPERATOR_TYPED_KERNEL_EX( + GroupQueryAttention, + kMSDomain, + 1, + float, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("M", DataTypeImpl::GetTensorType()), + GroupQueryAttention); + +template +GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) : OpKernel(info), GQAAttentionBase(info, false) { + int64_t num_heads = 0; + int64_t kv_num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0); + num_heads_ = static_cast(num_heads); + kv_num_heads_ = static_cast(kv_num_heads); + + mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); + local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); + do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; + rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; +} + +template +Status GroupQueryAttention::Compute(OpKernelContext* context) const { + const Tensor* query = context->Input(0); + const Tensor* key = context->Input(1); + const Tensor* value = context->Input(2); + const Tensor* past_key = context->Input(3); + const Tensor* past_value = context->Input(4); + const Tensor* seqlens_k = context->Input(5); + const Tensor* total_seqlen = context->Input(6); + const Tensor* cos_cache = context->Input(7); + const Tensor* sin_cache = context->Input(8); + + GroupQueryAttentionParameters parameters = {}; + constexpr float scale = 1.0f; + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, + key, + value, + past_key, + past_value, + cos_cache, + sin_cache, + ¶meters, + num_heads_, + kv_num_heads_, + seqlens_k, + total_seqlen, + scale)); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int present_kv_seqlen = parameters.seqlen_present_kv_cache; + int head_size = parameters.head_size; + int q_hidden_size = parameters.hidden_size; + const bool packed_qkv = parameters.is_packed_qkv; + + std::vector output_shape(3); + output_shape[0] = static_cast(batch_size); + output_shape[1] = static_cast(sequence_length); + output_shape[2] = static_cast(q_hidden_size); + Tensor* output = context->Output(0, output_shape); + + std::vector present_k_shape({static_cast(batch_size), static_cast(kv_num_heads_), static_cast(present_kv_seqlen), static_cast(head_size)}); + std::vector present_v_shape({static_cast(batch_size), static_cast(kv_num_heads_), static_cast(present_kv_seqlen), static_cast(head_size)}); + Tensor* present_k = context->Output(1, present_k_shape); + Tensor* present_v = context->Output(2, present_v_shape); + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + auto element_type = DataTypeImpl::GetType(); + OrtValue Q; + OrtValue K; + OrtValue V; + if (packed_qkv) { + ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH( + allocator, batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size, query, Q)); + } else if (sequence_length > 1) { + ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH( + allocator, batch_size, num_heads_, sequence_length, head_size, query, Q)); + ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH( + allocator, batch_size, kv_num_heads_, sequence_length, head_size, key, K)); + ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH( + allocator, batch_size, kv_num_heads_, sequence_length, head_size, value, V)); + } else { + Tensor::InitOrtValue(std::move(const_cast(*query)), Q); + Tensor::InitOrtValue(std::move(const_cast(*key)), K); + Tensor::InitOrtValue(std::move(const_cast(*value)), V); + } + + if (do_rotary_) { + rotary_embedding_helper::RotaryParameters rotary_params = {}; + rotary_params.batch_size = batch_size; + rotary_params.sequence_length = sequence_length; + rotary_params.hidden_size = q_hidden_size; + rotary_params.head_size = head_size; + rotary_params.rotary_embedding_dim = parameters.rotary_dim; + rotary_params.num_heads = num_heads_; + rotary_params.max_sequence_length = sequence_length; // unused + rotary_params.seq_stride = head_size; + rotary_params.head_stride = sequence_length * rotary_params.seq_stride; + rotary_params.batch_stride = (packed_qkv ? (num_heads_ + 2 * kv_num_heads_) : num_heads_) * rotary_params.head_stride; + rotary_params.position_ids_format = sequence_length == 1 ? 1 : 0; + rotary_params.transposed = true; + auto* tp = context->GetOperatorThreadPool(); + std::vector pos_ids(sequence_length == 1 ? batch_size : 1); + if (sequence_length == 1) { + for (int b = 0; b < batch_size; b++) { + pos_ids[b] = static_cast(seqlens_k->Data()[b]); + } + } else { + pos_ids[0] = static_cast(0); + } + const T* q_input; + const T* k_input; + T* q_rotary; + T* k_rotary; + if (packed_qkv) { + OrtValue RotaryQKV; + Tensor::InitOrtValue(element_type, TensorShape({batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size}), allocator, RotaryQKV); + q_input = Q.Get().Data(); + k_input = q_input + num_heads_ * sequence_length * head_size; + q_rotary = RotaryQKV.GetMutable()->MutableData(); + k_rotary = q_rotary + num_heads_ * sequence_length * head_size; + Q = RotaryQKV; + } else { + OrtValue RotaryQ; + Tensor::InitOrtValue(element_type, TensorShape({batch_size, num_heads_, sequence_length, head_size}), allocator, RotaryQ); + OrtValue RotaryK; + Tensor::InitOrtValue(element_type, TensorShape({batch_size, kv_num_heads_, sequence_length, head_size}), allocator, RotaryK); + q_input = Q.Get().Data(); + k_input = K.Get().Data(); + q_rotary = RotaryQ.GetMutable()->MutableData(); + k_rotary = RotaryK.GetMutable()->MutableData(); + Q = RotaryQ; + K = RotaryK; + } + ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, q_input, + pos_ids.data(), cos_cache->Data(), + sin_cache->Data(), q_rotary, rotary_interleaved_)); + + rotary_params.num_heads = kv_num_heads_; + rotary_params.hidden_size = parameters.kv_hidden_size; + if (!packed_qkv) { + rotary_params.batch_stride = kv_num_heads_ * rotary_params.head_stride; + } + ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, k_input, + pos_ids.data(), cos_cache->Data(), + sin_cache->Data(), k_rotary, rotary_interleaved_)); + if (packed_qkv) { + const T* v_input = k_input + kv_num_heads_ * sequence_length * head_size; + T* v_rotary = k_rotary + kv_num_heads_ * sequence_length * head_size; + ORT_RETURN_IF_ERROR(group_query_attention_helper::PackVIntoRotaryQKV(tp, parameters, v_input, v_rotary)); + } + } + + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + // Compute the attention score and apply the score to V + return ApplyAttention(Q.Get().Data(), packed_qkv ? nullptr : K.Get().Data(), + packed_qkv ? nullptr : V.Get().Data(), past_key, past_value, output, present_k, present_v, + seqlens_k, parameters, allocator, context); +} +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.h new file mode 100644 index 0000000000..cbe99e0378 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "gqa_attention_base.h" + +namespace onnxruntime { +namespace contrib { + +template +class GroupQueryAttention final : public OpKernel, public GQAAttentionBase { + public: + GroupQueryAttention(const OpKernelInfo& info); + Status Compute(OpKernelContext* context) const override; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h new file mode 100644 index 0000000000..e2a615eba4 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -0,0 +1,299 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/providers/common.h" +#include "contrib_ops/cpu/bert/attention_common.h" + +namespace onnxruntime { +namespace contrib { +namespace group_query_attention_helper { + +Status CheckInputs(const Tensor* query, + const Tensor* key, + const Tensor* value, + const Tensor* past_key, + const Tensor* past_value, + const Tensor* cos_cache, + const Tensor* sin_cache, + void* parameters, + int num_heads, + int kv_num_heads, + const Tensor* seqlens_k, + const Tensor* total_seqlen, + float scale) { + // Note: Here S* is seqlen_past_kv_cache, S+ is seqlen_present_kv_cache + // past_key : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr + // past_value : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr + // no packing for q/k/v: + // query (Q) : (B, S, D) or (B, S, (D_q + 2 D_kv)) + // key (K) : (B, S, D_kv) or nullptr + // value (V) : (B, S, D_kv) or nullptr + + AttentionQkvFormat qkv_format = Q_K_V_BSNH; + AttentionQkvFormat past_kv_format = Q_K_V_BNSH; + const bool is_packed_qkv = key == nullptr; + + const auto& query_dims = query->Shape().GetDims(); + if (query_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", + query_dims.size()); + } + + int batch_size = static_cast(query_dims[0]); + int sequence_length = static_cast(query_dims[1]); + int q_hidden_size = static_cast(query_dims[2]); + int head_size = 0; + + if (num_heads % kv_num_heads != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", + num_heads % kv_num_heads); + } + + int kv_hidden_size = 0; + // Check key and value when not packed + if (!is_packed_qkv) { + head_size = static_cast(q_hidden_size) / num_heads; + if (head_size % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size must be a multiple of 8. Got head_size % 8 == ", + head_size % 8); + } + if (value == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); + } + const auto& key_dims = key->Shape().GetDims(); + if (key_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", + key_dims.size()); + } else if (query_dims[0] != key_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 0 (batch size)"); + } else if (query_dims[1] != key_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 1 (sequence length)"); + } + kv_hidden_size = static_cast(key_dims[2]); + const auto& value_dims = value->Shape().GetDims(); + if (value_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", + value_dims.size()); + } else if (query_dims[0] != value_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 0 (batch size)"); + } else if (query_dims[1] != value_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 1 (sequence length)"); + } else if (value_dims[2] != kv_hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); + } + } else { + // Check packed qkv + head_size = static_cast(q_hidden_size) / (num_heads + 2 * kv_num_heads); + if (head_size % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size must be a multiple of 8. Got head_size % 8 == ", + head_size % 8); + } + if (value != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); + } + q_hidden_size = head_size * num_heads; + kv_hidden_size = head_size * kv_num_heads; + } + + // Check past-present KV + int32_t past_sequence_length = 0; + if (past_key != nullptr && past_value != nullptr) { + const auto& past_key_dims = past_key->Shape().GetDims(); + const auto& past_value_dims = past_value->Shape().GetDims(); + + if (past_key_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' is expected to have 4 dimensions, got ", + past_key_dims.size()); + } + if (past_value_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' is expected to have 4 dimensions, got ", + past_value_dims.size()); + } + + if (past_key_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 0 should be batch_size, got ", + past_key_dims[0]); + } + if (past_value_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 0 should be batch_size, got ", + past_value_dims[0]); + } + + if (past_key_dims[2] != past_value_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "BNSH Input 'past_key' and 'past_value' should have same dimension 2 (max sequence" + "length or past sequence length), got ", + past_key_dims[1]); + } + if (past_key_dims[1] != kv_num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' shall have kv_num_heads"); + } + if (past_value_dims[1] != kv_num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' shall have kv_num_heads"); + } + // We assume all sequence in past kv are right-padded to max or past sequence length + past_sequence_length = static_cast(past_key_dims[2]); + + if (past_key_dims[3] != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 3 should be same as head_size, got ", + past_key_dims[3]); + } + if (past_value_dims[3] != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 3 should be same as head_size, got ", + past_value_dims[3]); + } + } else if (past_key != nullptr || past_value != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' and 'past_value' shall be both present or both absent."); + } + + // Check seqlens_k tensor (holding past seqlen for token gen) + const auto& seqlens_dim = seqlens_k->Shape().GetDims(); + if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "seqlens_k must be shape (batch_size)."); + } + + // Set present sequence length and kv_share_buffer from input total_seqlen tensor + if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "total_sequence_length tensor must be of one element."); + } + int total_sequence_length = *((*total_seqlen).template Data()); + int present_sequence_length = std::max(total_sequence_length, past_sequence_length); + + int rotary_dim = 0; + if (cos_cache != nullptr && sin_cache != nullptr) { + const auto& cos_dims = cos_cache->Shape().GetDims(); + const auto& sin_dims = sin_cache->Shape().GetDims(); + + if (head_size % 16 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size shall be a multiple of 16. Got head_size % 16 == ", + head_size % 16); + } + if (cos_dims[0] < present_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache dimension 0 should be of max_sequence_length."); + } + if (sin_dims[0] < present_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "sin_cache dimension 0 should be of max_sequence_length."); + } + if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); + } + if (sin_dims[1] > (head_size / 16) * 8 || sin_dims[1] % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); + } + if (cos_dims[1] != sin_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache and sin_cache dimension 1 must be the same."); + } + rotary_dim = static_cast(cos_dims[1] * 2); + } else if (cos_cache != nullptr || sin_cache != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'cos_cache' and 'sin_cache' shall be both present or both absent."); + } + + bool is_prompt = sequence_length != 1; + + if (parameters != nullptr) { + GroupQueryAttentionParameters* output_parameters = reinterpret_cast(parameters); + output_parameters->batch_size = batch_size; + output_parameters->sequence_length = sequence_length; // sequence length of Q + output_parameters->seqlen_past_kv_cache = past_sequence_length; // max sequence length of past kv tensors + output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors + output_parameters->hidden_size = q_hidden_size; + output_parameters->num_heads = num_heads; + output_parameters->head_size = head_size; + output_parameters->kv_hidden_size = kv_hidden_size; + output_parameters->kv_num_heads = kv_num_heads; + output_parameters->rotary_dim = rotary_dim; + output_parameters->is_packed_qkv = is_packed_qkv; + output_parameters->is_unidirectional = true; + output_parameters->is_prompt = is_prompt; + output_parameters->scale = scale; + output_parameters->qkv_format = qkv_format; + output_parameters->past_kv_format = past_kv_format; + } + + return Status::OK(); +} + +Status CheckInputs(const Tensor* query, + const Tensor* key, + const Tensor* value, + const Tensor* past_key, + const Tensor* past_value, + const Tensor* cos_cache, + const Tensor* sin_cache, + void* parameters, + int num_heads, + int kv_num_heads, + const Tensor* seqlens_k, + const Tensor* total_seqlen, + float scale, + int max_threads_per_block) { + if (max_threads_per_block > 0 && num_heads > max_threads_per_block) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); + } + + return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale); +} + +template +Status PackVIntoRotaryQKV(concurrency::ThreadPool* tp, GroupQueryAttentionParameters parameters, const T* input, + T* output) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + int n_heads = parameters.num_heads; + const int kv_n_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + int seq_stride = head_size; + int head_stride = sequence_length * seq_stride; + int batch_stride = (n_heads + 2 * kv_n_heads) * head_stride; + + const int loop_len = batch_size * sequence_length * kv_n_heads; + const double cost = static_cast(head_size); + ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { + const int b = static_cast((ptr / kv_n_heads) / sequence_length); + const int s = static_cast((ptr / kv_n_heads) % sequence_length); + const int n = static_cast(ptr % kv_n_heads); + const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; + const T* input_data = input + block_offset; + T* output_data = output + block_offset; + for (int i = 0; i < head_size; i++) { + output_data[i] = input_data[i]; + } + } + }); + return Status::OK(); +} + +} // namespace group_query_attention_helper +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index c4e4b4ec70..6e85be15d9 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -4,15 +4,13 @@ #include "attention_cpu_base.h" #include "multihead_attention.h" #include "multihead_attention_helper.h" +#include "attention_utils.h" #include "core/common/common.h" #include "core/framework/tensorprotoutils.h" -#include "core/framework/transpose_helper.h" #include "core/graph/onnx_protobuf.h" #include "core/common/safeint.h" #include "core/platform/threadpool.h" -#include "core/providers/cpu/math/element_wise_ops.h" -#include "core/providers/cpu/tensor/reshape_helper.h" #include #include @@ -43,217 +41,6 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; } -// Reshape Q/K/V from BxSxD to BxSxNxH -Status Reshape_BSD_to_BSNH(Tensor* qkv, - int batch_size, - int sequence_length, - int num_heads, - int head_size) { - std::vector reshape_dims({batch_size, sequence_length, num_heads, head_size}); - gsl::span reshape_dims_span{reshape_dims}; - TensorShape qkv_bsnh(reshape_dims_span); - qkv->Reshape(qkv_bsnh); - return Status::OK(); -} - -// Transpose Q/K/V from BxSxNxH to BxNxSxH -Status Transpose_BSNH_to_BNSH(const Tensor* qkv, - OrtValue& qkv_transposed, - concurrency::ThreadPool* tp = nullptr) { - std::vector permutations({0, 2, 1, 3}); - gsl::span permutations_span{permutations}; - size_t from = 2, to = 1; - SingleAxisTranspose(permutations_span, *qkv, *qkv_transposed.GetMutable(), from, to, nullptr, tp); - return Status::OK(); -} - -// Add bias + transpose for each of Q/K/V -template -Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V data - query is BxSxD, key is BxLxD, value is BxLxD_v - const T* qkv_bias, // Input: QKV bias - bias is (D + D + D_v) - OrtValue& qkv_with_bias_transposed, // Output: Q/K/V data - query is BxNxSxH, key is BxNxLxH, value is BxNxLxH_v - int bias_offset, // bias offset to enter qkv_bias - int batch_size, // batch size - int sequence_length, // sequence_length for Q, kv_sequence_length for K/V - int num_heads, // num heads - int head_size, // head_size for Q/K, v_head_size for V - int hidden_size, // hidden_size for Q/K, v_hidden_size for V - OpKernelContext* context) { - // Note: the comments below will refer to Q's dimensions for simplicity - auto element_type = DataTypeImpl::GetType(); - constexpr size_t element_size = sizeof(T); - ProcessBroadcastSpanFuncs add_funcs{ - [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array(); - }, - [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1(); - }, - [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1(); - }}; // For element-wise add - - // Allocate space for output of Q(BS, D) + bias(D) - AllocatorPtr allocator; - ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - std::vector old_dims({batch_size, sequence_length, hidden_size}); - gsl::span old_dims_span{old_dims}; - TensorShape qkv_with_bias_shape(old_dims_span); - OrtValue qkv_with_bias; - Tensor::InitOrtValue(element_type, qkv_with_bias_shape, allocator, qkv_with_bias); - - // Get Q's bias from combined bias - std::vector bias_dims({hidden_size}); - gsl::span bias_dims_span{bias_dims}; - TensorShape bias_shape(bias_dims_span); - OrtValue bias; - Tensor::InitOrtValue(element_type, bias_shape, allocator, bias); - memcpy(bias.GetMutable()->MutableData(), qkv_bias + bias_offset, hidden_size * element_size); - - // Compute Q(BS, D) + bias(D) as broadcasted element-wise add - { - InputBroadcaster input_broadcaster(*bias.GetMutable(), *qkv); - const InputBroadcaster& const_input_broadcaster = input_broadcaster; - Tensor& output_tensor = *qkv_with_bias.GetMutable(); - - size_t span_size = input_broadcaster.GetSpanSize(); - size_t output_size = static_cast(output_tensor.Shape().Size()); - void* user_data = nullptr; - - const int loop_len = static_cast(output_size / span_size); - double unit_cost = 1.0f; - const auto cost = TensorOpCost{static_cast(input_broadcaster.Input0ElementSize()) * span_size, - static_cast(output_tensor.DataType()->Size()) * span_size, - unit_cost * span_size}; - auto tp = context->GetOperatorThreadPool(); - ThreadPool::TryParallelFor(tp, loop_len, cost, - [span_size, &const_input_broadcaster, &output_tensor, &add_funcs, user_data](std::ptrdiff_t first_span, - std::ptrdiff_t last_span) { - InputBroadcaster segment_input_broadcaster(const_input_broadcaster); - segment_input_broadcaster.AdvanceBy(first_span * span_size); - - OutputBroadcaster segment_output_broadcaster(span_size, output_tensor, - first_span * span_size, last_span * span_size); - - BroadcastHelper segment_helper(segment_input_broadcaster, segment_output_broadcaster, user_data); - BroadcastLooper(segment_helper, add_funcs); - }); - } - - // Reshape Q from BxSxD to BxSxNxH - ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(qkv_with_bias.GetMutable(), batch_size, sequence_length, num_heads, head_size)); - - // Transpose Q from BxSxNxH to BxNxSxH - auto tp = context->GetOperatorThreadPool(); - ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable(), qkv_with_bias_transposed, tp)); - - return Status::OK(); -} - -// Add bias + reshape for each of Q/K/V -// This is used in decoder_with_past when the sequence length is 1 -template -Status AddBiasReshape(const Tensor* qkv, // Input: Q/K/V data - query is BxSxD, key is BxLxD, value is BxLxD_v - const T* qkv_bias, // Input: QKV bias - bias is (D + D + D_v) - OrtValue& qkv_with_bias, // Output: Q/K/V data - query is BxNxSxH, key is BxNxLxH, value is BxNxLxH_v - int bias_offset, // bias offset to enter qkv_bias - int batch_size, // batch size - int sequence_length, // sequence_length for Q, kv_sequence_length for K/V - int num_heads, // num heads - int head_size, // head_size for Q/K, v_head_size for V - int hidden_size, // hidden_size for Q/K, v_hidden_size for V - OpKernelContext* context) { - // Note: the comments below will refer to Q's dimensions for simplicity - auto element_type = DataTypeImpl::GetType(); - constexpr size_t element_size = sizeof(T); - ProcessBroadcastSpanFuncs add_funcs{ - [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array(); - }, - [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1(); - }, - [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1(); - }}; // For element-wise add - - // Get Q's bias from combined bias - AllocatorPtr allocator; - ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - std::vector bias_dims({hidden_size}); - gsl::span bias_dims_span{bias_dims}; - TensorShape bias_shape(bias_dims_span); - OrtValue bias; - Tensor::InitOrtValue(element_type, bias_shape, allocator, bias); - auto num_bias_elements = SafeInt(hidden_size) * element_size; - memcpy(bias.GetMutable()->MutableData(), qkv_bias + bias_offset, num_bias_elements); - - // Compute Q(BS, D) + bias(D) as broadcasted element-wise add - { - InputBroadcaster input_broadcaster(*bias.GetMutable(), *qkv); - const InputBroadcaster& const_input_broadcaster = input_broadcaster; - Tensor& output_tensor = *qkv_with_bias.GetMutable(); - - size_t span_size = input_broadcaster.GetSpanSize(); - size_t output_size = static_cast(output_tensor.Shape().Size()); - void* user_data = nullptr; - - const int loop_len = static_cast(output_size / span_size); - double unit_cost = 1.0f; - const auto cost = TensorOpCost{static_cast(input_broadcaster.Input0ElementSize()) * span_size, - static_cast(output_tensor.DataType()->Size()) * span_size, - unit_cost * span_size}; - auto tp = context->GetOperatorThreadPool(); - ThreadPool::TryParallelFor(tp, loop_len, cost, - [span_size, &const_input_broadcaster, &output_tensor, &add_funcs, user_data](std::ptrdiff_t first_span, - std::ptrdiff_t last_span) { - InputBroadcaster segment_input_broadcaster(const_input_broadcaster); - segment_input_broadcaster.AdvanceBy(first_span * span_size); - - OutputBroadcaster segment_output_broadcaster(span_size, output_tensor, - first_span * span_size, last_span * span_size); - - BroadcastHelper segment_helper(segment_input_broadcaster, segment_output_broadcaster, user_data); - BroadcastLooper(segment_helper, add_funcs); - }); - } - - // Reshape Q from BxSxD to BxNxSxH - std::vector reshape_dims({batch_size, num_heads, sequence_length, head_size}); - gsl::span reshape_dims_span{reshape_dims}; - TensorShape qkv_final_dims(reshape_dims_span); - qkv_with_bias.GetMutable()->Reshape(qkv_final_dims); - - return Status::OK(); -} - -template -Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context, AllocatorPtr allocator, - int batch_size, int num_heads, int sequence_length, int head_size, - const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out) { - auto element_type = DataTypeImpl::GetType(); - std::vector new_dims({batch_size, num_heads, sequence_length, head_size}); - gsl::span new_dims_span{new_dims}; - TensorShape v_BNLH(new_dims_span); - Tensor::InitOrtValue(element_type, v_BNLH, allocator, out); - if (bias == nullptr) { - std::unique_ptr reshaped; - if (in->Shape().GetDims().size() == 3) { - reshaped = std::make_unique(in->DataType(), in->Shape(), const_cast(in->DataRaw()), in->Location()); - ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(reshaped.get(), batch_size, sequence_length, num_heads, head_size)); - } - ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH((reshaped == nullptr) ? in : reshaped.get(), out)); - } else { - const auto* qkv_bias = bias->Data(); - if (sequence_length == 1) { - ORT_RETURN_IF_ERROR(AddBiasReshape(in, qkv_bias, out, bias_offset, batch_size, sequence_length, num_heads, head_size, num_heads * head_size, context)); - } else { - ORT_RETURN_IF_ERROR(AddBiasTranspose(in, qkv_bias, out, bias_offset, batch_size, sequence_length, num_heads, head_size, num_heads * head_size, context)); - } - } - return Status::OK(); -}; - template Status MultiHeadAttention::Compute(OpKernelContext* context) const { const Tensor* query = context->Input(0); diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc index aa8b5b5f60..195ebdf6a4 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -36,6 +36,71 @@ RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) { } } +// TODO: rotary embedding in place +template +Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters parameters, const T* input, + const int64_t* position_ids, const T* cos_cache, const T* sin_cache, T* output, + bool interleaved) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int n_heads = parameters.num_heads; + const int head_size = parameters.head_size; + const int head_stride = parameters.head_stride; + const int seq_stride = parameters.seq_stride; + const int batch_stride = parameters.batch_stride; + const int position_ids_format = parameters.position_ids_format; + const int rotary_emb_dim = parameters.rotary_embedding_dim; + const int half_rotary_emb_dim = rotary_emb_dim / 2; + + const int loop_len = batch_size * sequence_length * n_heads; + const double cost = static_cast(rotary_emb_dim); + ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { + const int b = static_cast((ptr / n_heads) / sequence_length); + const int s = static_cast((ptr / n_heads) % sequence_length); + const int n = static_cast(ptr % n_heads); + + const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; + + const T* input_data = input + block_offset; + T* output_data = output + block_offset; + + // Cache is (M, H/2) or (M, rotary_embedding_dim/2) + const int position_id = (position_ids_format == 0) + ? static_cast(position_ids[0]) + s + : static_cast(position_ids[b * sequence_length + s]); + const int cache_offset = position_id * half_rotary_emb_dim; + const T* cos_data = cos_cache + cache_offset; + const T* sin_data = sin_cache + cache_offset; + + int cache_idx = 0; + T sign = 0; + int j = 0; + for (int i = 0; i < rotary_emb_dim; i++) { + if (interleaved) { + cache_idx = (i / 2) % half_rotary_emb_dim; + sign = (i % 2 == 0) ? static_cast(-1) : static_cast(1); + j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign + } else { + cache_idx = i % half_rotary_emb_dim; + sign = (i < half_rotary_emb_dim) ? static_cast(-1) : static_cast(1); + j = (i + half_rotary_emb_dim) % rotary_emb_dim; + } + output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; + } + for (int i = rotary_emb_dim; i < head_size; i++) { + output_data[i] = input_data[i]; + } + } + }); + + return Status::OK(); +} + +template Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters parameters, const float* input, + const int64_t* position_ids, const float* cos_cache, const float* sin_cache, float* output, + bool interleaved); + template Status RotaryEmbedding::Compute(OpKernelContext* context) const { const Tensor* input = context->Input(0); @@ -65,72 +130,12 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const { const T* sin_cache_data = sin_cache->Data(); T* output_dest = output->MutableData(); - const int batch_size = parameters.batch_size; - const int sequence_length = parameters.sequence_length; - const int n_heads = parameters.num_heads; - const int head_size = parameters.head_size; - const int position_ids_format = parameters.position_ids_format; - const int rotary_emb_dim = parameters.rotary_embedding_dim; - const int half_rotary_emb_dim = rotary_emb_dim / 2; - - // Default input tensor shape is [batch, seq_len, hidden_size] - int head_stride = head_size; - int seq_stride = n_heads * head_stride; - int batch_stride = sequence_length * seq_stride; - if (parameters.transposed) { - // Transposed input tensor shape is [batch, n_heads, seq_len, head_size] - seq_stride = head_size; - head_stride = sequence_length * seq_stride; - batch_stride = n_heads * head_stride; - } - AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); auto* tp = context->GetOperatorThreadPool(); - const int loop_len = batch_size * sequence_length * n_heads; - const double cost = static_cast(rotary_emb_dim); - ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { - for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { - const int b = static_cast((ptr / n_heads) / sequence_length); - const int s = static_cast((ptr / n_heads) % sequence_length); - const int n = static_cast(ptr % n_heads); - - const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; - - const T* input_data = input_src + block_offset; - T* output_data = output_dest + block_offset; - - // Cache is (M, H/2) or (M, rotary_embedding_dim/2) - const int position_id = (position_ids_format == 0) - ? static_cast(pos_ids_data[0]) + s - : static_cast(pos_ids_data[b * sequence_length + s]); - const int cache_offset = position_id * half_rotary_emb_dim; - const T* cos_data = cos_cache_data + cache_offset; - const T* sin_data = sin_cache_data + cache_offset; - - int cache_idx = 0; - T sign = 0; - int j = 0; - for (int i = 0; i < rotary_emb_dim; i++) { - if (interleaved) { - cache_idx = (i / 2) % half_rotary_emb_dim; - sign = (i % 2 == 0) ? static_cast(-1) : static_cast(1); - j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign - } else { - cache_idx = i % half_rotary_emb_dim; - sign = (i < half_rotary_emb_dim) ? static_cast(-1) : static_cast(1); - j = (i + half_rotary_emb_dim) % rotary_emb_dim; - } - output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; - } - for (int i = rotary_emb_dim; i < head_size; i++) { - output_data[i] = input_data[i]; - } - } - }); - - return Status::OK(); + return RunRotaryEmbedding(tp, parameters, input_src, pos_ids_data, cos_cache_data, sin_cache_data, output_dest, + interleaved); } } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h index 4e32424a22..b291db538d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h @@ -4,10 +4,16 @@ #pragma once #include "core/common/common.h" #include "core/framework/op_kernel.h" +#include "rotary_embedding_helper.h" namespace onnxruntime { namespace contrib { +template +Status RunRotaryEmbedding(onnxruntime::concurrency::ThreadPool* tp, rotary_embedding_helper::RotaryParameters parameters, const T* input, + const int64_t* position_ids, const T* cos_cache, const T* sin_cache, T* output, + bool interleaved); + template class RotaryEmbedding final : public OpKernel { public: diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h index dcbb36d1c4..d6968484a1 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h @@ -18,6 +18,9 @@ struct RotaryParameters { int rotary_embedding_dim; // Rotary embedding dimension. int num_heads; // num_heads = hidden_size / head_size int max_sequence_length; // Sequence length used by cos/sin cache + int head_stride; // Head stride + int seq_stride; // Sequence stride + int batch_stride; // Batch stride int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden) }; @@ -116,6 +119,23 @@ Status CheckInputs(const T* input, "head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[1]); } + num_heads = num_heads > 0 ? num_heads : static_cast(hidden_size / head_size); + // Calculate stride values + int head_stride; + int seq_stride; + int batch_stride; + if (transposed) { + // Transposed input tensor shape is [batch, n_heads, seq_len, head_size] + seq_stride = head_size; + head_stride = sequence_length * seq_stride; + batch_stride = num_heads * head_stride; + } else { + // Default input tensor shape is [batch, seq_len, hidden_size] + head_stride = head_size; + seq_stride = num_heads * head_stride; + batch_stride = sequence_length * seq_stride; + } + // Set rotary parameters if (parameters != nullptr) { RotaryParameters* output_parameters = reinterpret_cast(parameters); @@ -123,8 +143,11 @@ Status CheckInputs(const T* input, output_parameters->sequence_length = sequence_length; output_parameters->hidden_size = hidden_size; output_parameters->head_size = head_size; - output_parameters->num_heads = num_heads > 0 ? num_heads : static_cast(hidden_size / head_size); + output_parameters->num_heads = num_heads; output_parameters->max_sequence_length = max_sequence_length; + output_parameters->head_stride = head_stride; + output_parameters->seq_stride = seq_stride; + output_parameters->batch_stride = batch_stride; output_parameters->position_ids_format = position_ids_format; output_parameters->transposed = transposed; output_parameters->rotary_embedding_dim = rotary_embedding_dim > 0 ? rotary_embedding_dim : head_size; diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index f9d9b13f0f..b37ce2f72e 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -20,6 +20,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GroupQueryAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM); @@ -257,6 +258,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index adfa1b61e1..2796d532cc 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1009,6 +1009,9 @@ constexpr const char* GroupQueryAttention_ver1_doc = R"DOC( Group Query Self/Cross Attention. Supports different number of heads for q and kv. Only supports causal or local attention. +Supports rotary position embedding. +Supports k-v cache. +CPU EP supports fp32... CUDA EP supports fp16. )DOC"; ONNX_MS_OPERATOR_SET_SCHEMA( @@ -1094,7 +1097,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", "T") - .TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.") + .TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, "Constrain input and output to float tensors.") .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to int tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { GroupQueryAttentionTypeAndShapeInference(ctx, 3); diff --git a/onnxruntime/test/python/transformers/test_flash_attn.py b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py similarity index 100% rename from onnxruntime/test/python/transformers/test_flash_attn.py rename to onnxruntime/test/python/transformers/test_flash_attn_cuda.py diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py new file mode 100644 index 0000000000..d2f3a52d3b --- /dev/null +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -0,0 +1,1884 @@ +# -------------------------------------------------------------------------- +# Copyright 2020 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +import math +import random +import unittest + +import numpy +import torch +from bert_padding import pad_input, unpad_input + +try: + from colorama import Fore, init + + init(autoreset=True) +except ImportError: + print("colorama is not installed, please install it to get prettier output") + Fore = None +from einops import rearrange, repeat +from onnx import TensorProto, helper + +from onnxruntime import InferenceSession, OrtValue, SessionOptions + +torch.manual_seed(0) + +pipeline_mode = True # Reduces number of tests so pipeline doesn't time out + + +class Formats: + BSNH = 0 + BNSH = 1 + + +class Config: + batch_size = 0 + sequence_length = 0 + kv_sequence_length = 0 + past_sequence_length = 0 + num_heads = 0 + kv_num_heads = 0 + head_size = 0 + + def __init__(self, b, s, s2, sp, n, n2, h): + self.batch_size = b + self.sequence_length = s + self.kv_sequence_length = s2 + self.past_sequence_length = sp + self.num_heads = n + self.kv_num_heads = n2 + self.head_size = h + + +class PromptConfig: + batch_size = 0 + q_sequence_length = 0 + kv_sequence_length = 0 + buffer_sequence_length = 0 + num_heads = 0 + kv_num_heads = 0 + head_size = 0 + + def __init__(self, b, sq, skv, sb, n, n2, h): + self.batch_size = b + self.q_sequence_length = sq + self.kv_sequence_length = skv + self.buffer_sequence_length = sb + self.num_heads = n + self.kv_num_heads = n2 + self.head_size = h + + +# LLaMA Microsoft model +class LlamaMSRotaryEmbedding(torch.nn.Module): + def __init__(self): + super().__init__() + + def rotate_tensor( + self, + x: torch.Tensor, # BxSxNxH + cos: torch.Tensor, # 1xSx1x(H/2) + sin: torch.Tensor, # 1xSx1x(H/2) + pos: torch.Tensor, + interleaved: bool, + ): + # Dimension of x is [batch_size, seq_len, n_heads, head_dim] + rot_dim = 2 * cos.shape[3] + + # Dolly requires partial rotation + x_rot = x[:, :, :, :rot_dim] + + if interleaved: + x1 = x_rot[:, :, :, 0::2] + x2 = x_rot[:, :, :, 1::2] + else: + half = x_rot.shape[-1] // 2 + x1 = x[:, :, :, 0:half] + x2 = x[:, :, :, half : 2 * half] + + seq_len = x.shape[1] + + # cos_x: (1, S, 1, H/2) + # sin_x: (1, S, 1, H/2) + # x1: (B, S, N, H/2) + # x2: (B, S, N, H/2) + if seq_len == 1: + batch_size = x.shape[0] + pos_i = pos.unsqueeze(1).unsqueeze(2).unsqueeze(3).long() + cos_x = cos.expand(batch_size, -1, -1, -1) + sin_x = sin.expand(batch_size, -1, -1, -1) + cos_x = cos_x.gather(1, pos_i.expand(-1, -1, cos.shape[2], cos.shape[3])) + sin_x = sin_x.gather(1, pos_i.expand(-1, -1, sin.shape[2], sin.shape[3])) + real = cos_x * x1 - sin_x * x2 + imag = sin_x * x1 + cos_x * x2 + if interleaved: + x_rot[:, :, :, 0::2] = real + x_rot[:, :, :, 1::2] = imag + else: + x_rot = torch.cat((real, imag), dim=-1) + else: + cos_x = cos[:, 0:seq_len, :, :] + sin_x = sin[:, 0:seq_len, :, :] + real = cos_x * x1 - sin_x * x2 + imag = sin_x * x1 + cos_x * x2 + if interleaved: + x_rot[:, :, :, 0::2] = real + x_rot[:, :, :, 1::2] = imag + else: + x_rot = torch.cat((real, imag), dim=-1) + + return torch.cat((x_rot, x[:, :, :, rot_dim:]), dim=-1) + + def forward(self, x, cos, sin, pos, interleaved): + return self.rotate_tensor(x, cos, sin, pos, interleaved) + + +def create_group_query_attention_graph_prompt( + config, + past_kv_format=Formats.BSNH, + share_buffer=True, + local_window_size=-1, + rotary=False, + rotary_interleaved=False, + packed=False, +): + past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 + present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length + nodes = [ + helper.make_node( + "GroupQueryAttention", + [ + "query", + "key" if not packed else "", + "value" if not packed else "", + "past_key" if share_buffer else "", + "past_value" if share_buffer else "", + "seqlens_k", + "total_sequence_length", + "cos_cache" if rotary else "", + "sin_cache" if rotary else "", + ], + ["output", "present_key", "present_value"], + "GroupQueryAttention_0", + num_heads=config.num_heads, + kv_num_heads=config.kv_num_heads, + local_window_size=local_window_size, + do_rotary=rotary, + rotary_interleaved=rotary_interleaved, + # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, + # kv_share_buffer=1 if share_buffer else 0, + domain="com.microsoft", + ), + ] + + graph_input = [ + helper.make_tensor_value_info( + "query", + TensorProto.FLOAT, + [ + config.batch_size, + config.q_sequence_length, + ( + (config.num_heads * config.head_size) + if not packed + else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size) + ), + ], + ), + helper.make_tensor_value_info( + "seqlens_k", + TensorProto.INT32, + [config.batch_size], + ), + helper.make_tensor_value_info( + "total_sequence_length", + TensorProto.INT32, + [1], + ), + ] + if not packed: + graph_input += [ + helper.make_tensor_value_info( + "key", + TensorProto.FLOAT, + [ + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "value", + TensorProto.FLOAT, + [ + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + ] + if share_buffer: + graph_input += [ + helper.make_tensor_value_info( + "past_key", + TensorProto.FLOAT, + [ + config.batch_size, + past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "past_value", + TensorProto.FLOAT, + [ + config.batch_size, + past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen, + config.head_size, + ], + ), + ] + if rotary: + graph_input += [ + helper.make_tensor_value_info( + "cos_cache", + TensorProto.FLOAT, + [ + config.buffer_sequence_length if share_buffer else config.kv_sequence_length, + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + helper.make_tensor_value_info( + "sin_cache", + TensorProto.FLOAT, + [ + config.buffer_sequence_length if share_buffer else config.kv_sequence_length, + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + ] + + graph_output = [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT, + [config.batch_size, config.q_sequence_length, config.num_heads * config.head_size], + ), + helper.make_tensor_value_info( + "present_key", + TensorProto.FLOAT, + [ + config.batch_size, + present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "present_value", + TensorProto.FLOAT, + [ + config.batch_size, + present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "present_key", + TensorProto.FLOAT, + [ + config.batch_size, + config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "present_value", + TensorProto.FLOAT, + [ + config.batch_size, + config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + ], + ), + ] + + graph = helper.make_graph( + nodes, + "GroupQueryAttention_Graph", + graph_input, + graph_output, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def create_group_query_attention_graph_past( + config, + past_kv_format=Formats.BSNH, + share_buffer=True, + local_window_size=-1, + rotary=False, + rotary_interleaved=False, + packed=False, +): + past_kv_seqlen = config.kv_sequence_length + present_kv_seqlen = ( + config.kv_sequence_length if share_buffer else config.kv_sequence_length + config.sequence_length + ) + nodes = [ + helper.make_node( + "GroupQueryAttention", + [ + "query", + "key" if not packed else "", + "value" if not packed else "", + "past_key", + "past_value", + "seqlens_k", + "total_sequence_length", + "cos_cache" if rotary else "", + "sin_cache" if rotary else "", + ], + ["output", "present_key", "present_value"], + "GroupQueryAttention_0", + num_heads=config.num_heads, + kv_num_heads=config.kv_num_heads, + local_window_size=local_window_size, + do_rotary=rotary, + rotary_interleaved=rotary_interleaved, + # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, + # kv_share_buffer=1 if share_buffer else 0, + domain="com.microsoft", + ), + ] + + graph_input = [ + helper.make_tensor_value_info( + "query", + TensorProto.FLOAT, + [ + config.batch_size, + config.sequence_length, + ( + (config.num_heads * config.head_size) + if not packed + else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size) + ), + ], + ), + helper.make_tensor_value_info( + "past_key", + TensorProto.FLOAT, + [ + config.batch_size, + past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "past_value", + TensorProto.FLOAT, + [ + config.batch_size, + past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "seqlens_k", + TensorProto.INT32, + [config.batch_size], + ), + helper.make_tensor_value_info( + "total_sequence_length", + TensorProto.INT32, + [1], + ), + ] + if not packed: + graph_input += [ + helper.make_tensor_value_info( + "key", + TensorProto.FLOAT, + [ + config.batch_size, + config.sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "value", + TensorProto.FLOAT, + [ + config.batch_size, + config.sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + ] + if rotary: + graph_input += [ + helper.make_tensor_value_info( + "cos_cache", + TensorProto.FLOAT, + [ + config.kv_sequence_length + (0 if share_buffer else config.sequence_length), + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + helper.make_tensor_value_info( + "sin_cache", + TensorProto.FLOAT, + [ + config.kv_sequence_length + (0 if share_buffer else config.sequence_length), + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + ] + + graph_output = [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT, + [config.batch_size, config.sequence_length, config.num_heads * config.head_size], + ), + helper.make_tensor_value_info( + "present_key", + TensorProto.FLOAT, + [ + config.batch_size, + present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "present_value", + TensorProto.FLOAT, + [ + config.batch_size, + present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen, + config.head_size, + ], + ), + ] + + graph = helper.make_graph( + nodes, + "GroupQueryAttention_Graph", + graph_input, + graph_output, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): + assert mode in ["full", "random", "third"] + if mode == "full": + lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) + elif mode == "random": + lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen, (batch_size, 1), device=device) + else: + lengths = torch.randint(max_seqlen // 3, max_seqlen, (batch_size, 1), device=device) + padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths + return padding_mask + + +def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, d) + k: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d) + query_padding_mask: (batch_size, seqlen), bool + key_padding_mask: (batch_size, seqlen), bool + """ + assert not (kvpacked and qkvpacked) + batch_size, seqlen_q, nheads, d = q.shape + _, seqlen_k, nheads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d) + + if query_padding_mask is not None: + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) + + def output_pad_fn(output_unpad): + return pad_input(output_unpad, indices_q, batch_size, seqlen_q) + + else: + q_unpad = rearrange(q, "b s h d -> (b s) h d") + cu_seqlens_q = torch.arange( + 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device + ) + max_seqlen_q = seqlen_q + + def output_pad_fn(output_unpad): + return rearrange(output_unpad, "(b s) h d -> b s h d", b=batch_size) + + if key_padding_mask is not None: + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) + v_unpad, _, _, _ = unpad_input(v, key_padding_mask) + else: + k_unpad = rearrange(k, "b s h d -> (b s) h d") + v_unpad = rearrange(v, "b s h d -> (b s) h d") + cu_seqlens_k = torch.arange( + 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device + ) + max_seqlen_k = seqlen_k + + if qkvpacked: + assert (query_padding_mask == key_padding_mask).all() + assert nheads == nheads_k + qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) + qkv = torch.stack([q, k, v], dim=2) + if query_padding_mask is not None: + + def dqkv_pad_fn(dqkv_unpad): + return pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) + + else: + + def dqkv_pad_fn(dqkv_unpad): + return rearrange(dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) + + return ( + qkv_unpad.detach().requires_grad_(), + cu_seqlens_q, + max_seqlen_q, + qkv.detach().requires_grad_(), + output_pad_fn, + dqkv_pad_fn, + ) + elif kvpacked: + kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) + kv = torch.stack([k, v], dim=2) + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + + def dkv_pad_fn(dkv_unpad): + return pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) + + else: + + def dkv_pad_fn(dkv_unpad): + return rearrange(dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) + + return ( + q_unpad.detach().requires_grad_(), + kv_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + kv.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dkv_pad_fn, + ) + else: + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + + def dk_pad_fn(dk_unpad): + return pad_input(dk_unpad, indices_k, batch_size, seqlen_k) + + else: + + def dk_pad_fn(dk_unpad): + return rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) + + return ( + q_unpad.detach().requires_grad_(), + k_unpad.detach().requires_grad_(), + v_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + k.detach().requires_grad_(), + v.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) + + +def create_inputs(config: Config, kv_packed=False, qkv_packed=True): + qkv = torch.randn( + config.batch_size, + config.sequence_length, + 3, + config.num_heads, + config.head_size, + device="cpu", + dtype=torch.float32, + requires_grad=False, + ) + key_padding_mask = generate_random_padding_mask( + config.sequence_length, config.batch_size, device="cpu", mode="random" + ) + qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( + *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, kv_packed, qkv_packed + ) + return qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn, key_padding_mask + + +def generate_token_offset(cu_seqlens, max_seqlen): + token_offset = [] + token_padset = [] # These are the indices that contain padding tokens + for i in range(1, len(cu_seqlens)): + start = i - 1 + pre_seqlen = cu_seqlens[i - 1] + seqlen = cu_seqlens[i] + token_offset += range(start * max_seqlen, (start * max_seqlen) + (seqlen - pre_seqlen)) + token_padset += range((start * max_seqlen) + (seqlen - pre_seqlen), i * max_seqlen) + return numpy.asarray(token_offset + token_padset, dtype=numpy.int32) + + +def gqa_prompt_func( + q, + k, + v, + config, + new_k, + new_v, + cos=None, + sin=None, + seqlens_k=None, + window_size=-1, + past_kv_format=Formats.BSNH, + share_buffer=True, + rotary_interleaved=False, +): + onnx_model_str = create_group_query_attention_graph_prompt( + config, + past_kv_format, + share_buffer, + local_window_size=window_size, + rotary=cos is not None, + rotary_interleaved=rotary_interleaved, + packed=new_k is None, + ) + q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) + past_k = k.clone() if share_buffer else None + past_v = v.clone() if share_buffer else None + if new_k is not None: + new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) + new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) + if share_buffer: + ort_inputs = { + "query": q.detach().cpu().numpy(), + "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cpu", 0), + "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cpu", 0), + "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), + "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), + } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) + # TODO: do we need io binding for cpu input? + io_binding.bind_cpu_input("query", ort_inputs["query"]) + io_binding.bind_input( + "past_key", "cpu", 0, numpy.float32, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() + ) + io_binding.bind_input( + "past_value", + "cpu", + 0, + numpy.float32, + ort_inputs["past_value"].shape(), + ort_inputs["past_value"].data_ptr(), + ) + io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) + io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) + io_binding.bind_output("output") + io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"]) + io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"]) + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v + else: + ort_inputs = { + "query": q.detach().cpu().numpy(), + "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), + "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), + } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) + io_binding.bind_cpu_input("query", ort_inputs["query"]) + io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) + io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) + io_binding.bind_output("output") + io_binding.bind_output("present_key") + io_binding.bind_output("present_value") + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v + + +def gqa_past_func( + q, + k, + v, + config, + new_k, + new_v, + cos=None, + sin=None, + seqlens_k=None, + past_kv_format=Formats.BSNH, + share_buffer=True, + window_size=-1, + rotary_interleaved=False, +): + onnx_model_str = create_group_query_attention_graph_past( + config, + past_kv_format, + share_buffer, + local_window_size=window_size, + rotary=cos is not None, + rotary_interleaved=rotary_interleaved, + packed=new_k is None, + ) + q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) + past_k = k.clone() + past_v = v.clone() + if new_k is not None: + new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) + new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) + if share_buffer: + ort_inputs = { + "query": q.detach().cpu().numpy(), + "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cpu", 0), + "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cpu", 0), + "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), + "total_sequence_length": torch.tensor([config.kv_sequence_length], dtype=torch.int32) + .detach() + .cpu() + .numpy(), + } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) + io_binding.bind_cpu_input("query", ort_inputs["query"]) + io_binding.bind_input( + "past_key", "cpu", 0, numpy.float32, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() + ) + io_binding.bind_input( + "past_value", + "cpu", + 0, + numpy.float32, + ort_inputs["past_value"].shape(), + ort_inputs["past_value"].data_ptr(), + ) + io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) + io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) + io_binding.bind_output("output") + io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"]) + io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"]) + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v + else: + ort_inputs = { + "query": q.detach().cpu().numpy(), + "past_key": past_k.detach().cpu().numpy(), + "past_value": past_v.detach().cpu().numpy(), + "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), + "total_sequence_length": torch.tensor( + [config.kv_sequence_length + config.sequence_length], dtype=torch.int32 + ) + .detach() + .cpu() + .numpy(), + } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) + io_binding.bind_cpu_input("query", ort_inputs["query"]) + io_binding.bind_cpu_input("past_key", ort_inputs["past_key"]) + io_binding.bind_cpu_input("past_value", ort_inputs["past_value"]) + io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) + io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) + io_binding.bind_output("output") + io_binding.bind_output("present_key") + io_binding.bind_output("present_value") + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v + + +def construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + sq = seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + return col_idx > row_idx + sk - sq + + +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + query_padding_mask=None, + key_padding_mask=None, + device=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + sq = seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + col_idx < row_idx + sk - sq - window_size[0], + ) + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + upcast=True, + reorder_ops=False, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads_k, head_dim) + v: (batch_size, seqlen_k, nheads_k, head_dim) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + window_size: (int, int), left and right window size + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + q.device, + ) + scores.masked_fill_(local_mask, float("-inf")) + attention = torch.softmax(scores, dim=-1) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if window_size[0] >= 0 or window_size[1] >= 0: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + dropout_scaling = 1.0 / (1 - dropout_p) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + + +def attention_qkvpacked_ref( + qkv, key_padding_mask=None, dropout_p=0.0, dropout_mask=None, causal=False, upcast=True, reorder_ops=False +): + return attention_ref( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + key_padding_mask, + key_padding_mask, + dropout_p, + dropout_mask, + upcast=upcast, + causal=causal, + reorder_ops=reorder_ops, + ) + + +def parity_check_gqa_prompt( + config, + causal=True, + local=False, + past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, + rtol=1e-3, + atol=1e-3, +): + q = torch.randn( + config.batch_size, + config.q_sequence_length, + config.num_heads, + config.head_size, + device="cpu", + dtype=torch.float32, + requires_grad=False, + ) + k = torch.randn( + config.batch_size, + config.buffer_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.buffer_sequence_length, + config.head_size, + device="cpu", + dtype=torch.float32, + requires_grad=False, + ) + v = torch.randn( + config.batch_size, + config.buffer_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.buffer_sequence_length, + config.head_size, + device="cpu", + dtype=torch.float32, + requires_grad=False, + ) + new_k = torch.randn( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device="cpu", + dtype=torch.float32, + requires_grad=False, + ) + new_v = torch.randn( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device="cpu", + dtype=torch.float32, + requires_grad=False, + ) + + window_size = (-1, -1) + left_window_size = -1 + if local: + left_window_size = random.randint(1, config.kv_sequence_length) + window_size = (left_window_size, 0) + elif causal: + left_window_size = -1 + window_size = (-1, 0) + + # Pytorch to compare + k_cache_ref = k.clone() + v_cache_ref = v.clone() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + cache_seqlens = torch.tensor([config.kv_sequence_length], device="cpu").repeat(config.batch_size) + rotary_seqlens = torch.tensor([0], device="cpu").repeat(config.batch_size) + + if rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi + cos = torch.cos(angle).to(dtype=torch.float32) + sin = torch.sin(angle).to(dtype=torch.float32) + rot = LlamaMSRotaryEmbedding() + q_ro = rot( + q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), rotary_seqlens, rotary_interleaved + ) + k_ro = rot( + new_k.clone(), + cos.unsqueeze(0).unsqueeze(2), + sin.unsqueeze(0).unsqueeze(2), + rotary_seqlens, + rotary_interleaved, + ) + else: + cos, sin = None, None + q_ro, k_ro = q, new_k + + rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") + arange = rearrange(torch.arange(config.buffer_sequence_length, device="cpu"), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + kv_seqlens = torch.tensor([config.kv_sequence_length], device="cpu").repeat(config.batch_size) + kv_seqlens_expanded = rearrange(kv_seqlens, "b -> b 1") + update_mask = arange < kv_seqlens_expanded + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") + v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + key_padding_mask = arange < cache_seqlens_expanded + out_ref, _ = attention_ref( + q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + ) + out_ref = out_ref.detach().cpu().numpy() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + + # Flash function + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_prompt_func( + packed_qkv, + k, + v, + config, + None, + None, + cos, + sin, + cache_seqlens, + left_window_size, + past_format, + True, + rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_prompt_func( + q, + k, + v, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + left_window_size, + past_format, + True, + rotary_interleaved, + ) + out = torch.squeeze(out, 0) + out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) + out = out.detach().cpu().numpy() + + # Make sure past-present buffer updating correctly + assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + + # Compare results + all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) + if Fore is not None: + correct = Fore.GREEN + "True" if all_close else Fore.RED + "False" + else: + correct = "True" if all_close else "False" + print( + "KV-buffer", + " packed:", + packed, + " causal:", + causal, + " local:", + local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, + "past kv format:", + "BSNH" if past_format == Formats.BSNH else "BNSH", + " B:", + config.batch_size, + " S:", + config.q_sequence_length, + " kv S:", + config.kv_sequence_length, + " N:", + config.num_heads, + " kv N:", + config.kv_num_heads, + " h:", + config.head_size, + " Mean Error:", + numpy.mean(numpy.abs(out - out_ref)), + correct, + ) + + +def parity_check_gqa_prompt_no_buff( + config, + causal=True, + local=False, + past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, + rtol=1e-3, + atol=1e-3, +): + q = torch.randn( + config.batch_size, + config.q_sequence_length, + config.num_heads, + config.head_size, + device="cpu", + dtype=torch.float32, + requires_grad=False, + ) + new_k = torch.randn( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device="cpu", + dtype=torch.float32, + requires_grad=False, + ) + new_v = torch.randn( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device="cpu", + dtype=torch.float32, + requires_grad=False, + ) + + window_size = (-1, -1) + left_window_size = -1 + if local: + left_window_size = random.randint(1, config.kv_sequence_length) + window_size = (left_window_size, 0) + elif causal: + left_window_size = -1 + window_size = (-1, 0) + + # Pytorch to compare + k_cache_ref = new_k.clone() + v_cache_ref = new_v.clone() + cache_seqlens = torch.tensor([config.kv_sequence_length], device="cpu").repeat(config.batch_size) + rotary_seqlens = torch.tensor([0], device="cpu").repeat(config.batch_size) + + if rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi + cos = torch.cos(angle).to(dtype=torch.float32) + sin = torch.sin(angle).to(dtype=torch.float32) + rot = LlamaMSRotaryEmbedding() + q_ro = rot( + q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), rotary_seqlens, rotary_interleaved + ) + k_ro = rot( + k_cache_ref.clone(), + cos.unsqueeze(0).unsqueeze(2), + sin.unsqueeze(0).unsqueeze(2), + rotary_seqlens, + rotary_interleaved, + ) + else: + cos, sin = None, None + q_ro, k_ro = q, k_cache_ref + k_cache_ref = k_ro + + brange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + new_mask = brange < cache_seqlens_expanded + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + out_ref, _ = attention_ref( + q_ro, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True, window_size=window_size + ) + out_ref = out_ref.detach().cpu().numpy() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + + # Flash function + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_prompt_func( + packed_qkv, + None, + None, + config, + None, + None, + cos, + sin, + cache_seqlens, + left_window_size, + past_format, + False, + rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_prompt_func( + q, + None, + None, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + left_window_size, + past_format, + False, + rotary_interleaved, + ) + out = torch.squeeze(out, 0) + out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) + out = out.detach().cpu().numpy() + + # Make sure past-present buffer updating correctly + assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + + # Compare results + all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) + if Fore is not None: + correct = Fore.GREEN + "True" if all_close else Fore.RED + "False" + else: + correct = "True" if all_close else "False" + print( + "No buff", + " packed:", + packed, + " causal:", + causal, + " local:", + local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, + "past kv format:", + "BSNH" if past_format == Formats.BSNH else "BNSH", + " B:", + config.batch_size, + " S:", + config.q_sequence_length, + " kv S:", + config.kv_sequence_length, + " N:", + config.num_heads, + " kv N:", + config.kv_num_heads, + " h:", + config.head_size, + " Mean Error:", + numpy.mean(numpy.abs(out - out_ref)), + correct, + ) + + +def parity_check_gqa_past( + config, + causal=True, + local=False, + past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, + rtol=1e-3, + atol=1e-3, +): + q = torch.randn( + config.batch_size, + config.sequence_length, + config.num_heads, + config.head_size, + device="cpu", + dtype=torch.float32, + requires_grad=False, + ) + k = torch.randn( + config.batch_size, + config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + device="cpu", + dtype=torch.float32, + requires_grad=False, + ) + v = torch.randn( + config.batch_size, + config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + device="cpu", + dtype=torch.float32, + requires_grad=False, + ) + new_k = torch.randn( + config.batch_size, + config.sequence_length, + config.kv_num_heads, + config.head_size, + device="cpu", + dtype=torch.float32, + requires_grad=False, + ) + new_v = torch.randn( + config.batch_size, + config.sequence_length, + config.kv_num_heads, + config.head_size, + device="cpu", + dtype=torch.float32, + requires_grad=False, + ) + + window_size = (-1, -1) + left_window_size = -1 + if local: + left_window_size = random.randint(1, config.kv_sequence_length) + window_size = (left_window_size, 0) + elif causal: + left_window_size = -1 + window_size = (-1, 0) + + # Pytorch to compare + k_cache_ref = k.clone() + v_cache_ref = v.clone() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + # cache_seqlens = torch.tensor([config.past_sequence_length], device="cpu").repeat(config.batch_size) + cache_seqlens = torch.randint( + 0, + config.kv_sequence_length - config.sequence_length + 1, + (config.batch_size,), + dtype=torch.int32, + device="cpu", + ) + + if rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi + cos = torch.cos(angle).to(dtype=torch.float32) + sin = torch.sin(angle).to(dtype=torch.float32) + rot = LlamaMSRotaryEmbedding() + q_ro = rot( + q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), cache_seqlens, rotary_interleaved + ) + k_ro = rot( + new_k.clone(), + cos.unsqueeze(0).unsqueeze(2), + sin.unsqueeze(0).unsqueeze(2), + cache_seqlens, + rotary_interleaved, + ) + else: + cos, sin = None, None + q_ro, k_ro = q, new_k + + arange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length + ) + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") + v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length + out_ref, _ = attention_ref( + q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + ) + out_ref = out_ref.detach().cpu().numpy() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + + # ORT function + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_past_func( + packed_qkv, + k, + v, + config, + None, + None, + cos, + sin, + cache_seqlens, + past_format, + True, + left_window_size, + rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_past_func( + q, + k, + v, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + past_format, + True, + left_window_size, + rotary_interleaved, + ) + out = torch.squeeze(out, 0) + out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + out = out.detach().cpu().numpy() + + # Make sure past-present buffer updating correctly + assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + + # Compare results + all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) + if Fore is not None: + correct = Fore.GREEN + "True" if all_close else Fore.RED + "False" + else: + correct = "True" if all_close else "False" + print( + "KV-buffer", + "past kv format:", + "BSNH" if past_format == Formats.BSNH else "BNSH", + " packed:", + packed, + " causal:", + causal, + " local:", + local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, + " B:", + config.batch_size, + " S:", + config.sequence_length, + " kv S:", + config.kv_sequence_length, + " N:", + config.num_heads, + " kv N:", + config.kv_num_heads, + " h:", + config.head_size, + " Mean Error:", + numpy.mean(numpy.abs(out - out_ref)), + correct, + ) + + +def parity_check_gqa_past_no_buff( + config, + causal=True, + local=False, + past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, + rtol=1e-3, + atol=1e-3, +): + torch.manual_seed(69) + q = torch.randn( + config.batch_size, + config.sequence_length, + config.num_heads, + config.head_size, + device="cpu", + dtype=torch.float32, + requires_grad=False, + ) + k = torch.randn( + config.batch_size, + config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + device="cpu", + dtype=torch.float32, + requires_grad=False, + ) + v = torch.randn( + config.batch_size, + config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + device="cpu", + dtype=torch.float32, + requires_grad=False, + ) + new_k = torch.randn( + config.batch_size, + config.sequence_length, + config.kv_num_heads, + config.head_size, + device="cpu", + dtype=torch.float32, + requires_grad=False, + ) + new_v = torch.randn( + config.batch_size, + config.sequence_length, + config.kv_num_heads, + config.head_size, + device="cpu", + dtype=torch.float32, + requires_grad=False, + ) + + window_size = (-1, -1) + left_window_size = -1 + if local: + left_window_size = random.randint(1, config.kv_sequence_length) + window_size = (left_window_size, 0) + elif causal: + left_window_size = -1 + window_size = (-1, 0) + + # Pytorch to compare + k_cache_ref = k.clone() + v_cache_ref = v.clone() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + k_cache_ref = torch.cat((k_cache_ref, new_k), 1) + v_cache_ref = torch.cat((v_cache_ref, new_v), 1) + # cache_seqlens = torch.tensor([config.past_sequence_length], device="cpu").repeat(config.batch_size) + cache_seqlens = torch.randint( + 0, + config.kv_sequence_length, + (config.batch_size,), + dtype=torch.int32, + device="cpu", + ) + cache_seqlens[random.randint(0, config.batch_size - 1)] = config.kv_sequence_length + + if rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = ( + torch.rand(config.kv_sequence_length + config.sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi + ) + cos = torch.cos(angle).to(dtype=torch.float32) + sin = torch.sin(angle).to(dtype=torch.float32) + rot = LlamaMSRotaryEmbedding() + q_ro = rot( + q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), cache_seqlens, rotary_interleaved + ) + k_ro = rot( + new_k.clone(), + cos.unsqueeze(0).unsqueeze(2), + sin.unsqueeze(0).unsqueeze(2), + cache_seqlens, + rotary_interleaved, + ) + else: + cos, sin = None, None + q_ro, k_ro = q, new_k + + arange = rearrange(torch.arange(config.kv_sequence_length + config.sequence_length, device="cpu"), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length + ) + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") + v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length + out_ref, _ = attention_ref( + q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + ) + out_ref = out_ref.detach().cpu().numpy() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + + # Flash function + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_past_func( + packed_qkv, + k, + v, + config, + None, + None, + cos, + sin, + cache_seqlens, + past_format, + False, + window_size=left_window_size, + rotary_interleaved=rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_past_func( + q, + k, + v, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + past_format, + False, + window_size=left_window_size, + rotary_interleaved=rotary_interleaved, + ) + out = torch.squeeze(out, 0) + out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + out = out.detach().cpu().numpy() + + # Compare results + all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) + # if not all_close: + # print("seqlens", cache_seqlens) + # print("out", out) + # print("out_ref", out_ref) + # print(out - out_ref) + if Fore is not None: + correct = Fore.GREEN + "True" if all_close else Fore.RED + "False" + else: + correct = "True" if all_close else "False" + print( + "NO buff", + " packed:", + packed, + " causal:", + causal, + " local:", + local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, + "past kv format:", + "BSNH" if past_format == Formats.BSNH else "BNSH", + " B:", + config.batch_size, + " S:", + config.sequence_length, + " kv S:", + config.kv_sequence_length, + " N:", + config.num_heads, + " kv N:", + config.kv_num_heads, + " h:", + config.head_size, + " Mean Error:", + numpy.mean(numpy.abs(out - out_ref)), + correct, + ) + + +class TestGQA(unittest.TestCase): + def test_gqa_no_past(self): + torch.manual_seed(69) + print("-------- TEST GQA NO PAST (PROMPT CASE) ---------") + batches = [1, 3] if pipeline_mode else [1, 3, 5] + seqs = ( + [ + (127, 127), + (35, 35), + (2000, 2000), + (200, 200), + (240, 240), + ] + if pipeline_mode + else [ + (127, 127), + (35, 35), + (2000, 2000), + (200, 200), + (240, 240), + ] + ) + num_h = [(32, 32), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + for b in batches: + for sq, skv in seqs: + for n, n2 in num_h: + for h in h_sizes: + for local in [False, True]: + for rotary, rotary_interleaved in [(False, False), (True, False), (True, True)]: + for packed in [False, True]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + past_kv_format = Formats.BNSH + parity_check_gqa_prompt( + config, + local=local, + past_format=past_kv_format, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_prompt_no_buff( + config, + local=local, + past_format=past_kv_format, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + + def test_gqa_past(self): + print("-------- TEST GQA PAST (TOKEN GEN) ---------") + batches = [1, 3] if pipeline_mode else [1, 3, 5] + seqs = ( + [(1, 128), (1, 1024), (1, 2048)] + if pipeline_mode + else [ + (1, 128), + (1, 339), + (1, 1024), + (1, 5000), + (1, 800), + (1, 256), + (1, 799), + (1, 2048), + # (1, 128 * 512), + # (16, 128 * 512), + # (128, 128), + ] + ) + num_h = [(16, 16), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 64, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + random.seed(69) + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for local in [False, True]: + for rotary, rotary_interleaved in [(False, False), (True, False), (True, True)]: + for packed in [False, True]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + past_kv_format = Formats.BNSH + parity_check_gqa_past( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_past_no_buff( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + + +if __name__ == "__main__": + unittest.main()