GQA Rotary and Packed QKV with Flash (#18906)

### Description
These changes add rotary embedding and packed qkv input to gqa. As of
now, the changes are only supported with Flash-Attention (SM >= 80) but
should soon be supported with Memory Efficient Attention as well.



### Motivation and Context
With the fusion of rotary embedding into this Attention op, we hope to
observe some perf gain. The packed QKV should also provide some perf
gain in the context of certain models, like Llama2, that would benefit
from running ops on the fused QKV matrix, rather than the separate Q, K,
and V.

---------

Co-authored-by: Yufeng Li <liyufeng1987@gmail.com>
This commit is contained in:
aciddelgado 2024-01-23 16:34:26 -08:00 committed by GitHub
parent 532f8c642c
commit cbb29d80ff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 1517 additions and 272 deletions

View file

@ -2398,24 +2398,28 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Attributes
<dl>
<dt><tt>do_rotary</tt> : int</dt>
<dd>Whether to use rotary position embedding. Default value is 0.</dd>
<dt><tt>kv_num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for k and v</dd>
<dt><tt>local_window_size</tt> : int</dt>
<dd>left_window_size for local attention (like Mistral). Default value is -1 meaning unused.</dd>
<dt><tt>num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for q</dd>
<dt><tt>rotary_interleaved</tt> : int</dt>
<dd>Rotate using interleaved pattern. Default value is 0 (False).</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
</dl>
#### Inputs
#### Inputs (7 - 9)
<dl>
<dt><tt>query</tt> : T</dt>
<dd>Query with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>key</tt> : T</dt>
<dd>Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape(batch_size, sequence_length, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).</dd>
<dt><tt>key</tt> (optional) : T</dt>
<dd>Key with shape (batch_size, kv_sequence_length, kv_hidden_size) </dd>
<dt><tt>value</tt> : T</dt>
<dt><tt>value</tt> (optional) : T</dt>
<dd>Value with shape (batch_size, kv_sequence_length, kv_hidden_size)</dd>
<dt><tt>past_key</tt> (optional) : T</dt>
<dd>past state key with support for format BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
@ -2425,6 +2429,10 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.</dd>
<dt><tt>total_sequence_length</tt> : M</dt>
<dd>Scalar tensor of total sequence length (past + new).</dd>
<dt><tt>cos_cache</tt> (optional) : T</dt>
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
<dt><tt>sin_cache</tt> (optional) : T</dt>
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
</dl>
#### Outputs

View file

@ -843,7 +843,7 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|

View file

@ -99,10 +99,15 @@ struct GroupQueryAttentionParameters {
bool is_unidirectional; // causal
int local_window_size;
bool kv_share_buffer;
bool is_packed_qkv;
bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor
bool do_rotary;
bool rotary_interleaved;
float scale;
AttentionQkvFormat qkv_format;
AttentionQkvFormat past_kv_format;
int zeros_count;
int* zero_ptr;
};
namespace attention {

View file

@ -355,13 +355,15 @@ bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, in
Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
cudaStream_t stream,
void* q, // batch_size x seqlen_q x num_heads x head_size
void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size
void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size
void* k, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* v, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* out, // batch_size x seqlen_q x num_heads x head_size
void* softmax_lse, // batch_size x num_heads x seqlen_q
void* seqlens_k_, // batch_size
void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
int batch_size,
int num_heads,
int num_heads_k,
@ -376,16 +378,15 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
int num_splits,
void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads
void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
int local_window_size) {
// if (seqlen_q == 1) {
// is_causal = false;
// } // causal=true is the same as causal=false in this case
int local_window_size,
bool is_rotary_interleaved,
bool is_packed_qkv) {
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
// In kv-cache case, seqlen_k_max as kv sequence length
Flash_fwd_params params;
set_params_fprop(params,
batch_size,
@ -406,15 +407,24 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
is_causal ? 0 : -1);
params.dprops = &dprops;
if (k != nullptr && v != nullptr) {
if (k_new != nullptr && v_new != nullptr) {
params.seqlen_knew = seqlen_k_new;
params.knew_ptr = k;
params.vnew_ptr = v;
params.knew_ptr = k_new;
params.vnew_ptr = v_new;
// All stride are in elements, not bytes.
params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size;
params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size;
params.knew_row_stride = num_heads_k * head_size;
params.vnew_row_stride = num_heads_k * head_size;
if (is_packed_qkv) {
params.q_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size);
params.q_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size);
params.knew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size);
params.vnew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size);
params.knew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size);
params.vnew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size);
} else {
params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size;
params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size;
params.knew_row_stride = num_heads_k * head_size;
params.vnew_row_stride = num_heads_k * head_size;
}
params.knew_head_stride = head_size;
params.vnew_head_stride = head_size;
} else {
@ -434,6 +444,13 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
params.cu_seqlens_k = static_cast<int*>(seqlens_k_);
}
if (rotary_cos != nullptr) {
params.rotary_cos_ptr = rotary_cos;
params.rotary_sin_ptr = rotary_sin;
params.is_rotary_interleaved = is_rotary_interleaved;
params.rotary_dim = (head_size / 16) * 16;
}
params.num_splits = num_splits;
if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) {
params.softmax_lseaccum_ptr = softmax_lse_accum;
@ -444,7 +461,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
}
// Only split kernel supports appending to KV cache
run_mha_fwd(params, stream, /*force_split_kernel=*/k != nullptr);
run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr);
return Status::OK();
}

