diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 9dc183e0e7..fc559411df 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -5553,11 +5553,29 @@ This version of the operator has been available since version 1 of the 'com.micr
When number of sparse layout is 1, all heads have same sparse layout. Otherwise, different layouts are used cyclically.
For example, given 4 layouts (S0, S1, S2, S3), 8 heads will have layouts like (S0, S1, S2, S3, S0, S1, S2, S3).
- Padding shall be on the right side.
+ The block_row_indices and block_col_indices are the CSR representation of block mask. The block_col_indices might contain
+ paddings at the right side when different layout has different number of non-zeros in block mask.
- When do_rotary is True, cos_cache and sin_cache are required.
+ An example of block mask with 2 layouts where each layout is 4 x 4 blocks:
+ [[[1, 0, 0, 0],
+ [1, 1, 0, 0],
+ [0, 1, 1, 0],
+ [0, 1, 1, 1]],
+
+ [[1, 0, 0, 0],
+ [1, 1, 0, 0],
+ [1, 1, 1, 0],
+ [1, 0, 1, 1]]]
+
+ The corresponding CSR format:
+ block_col_indices = [[0, 0, 1, 1, 2, 1, 2, 3, -1], [0, 0, 1, 0, 1, 2, 0, 2, 3]]
+ block_row_indices = [[0, 1, 3, 5, 8], [0, 1, 3, 6, 9]]
+
+ When do_rotary is True, cos_cache and sin_cache are required. Note that the maximum sequence length supported by cos
+ or sin cache can be different from the maximum sequence length used by kv cache.
Only supports unidirectional attention with cache of past key and value in linear buffers.
+
For performance, past_key and present_key share same memory buffer, and past_value and present_value too.
#### Version
@@ -5581,7 +5599,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of tokens per sparse block. Choices: 16, 32, 64, 128
-#### Inputs (8 - 10)
+#### Inputs (9 - 11)
- query : T
@@ -5590,20 +5608,22 @@ This version of the operator has been available since version 1 of the 'com.micr
- Key with shape (batch_size, sequence_length, kv_num_heads * head_size)
- value (optional) : T
- Value with shape (batch_size, sequence_length, kv_num_heads * head_size)
-- past_key (optional) : T
-- Key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)
-- past_value (optional) : T
-- Value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)
-- block_mask : M
-- block mask. 1 indicates attention and 0 no attention. Its shape is (num_layout, max_blocks, max_blocks), where num_heads is divisible by num_layout, and max_blocks is max_sequence_length / sparse_block_size.
+- past_key : T
+- Key cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size)
+- past_value : T
+- Value cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size)
+- block_row_indices : M
+- The row indices of CSR format of block mask with shape (num_layout, max_blocks + 1).The num_heads is divisible by num_layout, and max_blocks is max_sequence_length / sparse_block_size.
+- block_col_indices : M
+- The col indices of CSR format of block mask with shape (num_layout, max_nnz_blocks).The max_nnz_blocks is the maximum number of non-zeros per layout in block mask.
- total_sequence_length : M
- Scalar tensor of maximum total sequence length (past_sequence_length + sequence_length) among keys.
- key_total_sequence_lengths : M
- 1D tensor with shape (batch_size) where each value is total sequence length of key excluding paddings.
- cos_cache (optional) : T
-- Cos cache of rotary with shape (max_sequence_length, head_size / 2).
+- Cos cache of rotary with shape (max_rotary_sequence_length, head_size / 2).
- sin_cache (optional) : T
-- Sin cache of rotary with shape (max_sequence_length, head_size / 2).
+- Sin cache of rotary with shape (max_rotary_sequence_length, head_size / 2).
#### Outputs
@@ -5612,9 +5632,9 @@ This version of the operator has been available since version 1 of the 'com.micr
output : T
3D output tensor with shape (batch_size, sequence_length, num_heads * head_size)
present_key : T
-Updated key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).
+Updated key cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size).
present_value : T
-Updated value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).
+Updated value cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size).
#### Type Constraints
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 414dfaab8d..3c33143a97 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -906,7 +906,7 @@ Do not modify directly.*
|SkipGroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*in* skip:**T**
*in* bias:**T**
*out* Y:**T**
*out* S:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
-|SparseAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* block_mask:**M**
*in* total_sequence_length:**M**
*in* key_total_sequence_lengths:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)|
+|SparseAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* block_row_indices:**M**
*in* block_col_indices:**M**
*in* total_sequence_length:**M**
*in* key_total_sequence_lengths:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)|
|TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|UnfoldTensor|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
index 7d07453337..a5b9c84c63 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
@@ -126,11 +126,15 @@ struct SparseAttentionParameters {
bool rotary_interleaved; // whether to use interleaved rotary embedding
int rotary_dim; // rotary embedding dimension
int sparse_block_size; // block size for sparse attention
- int num_sparse_layout; // number of sparse layout, or the first dimension of block_mask
+ int num_sparse_layout; // number of sparse layout
+ int stride_col_indices; // shape of block_col_indices is [num_sparse_layout, stride_col_indices]
+ int stride_row_indices; // shape of block_row_indices is [num_sparse_layout, stride_row_indices]
float scale; // scaling factor applied prior to softmax
bool is_packed_qkv; // whether qkv is packed
int total_sequence_length; // maximum total sequence length (past_sequence_length + sequence_length) among keys
- int max_sequence_length; // max sequence length allowed
+ int max_sequence_length; // max sequence length for sparse layout
+ int max_rotary_sequence_length; // max sequence length for rotary cos/sin cache
+ int max_cache_sequence_length; // max sequence length for kv cache buffer
bool past_present_share_buffer; // whether past_key and present_key share buffer, so is past_value and present_value
};
diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc
index d5673b29cf..506a6683de 100644
--- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc
@@ -4,7 +4,6 @@
#include "contrib_ops/cuda/sparse/sparse_attention_impl.h"
#include "contrib_ops/cuda/sparse/sparse_attention.h"
#include "contrib_ops/cuda/sparse/sparse_attention_helper.h"
-#include "contrib_ops/cuda/sparse/block_mask.h"
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_v1_api.h"
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_api.h"
#include "core/platform/env_var_utils.h"
@@ -26,7 +25,7 @@ namespace cuda {
.TypeConstraint("M", DataTypeImpl::GetTensorType()) \
.MayInplace(3, 1) \
.MayInplace(4, 2) \
- .InputMemoryType(OrtMemTypeCPUInput, 6), \
+ .InputMemoryType(OrtMemTypeCPUInput, 7), \
SparseAttention);
REGISTER_KERNEL_TYPED(MLFloat16)
@@ -77,15 +76,16 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const {
const Tensor* value = context->Input(2);
const Tensor* past_key = context->Input(3);
const Tensor* past_value = context->Input(4);
- const Tensor* block_mask = context->Input(5);
- const Tensor* total_seq_len = context->Input(6);
- const Tensor* seqlens_k_total = context->Input(7);
- const Tensor* cos_cache = context->Input(8);
- const Tensor* sin_cache = context->Input(9);
+ const Tensor* block_row_indices = context->Input(5);
+ const Tensor* block_col_indices = context->Input(6);
+ const Tensor* total_seq_len = context->Input(7);
+ const Tensor* seqlens_k_total = context->Input(8);
+ const Tensor* cos_cache = context->Input(9);
+ const Tensor* sin_cache = context->Input(10);
SparseAttentionParameters parameters;
- // Parameters from node attribute
+ // Parameters from node attribute shall be set before calling CheckInputs
parameters.sparse_block_size = sparse_block_size_;
parameters.num_heads = num_heads_;
parameters.kv_num_heads = kv_num_heads_;
@@ -101,7 +101,8 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const {
past_value,
cos_cache,
sin_cache,
- block_mask,
+ block_row_indices,
+ block_col_indices,
seqlens_k_total,
total_seq_len));
// Some limitations of CUDA kernels
@@ -177,7 +178,7 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const {
Tensor* output = context->Output(0, output_shape);
std::vector present_dims = {
- parameters.batch_size, parameters.kv_num_heads, parameters.max_sequence_length, parameters.head_size};
+ parameters.batch_size, parameters.kv_num_heads, parameters.max_cache_sequence_length, parameters.head_size};
TensorShape present_shape(present_dims);
Tensor* present_key = context->Output(1, present_shape);
Tensor* present_value = context->Output(2, present_shape);
@@ -188,13 +189,12 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const {
data.query = reinterpret_cast(query->Data());
data.key = key == nullptr ? nullptr : reinterpret_cast(key->Data());
data.value = value == nullptr ? nullptr : reinterpret_cast(value->Data());
- data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast(past_key->Data());
- data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast(past_value->Data());
- data.block_mask = block_mask->Data();
+ data.past_key = reinterpret_cast(past_key->Data());
+ data.past_value = reinterpret_cast(past_value->Data());
data.seqlens_k_total = (nullptr == seqlens_k_total) ? nullptr : seqlens_k_total->Data();
data.output = reinterpret_cast(output->MutableData());
- data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData());
- data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData());
+ data.present_key = reinterpret_cast(present_key->MutableData());
+ data.present_value = reinterpret_cast(present_value->MutableData());
// Check past and present share buffer.
parameters.past_present_share_buffer = (data.past_key != nullptr && data.past_key == data.present_key);
@@ -214,29 +214,9 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const {
// Currently, we use same block size in kernel.
// TODO: support kernel block size that is smaller than sparse_block_size in tunable (need expand block mask).
data.kernel_layout.block_size = parameters.sparse_block_size;
- data.kernel_layout.mask = data.block_mask;
data.kernel_layout.num_layout = parameters.num_sparse_layout;
- data.kernel_layout.num_cols = parameters.max_sequence_length / data.kernel_layout.block_size;
- data.kernel_layout.num_rows = parameters.max_sequence_length / data.kernel_layout.block_size;
-
- // Allocate buffer for CSR col and row indices.
- onnxruntime::Stream* stream = context->GetComputeStream();
- int dense_blocks = data.kernel_layout.num_layout * data.kernel_layout.num_cols * data.kernel_layout.num_rows;
- auto csr_col_indices_buffer = GetScratchBuffer(static_cast(dense_blocks), stream);
- auto csr_row_indices_buffer = GetScratchBuffer(
- static_cast(data.kernel_layout.num_layout * (data.kernel_layout.num_rows + 1)), stream);
-
- data.kernel_layout.csr_col_indices = reinterpret_cast(csr_col_indices_buffer.get());
- data.kernel_layout.csr_row_indices = reinterpret_cast(csr_row_indices_buffer.get());
-
- ConvertMaskToCSR(cuda_stream,
- data.kernel_layout.mask,
- data.kernel_layout.num_layout,
- data.kernel_layout.num_rows,
- data.kernel_layout.num_cols,
- csr_row_indices_buffer.get(),
- csr_col_indices_buffer.get(),
- device_prop.maxThreadsPerBlock);
+ data.kernel_layout.csr_col_indices = block_col_indices->Data();
+ data.kernel_layout.csr_row_indices = block_row_indices->Data();
size_t rotary_buffer_bytes = 0;
if (do_rotary_) {
@@ -244,7 +224,8 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const {
parameters.sequence_length * parameters.head_size;
rotary_buffer_bytes += sizeof(int64_t) * parameters.batch_size * parameters.sequence_length;
}
- auto rotary_buffer = GetScratchBuffer(rotary_buffer_bytes, context->GetComputeStream());
+ onnxruntime::Stream* stream = context->GetComputeStream();
+ auto rotary_buffer = GetScratchBuffer(rotary_buffer_bytes, stream);
data.rotary_buffer = reinterpret_cast(rotary_buffer.get());
size_t transposed_q_bytes = 0;
@@ -252,7 +233,7 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const {
transposed_q_bytes = parameters.batch_size * parameters.sequence_length *
parameters.num_heads * parameters.head_size * sizeof(T);
}
- auto transposed_q_buffer = GetScratchBuffer(transposed_q_bytes, context->GetComputeStream());
+ auto transposed_q_buffer = GetScratchBuffer(transposed_q_bytes, stream);
if (transposed_q_buffer) {
data.transposed_q_buffer = reinterpret_cast(transposed_q_buffer.get());
}
@@ -263,7 +244,7 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const {
(parameters.num_heads + 2 * parameters.kv_num_heads) *
parameters.head_size * sizeof(T));
}
- auto unpacked_qkv_buffer = GetScratchBuffer(unpacked_qkv_bytes, context->GetComputeStream());
+ auto unpacked_qkv_buffer = GetScratchBuffer(unpacked_qkv_bytes, stream);
if (unpacked_qkv_buffer) {
data.unpacked_qkv_buffer = reinterpret_cast(unpacked_qkv_buffer.get());
}
@@ -327,7 +308,7 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const {
}
}
- v2_kernel_buffer = GetScratchBuffer(v2_kernel_buffer_size, context->GetComputeStream());
+ v2_kernel_buffer = GetScratchBuffer(v2_kernel_buffer_size, stream);
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(v2_kernel_buffer.get(), v2_kernel_inputs_pinned,
sizeof(int32_t) * v2_kernel_buffer_size,
cudaMemcpyHostToDevice, cuda_stream));
diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_helper.h b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_helper.h
index 7e98b374c4..a5f1d50e61 100644
--- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_helper.h
+++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_helper.h
@@ -19,7 +19,8 @@ Status CheckInputs(void* params,
const Tensor* past_value,
const Tensor* cos_cache,
const Tensor* sin_cache,
- const Tensor* block_mask,
+ const Tensor* block_row_indices,
+ const Tensor* block_col_indices,
const Tensor* seqlens_k_total,
const Tensor* total_seq_len) {
// No packing for q/k/v:
@@ -31,14 +32,14 @@ Status CheckInputs(void* params,
// key nullptr
// value nullptr
// Shape for other inputs:
- // past_key (batch_size, kv_num_heads, max_sequence_length, head_size) or nullptr
- // past_value (batch_size, kv_num_heads, max_sequence_length, head_size) or nullptr
- // block_mask (num_heads, max_blocks, max_blocks) or (1, max_blocks, max_blocks)
- // where max_blocks = max_sequence_length / sparse_block_size
+ // past_key (batch_size, kv_num_heads, max_cache_sequence_length, head_size)
+ // past_value (batch_size, kv_num_heads, max_cache_sequence_length, head_size)
+ // block_row_indices (num_layout, max_blocks + 1), where max_blocks = max_sequence_length / sparse_block_size
+ // block_col_indices (num_layout, max_nnz)
// seqlens_k_total (batch_size) when do_rotary is True, optional otherwise
// total_seq_len (1)
- // cos_cache (max_sequence_length, rotary_dim / 2) when do_rotary is true.
- // sin_cache (max_sequence_length, rotary_dim / 2) when do_rotary is true.
+ // cos_cache (max_rotary_sequence_length, rotary_dim / 2) when do_rotary is true.
+ // sin_cache (max_rotary_sequence_length, rotary_dim / 2) when do_rotary is true.
assert(params != nullptr);
SparseAttentionParameters* parameters = reinterpret_cast(params);
@@ -121,57 +122,78 @@ Status CheckInputs(void* params,
kv_hidden_size = head_size * kv_num_heads;
}
- const auto& block_mask_dim = block_mask->Shape().GetDims();
- if (!(block_mask_dim.size() == 3 && block_mask_dim[1] == block_mask_dim[2] &&
- (static_cast(num_heads) % block_mask_dim[0] == 0L))) {
+ if (!onnxruntime::IsScalarOr1ElementVector(total_seq_len)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "total_sequence_length tensor must be of one element.");
+ }
+ int total_sequence_length = *((*total_seq_len).template Data());
+
+ // Check block_row_indices
+ const auto& block_row_indices_dim = block_row_indices->Shape().GetDims();
+ if (!(block_row_indices_dim.size() == 2 &&
+ block_row_indices_dim[1] > 1 &&
+ (static_cast(num_heads) % block_row_indices_dim[0] == 0L))) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
- "block_mask must have shape (num_layout, max_blocks, max_blocks) where num_heads is divisible by num_layout.");
+ "block_row_indices must have shape (num_layout, max_blocks + 1) where num_heads is divisible by num_layout.");
+ }
+ int max_blocks = static_cast(block_row_indices_dim[1]) - 1;
+
+ // Check block_col_indices
+ const auto& block_col_indices_dim = block_col_indices->Shape().GetDims();
+ if (!(block_col_indices_dim.size() == 2 &&
+ block_col_indices_dim[0] == block_row_indices_dim[0] &&
+ block_col_indices_dim[1] <= max_blocks * max_blocks)) {
+ return ORT_MAKE_STATUS(
+ ONNXRUNTIME, INVALID_ARGUMENT,
+ "block_col_indices must have shape (num_layout, max_nnz), "
+ "where max_nnz <= max_blocks * max_blocks.");
}
- int max_blocks = static_cast(block_mask_dim[1]);
int max_sequence_length = max_blocks * parameters->sparse_block_size;
-
- // Check past-present KV
- if (past_key != nullptr && past_value != nullptr) {
- if (past_key->Shape() != past_value->Shape()) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'past_key' and 'past_value' shall have same shape");
- }
-
- const auto& past_key_dims = past_key->Shape().GetDims();
- if (past_key_dims.size() != 4) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'past_key' is expected to have 4 dimensions, got ",
- past_key_dims.size());
- }
-
- if (past_key_dims[0] != batch_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'past_key' dimension 0 should be batch_size ", batch_size, ", got ",
- past_key_dims[0]);
- }
-
- if (past_key_dims[is_past_bsnh ? 2 : 1] != kv_num_heads) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past_key' shall have kv_num_heads");
- }
-
- int max_cache_sequence_length = static_cast(past_key_dims[is_past_bsnh ? 1 : 2]);
- if (max_cache_sequence_length != max_sequence_length) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'past_key' and 'block_mask' should have the same sequence length:",
- "max_sequence_length deduced from past_key is ", max_cache_sequence_length,
- "; max_sequence_length deduced from block_mask is ", max_sequence_length);
- }
-
- if (past_key_dims[3] != head_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'past_key' dimension 3 should be same as head_size, got ",
- past_key_dims[3]);
- }
- } else if (past_key != nullptr || past_value != nullptr) {
+ if (max_sequence_length < total_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'past_key' and 'past_value' shall be both present or both absent.");
+ "max_sequence_length should be no less than total_sequence_length:",
+ total_sequence_length,
+ ", max_sequence_length deduced from block_row_indices:", max_sequence_length);
+ }
+
+ // Check kv cache
+ ORT_ENFORCE(past_key != nullptr && past_value != nullptr);
+ if (past_key->Shape() != past_value->Shape()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_key' and 'past_value' shall have same shape");
+ }
+
+ const auto& past_key_dims = past_key->Shape().GetDims();
+ if (past_key_dims.size() != 4) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_key' is expected to have 4 dimensions, got ",
+ past_key_dims.size());
+ }
+
+ if (past_key_dims[0] != batch_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_key' dimension 0 should be batch_size ", batch_size, ", got ",
+ past_key_dims[0]);
+ }
+
+ if (past_key_dims[is_past_bsnh ? 2 : 1] != kv_num_heads) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past_key' shall have kv_num_heads");
+ }
+
+ int max_cache_sequence_length = static_cast(past_key_dims[is_past_bsnh ? 1 : 2]);
+ if (max_cache_sequence_length < total_sequence_length) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "max_cache_sequence_length should be no less than total_sequence_length:",
+ total_sequence_length,
+ ", max_cache_sequence_length:", max_cache_sequence_length);
+ }
+
+ if (past_key_dims[3] != head_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_key' dimension 3 should be same as head_size, got ",
+ past_key_dims[3]);
}
// Check the shape of total_key_sequence_lengths. We do not check the values here.
@@ -181,13 +203,8 @@ Status CheckInputs(void* params,
"key_total_sequence_lengths must have shape (batch_size).");
}
- if (!onnxruntime::IsScalarOr1ElementVector(total_seq_len)) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "total_sequence_length tensor must be of one element.");
- }
- int total_sequence_length = *((*total_seq_len).template Data());
-
int rotary_dim = 0;
+ int max_rotary_sequence_length = 0;
if (do_rotary) {
if (cos_cache == nullptr || sin_cache == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
@@ -202,14 +219,19 @@ Status CheckInputs(void* params,
"head_size shall be a multiple of 16. Got head_size = ",
head_size);
}
- if (cos_dims[0] < total_sequence_length) {
+ if (cos_dims[0] != sin_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "cos_cache dimension 0 should be not be less than total_sequence_length.");
+ "cos_cache and sin_cache dimension 0 should be same size.");
}
- if (sin_dims[0] < total_sequence_length) {
+
+ max_rotary_sequence_length = static_cast(cos_dims[0]);
+ if (max_rotary_sequence_length < total_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "sin_cache dimension 0 should be not be less than total_sequence_length.");
+ "max_rotary_sequence_length should be no less than total_sequence_length:",
+ total_sequence_length,
+ ", max_rotary_sequence_length:", max_rotary_sequence_length);
}
+
if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
@@ -229,12 +251,16 @@ Status CheckInputs(void* params,
parameters->sequence_length = sequence_length;
parameters->total_sequence_length = total_sequence_length;
parameters->max_sequence_length = max_sequence_length;
+ parameters->max_cache_sequence_length = max_cache_sequence_length;
+ parameters->max_rotary_sequence_length = max_rotary_sequence_length;
parameters->hidden_size = q_hidden_size;
parameters->head_size = head_size;
parameters->kv_hidden_size = kv_hidden_size;
parameters->rotary_dim = rotary_dim;
parameters->is_packed_qkv = is_packed_qkv;
- parameters->num_sparse_layout = static_cast(block_mask_dim[0]);
+ parameters->num_sparse_layout = static_cast(block_row_indices_dim[0]);
+ parameters->stride_row_indices = static_cast(block_row_indices_dim[1]);
+ parameters->stride_col_indices = static_cast(block_col_indices_dim[1]);
return Status::OK();
}
diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.cu b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.cu
index 5d6182b613..d833a7cf02 100644
--- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.cu
@@ -2,7 +2,6 @@
// Licensed under the MIT License.
#include "contrib_ops/cuda/sparse/sparse_attention_impl.h"
-#include "contrib_ops/cuda/sparse/block_mask.h"
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
#include "contrib_ops/cuda/bert/rotary_embedding_impl.h"
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
@@ -88,7 +87,7 @@ Status LaunchConcatKVInPlace(contrib::SparseAttentionParameters& parameters,
return LaunchConcatKVInPlace(parameters.batch_size,
parameters.kv_num_heads,
parameters.head_size,
- parameters.max_sequence_length,
+ parameters.max_cache_sequence_length,
nullptr,
data.seqlens_k_total,
parameters.sequence_length,
@@ -112,7 +111,6 @@ Status QkvToContext(
const int max_threads_per_block = device_prop.maxThreadsPerBlock;
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
- // const int present_sequence_length = parameters.max_sequence_length;
const int num_heads = parameters.num_heads;
const int kv_num_heads = parameters.kv_num_heads;
const int head_size = parameters.head_size;
@@ -182,14 +180,14 @@ Status QkvToContext(
position_ids_buff, data.cos_cache, data.sin_cache,
parameters.batch_size, parameters.sequence_length,
parameters.num_heads, parameters.head_size,
- parameters.rotary_dim, parameters.max_sequence_length,
+ parameters.rotary_dim, parameters.max_rotary_sequence_length,
/*position_ids_format*/ 1, parameters.rotary_interleaved,
max_threads_per_block, q_layout));
ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(stream, k_buffer, reinterpret_cast(key),
position_ids_buff, data.cos_cache, data.sin_cache,
parameters.batch_size, parameters.sequence_length,
parameters.kv_num_heads, parameters.head_size,
- parameters.rotary_dim, parameters.max_sequence_length,
+ parameters.rotary_dim, parameters.max_rotary_sequence_length,
/*position_ids_format*/ 1, parameters.rotary_interleaved,
max_threads_per_block, kv_layout));
query = reinterpret_cast(q_buffer);
@@ -215,29 +213,24 @@ Status QkvToContext(
// TODO: only dump to total sequence length instead of max sequence length.
#if DUMP_TENSOR_LEVEL > 0
- DUMP_TENSOR("key cache", data.present_key, batch_size, kv_num_heads, parameters.max_sequence_length, head_size);
- DUMP_TENSOR("value cache", data.present_value, batch_size, kv_num_heads, parameters.max_sequence_length, head_size);
-
- DUMP_TENSOR("block_mask",
- data.kernel_layout.mask,
- data.kernel_layout.num_layout,
- data.kernel_layout.num_rows,
- data.kernel_layout.num_cols);
+ DUMP_TENSOR("key cache", data.present_key, batch_size, kv_num_heads,
+ parameters.max_cache_sequence_length, head_size);
+ DUMP_TENSOR("value cache", data.present_value, batch_size, kv_num_heads,
+ parameters.max_cache_sequence_length, head_size);
DUMP_TENSOR("csr_col_indices",
data.kernel_layout.csr_col_indices,
data.kernel_layout.num_layout,
- data.kernel_layout.num_rows,
- data.kernel_layout.num_cols);
+ parameters.stride_col_indices);
DUMP_TENSOR("csr_row_indices",
data.kernel_layout.csr_row_indices,
data.kernel_layout.num_layout,
- data.kernel_layout.num_rows + 1);
+ parameters.stride_row_indices);
printf(
"batch_size=%d, sequence_length=%d, num_heads=%d, kv_num_heads=%d head_size=%d, "
- "total_sequence_length=%d, max_sequence_length=%d scale=%f block_size=%d "
+ "total_sequence_length=%d, max_sequence_length=%d max_cache_sequence_length=%d scale=%f block_size=%d "
"row_stride=%d col_stride=%d num_layout=%d\n",
parameters.batch_size,
parameters.sequence_length,
@@ -246,10 +239,11 @@ Status QkvToContext(
parameters.head_size,
parameters.total_sequence_length,
parameters.max_sequence_length,
+ parameters.max_cache_sequence_length,
parameters.scale,
data.kernel_layout.block_size,
- data.kernel_layout.num_rows + 1,
- data.kernel_layout.num_rows * data.kernel_layout.num_cols,
+ parameters.stride_row_indices,
+ parameters.stride_col_indices,
data.kernel_layout.num_layout);
#endif
@@ -262,19 +256,20 @@ Status QkvToContext(
reinterpret_cast(query),
reinterpret_cast(data.present_key),
reinterpret_cast(data.present_value),
+ q_layout == LAYOUT_BNSH,
parameters.batch_size,
parameters.sequence_length,
parameters.num_heads,
parameters.kv_num_heads,
parameters.head_size,
parameters.total_sequence_length,
- parameters.max_sequence_length,
+ parameters.max_cache_sequence_length,
parameters.scale,
- data.kernel_layout.block_size, // kernel_block_size
- data.kernel_layout.csr_row_indices, // skip past_seq_len in row indices
- data.kernel_layout.csr_col_indices, // (num_layout, num_rows, num_cols)
- data.kernel_layout.num_rows + 1, // stride per head in row indices
- data.kernel_layout.num_rows * data.kernel_layout.num_cols, // stride per head in col indices
+ data.kernel_layout.block_size, // kernel_block_size
+ data.kernel_layout.csr_row_indices, // shape (num_layout, stride_row_indices)
+ data.kernel_layout.csr_col_indices, // shape (num_layout, stride_col_indices)
+ parameters.stride_row_indices,
+ parameters.stride_col_indices,
data.kernel_layout.num_layout,
data.active_q_blocks,
data.q_batch_starts,
@@ -297,19 +292,20 @@ Status QkvToContext(
reinterpret_cast(query),
reinterpret_cast(data.present_key),
reinterpret_cast(data.present_value),
+ q_layout == LAYOUT_BNSH,
parameters.batch_size,
parameters.sequence_length,
parameters.num_heads,
parameters.kv_num_heads,
parameters.head_size,
parameters.total_sequence_length,
- parameters.max_sequence_length,
+ parameters.max_cache_sequence_length,
parameters.scale,
- data.kernel_layout.block_size, // kernel_block_size
- data.kernel_layout.csr_row_indices, // (num_layout, num_rows + 1)
- data.kernel_layout.csr_col_indices, // (num_layout, num_rows, num_cols)
- data.kernel_layout.num_rows + 1, // stride per head in row indices
- data.kernel_layout.num_rows * data.kernel_layout.num_cols, // stride per head in col indices
+ data.kernel_layout.block_size, // kernel_block_size
+ data.kernel_layout.csr_row_indices, // (num_layout, stride_row_indices)
+ data.kernel_layout.csr_col_indices, // (num_layout, stride_row_indices)
+ parameters.stride_row_indices,
+ parameters.stride_col_indices,
data.kernel_layout.num_layout);
if constexpr (std::is_same::value) {
@@ -319,7 +315,7 @@ Status QkvToContext(
}
}
- DUMP_TENSOR("output", reinterpret_cast(data.output), batch_size, num_heads, sequence_length, head_size);
+ DUMP_TENSOR("output", reinterpret_cast(data.output), batch_size, sequence_length, num_heads, head_size);
return Status::OK();
}
diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.h b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.h
index 03e2a3dd08..0b07b234b7 100644
--- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.h
+++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.h
@@ -16,14 +16,10 @@ namespace contrib {
namespace cuda {
struct BlockLayout {
- const int32_t* mask; // shape (num_layout, num_rows, num_cols), where num_rows = num_cols = max_seq_len / block_size.
int num_layout;
- int block_size; // kernel block size, which is <= sparse_block_size
-
- const int* csr_col_indices;
- const int* csr_row_indices;
- int num_rows;
- int num_cols;
+ int block_size; // kernel block size, which is <= sparse_block_size
+ const int* csr_row_indices; // shape [num_layout, stride_row_indices]
+ const int* csr_col_indices; // shape [num_layout, stride_col_indices]
};
template
diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h
index d69d3621d0..a90c603d7d 100644
--- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h
+++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h
@@ -22,6 +22,8 @@ struct SparseAttentionParams {
const void* k;
const void* v;
+ bool is_q_bnsh;
+
int batch_size;
int num_heads;
int kv_num_heads;
@@ -30,7 +32,7 @@ struct SparseAttentionParams {
int sequence_length;
int past_sequence_length;
int total_sequence_length;
- int max_sequence_length;
+ int max_cache_sequence_length;
float scale;
@@ -64,13 +66,14 @@ struct SparseAttentionParams {
const void* q,
const void* k,
const void* v,
+ bool is_q_bnsh,
int batch_size,
int sequence_length,
int num_heads,
int kv_num_heads,
int head_size,
int total_sequence_length,
- int max_sequence_length,
+ int max_cache_sequence_length,
float scale,
int kernel_block_size,
const int* layout_csr_row_indices,
@@ -84,6 +87,7 @@ struct SparseAttentionParams {
this->q = q;
this->k = k;
this->v = v;
+ this->is_q_bnsh = is_q_bnsh;
this->batch_size = batch_size;
this->sequence_length = sequence_length;
this->num_heads = num_heads;
@@ -91,7 +95,7 @@ struct SparseAttentionParams {
this->head_size = head_size;
this->past_sequence_length = total_sequence_length - sequence_length;
this->total_sequence_length = total_sequence_length;
- this->max_sequence_length = max_sequence_length;
+ this->max_cache_sequence_length = max_cache_sequence_length;
this->scale = scale == 0.0f ? 1.0f / sqrtf(static_cast(head_size)) : scale;
this->kernel_block_size = kernel_block_size;
this->layout_csr_row_indices = layout_csr_row_indices;
@@ -101,18 +105,16 @@ struct SparseAttentionParams {
this->num_layout = num_layout;
this->stride_qb = this->num_heads * this->sequence_length * this->head_size;
- this->stride_qh = this->sequence_length * this->head_size;
+ this->stride_qh = (is_q_bnsh ? this->sequence_length : this->num_heads) * this->head_size;
this->stride_qm = this->head_size;
// When kv buffer has max sequence length, stride should match max sequence length.
- int kv_buffer_sequence_length = max_sequence_length;
-
// KV cache is in BNSH format
- this->stride_kb = this->kv_num_heads * kv_buffer_sequence_length * this->head_size;
- this->stride_kh = kv_buffer_sequence_length * this->head_size;
+ this->stride_kb = this->kv_num_heads * max_cache_sequence_length * this->head_size;
+ this->stride_kh = max_cache_sequence_length * this->head_size;
this->stride_kn = this->head_size;
- this->stride_vb = this->kv_num_heads * kv_buffer_sequence_length * this->head_size;
- this->stride_vh = kv_buffer_sequence_length * this->head_size;
+ this->stride_vb = this->kv_num_heads * max_cache_sequence_length * this->head_size;
+ this->stride_vh = max_cache_sequence_length * this->head_size;
this->stride_vn = this->head_size;
// Output is BSNH format
@@ -142,8 +144,8 @@ struct SparseAttentionParams {
#if DUMP_TENSOR_LEVEL > 0
DUMP_TENSOR_INIT();
DUMP_TENSOR("q", reinterpret_cast(q), batch_size, num_heads, sequence_length, head_size);
- DUMP_TENSOR("k", reinterpret_cast(k), batch_size, kv_num_heads, max_sequence_length, head_size);
- DUMP_TENSOR("v", reinterpret_cast(v), batch_size, kv_num_heads, max_sequence_length, head_size);
+ DUMP_TENSOR("k", reinterpret_cast(k), batch_size, kv_num_heads, max_cache_sequence_length, head_size);
+ DUMP_TENSOR("v", reinterpret_cast(v), batch_size, kv_num_heads, max_cache_sequence_length, head_size);
DUMP_TENSOR("csr_col_indices",
layout_csr_col_indices,
num_layout,
diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_common.h b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_common.h
index 328cb8b5d8..af19a90b32 100644
--- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_common.h
+++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_common.h
@@ -18,6 +18,8 @@ struct SparseAttentionParams {
const void* k;
const void* v;
+ bool is_q_bnsh;
+
int batch_size;
int num_heads;
int kv_num_heads;
@@ -26,7 +28,7 @@ struct SparseAttentionParams {
int sequence_length;
int past_sequence_length;
int total_sequence_length;
- int max_sequence_length;
+ int max_cache_sequence_length;
float scale;
@@ -70,13 +72,14 @@ struct SparseAttentionParams {
const void* q,
const void* k,
const void* v,
+ bool is_q_bnsh,
int batch_size,
int sequence_length,
int num_heads,
int kv_num_heads,
int head_size,
int total_sequence_length,
- int max_sequence_length,
+ int max_cache_sequence_length,
float scale,
int kernel_block_size,
const int* layout_csr_row_indices,
@@ -97,6 +100,7 @@ struct SparseAttentionParams {
this->q = q;
this->k = k;
this->v = v;
+ this->is_q_bnsh = is_q_bnsh;
this->batch_size = batch_size;
this->sequence_length = sequence_length;
this->num_heads = num_heads;
@@ -104,7 +108,7 @@ struct SparseAttentionParams {
this->head_size = head_size;
this->past_sequence_length = total_sequence_length - sequence_length;
this->total_sequence_length = total_sequence_length;
- this->max_sequence_length = max_sequence_length;
+ this->max_cache_sequence_length = max_cache_sequence_length;
this->scale = scale == 0.0f ? 1.0f / sqrtf(static_cast(head_size)) : scale;
this->kernel_block_size = kernel_block_size;
this->layout_csr_row_indices = layout_csr_row_indices;
@@ -113,20 +117,18 @@ struct SparseAttentionParams {
this->layout_col_stride_h = layout_col_stride_h;
this->num_layout = num_layout;
- // Q is in BNSH format
+ // Q can be either BNSH or BSNH format
this->stride_qb = this->num_heads * this->sequence_length * this->head_size;
- this->stride_qh = this->sequence_length * this->head_size;
+ this->stride_qh = (is_q_bnsh ? this->sequence_length : this->num_heads) * this->head_size;
this->stride_qt = this->head_size;
// When kv buffer has max sequence length, stride should match max sequence length.
- int kv_buffer_sequence_length = max_sequence_length;
-
// KV cache is in BNSH format
- this->stride_kb = this->kv_num_heads * kv_buffer_sequence_length * this->head_size;
- this->stride_kh = kv_buffer_sequence_length * this->head_size;
+ this->stride_kb = this->kv_num_heads * max_cache_sequence_length * this->head_size;
+ this->stride_kh = max_cache_sequence_length * this->head_size;
this->stride_kt = this->head_size;
- this->stride_vb = this->kv_num_heads * kv_buffer_sequence_length * this->head_size;
- this->stride_vh = kv_buffer_sequence_length * this->head_size;
+ this->stride_vb = this->kv_num_heads * max_cache_sequence_length * this->head_size;
+ this->stride_vh = max_cache_sequence_length * this->head_size;
this->stride_vt = this->head_size;
// Output is BSNH format
@@ -167,8 +169,8 @@ struct SparseAttentionParams {
#if DUMP_TENSOR_LEVEL > 0
DUMP_TENSOR_INIT();
DUMP_TENSOR("q", reinterpret_cast(q), batch_size, num_heads, sequence_length, head_size);
- DUMP_TENSOR("k", reinterpret_cast(k), batch_size, kv_num_heads, max_sequence_length, head_size);
- DUMP_TENSOR("v", reinterpret_cast(v), batch_size, kv_num_heads, max_sequence_length, head_size);
+ DUMP_TENSOR("k", reinterpret_cast(k), batch_size, kv_num_heads, max_cache_sequence_length, head_size);
+ DUMP_TENSOR("v", reinterpret_cast(v), batch_size, kv_num_heads, max_cache_sequence_length, head_size);
DUMP_TENSOR("csr_col_indices",
layout_csr_col_indices,
num_layout,
@@ -187,13 +189,13 @@ struct SparseAttentionParams {
DUMP_TENSOR("q_start_sids", q_start_sids, 1, active_q_blocks);
printf(
- "layout_row_stride_h=%d, layout_col_stride_h=%d, num_layout=%d, scale=%f,\n"
+ "layout_row_stride_h=%d, layout_col_stride_h=%d, num_layout=%d, scale=%f, is_q_bnsh=%d,\n"
"stride_qb=%d, stride_qt=%d, stride_qh=%d, stride_kb=%d, stride_kt=%d, stride_kh=%d,\n"
"stride_vb=%d, stride_vt=%d, stride_vh=%d, stride_ob=%d, stride_ot=%d, stride_oh=%d,\n"
"num_heads=%d, kv_num_heads=%d, total_sequence_length=%d, past_sequence_length=%d\n"
"output=%p, q=%p, k=%p, v=%p, layout_csr_row_indices=%p, layout_csr_col_indices=%p\n"
"q_batch_starts=%p, q_batch_ends=%p, k_batch_starts=%p, k_batch_ends=%p, q_batch_ids=%p, q_start_sids=%p active_q_blocks=%d\n",
- layout_row_stride_h, layout_col_stride_h, num_layout, scale,
+ layout_row_stride_h, layout_col_stride_h, num_layout, scale, static_cast(is_q_bnsh),
stride_qb, stride_qt, stride_qh, stride_kb, stride_kt, stride_kh,
stride_vb, stride_vt, stride_vh, stride_ob, stride_ot, stride_oh,
num_heads, kv_num_heads, total_sequence_length, past_sequence_length,
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index e9de04f8a9..916f0c92fd 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -287,7 +287,7 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte
auto& past_shape = getInputShape(ctx, past_key_index);
auto& past_dims = past_shape.dim();
- // past key has shape (batch_size, kv_num_heads, max_sequence_length, head_size)
+ // past key has shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size)
if (past_dims.size() != 4) {
fail_shape_inference("The past_key input shall be 4 dimensions");
}
@@ -1152,11 +1152,29 @@ block_mask can be used to configure sparse layout for different head.
When number of sparse layout is 1, all heads have same sparse layout. Otherwise, different layouts are used cyclically.
For example, given 4 layouts (S0, S1, S2, S3), 8 heads will have layouts like (S0, S1, S2, S3, S0, S1, S2, S3).
-Padding shall be on the right side.
+The block_row_indices and block_col_indices are the CSR representation of block mask. The block_col_indices might contain
+paddings at the right side when different layout has different number of non-zeros in block mask.
-When do_rotary is True, cos_cache and sin_cache are required.
+An example of block mask with 2 layouts where each layout is 4 x 4 blocks:
+ [[[1, 0, 0, 0],
+ [1, 1, 0, 0],
+ [0, 1, 1, 0],
+ [0, 1, 1, 1]],
+
+ [[1, 0, 0, 0],
+ [1, 1, 0, 0],
+ [1, 1, 1, 0],
+ [1, 0, 1, 1]]]
+
+The corresponding CSR format:
+ block_col_indices = [[0, 0, 1, 1, 2, 1, 2, 3, -1], [0, 0, 1, 0, 1, 2, 0, 2, 3]]
+ block_row_indices = [[0, 1, 3, 5, 8], [0, 1, 3, 6, 9]]
+
+When do_rotary is True, cos_cache and sin_cache are required. Note that the maximum sequence length supported by cos
+or sin cache can be different from the maximum sequence length used by kv cache.
Only supports unidirectional attention with cache of past key and value in linear buffers.
+
For performance, past_key and present_key share same memory buffer, and past_value and present_value too.
)DOC";
@@ -1190,36 +1208,38 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
OpSchema::Optional)
.Input(3,
"past_key",
- "Key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)",
- "T",
- OpSchema::Optional)
+ "Key cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size)",
+ "T")
.Input(4,
"past_value",
- "Value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)",
- "T",
- OpSchema::Optional)
+ "Value cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size)",
+ "T")
.Input(5,
- "block_mask",
- "block mask. 1 indicates attention and 0 no attention. "
- "Its shape is (num_layout, max_blocks, max_blocks), "
- "where num_heads is divisible by num_layout, and max_blocks is max_sequence_length / sparse_block_size.",
+ "block_row_indices",
+ "The row indices of CSR format of block mask with shape (num_layout, max_blocks + 1)."
+ "The num_heads is divisible by num_layout, and max_blocks is max_sequence_length / sparse_block_size.",
"M")
.Input(6,
+ "block_col_indices",
+ "The col indices of CSR format of block mask with shape (num_layout, max_nnz_blocks)."
+ "The max_nnz_blocks is the maximum number of non-zeros per layout in block mask.",
+ "M")
+ .Input(7,
"total_sequence_length",
"Scalar tensor of maximum total sequence length (past_sequence_length + sequence_length) among keys.",
"M")
- .Input(7,
+ .Input(8,
"key_total_sequence_lengths",
"1D tensor with shape (batch_size) where each value is total sequence length of key excluding paddings.",
"M")
- .Input(8,
+ .Input(9,
"cos_cache",
- "Cos cache of rotary with shape (max_sequence_length, head_size / 2).",
+ "Cos cache of rotary with shape (max_rotary_sequence_length, head_size / 2).",
"T",
OpSchema::Optional)
- .Input(9,
+ .Input(10,
"sin_cache",
- "Sin cache of rotary with shape (max_sequence_length, head_size / 2).",
+ "Sin cache of rotary with shape (max_rotary_sequence_length, head_size / 2).",
"T",
OpSchema::Optional)
.Output(0,
@@ -1228,11 +1248,11 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"T")
.Output(1,
"present_key",
- "Updated key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).",
+ "Updated key cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size).",
"T")
.Output(2,
"present_value",
- "Updated value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).",
+ "Updated value cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size).",
"T")
.TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.")
.TypeConstraint("M", {"tensor(int32)"}, "Constrain integer type.")
diff --git a/onnxruntime/test/python/transformers/test_sparse_attention.py b/onnxruntime/test/python/transformers/test_sparse_attention.py
index 9fcc23288e..2dd6dc627b 100644
--- a/onnxruntime/test/python/transformers/test_sparse_attention.py
+++ b/onnxruntime/test/python/transformers/test_sparse_attention.py
@@ -38,21 +38,27 @@ class AttentionConfig:
dtype=torch.float16,
share_buffer: bool = True,
is_packed_qkv: bool = False,
+ max_cache_sequence_length=None,
+ max_rotary_sequence_length=None,
):
self.operator = operator
self.batch_size = batch_size
self.sequence_length = sequence_length
self.max_sequence_length = max_sequence_length
+ self.max_cache_sequence_length = max_cache_sequence_length or max_sequence_length
+ self.max_rotary_sequence_length = max_rotary_sequence_length or max_sequence_length
self.past_sequence_length = past_sequence_length
self.num_heads = num_heads
self.kv_num_heads = kv_num_heads
self.head_size = head_size
- self.softmax_scale = softmax_scale if softmax_scale is not None else 1.0 / (head_size**0.5)
+ self.softmax_scale = softmax_scale or (1.0 / (head_size**0.5))
# Derived values
self.total_sequence_length = sequence_length + past_sequence_length
- self.past_buffer_length = max_sequence_length if share_buffer else past_sequence_length
- self.present_buffer_length = max_sequence_length if share_buffer else (past_sequence_length + sequence_length)
+ self.past_buffer_length = self.max_cache_sequence_length if share_buffer else past_sequence_length
+ self.present_buffer_length = (
+ self.max_cache_sequence_length if share_buffer else (past_sequence_length + sequence_length)
+ )
self.do_rotary = do_rotary
self.rotary_interleaved = rotary_interleaved
@@ -75,8 +81,8 @@ class AttentionConfig:
"output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
"present_key": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size),
"present_value": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size),
- "cos_cache": (self.max_sequence_length, (math.floor(self.head_size / 16) * 16) // 2),
- "sin_cache": (self.max_sequence_length, (math.floor(self.head_size / 16) * 16) // 2),
+ "cos_cache": (self.max_rotary_sequence_length, (math.floor(self.head_size / 16) * 16) // 2),
+ "sin_cache": (self.max_rotary_sequence_length, (math.floor(self.head_size / 16) * 16) // 2),
}
if not self.is_packed_qkv:
@@ -92,7 +98,7 @@ class AttentionConfig:
def get_cos_sin_cache(self, dtype):
rotary_fraction = 1.0
rotary_dim = math.floor(int(rotary_fraction * self.head_size) / 16) * 16
- angle = torch.rand(self.max_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi
+ angle = torch.rand(self.max_rotary_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
return cos.to(device=self.device), sin.to(device=self.device)
@@ -151,6 +157,8 @@ class GroupQueryAttentionConfig(AttentionConfig):
local_window_size: int = -1,
attention_mask=None,
is_packed_qkv=False,
+ max_cache_sequence_length=None,
+ max_rotary_sequence_length=None,
):
super().__init__(
"GroupQueryAttention",
@@ -166,6 +174,8 @@ class GroupQueryAttentionConfig(AttentionConfig):
rotary_interleaved,
device,
is_packed_qkv=is_packed_qkv,
+ max_cache_sequence_length=max_cache_sequence_length,
+ max_rotary_sequence_length=max_rotary_sequence_length,
)
# local_window_size is for ORT only, not for Torch implementation.
self.local_window_size = local_window_size
@@ -212,6 +222,8 @@ class SparseAttentionConfig(AttentionConfig):
rotary_interleaved: bool = False,
device="cuda",
is_packed_qkv=False,
+ max_cache_sequence_length=None,
+ max_rotary_sequence_length=None,
):
super().__init__(
"SparseAttention",
@@ -227,6 +239,8 @@ class SparseAttentionConfig(AttentionConfig):
rotary_interleaved,
device,
is_packed_qkv=is_packed_qkv,
+ max_cache_sequence_length=max_cache_sequence_length,
+ max_rotary_sequence_length=max_rotary_sequence_length,
)
self.sparse_block_size = sparse_block_size
self.num_layout = num_layout
@@ -237,18 +251,23 @@ class SparseAttentionConfig(AttentionConfig):
def block_mask(self):
return get_block_mask(self.num_layout, self.max_blocks, self.local_blocks, self.vert_stride).to(self.device)
+ def block_indices(self):
+ row_indices, col_indices = dense_to_csr(self.block_mask())
+ return row_indices.to(torch.int32).to(self.device), col_indices.to(torch.int32).to(self.device)
+
def dense_mask(self):
- expand_block_mask = self.block_mask()
dense_mask = get_dense_mask(
- expand_block_mask, self.total_sequence_length, self.sequence_length, self.sparse_block_size
+ self.block_mask(), self.total_sequence_length, self.sequence_length, self.sparse_block_size
)
return dense_mask.repeat(self.batch_size, self.num_heads // self.num_layout, 1, 1).to(self.device)
def shape_dict(self):
shapes = super().shape_dict()
+ block_row_indices, block_col_indices = self.block_indices()
shapes.update(
{
- "block_mask": (self.num_layout, self.max_blocks, self.max_blocks),
+ "block_row_indices": tuple(block_row_indices.shape),
+ "block_col_indices": tuple(block_col_indices.shape),
"key_total_sequence_lengths": (self.batch_size,),
}
)
@@ -257,10 +276,11 @@ class SparseAttentionConfig(AttentionConfig):
def random_inputs(self):
feeds = super().random_inputs()
k_seqlens = torch.ones((self.batch_size,), device=self.device, dtype=torch.int32) * self.total_sequence_length
+ block_row_indices, block_col_indices = self.block_indices()
feeds.update(
{
- "block_mask": self.block_mask(),
- "total_sequence_length": torch.tensor([self.total_sequence_length], dtype=torch.int32),
+ "block_row_indices": block_row_indices,
+ "block_col_indices": block_col_indices,
"key_total_sequence_lengths": k_seqlens,
}
)
@@ -281,6 +301,8 @@ class SparseAttentionConfig(AttentionConfig):
self.device,
local_window_size=self.local_blocks * self.sparse_block_size if use_local else -1,
is_packed_qkv=self.is_packed_qkv,
+ max_cache_sequence_length=self.max_cache_sequence_length,
+ max_rotary_sequence_length=self.max_rotary_sequence_length,
)
def get_comparable_torch_gqa_config(self, use_sparse=False) -> GroupQueryAttentionConfig:
@@ -305,6 +327,8 @@ class SparseAttentionConfig(AttentionConfig):
self.device,
attention_mask=attention_mask,
is_packed_qkv=False, # torch reference implementation does not support packed qkv.
+ max_cache_sequence_length=self.max_cache_sequence_length,
+ max_rotary_sequence_length=self.max_rotary_sequence_length,
)
@@ -327,6 +351,19 @@ def get_block_mask(num_layout, max_blocks, local_blocks, vert_stride):
return block_mask
+def dense_to_csr(x):
+ """Turning a 3D torch tensor (x) to CSR rows/cols indexing."""
+ assert x.dim() == 3
+ pad = -1
+ x = [xi.to_sparse_csr() for xi in x]
+ row_indices = torch.vstack([xi.crow_indices() for xi in x])
+ cols = [xi.col_indices() for xi in x]
+ max_cols = max(len(xi) for xi in cols)
+ cols = [torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])]) for xi in cols]
+ col_indices = torch.vstack(cols)
+ return row_indices, col_indices
+
+
def get_dense_mask(block_mask, total_seq_len, query_seq_len, block_size):
dense_mask = torch.kron(block_mask, block_mask.new_ones((block_size, block_size)))[
:, :total_seq_len, :total_seq_len
@@ -350,7 +387,8 @@ def create_sparse_attention_onnx_model(config: SparseAttentionConfig):
"value" + suffix if not config.is_packed_qkv else "",
"past_key" + suffix,
"past_value" + suffix,
- "block_mask",
+ "block_row_indices", # no suffix since int32 need not cast for bfloat graph.
+ "block_col_indices",
"total_sequence_length" if config.share_buffer else "",
"key_total_sequence_lengths",
"cos_cache" + suffix if config.do_rotary else "",
@@ -410,7 +448,12 @@ def create_sparse_attention_onnx_model(config: SparseAttentionConfig):
[
helper.make_tensor_value_info("past_key", io_float_type, list(shape_dict["past_key"])),
helper.make_tensor_value_info("past_value", io_float_type, list(shape_dict["past_value"])),
- helper.make_tensor_value_info("block_mask", TensorProto.INT32, list(shape_dict["block_mask"])),
+ helper.make_tensor_value_info(
+ "block_row_indices", TensorProto.INT32, list(shape_dict["block_row_indices"])
+ ),
+ helper.make_tensor_value_info(
+ "block_col_indices", TensorProto.INT32, list(shape_dict["block_col_indices"])
+ ),
helper.make_tensor_value_info(
"total_sequence_length", TensorProto.INT32, list(shape_dict["total_sequence_length"])
),
@@ -704,7 +747,8 @@ class OrtSparseAttention:
print("query(BSNH, SA)", query)
print("key(BSNH, SA)", key)
print("value(BSNH, SA)", value)
- print("block_mask (SA)", self.feed_dict["block_mask"])
+ print("block_row_indices", self.feed_dict["block_row_indices"])
+ print("block_col_indices", self.feed_dict["block_col_indices"])
print("total_sequence_length", self.feed_dict["total_sequence_length"])
print("key_total_sequence_lengths", self.feed_dict["key_total_sequence_lengths"])
@@ -778,6 +822,7 @@ class TestSparseAttention(unittest.TestCase):
softmax_scale=1.8 / (128**0.5),
device=device,
is_packed_qkv=packed_qkv,
+ max_cache_sequence_length=None if seq_len >= 128 else 128, # test smaller kv cache buffer.
)
self.run_one_relevance_test(config)
@@ -806,6 +851,7 @@ class TestSparseAttention(unittest.TestCase):
rotary_interleaved=(past_seq_len % 2 == 1),
device=device,
is_packed_qkv=packed_qkv,
+ max_rotary_sequence_length=None if past_seq_len >= 128 else 128, # test smaller rotary buffer.
)
if do_rotary: