mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Update SparseAttention op spec to make it more flexible (#20625)
### Description Make the operator more flexible: (1) Decouple the max sequence length of rotary cache, kv cache and block mask. They are allowed to have different values. (2) Replace block_mask dense by CSR format (block_row_indices and block_col_indices) to improve performance. (3) Mark past_key and past_value as required inputs since we need them to compute the shape of present_key and present_value. ### Motivation and Context (1) LongRoPE has short and long rotary cache, which has different length. (2) Most users do not have enough GPU memory to run maximum sequence length 128K. This change allows user to use smaller kv cache length to test without hitting out of memory.
This commit is contained in:
parent
a0c4bd4da7
commit
01dd991f97
11 changed files with 313 additions and 220 deletions
|
|
@ -5553,11 +5553,29 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
When number of sparse layout is 1, all heads have same sparse layout. Otherwise, different layouts are used cyclically.
|
||||
For example, given 4 layouts (S0, S1, S2, S3), 8 heads will have layouts like (S0, S1, S2, S3, S0, S1, S2, S3).
|
||||
|
||||
Padding shall be on the right side.
|
||||
The block_row_indices and block_col_indices are the CSR representation of block mask. The block_col_indices might contain
|
||||
paddings at the right side when different layout has different number of non-zeros in block mask.
|
||||
|
||||
When do_rotary is True, cos_cache and sin_cache are required.
|
||||
An example of block mask with 2 layouts where each layout is 4 x 4 blocks:
|
||||
[[[1, 0, 0, 0],
|
||||
[1, 1, 0, 0],
|
||||
[0, 1, 1, 0],
|
||||
[0, 1, 1, 1]],
|
||||
|
||||
[[1, 0, 0, 0],
|
||||
[1, 1, 0, 0],
|
||||
[1, 1, 1, 0],
|
||||
[1, 0, 1, 1]]]
|
||||
|
||||
The corresponding CSR format:
|
||||
block_col_indices = [[0, 0, 1, 1, 2, 1, 2, 3, -1], [0, 0, 1, 0, 1, 2, 0, 2, 3]]
|
||||
block_row_indices = [[0, 1, 3, 5, 8], [0, 1, 3, 6, 9]]
|
||||
|
||||
When do_rotary is True, cos_cache and sin_cache are required. Note that the maximum sequence length supported by cos
|
||||
or sin cache can be different from the maximum sequence length used by kv cache.
|
||||
|
||||
Only supports unidirectional attention with cache of past key and value in linear buffers.
|
||||
|
||||
For performance, past_key and present_key share same memory buffer, and past_value and present_value too.
|
||||
|
||||
#### Version
|
||||
|
|
@ -5581,7 +5599,7 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
<dd>Number of tokens per sparse block. Choices: 16, 32, 64, 128</dd>
|
||||
</dl>
|
||||
|
||||
#### Inputs (8 - 10)
|
||||
#### Inputs (9 - 11)
|
||||
|
||||
<dl>
|
||||
<dt><tt>query</tt> : T</dt>
|
||||
|
|
@ -5590,20 +5608,22 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
<dd>Key with shape (batch_size, sequence_length, kv_num_heads * head_size)</dd>
|
||||
<dt><tt>value</tt> (optional) : T</dt>
|
||||
<dd>Value with shape (batch_size, sequence_length, kv_num_heads * head_size)</dd>
|
||||
<dt><tt>past_key</tt> (optional) : T</dt>
|
||||
<dd>Key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)</dd>
|
||||
<dt><tt>past_value</tt> (optional) : T</dt>
|
||||
<dd>Value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)</dd>
|
||||
<dt><tt>block_mask</tt> : M</dt>
|
||||
<dd>block mask. 1 indicates attention and 0 no attention. Its shape is (num_layout, max_blocks, max_blocks), where num_heads is divisible by num_layout, and max_blocks is max_sequence_length / sparse_block_size.</dd>
|
||||
<dt><tt>past_key</tt> : T</dt>
|
||||
<dd>Key cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size)</dd>
|
||||
<dt><tt>past_value</tt> : T</dt>
|
||||
<dd>Value cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size)</dd>
|
||||
<dt><tt>block_row_indices</tt> : M</dt>
|
||||
<dd>The row indices of CSR format of block mask with shape (num_layout, max_blocks + 1).The num_heads is divisible by num_layout, and max_blocks is max_sequence_length / sparse_block_size.</dd>
|
||||
<dt><tt>block_col_indices</tt> : M</dt>
|
||||
<dd>The col indices of CSR format of block mask with shape (num_layout, max_nnz_blocks).The max_nnz_blocks is the maximum number of non-zeros per layout in block mask.</dd>
|
||||
<dt><tt>total_sequence_length</tt> : M</dt>
|
||||
<dd>Scalar tensor of maximum total sequence length (past_sequence_length + sequence_length) among keys.</dd>
|
||||
<dt><tt>key_total_sequence_lengths</tt> : M</dt>
|
||||
<dd>1D tensor with shape (batch_size) where each value is total sequence length of key excluding paddings.</dd>
|
||||
<dt><tt>cos_cache</tt> (optional) : T</dt>
|
||||
<dd>Cos cache of rotary with shape (max_sequence_length, head_size / 2).</dd>
|
||||
<dd>Cos cache of rotary with shape (max_rotary_sequence_length, head_size / 2).</dd>
|
||||
<dt><tt>sin_cache</tt> (optional) : T</dt>
|
||||
<dd>Sin cache of rotary with shape (max_sequence_length, head_size / 2).</dd>
|
||||
<dd>Sin cache of rotary with shape (max_rotary_sequence_length, head_size / 2).</dd>
|
||||
</dl>
|
||||
|
||||
#### Outputs
|
||||
|
|
@ -5612,9 +5632,9 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
<dt><tt>output</tt> : T</dt>
|
||||
<dd>3D output tensor with shape (batch_size, sequence_length, num_heads * head_size)</dd>
|
||||
<dt><tt>present_key</tt> : T</dt>
|
||||
<dd>Updated key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).</dd>
|
||||
<dd>Updated key cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size).</dd>
|
||||
<dt><tt>present_value</tt> : T</dt>
|
||||
<dd>Updated value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).</dd>
|
||||
<dd>Updated value cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size).</dd>
|
||||
</dl>
|
||||
|
||||
#### Type Constraints
|
||||
|
|
|
|||
|
|
@ -906,7 +906,7 @@ Do not modify directly.*
|
|||
|SkipGroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *in* skip:**T**<br> *in* bias:**T**<br> *out* Y:**T**<br> *out* S:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|SkipSimplifiedLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|SparseAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* block_mask:**M**<br> *in* total_sequence_length:**M**<br> *in* key_total_sequence_lengths:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
|
||||
|SparseAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* block_row_indices:**M**<br> *in* block_col_indices:**M**<br> *in* total_sequence_length:**M**<br> *in* key_total_sequence_lengths:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
|
||||
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|
||||
|Trilu|*in* X:**T**<br> *in* k:**tensor(int64)**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|UnfoldTensor|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|
|
|
|||
|
|
@ -126,11 +126,15 @@ struct SparseAttentionParameters {
|
|||
bool rotary_interleaved; // whether to use interleaved rotary embedding
|
||||
int rotary_dim; // rotary embedding dimension
|
||||
int sparse_block_size; // block size for sparse attention
|
||||
int num_sparse_layout; // number of sparse layout, or the first dimension of block_mask
|
||||
int num_sparse_layout; // number of sparse layout
|
||||
int stride_col_indices; // shape of block_col_indices is [num_sparse_layout, stride_col_indices]
|
||||
int stride_row_indices; // shape of block_row_indices is [num_sparse_layout, stride_row_indices]
|
||||
float scale; // scaling factor applied prior to softmax
|
||||
bool is_packed_qkv; // whether qkv is packed
|
||||
int total_sequence_length; // maximum total sequence length (past_sequence_length + sequence_length) among keys
|
||||
int max_sequence_length; // max sequence length allowed
|
||||
int max_sequence_length; // max sequence length for sparse layout
|
||||
int max_rotary_sequence_length; // max sequence length for rotary cos/sin cache
|
||||
int max_cache_sequence_length; // max sequence length for kv cache buffer
|
||||
bool past_present_share_buffer; // whether past_key and present_key share buffer, so is past_value and present_value
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
#include "contrib_ops/cuda/sparse/sparse_attention_impl.h"
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention.h"
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_helper.h"
|
||||
#include "contrib_ops/cuda/sparse/block_mask.h"
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_v1_api.h"
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_api.h"
|
||||
#include "core/platform/env_var_utils.h"
|
||||
|
|
@ -26,7 +25,7 @@ namespace cuda {
|
|||
.TypeConstraint("M", DataTypeImpl::GetTensorType<int32_t>()) \
|
||||
.MayInplace(3, 1) \
|
||||
.MayInplace(4, 2) \
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 6), \
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 7), \
|
||||
SparseAttention<T>);
|
||||
|
||||
REGISTER_KERNEL_TYPED(MLFloat16)
|
||||
|
|
@ -77,15 +76,16 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
const Tensor* value = context->Input<Tensor>(2);
|
||||
const Tensor* past_key = context->Input<Tensor>(3);
|
||||
const Tensor* past_value = context->Input<Tensor>(4);
|
||||
const Tensor* block_mask = context->Input<Tensor>(5);
|
||||
const Tensor* total_seq_len = context->Input<Tensor>(6);
|
||||
const Tensor* seqlens_k_total = context->Input<Tensor>(7);
|
||||
const Tensor* cos_cache = context->Input<Tensor>(8);
|
||||
const Tensor* sin_cache = context->Input<Tensor>(9);
|
||||
const Tensor* block_row_indices = context->Input<Tensor>(5);
|
||||
const Tensor* block_col_indices = context->Input<Tensor>(6);
|
||||
const Tensor* total_seq_len = context->Input<Tensor>(7);
|
||||
const Tensor* seqlens_k_total = context->Input<Tensor>(8);
|
||||
const Tensor* cos_cache = context->Input<Tensor>(9);
|
||||
const Tensor* sin_cache = context->Input<Tensor>(10);
|
||||
|
||||
SparseAttentionParameters parameters;
|
||||
|
||||
// Parameters from node attribute
|
||||
// Parameters from node attribute shall be set before calling CheckInputs
|
||||
parameters.sparse_block_size = sparse_block_size_;
|
||||
parameters.num_heads = num_heads_;
|
||||
parameters.kv_num_heads = kv_num_heads_;
|
||||
|
|
@ -101,7 +101,8 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
past_value,
|
||||
cos_cache,
|
||||
sin_cache,
|
||||
block_mask,
|
||||
block_row_indices,
|
||||
block_col_indices,
|
||||
seqlens_k_total,
|
||||
total_seq_len));
|
||||
// Some limitations of CUDA kernels
|
||||
|
|
@ -177,7 +178,7 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
Tensor* output = context->Output(0, output_shape);
|
||||
|
||||
std::vector<int64_t> present_dims = {
|
||||
parameters.batch_size, parameters.kv_num_heads, parameters.max_sequence_length, parameters.head_size};
|
||||
parameters.batch_size, parameters.kv_num_heads, parameters.max_cache_sequence_length, parameters.head_size};
|
||||
TensorShape present_shape(present_dims);
|
||||
Tensor* present_key = context->Output(1, present_shape);
|
||||
Tensor* present_value = context->Output(2, present_shape);
|
||||
|
|
@ -188,13 +189,12 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
data.query = reinterpret_cast<const CudaT*>(query->Data<T>());
|
||||
data.key = key == nullptr ? nullptr : reinterpret_cast<const CudaT*>(key->Data<T>());
|
||||
data.value = value == nullptr ? nullptr : reinterpret_cast<const CudaT*>(value->Data<T>());
|
||||
data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast<const CudaT*>(past_key->Data<T>());
|
||||
data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast<const CudaT*>(past_value->Data<T>());
|
||||
data.block_mask = block_mask->Data<int32_t>();
|
||||
data.past_key = reinterpret_cast<const CudaT*>(past_key->Data<T>());
|
||||
data.past_value = reinterpret_cast<const CudaT*>(past_value->Data<T>());
|
||||
data.seqlens_k_total = (nullptr == seqlens_k_total) ? nullptr : seqlens_k_total->Data<int32_t>();
|
||||
data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
|
||||
data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast<CudaT*>(present_key->MutableData<T>());
|
||||
data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast<CudaT*>(present_value->MutableData<T>());
|
||||
data.present_key = reinterpret_cast<CudaT*>(present_key->MutableData<T>());
|
||||
data.present_value = reinterpret_cast<CudaT*>(present_value->MutableData<T>());
|
||||
|
||||
// Check past and present share buffer.
|
||||
parameters.past_present_share_buffer = (data.past_key != nullptr && data.past_key == data.present_key);
|
||||
|
|
@ -214,29 +214,9 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
// Currently, we use same block size in kernel.
|
||||
// TODO: support kernel block size that is smaller than sparse_block_size in tunable (need expand block mask).
|
||||
data.kernel_layout.block_size = parameters.sparse_block_size;
|
||||
data.kernel_layout.mask = data.block_mask;
|
||||
data.kernel_layout.num_layout = parameters.num_sparse_layout;
|
||||
data.kernel_layout.num_cols = parameters.max_sequence_length / data.kernel_layout.block_size;
|
||||
data.kernel_layout.num_rows = parameters.max_sequence_length / data.kernel_layout.block_size;
|
||||
|
||||
// Allocate buffer for CSR col and row indices.
|
||||
onnxruntime::Stream* stream = context->GetComputeStream();
|
||||
int dense_blocks = data.kernel_layout.num_layout * data.kernel_layout.num_cols * data.kernel_layout.num_rows;
|
||||
auto csr_col_indices_buffer = GetScratchBuffer<int>(static_cast<size_t>(dense_blocks), stream);
|
||||
auto csr_row_indices_buffer = GetScratchBuffer<int>(
|
||||
static_cast<size_t>(data.kernel_layout.num_layout * (data.kernel_layout.num_rows + 1)), stream);
|
||||
|
||||
data.kernel_layout.csr_col_indices = reinterpret_cast<const int*>(csr_col_indices_buffer.get());
|
||||
data.kernel_layout.csr_row_indices = reinterpret_cast<const int*>(csr_row_indices_buffer.get());
|
||||
|
||||
ConvertMaskToCSR(cuda_stream,
|
||||
data.kernel_layout.mask,
|
||||
data.kernel_layout.num_layout,
|
||||
data.kernel_layout.num_rows,
|
||||
data.kernel_layout.num_cols,
|
||||
csr_row_indices_buffer.get(),
|
||||
csr_col_indices_buffer.get(),
|
||||
device_prop.maxThreadsPerBlock);
|
||||
data.kernel_layout.csr_col_indices = block_col_indices->Data<int32_t>();
|
||||
data.kernel_layout.csr_row_indices = block_row_indices->Data<int32_t>();
|
||||
|
||||
size_t rotary_buffer_bytes = 0;
|
||||
if (do_rotary_) {
|
||||
|
|
@ -244,7 +224,8 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
parameters.sequence_length * parameters.head_size;
|
||||
rotary_buffer_bytes += sizeof(int64_t) * parameters.batch_size * parameters.sequence_length;
|
||||
}
|
||||
auto rotary_buffer = GetScratchBuffer<void>(rotary_buffer_bytes, context->GetComputeStream());
|
||||
onnxruntime::Stream* stream = context->GetComputeStream();
|
||||
auto rotary_buffer = GetScratchBuffer<void>(rotary_buffer_bytes, stream);
|
||||
data.rotary_buffer = reinterpret_cast<CudaT*>(rotary_buffer.get());
|
||||
|
||||
size_t transposed_q_bytes = 0;
|
||||
|
|
@ -252,7 +233,7 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
transposed_q_bytes = parameters.batch_size * parameters.sequence_length *
|
||||
parameters.num_heads * parameters.head_size * sizeof(T);
|
||||
}
|
||||
auto transposed_q_buffer = GetScratchBuffer<void>(transposed_q_bytes, context->GetComputeStream());
|
||||
auto transposed_q_buffer = GetScratchBuffer<void>(transposed_q_bytes, stream);
|
||||
if (transposed_q_buffer) {
|
||||
data.transposed_q_buffer = reinterpret_cast<CudaT*>(transposed_q_buffer.get());
|
||||
}
|
||||
|
|
@ -263,7 +244,7 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
(parameters.num_heads + 2 * parameters.kv_num_heads) *
|
||||
parameters.head_size * sizeof(T));
|
||||
}
|
||||
auto unpacked_qkv_buffer = GetScratchBuffer<void>(unpacked_qkv_bytes, context->GetComputeStream());
|
||||
auto unpacked_qkv_buffer = GetScratchBuffer<void>(unpacked_qkv_bytes, stream);
|
||||
if (unpacked_qkv_buffer) {
|
||||
data.unpacked_qkv_buffer = reinterpret_cast<CudaT*>(unpacked_qkv_buffer.get());
|
||||
}
|
||||
|
|
@ -327,7 +308,7 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
}
|
||||
}
|
||||
|
||||
v2_kernel_buffer = GetScratchBuffer<int>(v2_kernel_buffer_size, context->GetComputeStream());
|
||||
v2_kernel_buffer = GetScratchBuffer<int>(v2_kernel_buffer_size, stream);
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(v2_kernel_buffer.get(), v2_kernel_inputs_pinned,
|
||||
sizeof(int32_t) * v2_kernel_buffer_size,
|
||||
cudaMemcpyHostToDevice, cuda_stream));
|
||||
|
|
|
|||
|
|
@ -19,7 +19,8 @@ Status CheckInputs(void* params,
|
|||
const Tensor* past_value,
|
||||
const Tensor* cos_cache,
|
||||
const Tensor* sin_cache,
|
||||
const Tensor* block_mask,
|
||||
const Tensor* block_row_indices,
|
||||
const Tensor* block_col_indices,
|
||||
const Tensor* seqlens_k_total,
|
||||
const Tensor* total_seq_len) {
|
||||
// No packing for q/k/v:
|
||||
|
|
@ -31,14 +32,14 @@ Status CheckInputs(void* params,
|
|||
// key nullptr
|
||||
// value nullptr
|
||||
// Shape for other inputs:
|
||||
// past_key (batch_size, kv_num_heads, max_sequence_length, head_size) or nullptr
|
||||
// past_value (batch_size, kv_num_heads, max_sequence_length, head_size) or nullptr
|
||||
// block_mask (num_heads, max_blocks, max_blocks) or (1, max_blocks, max_blocks)
|
||||
// where max_blocks = max_sequence_length / sparse_block_size
|
||||
// past_key (batch_size, kv_num_heads, max_cache_sequence_length, head_size)
|
||||
// past_value (batch_size, kv_num_heads, max_cache_sequence_length, head_size)
|
||||
// block_row_indices (num_layout, max_blocks + 1), where max_blocks = max_sequence_length / sparse_block_size
|
||||
// block_col_indices (num_layout, max_nnz)
|
||||
// seqlens_k_total (batch_size) when do_rotary is True, optional otherwise
|
||||
// total_seq_len (1)
|
||||
// cos_cache (max_sequence_length, rotary_dim / 2) when do_rotary is true.
|
||||
// sin_cache (max_sequence_length, rotary_dim / 2) when do_rotary is true.
|
||||
// cos_cache (max_rotary_sequence_length, rotary_dim / 2) when do_rotary is true.
|
||||
// sin_cache (max_rotary_sequence_length, rotary_dim / 2) when do_rotary is true.
|
||||
|
||||
assert(params != nullptr);
|
||||
SparseAttentionParameters* parameters = reinterpret_cast<SparseAttentionParameters*>(params);
|
||||
|
|
@ -121,57 +122,78 @@ Status CheckInputs(void* params,
|
|||
kv_hidden_size = head_size * kv_num_heads;
|
||||
}
|
||||
|
||||
const auto& block_mask_dim = block_mask->Shape().GetDims();
|
||||
if (!(block_mask_dim.size() == 3 && block_mask_dim[1] == block_mask_dim[2] &&
|
||||
(static_cast<int64_t>(num_heads) % block_mask_dim[0] == 0L))) {
|
||||
if (!onnxruntime::IsScalarOr1ElementVector(total_seq_len)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"total_sequence_length tensor must be of one element.");
|
||||
}
|
||||
int total_sequence_length = *((*total_seq_len).template Data<int32_t>());
|
||||
|
||||
// Check block_row_indices
|
||||
const auto& block_row_indices_dim = block_row_indices->Shape().GetDims();
|
||||
if (!(block_row_indices_dim.size() == 2 &&
|
||||
block_row_indices_dim[1] > 1 &&
|
||||
(static_cast<int64_t>(num_heads) % block_row_indices_dim[0] == 0L))) {
|
||||
return ORT_MAKE_STATUS(
|
||||
ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"block_mask must have shape (num_layout, max_blocks, max_blocks) where num_heads is divisible by num_layout.");
|
||||
"block_row_indices must have shape (num_layout, max_blocks + 1) where num_heads is divisible by num_layout.");
|
||||
}
|
||||
int max_blocks = static_cast<int>(block_row_indices_dim[1]) - 1;
|
||||
|
||||
// Check block_col_indices
|
||||
const auto& block_col_indices_dim = block_col_indices->Shape().GetDims();
|
||||
if (!(block_col_indices_dim.size() == 2 &&
|
||||
block_col_indices_dim[0] == block_row_indices_dim[0] &&
|
||||
block_col_indices_dim[1] <= max_blocks * max_blocks)) {
|
||||
return ORT_MAKE_STATUS(
|
||||
ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"block_col_indices must have shape (num_layout, max_nnz), "
|
||||
"where max_nnz <= max_blocks * max_blocks.");
|
||||
}
|
||||
|
||||
int max_blocks = static_cast<int>(block_mask_dim[1]);
|
||||
int max_sequence_length = max_blocks * parameters->sparse_block_size;
|
||||
|
||||
// Check past-present KV
|
||||
if (past_key != nullptr && past_value != nullptr) {
|
||||
if (past_key->Shape() != past_value->Shape()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_key' and 'past_value' shall have same shape");
|
||||
}
|
||||
|
||||
const auto& past_key_dims = past_key->Shape().GetDims();
|
||||
if (past_key_dims.size() != 4) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_key' is expected to have 4 dimensions, got ",
|
||||
past_key_dims.size());
|
||||
}
|
||||
|
||||
if (past_key_dims[0] != batch_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_key' dimension 0 should be batch_size ", batch_size, ", got ",
|
||||
past_key_dims[0]);
|
||||
}
|
||||
|
||||
if (past_key_dims[is_past_bsnh ? 2 : 1] != kv_num_heads) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past_key' shall have kv_num_heads");
|
||||
}
|
||||
|
||||
int max_cache_sequence_length = static_cast<int>(past_key_dims[is_past_bsnh ? 1 : 2]);
|
||||
if (max_cache_sequence_length != max_sequence_length) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_key' and 'block_mask' should have the same sequence length:",
|
||||
"max_sequence_length deduced from past_key is ", max_cache_sequence_length,
|
||||
"; max_sequence_length deduced from block_mask is ", max_sequence_length);
|
||||
}
|
||||
|
||||
if (past_key_dims[3] != head_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_key' dimension 3 should be same as head_size, got ",
|
||||
past_key_dims[3]);
|
||||
}
|
||||
} else if (past_key != nullptr || past_value != nullptr) {
|
||||
if (max_sequence_length < total_sequence_length) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_key' and 'past_value' shall be both present or both absent.");
|
||||
"max_sequence_length should be no less than total_sequence_length:",
|
||||
total_sequence_length,
|
||||
", max_sequence_length deduced from block_row_indices:", max_sequence_length);
|
||||
}
|
||||
|
||||
// Check kv cache
|
||||
ORT_ENFORCE(past_key != nullptr && past_value != nullptr);
|
||||
if (past_key->Shape() != past_value->Shape()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_key' and 'past_value' shall have same shape");
|
||||
}
|
||||
|
||||
const auto& past_key_dims = past_key->Shape().GetDims();
|
||||
if (past_key_dims.size() != 4) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_key' is expected to have 4 dimensions, got ",
|
||||
past_key_dims.size());
|
||||
}
|
||||
|
||||
if (past_key_dims[0] != batch_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_key' dimension 0 should be batch_size ", batch_size, ", got ",
|
||||
past_key_dims[0]);
|
||||
}
|
||||
|
||||
if (past_key_dims[is_past_bsnh ? 2 : 1] != kv_num_heads) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past_key' shall have kv_num_heads");
|
||||
}
|
||||
|
||||
int max_cache_sequence_length = static_cast<int>(past_key_dims[is_past_bsnh ? 1 : 2]);
|
||||
if (max_cache_sequence_length < total_sequence_length) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"max_cache_sequence_length should be no less than total_sequence_length:",
|
||||
total_sequence_length,
|
||||
", max_cache_sequence_length:", max_cache_sequence_length);
|
||||
}
|
||||
|
||||
if (past_key_dims[3] != head_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_key' dimension 3 should be same as head_size, got ",
|
||||
past_key_dims[3]);
|
||||
}
|
||||
|
||||
// Check the shape of total_key_sequence_lengths. We do not check the values here.
|
||||
|
|
@ -181,13 +203,8 @@ Status CheckInputs(void* params,
|
|||
"key_total_sequence_lengths must have shape (batch_size).");
|
||||
}
|
||||
|
||||
if (!onnxruntime::IsScalarOr1ElementVector(total_seq_len)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"total_sequence_length tensor must be of one element.");
|
||||
}
|
||||
int total_sequence_length = *((*total_seq_len).template Data<int32_t>());
|
||||
|
||||
int rotary_dim = 0;
|
||||
int max_rotary_sequence_length = 0;
|
||||
if (do_rotary) {
|
||||
if (cos_cache == nullptr || sin_cache == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
|
|
@ -202,14 +219,19 @@ Status CheckInputs(void* params,
|
|||
"head_size shall be a multiple of 16. Got head_size = ",
|
||||
head_size);
|
||||
}
|
||||
if (cos_dims[0] < total_sequence_length) {
|
||||
if (cos_dims[0] != sin_dims[0]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"cos_cache dimension 0 should be not be less than total_sequence_length.");
|
||||
"cos_cache and sin_cache dimension 0 should be same size.");
|
||||
}
|
||||
if (sin_dims[0] < total_sequence_length) {
|
||||
|
||||
max_rotary_sequence_length = static_cast<int>(cos_dims[0]);
|
||||
if (max_rotary_sequence_length < total_sequence_length) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"sin_cache dimension 0 should be not be less than total_sequence_length.");
|
||||
"max_rotary_sequence_length should be no less than total_sequence_length:",
|
||||
total_sequence_length,
|
||||
", max_rotary_sequence_length:", max_rotary_sequence_length);
|
||||
}
|
||||
|
||||
if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
|
||||
|
|
@ -229,12 +251,16 @@ Status CheckInputs(void* params,
|
|||
parameters->sequence_length = sequence_length;
|
||||
parameters->total_sequence_length = total_sequence_length;
|
||||
parameters->max_sequence_length = max_sequence_length;
|
||||
parameters->max_cache_sequence_length = max_cache_sequence_length;
|
||||
parameters->max_rotary_sequence_length = max_rotary_sequence_length;
|
||||
parameters->hidden_size = q_hidden_size;
|
||||
parameters->head_size = head_size;
|
||||
parameters->kv_hidden_size = kv_hidden_size;
|
||||
parameters->rotary_dim = rotary_dim;
|
||||
parameters->is_packed_qkv = is_packed_qkv;
|
||||
parameters->num_sparse_layout = static_cast<int>(block_mask_dim[0]);
|
||||
parameters->num_sparse_layout = static_cast<int>(block_row_indices_dim[0]);
|
||||
parameters->stride_row_indices = static_cast<int>(block_row_indices_dim[1]);
|
||||
parameters->stride_col_indices = static_cast<int>(block_col_indices_dim[1]);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_impl.h"
|
||||
#include "contrib_ops/cuda/sparse/block_mask.h"
|
||||
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
|
||||
#include "contrib_ops/cuda/bert/rotary_embedding_impl.h"
|
||||
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
|
||||
|
|
@ -88,7 +87,7 @@ Status LaunchConcatKVInPlace(contrib::SparseAttentionParameters& parameters,
|
|||
return LaunchConcatKVInPlace(parameters.batch_size,
|
||||
parameters.kv_num_heads,
|
||||
parameters.head_size,
|
||||
parameters.max_sequence_length,
|
||||
parameters.max_cache_sequence_length,
|
||||
nullptr,
|
||||
data.seqlens_k_total,
|
||||
parameters.sequence_length,
|
||||
|
|
@ -112,7 +111,6 @@ Status QkvToContext(
|
|||
const int max_threads_per_block = device_prop.maxThreadsPerBlock;
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int sequence_length = parameters.sequence_length;
|
||||
// const int present_sequence_length = parameters.max_sequence_length;
|
||||
const int num_heads = parameters.num_heads;
|
||||
const int kv_num_heads = parameters.kv_num_heads;
|
||||
const int head_size = parameters.head_size;
|
||||
|
|
@ -182,14 +180,14 @@ Status QkvToContext(
|
|||
position_ids_buff, data.cos_cache, data.sin_cache,
|
||||
parameters.batch_size, parameters.sequence_length,
|
||||
parameters.num_heads, parameters.head_size,
|
||||
parameters.rotary_dim, parameters.max_sequence_length,
|
||||
parameters.rotary_dim, parameters.max_rotary_sequence_length,
|
||||
/*position_ids_format*/ 1, parameters.rotary_interleaved,
|
||||
max_threads_per_block, q_layout));
|
||||
ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel<T>(stream, k_buffer, reinterpret_cast<const T*>(key),
|
||||
position_ids_buff, data.cos_cache, data.sin_cache,
|
||||
parameters.batch_size, parameters.sequence_length,
|
||||
parameters.kv_num_heads, parameters.head_size,
|
||||
parameters.rotary_dim, parameters.max_sequence_length,
|
||||
parameters.rotary_dim, parameters.max_rotary_sequence_length,
|
||||
/*position_ids_format*/ 1, parameters.rotary_interleaved,
|
||||
max_threads_per_block, kv_layout));
|
||||
query = reinterpret_cast<const void*>(q_buffer);
|
||||
|
|
@ -215,29 +213,24 @@ Status QkvToContext(
|
|||
|
||||
// TODO: only dump to total sequence length instead of max sequence length.
|
||||
#if DUMP_TENSOR_LEVEL > 0
|
||||
DUMP_TENSOR("key cache", data.present_key, batch_size, kv_num_heads, parameters.max_sequence_length, head_size);
|
||||
DUMP_TENSOR("value cache", data.present_value, batch_size, kv_num_heads, parameters.max_sequence_length, head_size);
|
||||
|
||||
DUMP_TENSOR("block_mask",
|
||||
data.kernel_layout.mask,
|
||||
data.kernel_layout.num_layout,
|
||||
data.kernel_layout.num_rows,
|
||||
data.kernel_layout.num_cols);
|
||||
DUMP_TENSOR("key cache", data.present_key, batch_size, kv_num_heads,
|
||||
parameters.max_cache_sequence_length, head_size);
|
||||
DUMP_TENSOR("value cache", data.present_value, batch_size, kv_num_heads,
|
||||
parameters.max_cache_sequence_length, head_size);
|
||||
|
||||
DUMP_TENSOR("csr_col_indices",
|
||||
data.kernel_layout.csr_col_indices,
|
||||
data.kernel_layout.num_layout,
|
||||
data.kernel_layout.num_rows,
|
||||
data.kernel_layout.num_cols);
|
||||
parameters.stride_col_indices);
|
||||
|
||||
DUMP_TENSOR("csr_row_indices",
|
||||
data.kernel_layout.csr_row_indices,
|
||||
data.kernel_layout.num_layout,
|
||||
data.kernel_layout.num_rows + 1);
|
||||
parameters.stride_row_indices);
|
||||
|
||||
printf(
|
||||
"batch_size=%d, sequence_length=%d, num_heads=%d, kv_num_heads=%d head_size=%d, "
|
||||
"total_sequence_length=%d, max_sequence_length=%d scale=%f block_size=%d "
|
||||
"total_sequence_length=%d, max_sequence_length=%d max_cache_sequence_length=%d scale=%f block_size=%d "
|
||||
"row_stride=%d col_stride=%d num_layout=%d\n",
|
||||
parameters.batch_size,
|
||||
parameters.sequence_length,
|
||||
|
|
@ -246,10 +239,11 @@ Status QkvToContext(
|
|||
parameters.head_size,
|
||||
parameters.total_sequence_length,
|
||||
parameters.max_sequence_length,
|
||||
parameters.max_cache_sequence_length,
|
||||
parameters.scale,
|
||||
data.kernel_layout.block_size,
|
||||
data.kernel_layout.num_rows + 1,
|
||||
data.kernel_layout.num_rows * data.kernel_layout.num_cols,
|
||||
parameters.stride_row_indices,
|
||||
parameters.stride_col_indices,
|
||||
data.kernel_layout.num_layout);
|
||||
#endif
|
||||
|
||||
|
|
@ -262,19 +256,20 @@ Status QkvToContext(
|
|||
reinterpret_cast<const void*>(query),
|
||||
reinterpret_cast<const void*>(data.present_key),
|
||||
reinterpret_cast<const void*>(data.present_value),
|
||||
q_layout == LAYOUT_BNSH,
|
||||
parameters.batch_size,
|
||||
parameters.sequence_length,
|
||||
parameters.num_heads,
|
||||
parameters.kv_num_heads,
|
||||
parameters.head_size,
|
||||
parameters.total_sequence_length,
|
||||
parameters.max_sequence_length,
|
||||
parameters.max_cache_sequence_length,
|
||||
parameters.scale,
|
||||
data.kernel_layout.block_size, // kernel_block_size
|
||||
data.kernel_layout.csr_row_indices, // skip past_seq_len in row indices
|
||||
data.kernel_layout.csr_col_indices, // (num_layout, num_rows, num_cols)
|
||||
data.kernel_layout.num_rows + 1, // stride per head in row indices
|
||||
data.kernel_layout.num_rows * data.kernel_layout.num_cols, // stride per head in col indices
|
||||
data.kernel_layout.block_size, // kernel_block_size
|
||||
data.kernel_layout.csr_row_indices, // shape (num_layout, stride_row_indices)
|
||||
data.kernel_layout.csr_col_indices, // shape (num_layout, stride_col_indices)
|
||||
parameters.stride_row_indices,
|
||||
parameters.stride_col_indices,
|
||||
data.kernel_layout.num_layout,
|
||||
data.active_q_blocks,
|
||||
data.q_batch_starts,
|
||||
|
|
@ -297,19 +292,20 @@ Status QkvToContext(
|
|||
reinterpret_cast<const void*>(query),
|
||||
reinterpret_cast<const void*>(data.present_key),
|
||||
reinterpret_cast<const void*>(data.present_value),
|
||||
q_layout == LAYOUT_BNSH,
|
||||
parameters.batch_size,
|
||||
parameters.sequence_length,
|
||||
parameters.num_heads,
|
||||
parameters.kv_num_heads,
|
||||
parameters.head_size,
|
||||
parameters.total_sequence_length,
|
||||
parameters.max_sequence_length,
|
||||
parameters.max_cache_sequence_length,
|
||||
parameters.scale,
|
||||
data.kernel_layout.block_size, // kernel_block_size
|
||||
data.kernel_layout.csr_row_indices, // (num_layout, num_rows + 1)
|
||||
data.kernel_layout.csr_col_indices, // (num_layout, num_rows, num_cols)
|
||||
data.kernel_layout.num_rows + 1, // stride per head in row indices
|
||||
data.kernel_layout.num_rows * data.kernel_layout.num_cols, // stride per head in col indices
|
||||
data.kernel_layout.block_size, // kernel_block_size
|
||||
data.kernel_layout.csr_row_indices, // (num_layout, stride_row_indices)
|
||||
data.kernel_layout.csr_col_indices, // (num_layout, stride_row_indices)
|
||||
parameters.stride_row_indices,
|
||||
parameters.stride_col_indices,
|
||||
data.kernel_layout.num_layout);
|
||||
|
||||
if constexpr (std::is_same<T, BFloat16>::value) {
|
||||
|
|
@ -319,7 +315,7 @@ Status QkvToContext(
|
|||
}
|
||||
}
|
||||
|
||||
DUMP_TENSOR("output", reinterpret_cast<const T*>(data.output), batch_size, num_heads, sequence_length, head_size);
|
||||
DUMP_TENSOR("output", reinterpret_cast<const T*>(data.output), batch_size, sequence_length, num_heads, head_size);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,14 +16,10 @@ namespace contrib {
|
|||
namespace cuda {
|
||||
|
||||
struct BlockLayout {
|
||||
const int32_t* mask; // shape (num_layout, num_rows, num_cols), where num_rows = num_cols = max_seq_len / block_size.
|
||||
int num_layout;
|
||||
int block_size; // kernel block size, which is <= sparse_block_size
|
||||
|
||||
const int* csr_col_indices;
|
||||
const int* csr_row_indices;
|
||||
int num_rows;
|
||||
int num_cols;
|
||||
int block_size; // kernel block size, which is <= sparse_block_size
|
||||
const int* csr_row_indices; // shape [num_layout, stride_row_indices]
|
||||
const int* csr_col_indices; // shape [num_layout, stride_col_indices]
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ struct SparseAttentionParams {
|
|||
const void* k;
|
||||
const void* v;
|
||||
|
||||
bool is_q_bnsh;
|
||||
|
||||
int batch_size;
|
||||
int num_heads;
|
||||
int kv_num_heads;
|
||||
|
|
@ -30,7 +32,7 @@ struct SparseAttentionParams {
|
|||
int sequence_length;
|
||||
int past_sequence_length;
|
||||
int total_sequence_length;
|
||||
int max_sequence_length;
|
||||
int max_cache_sequence_length;
|
||||
|
||||
float scale;
|
||||
|
||||
|
|
@ -64,13 +66,14 @@ struct SparseAttentionParams {
|
|||
const void* q,
|
||||
const void* k,
|
||||
const void* v,
|
||||
bool is_q_bnsh,
|
||||
int batch_size,
|
||||
int sequence_length,
|
||||
int num_heads,
|
||||
int kv_num_heads,
|
||||
int head_size,
|
||||
int total_sequence_length,
|
||||
int max_sequence_length,
|
||||
int max_cache_sequence_length,
|
||||
float scale,
|
||||
int kernel_block_size,
|
||||
const int* layout_csr_row_indices,
|
||||
|
|
@ -84,6 +87,7 @@ struct SparseAttentionParams {
|
|||
this->q = q;
|
||||
this->k = k;
|
||||
this->v = v;
|
||||
this->is_q_bnsh = is_q_bnsh;
|
||||
this->batch_size = batch_size;
|
||||
this->sequence_length = sequence_length;
|
||||
this->num_heads = num_heads;
|
||||
|
|
@ -91,7 +95,7 @@ struct SparseAttentionParams {
|
|||
this->head_size = head_size;
|
||||
this->past_sequence_length = total_sequence_length - sequence_length;
|
||||
this->total_sequence_length = total_sequence_length;
|
||||
this->max_sequence_length = max_sequence_length;
|
||||
this->max_cache_sequence_length = max_cache_sequence_length;
|
||||
this->scale = scale == 0.0f ? 1.0f / sqrtf(static_cast<float>(head_size)) : scale;
|
||||
this->kernel_block_size = kernel_block_size;
|
||||
this->layout_csr_row_indices = layout_csr_row_indices;
|
||||
|
|
@ -101,18 +105,16 @@ struct SparseAttentionParams {
|
|||
this->num_layout = num_layout;
|
||||
|
||||
this->stride_qb = this->num_heads * this->sequence_length * this->head_size;
|
||||
this->stride_qh = this->sequence_length * this->head_size;
|
||||
this->stride_qh = (is_q_bnsh ? this->sequence_length : this->num_heads) * this->head_size;
|
||||
this->stride_qm = this->head_size;
|
||||
|
||||
// When kv buffer has max sequence length, stride should match max sequence length.
|
||||
int kv_buffer_sequence_length = max_sequence_length;
|
||||
|
||||
// KV cache is in BNSH format
|
||||
this->stride_kb = this->kv_num_heads * kv_buffer_sequence_length * this->head_size;
|
||||
this->stride_kh = kv_buffer_sequence_length * this->head_size;
|
||||
this->stride_kb = this->kv_num_heads * max_cache_sequence_length * this->head_size;
|
||||
this->stride_kh = max_cache_sequence_length * this->head_size;
|
||||
this->stride_kn = this->head_size;
|
||||
this->stride_vb = this->kv_num_heads * kv_buffer_sequence_length * this->head_size;
|
||||
this->stride_vh = kv_buffer_sequence_length * this->head_size;
|
||||
this->stride_vb = this->kv_num_heads * max_cache_sequence_length * this->head_size;
|
||||
this->stride_vh = max_cache_sequence_length * this->head_size;
|
||||
this->stride_vn = this->head_size;
|
||||
|
||||
// Output is BSNH format
|
||||
|
|
@ -142,8 +144,8 @@ struct SparseAttentionParams {
|
|||
#if DUMP_TENSOR_LEVEL > 0
|
||||
DUMP_TENSOR_INIT();
|
||||
DUMP_TENSOR("q", reinterpret_cast<const half*>(q), batch_size, num_heads, sequence_length, head_size);
|
||||
DUMP_TENSOR("k", reinterpret_cast<const half*>(k), batch_size, kv_num_heads, max_sequence_length, head_size);
|
||||
DUMP_TENSOR("v", reinterpret_cast<const half*>(v), batch_size, kv_num_heads, max_sequence_length, head_size);
|
||||
DUMP_TENSOR("k", reinterpret_cast<const half*>(k), batch_size, kv_num_heads, max_cache_sequence_length, head_size);
|
||||
DUMP_TENSOR("v", reinterpret_cast<const half*>(v), batch_size, kv_num_heads, max_cache_sequence_length, head_size);
|
||||
DUMP_TENSOR("csr_col_indices",
|
||||
layout_csr_col_indices,
|
||||
num_layout,
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ struct SparseAttentionParams {
|
|||
const void* k;
|
||||
const void* v;
|
||||
|
||||
bool is_q_bnsh;
|
||||
|
||||
int batch_size;
|
||||
int num_heads;
|
||||
int kv_num_heads;
|
||||
|
|
@ -26,7 +28,7 @@ struct SparseAttentionParams {
|
|||
int sequence_length;
|
||||
int past_sequence_length;
|
||||
int total_sequence_length;
|
||||
int max_sequence_length;
|
||||
int max_cache_sequence_length;
|
||||
|
||||
float scale;
|
||||
|
||||
|
|
@ -70,13 +72,14 @@ struct SparseAttentionParams {
|
|||
const void* q,
|
||||
const void* k,
|
||||
const void* v,
|
||||
bool is_q_bnsh,
|
||||
int batch_size,
|
||||
int sequence_length,
|
||||
int num_heads,
|
||||
int kv_num_heads,
|
||||
int head_size,
|
||||
int total_sequence_length,
|
||||
int max_sequence_length,
|
||||
int max_cache_sequence_length,
|
||||
float scale,
|
||||
int kernel_block_size,
|
||||
const int* layout_csr_row_indices,
|
||||
|
|
@ -97,6 +100,7 @@ struct SparseAttentionParams {
|
|||
this->q = q;
|
||||
this->k = k;
|
||||
this->v = v;
|
||||
this->is_q_bnsh = is_q_bnsh;
|
||||
this->batch_size = batch_size;
|
||||
this->sequence_length = sequence_length;
|
||||
this->num_heads = num_heads;
|
||||
|
|
@ -104,7 +108,7 @@ struct SparseAttentionParams {
|
|||
this->head_size = head_size;
|
||||
this->past_sequence_length = total_sequence_length - sequence_length;
|
||||
this->total_sequence_length = total_sequence_length;
|
||||
this->max_sequence_length = max_sequence_length;
|
||||
this->max_cache_sequence_length = max_cache_sequence_length;
|
||||
this->scale = scale == 0.0f ? 1.0f / sqrtf(static_cast<float>(head_size)) : scale;
|
||||
this->kernel_block_size = kernel_block_size;
|
||||
this->layout_csr_row_indices = layout_csr_row_indices;
|
||||
|
|
@ -113,20 +117,18 @@ struct SparseAttentionParams {
|
|||
this->layout_col_stride_h = layout_col_stride_h;
|
||||
this->num_layout = num_layout;
|
||||
|
||||
// Q is in BNSH format
|
||||
// Q can be either BNSH or BSNH format
|
||||
this->stride_qb = this->num_heads * this->sequence_length * this->head_size;
|
||||
this->stride_qh = this->sequence_length * this->head_size;
|
||||
this->stride_qh = (is_q_bnsh ? this->sequence_length : this->num_heads) * this->head_size;
|
||||
this->stride_qt = this->head_size;
|
||||
|
||||
// When kv buffer has max sequence length, stride should match max sequence length.
|
||||
int kv_buffer_sequence_length = max_sequence_length;
|
||||
|
||||
// KV cache is in BNSH format
|
||||
this->stride_kb = this->kv_num_heads * kv_buffer_sequence_length * this->head_size;
|
||||
this->stride_kh = kv_buffer_sequence_length * this->head_size;
|
||||
this->stride_kb = this->kv_num_heads * max_cache_sequence_length * this->head_size;
|
||||
this->stride_kh = max_cache_sequence_length * this->head_size;
|
||||
this->stride_kt = this->head_size;
|
||||
this->stride_vb = this->kv_num_heads * kv_buffer_sequence_length * this->head_size;
|
||||
this->stride_vh = kv_buffer_sequence_length * this->head_size;
|
||||
this->stride_vb = this->kv_num_heads * max_cache_sequence_length * this->head_size;
|
||||
this->stride_vh = max_cache_sequence_length * this->head_size;
|
||||
this->stride_vt = this->head_size;
|
||||
|
||||
// Output is BSNH format
|
||||
|
|
@ -167,8 +169,8 @@ struct SparseAttentionParams {
|
|||
#if DUMP_TENSOR_LEVEL > 0
|
||||
DUMP_TENSOR_INIT();
|
||||
DUMP_TENSOR("q", reinterpret_cast<const half*>(q), batch_size, num_heads, sequence_length, head_size);
|
||||
DUMP_TENSOR("k", reinterpret_cast<const half*>(k), batch_size, kv_num_heads, max_sequence_length, head_size);
|
||||
DUMP_TENSOR("v", reinterpret_cast<const half*>(v), batch_size, kv_num_heads, max_sequence_length, head_size);
|
||||
DUMP_TENSOR("k", reinterpret_cast<const half*>(k), batch_size, kv_num_heads, max_cache_sequence_length, head_size);
|
||||
DUMP_TENSOR("v", reinterpret_cast<const half*>(v), batch_size, kv_num_heads, max_cache_sequence_length, head_size);
|
||||
DUMP_TENSOR("csr_col_indices",
|
||||
layout_csr_col_indices,
|
||||
num_layout,
|
||||
|
|
@ -187,13 +189,13 @@ struct SparseAttentionParams {
|
|||
DUMP_TENSOR("q_start_sids", q_start_sids, 1, active_q_blocks);
|
||||
|
||||
printf(
|
||||
"layout_row_stride_h=%d, layout_col_stride_h=%d, num_layout=%d, scale=%f,\n"
|
||||
"layout_row_stride_h=%d, layout_col_stride_h=%d, num_layout=%d, scale=%f, is_q_bnsh=%d,\n"
|
||||
"stride_qb=%d, stride_qt=%d, stride_qh=%d, stride_kb=%d, stride_kt=%d, stride_kh=%d,\n"
|
||||
"stride_vb=%d, stride_vt=%d, stride_vh=%d, stride_ob=%d, stride_ot=%d, stride_oh=%d,\n"
|
||||
"num_heads=%d, kv_num_heads=%d, total_sequence_length=%d, past_sequence_length=%d\n"
|
||||
"output=%p, q=%p, k=%p, v=%p, layout_csr_row_indices=%p, layout_csr_col_indices=%p\n"
|
||||
"q_batch_starts=%p, q_batch_ends=%p, k_batch_starts=%p, k_batch_ends=%p, q_batch_ids=%p, q_start_sids=%p active_q_blocks=%d\n",
|
||||
layout_row_stride_h, layout_col_stride_h, num_layout, scale,
|
||||
layout_row_stride_h, layout_col_stride_h, num_layout, scale, static_cast<int>(is_q_bnsh),
|
||||
stride_qb, stride_qt, stride_qh, stride_kb, stride_kt, stride_kh,
|
||||
stride_vb, stride_vt, stride_vh, stride_ob, stride_ot, stride_oh,
|
||||
num_heads, kv_num_heads, total_sequence_length, past_sequence_length,
|
||||
|
|
|
|||
|
|
@ -287,7 +287,7 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte
|
|||
auto& past_shape = getInputShape(ctx, past_key_index);
|
||||
auto& past_dims = past_shape.dim();
|
||||
|
||||
// past key has shape (batch_size, kv_num_heads, max_sequence_length, head_size)
|
||||
// past key has shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size)
|
||||
if (past_dims.size() != 4) {
|
||||
fail_shape_inference("The past_key input shall be 4 dimensions");
|
||||
}
|
||||
|
|
@ -1152,11 +1152,29 @@ block_mask can be used to configure sparse layout for different head.
|
|||
When number of sparse layout is 1, all heads have same sparse layout. Otherwise, different layouts are used cyclically.
|
||||
For example, given 4 layouts (S0, S1, S2, S3), 8 heads will have layouts like (S0, S1, S2, S3, S0, S1, S2, S3).
|
||||
|
||||
Padding shall be on the right side.
|
||||
The block_row_indices and block_col_indices are the CSR representation of block mask. The block_col_indices might contain
|
||||
paddings at the right side when different layout has different number of non-zeros in block mask.
|
||||
|
||||
When do_rotary is True, cos_cache and sin_cache are required.
|
||||
An example of block mask with 2 layouts where each layout is 4 x 4 blocks:
|
||||
[[[1, 0, 0, 0],
|
||||
[1, 1, 0, 0],
|
||||
[0, 1, 1, 0],
|
||||
[0, 1, 1, 1]],
|
||||
|
||||
[[1, 0, 0, 0],
|
||||
[1, 1, 0, 0],
|
||||
[1, 1, 1, 0],
|
||||
[1, 0, 1, 1]]]
|
||||
|
||||
The corresponding CSR format:
|
||||
block_col_indices = [[0, 0, 1, 1, 2, 1, 2, 3, -1], [0, 0, 1, 0, 1, 2, 0, 2, 3]]
|
||||
block_row_indices = [[0, 1, 3, 5, 8], [0, 1, 3, 6, 9]]
|
||||
|
||||
When do_rotary is True, cos_cache and sin_cache are required. Note that the maximum sequence length supported by cos
|
||||
or sin cache can be different from the maximum sequence length used by kv cache.
|
||||
|
||||
Only supports unidirectional attention with cache of past key and value in linear buffers.
|
||||
|
||||
For performance, past_key and present_key share same memory buffer, and past_value and present_value too.
|
||||
)DOC";
|
||||
|
||||
|
|
@ -1190,36 +1208,38 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
OpSchema::Optional)
|
||||
.Input(3,
|
||||
"past_key",
|
||||
"Key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)",
|
||||
"T",
|
||||
OpSchema::Optional)
|
||||
"Key cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size)",
|
||||
"T")
|
||||
.Input(4,
|
||||
"past_value",
|
||||
"Value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)",
|
||||
"T",
|
||||
OpSchema::Optional)
|
||||
"Value cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size)",
|
||||
"T")
|
||||
.Input(5,
|
||||
"block_mask",
|
||||
"block mask. 1 indicates attention and 0 no attention. "
|
||||
"Its shape is (num_layout, max_blocks, max_blocks), "
|
||||
"where num_heads is divisible by num_layout, and max_blocks is max_sequence_length / sparse_block_size.",
|
||||
"block_row_indices",
|
||||
"The row indices of CSR format of block mask with shape (num_layout, max_blocks + 1)."
|
||||
"The num_heads is divisible by num_layout, and max_blocks is max_sequence_length / sparse_block_size.",
|
||||
"M")
|
||||
.Input(6,
|
||||
"block_col_indices",
|
||||
"The col indices of CSR format of block mask with shape (num_layout, max_nnz_blocks)."
|
||||
"The max_nnz_blocks is the maximum number of non-zeros per layout in block mask.",
|
||||
"M")
|
||||
.Input(7,
|
||||
"total_sequence_length",
|
||||
"Scalar tensor of maximum total sequence length (past_sequence_length + sequence_length) among keys.",
|
||||
"M")
|
||||
.Input(7,
|
||||
.Input(8,
|
||||
"key_total_sequence_lengths",
|
||||
"1D tensor with shape (batch_size) where each value is total sequence length of key excluding paddings.",
|
||||
"M")
|
||||
.Input(8,
|
||||
.Input(9,
|
||||
"cos_cache",
|
||||
"Cos cache of rotary with shape (max_sequence_length, head_size / 2).",
|
||||
"Cos cache of rotary with shape (max_rotary_sequence_length, head_size / 2).",
|
||||
"T",
|
||||
OpSchema::Optional)
|
||||
.Input(9,
|
||||
.Input(10,
|
||||
"sin_cache",
|
||||
"Sin cache of rotary with shape (max_sequence_length, head_size / 2).",
|
||||
"Sin cache of rotary with shape (max_rotary_sequence_length, head_size / 2).",
|
||||
"T",
|
||||
OpSchema::Optional)
|
||||
.Output(0,
|
||||
|
|
@ -1228,11 +1248,11 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
"T")
|
||||
.Output(1,
|
||||
"present_key",
|
||||
"Updated key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).",
|
||||
"Updated key cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size).",
|
||||
"T")
|
||||
.Output(2,
|
||||
"present_value",
|
||||
"Updated value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).",
|
||||
"Updated value cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size).",
|
||||
"T")
|
||||
.TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.")
|
||||
.TypeConstraint("M", {"tensor(int32)"}, "Constrain integer type.")
|
||||
|
|
|
|||
|
|
@ -38,21 +38,27 @@ class AttentionConfig:
|
|||
dtype=torch.float16,
|
||||
share_buffer: bool = True,
|
||||
is_packed_qkv: bool = False,
|
||||
max_cache_sequence_length=None,
|
||||
max_rotary_sequence_length=None,
|
||||
):
|
||||
self.operator = operator
|
||||
self.batch_size = batch_size
|
||||
self.sequence_length = sequence_length
|
||||
self.max_sequence_length = max_sequence_length
|
||||
self.max_cache_sequence_length = max_cache_sequence_length or max_sequence_length
|
||||
self.max_rotary_sequence_length = max_rotary_sequence_length or max_sequence_length
|
||||
self.past_sequence_length = past_sequence_length
|
||||
self.num_heads = num_heads
|
||||
self.kv_num_heads = kv_num_heads
|
||||
self.head_size = head_size
|
||||
self.softmax_scale = softmax_scale if softmax_scale is not None else 1.0 / (head_size**0.5)
|
||||
self.softmax_scale = softmax_scale or (1.0 / (head_size**0.5))
|
||||
|
||||
# Derived values
|
||||
self.total_sequence_length = sequence_length + past_sequence_length
|
||||
self.past_buffer_length = max_sequence_length if share_buffer else past_sequence_length
|
||||
self.present_buffer_length = max_sequence_length if share_buffer else (past_sequence_length + sequence_length)
|
||||
self.past_buffer_length = self.max_cache_sequence_length if share_buffer else past_sequence_length
|
||||
self.present_buffer_length = (
|
||||
self.max_cache_sequence_length if share_buffer else (past_sequence_length + sequence_length)
|
||||
)
|
||||
|
||||
self.do_rotary = do_rotary
|
||||
self.rotary_interleaved = rotary_interleaved
|
||||
|
|
@ -75,8 +81,8 @@ class AttentionConfig:
|
|||
"output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
|
||||
"present_key": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size),
|
||||
"present_value": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size),
|
||||
"cos_cache": (self.max_sequence_length, (math.floor(self.head_size / 16) * 16) // 2),
|
||||
"sin_cache": (self.max_sequence_length, (math.floor(self.head_size / 16) * 16) // 2),
|
||||
"cos_cache": (self.max_rotary_sequence_length, (math.floor(self.head_size / 16) * 16) // 2),
|
||||
"sin_cache": (self.max_rotary_sequence_length, (math.floor(self.head_size / 16) * 16) // 2),
|
||||
}
|
||||
|
||||
if not self.is_packed_qkv:
|
||||
|
|
@ -92,7 +98,7 @@ class AttentionConfig:
|
|||
def get_cos_sin_cache(self, dtype):
|
||||
rotary_fraction = 1.0
|
||||
rotary_dim = math.floor(int(rotary_fraction * self.head_size) / 16) * 16
|
||||
angle = torch.rand(self.max_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi
|
||||
angle = torch.rand(self.max_rotary_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi
|
||||
cos = torch.cos(angle).to(dtype=dtype)
|
||||
sin = torch.sin(angle).to(dtype=dtype)
|
||||
return cos.to(device=self.device), sin.to(device=self.device)
|
||||
|
|
@ -151,6 +157,8 @@ class GroupQueryAttentionConfig(AttentionConfig):
|
|||
local_window_size: int = -1,
|
||||
attention_mask=None,
|
||||
is_packed_qkv=False,
|
||||
max_cache_sequence_length=None,
|
||||
max_rotary_sequence_length=None,
|
||||
):
|
||||
super().__init__(
|
||||
"GroupQueryAttention",
|
||||
|
|
@ -166,6 +174,8 @@ class GroupQueryAttentionConfig(AttentionConfig):
|
|||
rotary_interleaved,
|
||||
device,
|
||||
is_packed_qkv=is_packed_qkv,
|
||||
max_cache_sequence_length=max_cache_sequence_length,
|
||||
max_rotary_sequence_length=max_rotary_sequence_length,
|
||||
)
|
||||
# local_window_size is for ORT only, not for Torch implementation.
|
||||
self.local_window_size = local_window_size
|
||||
|
|
@ -212,6 +222,8 @@ class SparseAttentionConfig(AttentionConfig):
|
|||
rotary_interleaved: bool = False,
|
||||
device="cuda",
|
||||
is_packed_qkv=False,
|
||||
max_cache_sequence_length=None,
|
||||
max_rotary_sequence_length=None,
|
||||
):
|
||||
super().__init__(
|
||||
"SparseAttention",
|
||||
|
|
@ -227,6 +239,8 @@ class SparseAttentionConfig(AttentionConfig):
|
|||
rotary_interleaved,
|
||||
device,
|
||||
is_packed_qkv=is_packed_qkv,
|
||||
max_cache_sequence_length=max_cache_sequence_length,
|
||||
max_rotary_sequence_length=max_rotary_sequence_length,
|
||||
)
|
||||
self.sparse_block_size = sparse_block_size
|
||||
self.num_layout = num_layout
|
||||
|
|
@ -237,18 +251,23 @@ class SparseAttentionConfig(AttentionConfig):
|
|||
def block_mask(self):
|
||||
return get_block_mask(self.num_layout, self.max_blocks, self.local_blocks, self.vert_stride).to(self.device)
|
||||
|
||||
def block_indices(self):
|
||||
row_indices, col_indices = dense_to_csr(self.block_mask())
|
||||
return row_indices.to(torch.int32).to(self.device), col_indices.to(torch.int32).to(self.device)
|
||||
|
||||
def dense_mask(self):
|
||||
expand_block_mask = self.block_mask()
|
||||
dense_mask = get_dense_mask(
|
||||
expand_block_mask, self.total_sequence_length, self.sequence_length, self.sparse_block_size
|
||||
self.block_mask(), self.total_sequence_length, self.sequence_length, self.sparse_block_size
|
||||
)
|
||||
return dense_mask.repeat(self.batch_size, self.num_heads // self.num_layout, 1, 1).to(self.device)
|
||||
|
||||
def shape_dict(self):
|
||||
shapes = super().shape_dict()
|
||||
block_row_indices, block_col_indices = self.block_indices()
|
||||
shapes.update(
|
||||
{
|
||||
"block_mask": (self.num_layout, self.max_blocks, self.max_blocks),
|
||||
"block_row_indices": tuple(block_row_indices.shape),
|
||||
"block_col_indices": tuple(block_col_indices.shape),
|
||||
"key_total_sequence_lengths": (self.batch_size,),
|
||||
}
|
||||
)
|
||||
|
|
@ -257,10 +276,11 @@ class SparseAttentionConfig(AttentionConfig):
|
|||
def random_inputs(self):
|
||||
feeds = super().random_inputs()
|
||||
k_seqlens = torch.ones((self.batch_size,), device=self.device, dtype=torch.int32) * self.total_sequence_length
|
||||
block_row_indices, block_col_indices = self.block_indices()
|
||||
feeds.update(
|
||||
{
|
||||
"block_mask": self.block_mask(),
|
||||
"total_sequence_length": torch.tensor([self.total_sequence_length], dtype=torch.int32),
|
||||
"block_row_indices": block_row_indices,
|
||||
"block_col_indices": block_col_indices,
|
||||
"key_total_sequence_lengths": k_seqlens,
|
||||
}
|
||||
)
|
||||
|
|
@ -281,6 +301,8 @@ class SparseAttentionConfig(AttentionConfig):
|
|||
self.device,
|
||||
local_window_size=self.local_blocks * self.sparse_block_size if use_local else -1,
|
||||
is_packed_qkv=self.is_packed_qkv,
|
||||
max_cache_sequence_length=self.max_cache_sequence_length,
|
||||
max_rotary_sequence_length=self.max_rotary_sequence_length,
|
||||
)
|
||||
|
||||
def get_comparable_torch_gqa_config(self, use_sparse=False) -> GroupQueryAttentionConfig:
|
||||
|
|
@ -305,6 +327,8 @@ class SparseAttentionConfig(AttentionConfig):
|
|||
self.device,
|
||||
attention_mask=attention_mask,
|
||||
is_packed_qkv=False, # torch reference implementation does not support packed qkv.
|
||||
max_cache_sequence_length=self.max_cache_sequence_length,
|
||||
max_rotary_sequence_length=self.max_rotary_sequence_length,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -327,6 +351,19 @@ def get_block_mask(num_layout, max_blocks, local_blocks, vert_stride):
|
|||
return block_mask
|
||||
|
||||
|
||||
def dense_to_csr(x):
|
||||
"""Turning a 3D torch tensor (x) to CSR rows/cols indexing."""
|
||||
assert x.dim() == 3
|
||||
pad = -1
|
||||
x = [xi.to_sparse_csr() for xi in x]
|
||||
row_indices = torch.vstack([xi.crow_indices() for xi in x])
|
||||
cols = [xi.col_indices() for xi in x]
|
||||
max_cols = max(len(xi) for xi in cols)
|
||||
cols = [torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])]) for xi in cols]
|
||||
col_indices = torch.vstack(cols)
|
||||
return row_indices, col_indices
|
||||
|
||||
|
||||
def get_dense_mask(block_mask, total_seq_len, query_seq_len, block_size):
|
||||
dense_mask = torch.kron(block_mask, block_mask.new_ones((block_size, block_size)))[
|
||||
:, :total_seq_len, :total_seq_len
|
||||
|
|
@ -350,7 +387,8 @@ def create_sparse_attention_onnx_model(config: SparseAttentionConfig):
|
|||
"value" + suffix if not config.is_packed_qkv else "",
|
||||
"past_key" + suffix,
|
||||
"past_value" + suffix,
|
||||
"block_mask",
|
||||
"block_row_indices", # no suffix since int32 need not cast for bfloat graph.
|
||||
"block_col_indices",
|
||||
"total_sequence_length" if config.share_buffer else "",
|
||||
"key_total_sequence_lengths",
|
||||
"cos_cache" + suffix if config.do_rotary else "",
|
||||
|
|
@ -410,7 +448,12 @@ def create_sparse_attention_onnx_model(config: SparseAttentionConfig):
|
|||
[
|
||||
helper.make_tensor_value_info("past_key", io_float_type, list(shape_dict["past_key"])),
|
||||
helper.make_tensor_value_info("past_value", io_float_type, list(shape_dict["past_value"])),
|
||||
helper.make_tensor_value_info("block_mask", TensorProto.INT32, list(shape_dict["block_mask"])),
|
||||
helper.make_tensor_value_info(
|
||||
"block_row_indices", TensorProto.INT32, list(shape_dict["block_row_indices"])
|
||||
),
|
||||
helper.make_tensor_value_info(
|
||||
"block_col_indices", TensorProto.INT32, list(shape_dict["block_col_indices"])
|
||||
),
|
||||
helper.make_tensor_value_info(
|
||||
"total_sequence_length", TensorProto.INT32, list(shape_dict["total_sequence_length"])
|
||||
),
|
||||
|
|
@ -704,7 +747,8 @@ class OrtSparseAttention:
|
|||
print("query(BSNH, SA)", query)
|
||||
print("key(BSNH, SA)", key)
|
||||
print("value(BSNH, SA)", value)
|
||||
print("block_mask (SA)", self.feed_dict["block_mask"])
|
||||
print("block_row_indices", self.feed_dict["block_row_indices"])
|
||||
print("block_col_indices", self.feed_dict["block_col_indices"])
|
||||
print("total_sequence_length", self.feed_dict["total_sequence_length"])
|
||||
print("key_total_sequence_lengths", self.feed_dict["key_total_sequence_lengths"])
|
||||
|
||||
|
|
@ -778,6 +822,7 @@ class TestSparseAttention(unittest.TestCase):
|
|||
softmax_scale=1.8 / (128**0.5),
|
||||
device=device,
|
||||
is_packed_qkv=packed_qkv,
|
||||
max_cache_sequence_length=None if seq_len >= 128 else 128, # test smaller kv cache buffer.
|
||||
)
|
||||
self.run_one_relevance_test(config)
|
||||
|
||||
|
|
@ -806,6 +851,7 @@ class TestSparseAttention(unittest.TestCase):
|
|||
rotary_interleaved=(past_seq_len % 2 == 1),
|
||||
device=device,
|
||||
is_packed_qkv=packed_qkv,
|
||||
max_rotary_sequence_length=None if past_seq_len >= 128 else 128, # test smaller rotary buffer.
|
||||
)
|
||||
|
||||
if do_rotary:
|
||||
|
|
|
|||
Loading…
Reference in a new issue