View file

@ -87,6 +87,8 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
void* out, // batch_size x seqlen_q x num_heads x head_size
void* softmax_lse, // batch_size x num_heads x seqlen_q
void* seqlens_k_, // batch_size
void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
int batch_size,
int num_heads,
int num_heads_k,
@ -101,7 +103,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
int num_splits = 0,
void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads
void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
int local_window_size = -1);
int local_window_size = -1,
bool is_rotary_interleaved = false,
bool is_packed_qkv = false);
size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads);

View file

@ -47,6 +47,8 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
kv_num_heads_ = static_cast<int>(kv_num_heads);
is_past_bsnh_ = false; // info.GetAttrOrDefault<int64_t>("is_past_bsnh", 1) == 1;
local_window_size_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("local_window_size", -1));
do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
#if USE_FLASH_ATTENTION
@ -62,6 +64,9 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
#else
disable_memory_efficient_attention_ = true;
#endif
if (!disable_flash_attention_) {
zeros_ = this->GetScratchBuffer<int>(kZerosCount, nullptr);
}
}
template <typename T>
@ -73,6 +78,8 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* past_value = context->Input<Tensor>(4);
const Tensor* seqlens_k = context->Input<Tensor>(5);
const Tensor* total_seqlen = context->Input<Tensor>(6);
const Tensor* cos_cache = context->Input<Tensor>(7);
const Tensor* sin_cache = context->Input<Tensor>(8);
auto& device_prop = GetDeviceProp();
GroupQueryAttentionParameters parameters;
@ -84,6 +91,8 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
value,
past_key,
past_value,
cos_cache,
sin_cache,
&parameters,
num_heads_,
kv_num_heads_,
@ -93,7 +102,13 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
scale_,
device_prop.maxThreadsPerBlock));
parameters.local_window_size = local_window_size_;
parameters.is_unidirectional = is_unidirectional_;
parameters.zeros_count = kZerosCount;
parameters.zero_ptr = zeros_.get();
// parameters.left_padding = left_padding_;
int sequence_length = parameters.sequence_length;
parameters.do_rotary = do_rotary_;
parameters.rotary_interleaved = rotary_interleaved_;
TensorShapeVector output_shape(3);
output_shape[0] = static_cast<int64_t>(parameters.batch_size);
@ -139,6 +154,8 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
local_window_size_ == -1 &&
do_rotary_ == false &&
key != nullptr &&
(parameters.head_size & 7) == 0 &&
parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length &&
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
@ -182,8 +199,8 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
Tensor* present_value = context->Output(2, present_shape);
data.query = reinterpret_cast<const CudaT*>(query->Data<T>());
data.key = reinterpret_cast<const CudaT*>(key->Data<T>());
data.value = reinterpret_cast<const CudaT*>(value->Data<T>());
data.key = key == nullptr ? nullptr : reinterpret_cast<const CudaT*>(key->Data<T>());
data.value = value == nullptr ? nullptr : reinterpret_cast<const CudaT*>(value->Data<T>());
data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast<const CudaT*>(past_key->Data<T>());
data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast<const CudaT*>(past_value->Data<T>());
data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
@ -229,6 +246,11 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
if (fmha_buffer != nullptr) {
data.fmha_buffer = reinterpret_cast<CudaT*>(fmha_buffer.get());
}
// Rotary
if (parameters.do_rotary) {
data.cos_cache = reinterpret_cast<const CudaT*>(cos_cache->Data<T>());
data.sin_cache = reinterpret_cast<const CudaT*>(sin_cache->Data<T>());
}
cublasHandle_t cublas = GetCublasHandle(context);

View file

@ -23,10 +23,15 @@ class GroupQueryAttention final : public CudaKernel {
int num_heads_; // number of attention heads
int kv_num_heads_; // different for k and v for group query attention
int local_window_size_;
bool is_unidirectional_;
bool is_past_bsnh_;
bool do_rotary_;
bool rotary_interleaved_;
float scale_;
bool disable_flash_attention_;
bool disable_memory_efficient_attention_;
static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256)
IAllocatorUniquePtr<int> zeros_;
};
} // namespace cuda

View file

@ -16,6 +16,8 @@ Status CheckInputs(const Tensor* query,
const Tensor* value,
const Tensor* past_key,
const Tensor* past_value,
const Tensor* cos_cache,
const Tensor* sin_cache,
void* parameters,
int num_heads,
int kv_num_heads,
@ -24,19 +26,18 @@ Status CheckInputs(const Tensor* query,
bool is_past_bsnh,
float scale) {
// Note: Here S* is past_cache_sequence_length, S- is past_sequence_length, S+ is sequence_length
// past_key : (B, N_k, S*, H) or (B, N_k, S-, H)
// past_value : (B, N_k, S*, H) or (B, N_k, S-, H)
// past_key : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr
// past_value : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr
// no packing for q/k/v:
// query (Q) : (B, S, D)
// key (K) : (B, S, D_kv)
// value (V) : (B, S, D_kv)
// query (Q) : (B, S, D) or (B, S, (D_q + 2 D_kv))
// key (K) : (B, S, D_kv) or nullptr
// value (V) : (B, S, D_kv) or nullptr
ORT_UNUSED_PARAMETER(value);
AttentionQkvFormat qkv_format = Q_K_V_BSNH;
AttentionQkvFormat past_kv_format = is_past_bsnh ? Q_K_V_BSNH : Q_K_V_BNSH;
const bool is_packed_qkv = key == nullptr;
const auto& query_dims = query->Shape().GetDims();
const auto& key_dims = key->Shape().GetDims();
if (query_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ",
@ -46,10 +47,69 @@ Status CheckInputs(const Tensor* query,
int batch_size = static_cast<int>(query_dims[0]);
int sequence_length = static_cast<int>(query_dims[1]);
int q_hidden_size = static_cast<int>(query_dims[2]);
int head_size = static_cast<int>(q_hidden_size) / num_heads;
int head_size = 0;
int kv_hidden_size = static_cast<int>(key_dims[2]);
if (num_heads % kv_num_heads != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
num_heads % kv_num_heads);
}
int kv_hidden_size = 0;
// Check key and value when not packed
if (!is_packed_qkv) {
head_size = static_cast<int>(q_hidden_size) / num_heads;
if (head_size % 8 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"head_size must be a multiple of 8. Got head_size % 8 == ",
head_size % 8);
}
if (value == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
}
const auto& key_dims = key->Shape().GetDims();
if (key_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
key_dims.size());
} else if (query_dims[0] != key_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 0 (batch size)");
} else if (query_dims[1] != key_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 1 (sequence length)");
}
kv_hidden_size = static_cast<int>(key_dims[2]);
const auto& value_dims = value->Shape().GetDims();
if (value_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
value_dims.size());
} else if (query_dims[0] != value_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'value' shall have same dim 0 (batch size)");
} else if (query_dims[1] != value_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'value' shall have same dim 1 (sequence length)");
} else if (value_dims[2] != kv_hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
}
} else {
// Check packed qkv
head_size = static_cast<int>(q_hidden_size) / (num_heads + 2 * kv_num_heads);
if (head_size % 8 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"head_size must be a multiple of 8. Got head_size % 8 == ",
head_size % 8);
}
if (value != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
}
q_hidden_size = head_size * num_heads;
kv_hidden_size = head_size * kv_num_heads;
}
// Check past-present KV
int32_t past_sequence_length = 0;
if (past_key != nullptr && past_value != nullptr) {
const auto& past_key_dims = past_key->Shape().GetDims();
@ -130,41 +190,6 @@ Status CheckInputs(const Tensor* query,
"Input 'past_key' and 'past_value' shall be both present or both absent.");
}
if (key_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
key_dims.size());
}
if (query_dims[0] != key_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 0 (batch size)");
}
if (num_heads % kv_num_heads != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
num_heads % kv_num_heads);
}
const auto& value_dims = value->Shape().GetDims();
if (value_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
value_dims.size());
}
if (query_dims[0] != value_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'value' shall have same dim 0 (batch_size)");
}
if (static_cast<int64_t>(sequence_length) != value_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query,' 'key,' and 'value' shall have the same dim 1 (sequence_length)");
}
if (value_dims[2] != kv_hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
}
// Check seqlens_k tensor (holding past seqlen for token gen)
const auto& seqlens_dim = seqlens_k->Shape().GetDims();
if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) {
@ -180,6 +205,36 @@ Status CheckInputs(const Tensor* query,
int total_sequence_length = *((*total_seqlen).template Data<int32_t>());
int present_sequence_length = std::max(total_sequence_length, past_sequence_length);
if (cos_cache != nullptr && sin_cache != nullptr) {
const auto& cos_dims = cos_cache->Shape().GetDims();
const auto& sin_dims = sin_cache->Shape().GetDims();
if (head_size % 16 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"head_size shall be a multiple of 16. Got head_size % 16 == ",
head_size % 16);
}
if (cos_dims[0] != present_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cos_cache dimension 0 must be of present_sequence_length.");
}
if (sin_dims[0] != present_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"sin_cache dimension 0 must be of present_sequence_length.");
}
if (cos_dims[1] != (head_size / 16) * 8) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
}
if (sin_dims[1] != (head_size / 16) * 8) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
}
} else if (cos_cache != nullptr || sin_cache != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'cos_cache' and 'sin_cache' shall be both present or both absent.");
}
bool is_prompt = sequence_length != 1;
if (parameters != nullptr) {
@ -190,9 +245,10 @@ Status CheckInputs(const Tensor* query,
output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors
output_parameters->hidden_size = q_hidden_size;
output_parameters->num_heads = num_heads;
output_parameters->head_size = q_hidden_size / num_heads;
output_parameters->head_size = head_size;
output_parameters->kv_hidden_size = kv_hidden_size;
output_parameters->kv_num_heads = kv_num_heads;
output_parameters->is_packed_qkv = is_packed_qkv;
output_parameters->is_unidirectional = true;
output_parameters->is_prompt = is_prompt;
output_parameters->scale = scale;
@ -208,6 +264,8 @@ Status CheckInputs(const Tensor* query,
const Tensor* value,
const Tensor* past_key,
const Tensor* past_value,
const Tensor* cos_cache,
const Tensor* sin_cache,
void* parameters,
int num_heads,
int kv_num_heads,
@ -220,7 +278,7 @@ Status CheckInputs(const Tensor* query,
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block);
}
return CheckInputs(query, key, value, past_key, past_value, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale);
return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale);
}
} // namespace group_query_attention_helper

View file

@ -151,9 +151,10 @@ template <typename T>
Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData<T>& data,
cudaStream_t stream,
const int max_threads_per_block) {
const int max_threads_per_block,
const bool past_only = false) {
const int batch_size = parameters.batch_size;
const int kv_sequence_length = parameters.sequence_length;
const int kv_sequence_length = past_only ? 0 : parameters.sequence_length;
const int past_sequence_length = parameters.seqlen_past_kv_cache;
const int present_sequence_length = parameters.seqlen_present_kv_cache;
const int kv_num_heads = parameters.kv_num_heads;
@ -441,7 +442,6 @@ Status LaunchUngroup(contrib::GroupQueryAttentionParameters& parameters,
return CUDA_CALL(cudaGetLastError());
}
__global__ void PastToTotalSeqlen(int32_t* seqlens_k,
int32_t* seqlens_k_buff,
const int add_seqlen) {
@ -451,7 +451,7 @@ __global__ void PastToTotalSeqlen(int32_t* seqlens_k,
// Convert Past to Total sequence length tensor
Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k,
int32_t* seqlens_k_buff, bool is_total, cudaStream_t stream,
const int threads_per_block) {
const int threads_per_block) {
if (parameters.is_prompt) {
return Status::OK();
}
@ -482,91 +482,63 @@ Status FlashAttention(
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int kv_sequence_length = parameters.sequence_length;
const int present_sequence_length = parameters.seqlen_present_kv_cache;
const int num_heads = parameters.num_heads;
const int kv_num_heads = parameters.kv_num_heads;
const int head_size = parameters.head_size;
AttentionQkvFormat past_kv_format = parameters.past_kv_format;
void* query = reinterpret_cast<void*>(const_cast<T*>(data.query));
void* key = reinterpret_cast<void*>(const_cast<T*>(data.key));
void* value = reinterpret_cast<void*>(const_cast<T*>(data.value));
bool is_causal = true;
bool is_bf16 = std::is_same<T, BFloat16>::value;
// Note: seqlens_k is past sequence length for flash
if (parameters.is_prompt) {
// Launch kernel to copy seqlen
constexpr int thr_per_blk = 256;
int blk_in_grid = (batch_size + thr_per_blk -1) / thr_per_blk;
repeat_seqlen<<<blk_in_grid, thr_per_blk, 0, stream>>>(data.seqlens_k_total, parameters.sequence_length, batch_size);
void* query = reinterpret_cast<void*>(const_cast<T*>(data.query));
void* key;
void* value;
if (!parameters.is_packed_qkv) {
key = reinterpret_cast<void*>(const_cast<T*>(data.key));
value = reinterpret_cast<void*>(const_cast<T*>(data.value));
} else {
const size_t key_offset = static_cast<size_t>(num_heads * head_size);
const size_t value_offset = static_cast<size_t>(kv_num_heads * head_size);
key = reinterpret_cast<T*>(query) + key_offset;
value = reinterpret_cast<T*>(key) + value_offset;
}
void* seqlens_k = reinterpret_cast<void*>(data.seqlens_k);
if (parameters.kv_share_buffer) {
// Share buffer case
if (data.past_key == nullptr || data.past_key != data.present_key) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Past and present kv shall share the same tensor when kv_share_buffer is on.");
if (parameters.is_prompt) {
// set seqlens_k to zeros... flash api uses seqlens_k to indicate where to append key and value
// user should use seqlens_k to index into output to get new tokens
if (batch_size <= parameters.zeros_count) {
seqlens_k = parameters.zero_ptr;
} else {
// Launch kernel to create larger seqlen tensor when batch_size > 256
constexpr int thr_per_blk = 256;
int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk;
repeat_seqlen<<<blk_in_grid, thr_per_blk, 0, stream>>>(data.seqlens_k_total, 0, batch_size);
seqlens_k = data.seqlens_k_total;
}
if (parameters.is_prompt) {
ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block));
key = nullptr;
value = nullptr;
seqlens_k = reinterpret_cast<void*>(data.seqlens_k_total);
}
void* present_key = reinterpret_cast<void*>(const_cast<T*>(data.present_key));
void* present_value = reinterpret_cast<void*>(const_cast<T*>(data.present_value));
DUMP_TENSOR_INIT();
DUMP_TENSOR("seqlens_k", reinterpret_cast<int*>(seqlens_k), batch_size, 1);
bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast<void*>(data.softmax_lse),
seqlens_k, batch_size, num_heads, kv_num_heads,
head_size, sequence_length, present_sequence_length, kv_sequence_length,
scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum),
reinterpret_cast<void*>(data.out_accum), parameters.local_window_size));
} else {
// Not share buffer case
// Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient
if (data.past_key != nullptr && data.past_key == data.present_key) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Past and present kv share the same tensor but kv_share_buffer is not on.");
}
ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block));
if (!parameters.is_prompt) {
ORT_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256));
}
seqlens_k = reinterpret_cast<void*>(data.seqlens_k_total);
void* present_key = reinterpret_cast<void*>(const_cast<T*>(data.present_key));
void* present_value = reinterpret_cast<void*>(const_cast<T*>(data.present_value));
DUMP_TENSOR_INIT();
DUMP_TENSOR("seqlens_k", reinterpret_cast<int*>(seqlens_k), batch_size, 1);
DUMP_TENSOR("Q", data.query, batch_size, sequence_length, num_heads, head_size);
DUMP_TENSOR("K", data.present_key, batch_size, kv_num_heads, present_sequence_length, head_size);
DUMP_TENSOR("V", data.present_value, batch_size, kv_num_heads, present_sequence_length, head_size);
bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
device_prop, stream, query, present_key, present_value, nullptr, nullptr, data.output, reinterpret_cast<void*>(data.softmax_lse),
seqlens_k, batch_size, num_heads, kv_num_heads,
head_size, sequence_length, present_sequence_length, 0,
scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum),
reinterpret_cast<void*>(data.out_accum), parameters.local_window_size));
} else if (!parameters.kv_share_buffer) { // copy past kv to present kv
ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block, true));
}
void* present_key = reinterpret_cast<void*>(const_cast<T*>(data.present_key));
void* present_value = reinterpret_cast<void*>(const_cast<T*>(data.present_value));
void* cos_cache = reinterpret_cast<void*>(const_cast<T*>(data.cos_cache));
void* sin_cache = reinterpret_cast<void*>(const_cast<T*>(data.sin_cache));
bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
device_prop, stream, query, present_key, present_value, key, value, data.output,
reinterpret_cast<void*>(data.softmax_lse), seqlens_k, cos_cache, sin_cache,
batch_size, num_heads, kv_num_heads, head_size, sequence_length,
parameters.seqlen_present_kv_cache, kv_sequence_length,
scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum),
reinterpret_cast<void*>(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved,
parameters.is_packed_qkv));
// if (parameters.left_padding && parameters.is_prompt) {
// ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock));
// }
DUMP_TENSOR_INIT();
DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size);
@ -672,7 +644,6 @@ Status EfficientAttention(
p.has_custom_right_padding = true;
run_memory_efficient_attention(p);
DUMP_TENSOR_INIT();
DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size);
return Status::OK();

View file

@ -21,6 +21,8 @@ struct GroupQueryAttentionData {
const T* past_key = nullptr;
const T* past_value = nullptr;
int* seqlens_k = nullptr;
const T* cos_cache = nullptr;
const T* sin_cache = nullptr;
// Flash buffers
T* softmax_lse = nullptr;
T* softmax_lse_accum = nullptr;

View file

@ -259,13 +259,13 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext&
*output_shape.add_dim() = query_dims[1];
*output_shape.add_dim() = query_dims[2];
updateOutputShape(ctx, 0, output_shape);
} else {
fail_shape_inference("Missing input 2 (value)");
}
}
if (ctx.getNumOutputs() > 1) { // has present output
if (hasInputShape(ctx, past_key_index)) {
// auto& query_shape = getInputShape(ctx, 0);
// auto& query_dims = query_shape.dim();
auto& past_shape = getInputShape(ctx, past_key_index);
auto& past_dims = past_shape.dim();
if (past_dims.size() != 4) {
@ -273,8 +273,7 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext&
}
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, past_key_index, 1);
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, static_cast<size_t>(past_key_index) + 1, 2);
ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, past_key_index, 1);
ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, static_cast<size_t>(past_key_index) + 1, 2);
// TODO(aciddelgado): propagate output shapes depending if kv-share buffer is on or not
}
}
}
@ -1015,18 +1014,29 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"left_window_size for local attention (like Mistral). Default value is -1 meaning unused.",
AttributeProto::INT,
static_cast<int64_t>(-1))
.Attr("do_rotary",
"Whether to use rotary position embedding. Default value is 0.",
AttributeProto::INT,
OPTIONAL_VALUE)
.Attr("rotary_interleaved",
"Rotate using interleaved pattern. Default value is 0 (False).",
AttributeProto::INT,
OPTIONAL_VALUE)
.Input(0,
"query",
"Query with shape (batch_size, sequence_length, hidden_size)",
"Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape"
"(batch_size, sequence_length, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).",
"T")
.Input(1,
"key",
"Key with shape (batch_size, kv_sequence_length, kv_hidden_size) ",
"T")
"T",
OpSchema::Optional)
.Input(2,
"value",
"Value with shape (batch_size, kv_sequence_length, kv_hidden_size)",
"T")
"T",
OpSchema::Optional)
.Input(3,
"past_key",
"past state key with support for format BNSH. When past_key uses same tensor as present_key"
@ -1047,6 +1057,16 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"total_sequence_length",
"Scalar tensor of total sequence length (past + new).",
"M")
.Input(7,
"cos_cache",
"2D tensor with shape (max_sequence_length, head_size / 2).",
"T",
OpSchema::Optional)
.Input(8,
"sin_cache",
"2D tensor with shape (max_sequence_length, head_size / 2).",
"T",
OpSchema::Optional)
.Output(0,
"output",
"3D output tensor with shape (batch_size, sequence_length, hidden_size)",

View file

@ -0,0 +1,693 @@
# Copyright (c) 2023, Tri Dao.
from typing import Optional, Tuple, Union
import torch
import triton
import triton.language as tl
from einops import rearrange, repeat
##### TRITON KERNEL FOR ROTARY #####
# @triton.autotune(
# configs=[
# triton.Config({"block_m": 2}),
# triton.Config({"block_m": 4}),
# triton.Config({"block_m": 8}),
# triton.Config({"block_m": 16}),
# ],
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"],
# )
@triton.jit
def rotary_kernel(
out_, # Pointers to matrices
x_,
cos_,
sin_,
CU_SEQLENS,
SEQLEN_OFFSETS, # this could be int or a pointer
# Matrix dimensions
seqlen,
nheads,
rotary_dim,
seqlen_ro,
CACHE_KEY_SEQLEN,
# strides
stride_out_batch,
stride_out_seqlen,
stride_out_nheads,
stride_out_headdim,
stride_x_batch,
stride_x_seqlen,
stride_x_nheads,
stride_x_headdim,
# Meta-parameters
block_k: tl.constexpr,
IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
IS_VARLEN: tl.constexpr,
INTERLEAVED: tl.constexpr,
CONJUGATE: tl.constexpr,
block_m: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_batch = tl.program_id(axis=1)
pid_head = tl.program_id(axis=2)
rotary_dim_half = rotary_dim // 2
if not IS_VARLEN:
x_ = x_ + pid_batch * stride_x_batch + pid_head * stride_x_nheads
out_ = out_ + pid_batch * stride_out_batch + pid_head * stride_out_nheads
else:
start_idx = tl.load(CU_SEQLENS + pid_batch)
seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
x_ = x_ + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
out_ = out_ + start_idx * stride_out_seqlen + pid_head * stride_out_nheads
if pid_m * block_m >= seqlen:
return
rm = pid_m * block_m + tl.arange(0, block_m)
if not IS_SEQLEN_OFFSETS_TENSOR:
rm_cs = rm + SEQLEN_OFFSETS
else:
rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
rk = tl.arange(0, block_k)
rk_half = tl.arange(0, block_k // 2)
if not INTERLEAVED:
# Load the 1st and 2nd halves of x_, do calculation, then store to 1st and 2nd halves of out_
x_ = x_ + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)
cos_ = cos_ + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
sin_ = sin_ + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
cos = tl.load(cos_, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0).to(
tl.float32
)
sin = tl.load(sin_, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(
tl.float32
)
x0 = tl.load(x_, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32)
x1 = tl.load(
x_ + rotary_dim_half * stride_x_headdim,
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
other=0.0,
).to(tl.float32)
if CONJUGATE:
sin = -sin
o0 = x0 * cos - x1 * sin
o1 = x0 * sin + x1 * cos
# write back result
out_ = out_ + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
tl.store(out_, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
tl.store(
out_ + rotary_dim_half * stride_out_headdim,
o1,
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
)
else:
# We don't want to load x_[0, 2, 4, ...] and x_[1, 3, 5, ...] separately since both are slow.
# Instead, we load x0 = x_[0, 1, 2, 3, ...] and x1 = x_[1, 0, 3, 2, ...].
# Loading x0 will be fast but x1 will be slow.
# Then we load cos = cos_[0, 0, 1, 1, ...] and sin = sin_[0, 0, 1, 1, ...].
# Then we do the calculation and use tl.where to pick put the right outputs for the even
# and for the odd indices.
rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
rk_repeat = tl.arange(0, block_k) // 2
x0_ = x_ + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)
x1_ = x_ + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)
cos_ = cos_ + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
sin_ = sin_ + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
cos = tl.load(
cos_,
mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
other=1.0,
).to(tl.float32)
sin = tl.load(
sin_,
mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
other=0.0,
).to(tl.float32)
x0 = tl.load(x0_, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(tl.float32)
x1 = tl.load(x1_, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0).to(tl.float32)
if CONJUGATE:
sin = -sin
x0_cos = x0 * cos
x1_sin = x1 * sin
out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
out_ = out_ + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
tl.store(out_, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
def apply_rotary(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
interleaved=False,
inplace=False,
conjugate=False,
) -> torch.Tensor:
"""
Arguments:
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim).
cos: (seqlen_ro, rotary_dim / 2)
sin: (seqlen_ro, rotary_dim / 2)
seqlen_offsets: integer or integer tensor of size (batch,)
cu_seqlens: (batch + 1,) or None
max_seqlen: int
Returns:
y: (batch, seqlen, nheads, headdim)
"""
is_varlen = cu_seqlens is not None
if not is_varlen:
batch, seqlen, nheads, headdim = x.shape
else:
assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed"
total_seqlen, nheads, headdim = x.shape
batch_p_1 = cu_seqlens.shape[0]
batch = batch_p_1 - 1
seqlen = max_seqlen
seqlen_ro, rotary_dim = cos.shape
assert sin.shape == cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
assert headdim <= 256, "Only support headdim <= 256"
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
assert cos.dtype == sin.dtype, f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
assert x.dtype == cos.dtype, f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
cos, sin = cos.contiguous(), sin.contiguous()
if isinstance(seqlen_offsets, torch.Tensor):
assert seqlen_offsets.shape == (batch,)
assert seqlen_offsets.dtype in [torch.int32, torch.int64]
seqlen_offsets = seqlen_offsets.contiguous()
else:
assert seqlen_offsets + seqlen <= seqlen_ro
output = torch.empty_like(x) if not inplace else x
if rotary_dim < headdim and not inplace:
output[..., rotary_dim:].copy_(x[..., rotary_dim:])
block_k = 32 if rotary_dim <= 32 else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
grid = lambda META: (triton.cdiv(seqlen, META["block_m"]), batch, nheads) # noqa
block_m = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with torch.cuda.device(x.device.index):
rotary_kernel[grid](
output, # data ptrs
x,
cos,
sin,
cu_seqlens,
seqlen_offsets,
seqlen, # shapes
nheads,
rotary_dim,
seqlen_ro,
seqlen // 128, # key for triton cache (limit number of compilations)
output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
output.stride(-3), # seqlen_stride or total_seqlen_stride
output.stride(-2), # nheads_stride
output.stride(-1), # headdim_stride
x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
x.stride(-3), # seqlen stride or total_seqlen_stride
x.stride(-2), # nheads stride
x.stride(-1), # headdim stride
block_k,
isinstance(seqlen_offsets, torch.Tensor),
is_varlen,
interleaved,
conjugate,
block_m,
)
return output
##### ROTARY API #####
def rotate_half(x, interleaved=False):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
return torch.cat(
[x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
dim=-1,
)
class ApplyRotaryEmb(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x,
cos,
sin,
interleaved=False,
inplace=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
):
out = apply_rotary(
x,
cos,
sin,
seqlen_offsets=seqlen_offsets,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
interleaved=interleaved,
inplace=inplace,
)
if isinstance(seqlen_offsets, int):
ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
ctx.seqlen_offsets = seqlen_offsets
else:
ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
ctx.seqlen_offsets = None
ctx.interleaved = interleaved
ctx.inplace = inplace
ctx.max_seqlen = max_seqlen
return out if not inplace else x
@staticmethod
def backward(ctx, do):
seqlen_offsets = ctx.seqlen_offsets
if seqlen_offsets is None:
cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
else:
cos, sin, cu_seqlens = ctx.saved_tensors
# TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
# "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
if not ctx.interleaved and not ctx.inplace:
do = do.clone()
dx = apply_rotary(
do,
cos,
sin,
seqlen_offsets=seqlen_offsets,
cu_seqlens=cu_seqlens,
max_seqlen=ctx.max_seqlen,
interleaved=ctx.interleaved,
inplace=ctx.inplace,
conjugate=True,
)
return dx, None, None, None, None, None, None, None
def apply_rotary_emb(
x,
cos,
sin,
interleaved=False,
inplace=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
):
"""
Arguments:
x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim)
cos, sin: (seqlen_rotary, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
inplace: if True, apply rotary embedding in-place.
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
Most commonly used in inference when we have KV cache.
cu_seqlens: (batch + 1,) or None
max_seqlen: int
Return:
out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
return ApplyRotaryEmb.apply(x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen)
# For backward compatibility
apply_rotary_emb_func = apply_rotary_emb
class ApplyRotaryEmbQKV(torch.autograd.Function):
@staticmethod
def forward(
ctx,
qkv,
cos,
sin,
cos_k=None,
sin_k=None,
interleaved=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
):
batch, seqlen, three, nheads, headdim = qkv.shape
assert three == 3
if cos_k is None and sin_k is None and qkv.is_contiguous():
# Call 1 kernel instead of 2 kernels
# We need qkv to be contiguous so that when we reshape to combine (3, nheads)
# dimensions, we get the same tensor
# qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
apply_rotary(qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True)
else:
cos_k = cos if cos_k is None else cos_k
sin_k = sin if sin_k is None else sin_k
q, k = qkv[:, :, 0], qkv[:, :, 1]
apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True)
apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True)
ctx.save_for_backward(cos, sin, cos_k, sin_k)
if isinstance(seqlen_offsets, int):
ctx.save_for_backward(cos, sin, cos_k, sin_k)
ctx.seqlen_offsets = seqlen_offsets
else:
ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets)
ctx.seqlen_offsets = None
ctx.interleaved = interleaved
return qkv
@staticmethod
def backward(ctx, dqkv):
seqlen_offsets = ctx.seqlen_offsets
if seqlen_offsets is None:
cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors
else:
cos, sin, cos_k, sin_k = ctx.saved_tensors
if cos_k is None and sin_k is None and dqkv.is_contiguous():
# Call 1 kernel instead of 2 kernels
# We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
# dimensions, we get the same tensor
dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d")
apply_rotary(
dqk,
cos,
sin,
seqlen_offsets=seqlen_offsets,
interleaved=ctx.interleaved,
inplace=True,
conjugate=True,
)
else:
cos_k = cos if cos_k is None else cos_k
sin_k = sin if sin_k is None else sin_k
dq, dk = dqkv[:, :, 0], dqkv[:, :, 1]
apply_rotary(dq, cos, sin, seqlen_offsets, interleaved=ctx.interleaved, inplace=True, conjugate=True)
apply_rotary(
dk,
cos_k,
sin_k,
seqlen_offsets,
interleaved=ctx.interleaved,
inplace=True,
conjugate=True,
)
return dqkv, None, None, None, None, None, None
def apply_rotary_emb_qkv_(
qkv,
cos,
sin,
cos_k=None,
sin_k=None,
interleaved=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
Most commonly used in inference when we have KV cache.
Return:
qkv: (batch_size, seqlen, 3, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
"""
return ApplyRotaryEmbQKV.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets)
class ApplyRotaryEmbKV(torch.autograd.Function):
@staticmethod
def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0):
batch, seqlen, two, nheads, headdim = kv.shape
assert two == 2
k = kv[:, :, 0]
apply_rotary(k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True)
if isinstance(seqlen_offsets, int):
ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward
ctx.seqlen_offsets = seqlen_offsets
else:
ctx.save_for_backward(cos, sin, seqlen_offsets)
ctx.seqlen_offsets = None
ctx.interleaved = interleaved
return kv
@staticmethod
def backward(ctx, dkv):
seqlen_offsets = ctx.seqlen_offsets
if seqlen_offsets is None:
cos, sin, seqlen_offsets = ctx.saved_tensors
else:
cos, sin = ctx.saved_tensors
apply_rotary(
dkv[:, :, 0],
cos,
sin,
seqlen_offsets=seqlen_offsets,
interleaved=ctx.interleaved,
inplace=True,
conjugate=True,
)
return dkv, None, None, None, None
apply_rotary_emb_kv_ = ApplyRotaryEmbKV.apply
def apply_rotary_emb_kv_(
kv,
cos,
sin,
interleaved=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
):
"""
Arguments:
kv: (batch_size, seqlen, 2, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
Most commonly used in inference when we have KV cache.
Return:
kv: (batch_size, seqlen, 2, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of K.
"""
return ApplyRotaryEmbKV.apply(kv, cos, sin, interleaved, seqlen_offsets)
class RotaryEmbedding(torch.nn.Module):
"""
The rotary position embeddings from RoFormer_ (Su et. al).
A crucial insight from the method is that the query and keys are
transformed by rotation matrices which depend on the relative positions.
Other implementations are available in the Rotary Transformer repo_ and in
GPT-NeoX_, GPT-NeoX was an inspiration
.. _RoFormer: https://arxiv.org/abs/2104.09864
.. _repo: https://github.com/ZhuiyiTechnology/roformer
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
"""
def __init__(
self,
dim: int,
base=10000.0,
interleaved=False,
scale_base=None,
pos_idx_in_fp32=True,
device=None,
):
"""
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
otherwise they might be in lower precision.
This option was added because previously (before 2023-07-02), when we construct
the position indices, we use the dtype of self.inv_freq. In most cases this would
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
self.inv_freq would be bf16, and the position indices are also in bf16.
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
embeddings for some positions will coincide.
To maintain compatibility with models previously trained in pure bf16,
we add this option.
"""
super().__init__()
self.dim = dim
self.base = float(base)
self.pos_idx_in_fp32 = pos_idx_in_fp32
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = self._compute_inv_freq(device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.interleaved = interleaved
self.scale_base = scale_base
scale = (
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
if scale_base is not None
else None
)
self.register_buffer("scale", scale, persistent=False)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
def _compute_inv_freq(self, device=None):
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
# Reset the tables if the sequence length has changed,
# if we're on a new device (possibly due to tracing for instance),
# or if we're switching from inference mode to training
if (
seqlen > self._seq_len_cached
or self._cos_cached is None
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
or (self.training and self._cos_cached.is_inference())
):
self._seq_len_cached = seqlen
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
if self.pos_idx_in_fp32:
t = torch.arange(seqlen, device=device, dtype=torch.float32)
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
# will be large. Having it in bf16 will lose a lot of precision and cause the
# cos & sin output to change significantly.
# We want to recompute self.inv_freq if it was not loaded in fp32
if self.inv_freq.dtype != torch.float32:
inv_freq = self._compute_inv_freq(device=device)
else:
inv_freq = self.inv_freq
else:
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
inv_freq = self.inv_freq
# Don't do einsum, it converts fp32 to fp16 under AMP
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, inv_freq)
if self.scale is None:
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
else:
power = (
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
) / self.scale_base
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
# We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
def forward(
self,
qkv: torch.Tensor,
kv: Optional[torch.Tensor] = None,
seqlen_offset: Union[int, torch.Tensor] = 0,
max_seqlen: Optional[int] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
else it's just q of shape (batch, seqlen, nheads, headdim)
kv: (batch, seqlen, 2, nheads, headdim)
seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
Most commonly used in inference when we have KV cache.
If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
should pass in max_seqlen, which will update the cos / sin cache up to that length.
Apply rotary embedding *inplace* to qkv and / or kv.
"""
seqlen = qkv.shape[1]
if max_seqlen is not None:
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
elif isinstance(seqlen_offset, int):
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
if kv is None:
if self.scale is None:
return apply_rotary_emb_qkv_(
qkv,
self._cos_cached,
self._sin_cached,
interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
)
else:
return apply_rotary_emb_qkv_(
qkv,
self._cos_cached,
self._sin_cached,
self._cos_k_cached,
self._sin_k_cached,
interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
)
else:
q = qkv
q = apply_rotary_emb_func(
q,
self._cos_cached,
self._sin_cached,
interleaved=self.interleaved,
inplace=True,
seqlen_offsets=seqlen_offset,
)
if self.scale is None:
kv = apply_rotary_emb_kv_(
kv,
self._cos_cached,
self._sin_cached,
interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
)
else:
kv = apply_rotary_emb_kv_(
kv,
self._cos_k_cached,
self._sin_k_cached,
interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
)
return q, kv

File diff suppressed because it is too large Load diff

View file

@ -2046,7 +2046,8 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs):
numpy_init_version = numpy.__version__
pb_init_version = google.protobuf.__version__
run_subprocess(
[sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], cwd=SCRIPT_DIR
[sys.executable, "-m", "pip", "install", "-r", "requirements-transformers-test.txt"],
cwd=SCRIPT_DIR,
)
run_subprocess([sys.executable, "-m", "pytest", "transformers"], cwd=cwd)
# Restore initial numpy/protobuf version in case other tests use it

View file

@ -3,7 +3,8 @@ packaging
protobuf==3.20.2
numpy==1.24.0 ; python_version < '3.12'
numpy==1.26.0 ; python_version >= '3.12'
torch
coloredlogs==15.0
transformers==4.36.0
psutil
einops
einops