mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-20 02:07:56 +00:00
GQA Rotary and Packed QKV with Flash (#18906)
### Description These changes add rotary embedding and packed qkv input to gqa. As of now, the changes are only supported with Flash-Attention (SM >= 80) but should soon be supported with Memory Efficient Attention as well. ### Motivation and Context With the fusion of rotary embedding into this Attention op, we hope to observe some perf gain. The packed QKV should also provide some perf gain in the context of certain models, like Llama2, that would benefit from running ops on the fused QKV matrix, rather than the separate Q, K, and V. --------- Co-authored-by: Yufeng Li <liyufeng1987@gmail.com>
This commit is contained in:
parent
532f8c642c
commit
cbb29d80ff
15 changed files with 1517 additions and 272 deletions
|
|
@ -2398,24 +2398,28 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
#### Attributes
|
||||
|
||||
<dl>
|
||||
<dt><tt>do_rotary</tt> : int</dt>
|
||||
<dd>Whether to use rotary position embedding. Default value is 0.</dd>
|
||||
<dt><tt>kv_num_heads</tt> : int (required)</dt>
|
||||
<dd>Number of attention heads for k and v</dd>
|
||||
<dt><tt>local_window_size</tt> : int</dt>
|
||||
<dd>left_window_size for local attention (like Mistral). Default value is -1 meaning unused.</dd>
|
||||
<dt><tt>num_heads</tt> : int (required)</dt>
|
||||
<dd>Number of attention heads for q</dd>
|
||||
<dt><tt>rotary_interleaved</tt> : int</dt>
|
||||
<dd>Rotate using interleaved pattern. Default value is 0 (False).</dd>
|
||||
<dt><tt>scale</tt> : float</dt>
|
||||
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
|
||||
</dl>
|
||||
|
||||
#### Inputs
|
||||
#### Inputs (7 - 9)
|
||||
|
||||
<dl>
|
||||
<dt><tt>query</tt> : T</dt>
|
||||
<dd>Query with shape (batch_size, sequence_length, hidden_size)</dd>
|
||||
<dt><tt>key</tt> : T</dt>
|
||||
<dd>Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape(batch_size, sequence_length, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).</dd>
|
||||
<dt><tt>key</tt> (optional) : T</dt>
|
||||
<dd>Key with shape (batch_size, kv_sequence_length, kv_hidden_size) </dd>
|
||||
<dt><tt>value</tt> : T</dt>
|
||||
<dt><tt>value</tt> (optional) : T</dt>
|
||||
<dd>Value with shape (batch_size, kv_sequence_length, kv_hidden_size)</dd>
|
||||
<dt><tt>past_key</tt> (optional) : T</dt>
|
||||
<dd>past state key with support for format BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
|
||||
|
|
@ -2425,6 +2429,10 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
<dd>1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.</dd>
|
||||
<dt><tt>total_sequence_length</tt> : M</dt>
|
||||
<dd>Scalar tensor of total sequence length (past + new).</dd>
|
||||
<dt><tt>cos_cache</tt> (optional) : T</dt>
|
||||
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
|
||||
<dt><tt>sin_cache</tt> (optional) : T</dt>
|
||||
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
|
||||
</dl>
|
||||
|
||||
#### Outputs
|
||||
|
|
|
|||
|
|
@ -843,7 +843,7 @@ Do not modify directly.*
|
|||
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|
||||
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
|
||||
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
|
||||
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|
|
|
|||
|
|
@ -99,10 +99,15 @@ struct GroupQueryAttentionParameters {
|
|||
bool is_unidirectional; // causal
|
||||
int local_window_size;
|
||||
bool kv_share_buffer;
|
||||
bool is_packed_qkv;
|
||||
bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor
|
||||
bool do_rotary;
|
||||
bool rotary_interleaved;
|
||||
float scale;
|
||||
AttentionQkvFormat qkv_format;
|
||||
AttentionQkvFormat past_kv_format;
|
||||
int zeros_count;
|
||||
int* zero_ptr;
|
||||
};
|
||||
|
||||
namespace attention {
|
||||
|
|
|
|||
|
|
@ -355,13 +355,15 @@ bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, in
|
|||
Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
|
||||
cudaStream_t stream,
|
||||
void* q, // batch_size x seqlen_q x num_heads x head_size
|
||||
void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size
|
||||
void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size
|
||||
void* k, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
|
||||
void* v, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
|
||||
void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
|
||||
void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
|
||||
void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
|
||||
void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
|
||||
void* out, // batch_size x seqlen_q x num_heads x head_size
|
||||
void* softmax_lse, // batch_size x num_heads x seqlen_q
|
||||
void* seqlens_k_, // batch_size
|
||||
void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
|
||||
void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
|
||||
int batch_size,
|
||||
int num_heads,
|
||||
int num_heads_k,
|
||||
|
|
@ -376,16 +378,15 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
|
|||
int num_splits,
|
||||
void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads
|
||||
void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
|
||||
int local_window_size) {
|
||||
// if (seqlen_q == 1) {
|
||||
// is_causal = false;
|
||||
// } // causal=true is the same as causal=false in this case
|
||||
|
||||
int local_window_size,
|
||||
bool is_rotary_interleaved,
|
||||
bool is_packed_qkv) {
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int head_size_rounded = round_multiple(head_size, 32);
|
||||
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
||||
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
||||
|
||||
// In kv-cache case, seqlen_k_max as kv sequence length
|
||||
Flash_fwd_params params;
|
||||
set_params_fprop(params,
|
||||
batch_size,
|
||||
|
|
@ -406,15 +407,24 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
|
|||
is_causal ? 0 : -1);
|
||||
params.dprops = &dprops;
|
||||
|
||||
if (k != nullptr && v != nullptr) {
|
||||
if (k_new != nullptr && v_new != nullptr) {
|
||||
params.seqlen_knew = seqlen_k_new;
|
||||
params.knew_ptr = k;
|
||||
params.vnew_ptr = v;
|
||||
params.knew_ptr = k_new;
|
||||
params.vnew_ptr = v_new;
|
||||
// All stride are in elements, not bytes.
|
||||
params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size;
|
||||
params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size;
|
||||
params.knew_row_stride = num_heads_k * head_size;
|
||||
params.vnew_row_stride = num_heads_k * head_size;
|
||||
if (is_packed_qkv) {
|
||||
params.q_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size);
|
||||
params.q_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size);
|
||||
params.knew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size);
|
||||
params.vnew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size);
|
||||
params.knew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size);
|
||||
params.vnew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size);
|
||||
} else {
|
||||
params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size;
|
||||
params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size;
|
||||
params.knew_row_stride = num_heads_k * head_size;
|
||||
params.vnew_row_stride = num_heads_k * head_size;
|
||||
}
|
||||
params.knew_head_stride = head_size;
|
||||
params.vnew_head_stride = head_size;
|
||||
} else {
|
||||
|
|
@ -434,6 +444,13 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
|
|||
params.cu_seqlens_k = static_cast<int*>(seqlens_k_);
|
||||
}
|
||||
|
||||
if (rotary_cos != nullptr) {
|
||||
params.rotary_cos_ptr = rotary_cos;
|
||||
params.rotary_sin_ptr = rotary_sin;
|
||||
params.is_rotary_interleaved = is_rotary_interleaved;
|
||||
params.rotary_dim = (head_size / 16) * 16;
|
||||
}
|
||||
|
||||
params.num_splits = num_splits;
|
||||
if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) {
|
||||
params.softmax_lseaccum_ptr = softmax_lse_accum;
|
||||
|
|
@ -444,7 +461,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
|
|||
}
|
||||
|
||||
// Only split kernel supports appending to KV cache
|
||||
run_mha_fwd(params, stream, /*force_split_kernel=*/k != nullptr);
|
||||
run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -87,6 +87,8 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
|
|||
void* out, // batch_size x seqlen_q x num_heads x head_size
|
||||
void* softmax_lse, // batch_size x num_heads x seqlen_q
|
||||
void* seqlens_k_, // batch_size
|
||||
void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
|
||||
void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
|
||||
int batch_size,
|
||||
int num_heads,
|
||||
int num_heads_k,
|
||||
|
|
@ -101,7 +103,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
|
|||
int num_splits = 0,
|
||||
void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads
|
||||
void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
|
||||
int local_window_size = -1);
|
||||
int local_window_size = -1,
|
||||
bool is_rotary_interleaved = false,
|
||||
bool is_packed_qkv = false);
|
||||
|
||||
size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads);
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
|
|||
kv_num_heads_ = static_cast<int>(kv_num_heads);
|
||||
is_past_bsnh_ = false; // info.GetAttrOrDefault<int64_t>("is_past_bsnh", 1) == 1;
|
||||
local_window_size_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("local_window_size", -1));
|
||||
do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
|
||||
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;
|
||||
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
|
|
@ -62,6 +64,9 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
|
|||
#else
|
||||
disable_memory_efficient_attention_ = true;
|
||||
#endif
|
||||
if (!disable_flash_attention_) {
|
||||
zeros_ = this->GetScratchBuffer<int>(kZerosCount, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -73,6 +78,8 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
const Tensor* past_value = context->Input<Tensor>(4);
|
||||
const Tensor* seqlens_k = context->Input<Tensor>(5);
|
||||
const Tensor* total_seqlen = context->Input<Tensor>(6);
|
||||
const Tensor* cos_cache = context->Input<Tensor>(7);
|
||||
const Tensor* sin_cache = context->Input<Tensor>(8);
|
||||
|
||||
auto& device_prop = GetDeviceProp();
|
||||
GroupQueryAttentionParameters parameters;
|
||||
|
|
@ -84,6 +91,8 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
value,
|
||||
past_key,
|
||||
past_value,
|
||||
cos_cache,
|
||||
sin_cache,
|
||||
¶meters,
|
||||
num_heads_,
|
||||
kv_num_heads_,
|
||||
|
|
@ -93,7 +102,13 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
scale_,
|
||||
device_prop.maxThreadsPerBlock));
|
||||
parameters.local_window_size = local_window_size_;
|
||||
parameters.is_unidirectional = is_unidirectional_;
|
||||
parameters.zeros_count = kZerosCount;
|
||||
parameters.zero_ptr = zeros_.get();
|
||||
// parameters.left_padding = left_padding_;
|
||||
int sequence_length = parameters.sequence_length;
|
||||
parameters.do_rotary = do_rotary_;
|
||||
parameters.rotary_interleaved = rotary_interleaved_;
|
||||
|
||||
TensorShapeVector output_shape(3);
|
||||
output_shape[0] = static_cast<int64_t>(parameters.batch_size);
|
||||
|
|
@ -139,6 +154,8 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
!use_flash_attention &&
|
||||
!disable_memory_efficient_attention_ &&
|
||||
local_window_size_ == -1 &&
|
||||
do_rotary_ == false &&
|
||||
key != nullptr &&
|
||||
(parameters.head_size & 7) == 0 &&
|
||||
parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length &&
|
||||
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
|
||||
|
|
@ -182,8 +199,8 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
Tensor* present_value = context->Output(2, present_shape);
|
||||
|
||||
data.query = reinterpret_cast<const CudaT*>(query->Data<T>());
|
||||
data.key = reinterpret_cast<const CudaT*>(key->Data<T>());
|
||||
data.value = reinterpret_cast<const CudaT*>(value->Data<T>());
|
||||
data.key = key == nullptr ? nullptr : reinterpret_cast<const CudaT*>(key->Data<T>());
|
||||
data.value = value == nullptr ? nullptr : reinterpret_cast<const CudaT*>(value->Data<T>());
|
||||
data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast<const CudaT*>(past_key->Data<T>());
|
||||
data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast<const CudaT*>(past_value->Data<T>());
|
||||
data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
|
||||
|
|
@ -229,6 +246,11 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
if (fmha_buffer != nullptr) {
|
||||
data.fmha_buffer = reinterpret_cast<CudaT*>(fmha_buffer.get());
|
||||
}
|
||||
// Rotary
|
||||
if (parameters.do_rotary) {
|
||||
data.cos_cache = reinterpret_cast<const CudaT*>(cos_cache->Data<T>());
|
||||
data.sin_cache = reinterpret_cast<const CudaT*>(sin_cache->Data<T>());
|
||||
}
|
||||
|
||||
cublasHandle_t cublas = GetCublasHandle(context);
|
||||
|
||||
|
|
|
|||
|
|
@ -23,10 +23,15 @@ class GroupQueryAttention final : public CudaKernel {
|
|||
int num_heads_; // number of attention heads
|
||||
int kv_num_heads_; // different for k and v for group query attention
|
||||
int local_window_size_;
|
||||
bool is_unidirectional_;
|
||||
bool is_past_bsnh_;
|
||||
bool do_rotary_;
|
||||
bool rotary_interleaved_;
|
||||
float scale_;
|
||||
bool disable_flash_attention_;
|
||||
bool disable_memory_efficient_attention_;
|
||||
static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256)
|
||||
IAllocatorUniquePtr<int> zeros_;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ Status CheckInputs(const Tensor* query,
|
|||
const Tensor* value,
|
||||
const Tensor* past_key,
|
||||
const Tensor* past_value,
|
||||
const Tensor* cos_cache,
|
||||
const Tensor* sin_cache,
|
||||
void* parameters,
|
||||
int num_heads,
|
||||
int kv_num_heads,
|
||||
|
|
@ -24,19 +26,18 @@ Status CheckInputs(const Tensor* query,
|
|||
bool is_past_bsnh,
|
||||
float scale) {
|
||||
// Note: Here S* is past_cache_sequence_length, S- is past_sequence_length, S+ is sequence_length
|
||||
// past_key : (B, N_k, S*, H) or (B, N_k, S-, H)
|
||||
// past_value : (B, N_k, S*, H) or (B, N_k, S-, H)
|
||||
// past_key : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr
|
||||
// past_value : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr
|
||||
// no packing for q/k/v:
|
||||
// query (Q) : (B, S, D)
|
||||
// key (K) : (B, S, D_kv)
|
||||
// value (V) : (B, S, D_kv)
|
||||
// query (Q) : (B, S, D) or (B, S, (D_q + 2 D_kv))
|
||||
// key (K) : (B, S, D_kv) or nullptr
|
||||
// value (V) : (B, S, D_kv) or nullptr
|
||||
ORT_UNUSED_PARAMETER(value);
|
||||
|
||||
AttentionQkvFormat qkv_format = Q_K_V_BSNH;
|
||||
AttentionQkvFormat past_kv_format = is_past_bsnh ? Q_K_V_BSNH : Q_K_V_BNSH;
|
||||
|
||||
const bool is_packed_qkv = key == nullptr;
|
||||
const auto& query_dims = query->Shape().GetDims();
|
||||
const auto& key_dims = key->Shape().GetDims();
|
||||
|
||||
if (query_dims.size() != 3) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ",
|
||||
|
|
@ -46,10 +47,69 @@ Status CheckInputs(const Tensor* query,
|
|||
int batch_size = static_cast<int>(query_dims[0]);
|
||||
int sequence_length = static_cast<int>(query_dims[1]);
|
||||
int q_hidden_size = static_cast<int>(query_dims[2]);
|
||||
int head_size = static_cast<int>(q_hidden_size) / num_heads;
|
||||
int head_size = 0;
|
||||
|
||||
int kv_hidden_size = static_cast<int>(key_dims[2]);
|
||||
if (num_heads % kv_num_heads != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
|
||||
num_heads % kv_num_heads);
|
||||
}
|
||||
|
||||
int kv_hidden_size = 0;
|
||||
// Check key and value when not packed
|
||||
if (!is_packed_qkv) {
|
||||
head_size = static_cast<int>(q_hidden_size) / num_heads;
|
||||
if (head_size % 8 != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"head_size must be a multiple of 8. Got head_size % 8 == ",
|
||||
head_size % 8);
|
||||
}
|
||||
if (value == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
|
||||
}
|
||||
const auto& key_dims = key->Shape().GetDims();
|
||||
if (key_dims.size() != 3) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
|
||||
key_dims.size());
|
||||
} else if (query_dims[0] != key_dims[0]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'query' and 'key' shall have same dim 0 (batch size)");
|
||||
} else if (query_dims[1] != key_dims[1]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'query' and 'key' shall have same dim 1 (sequence length)");
|
||||
}
|
||||
kv_hidden_size = static_cast<int>(key_dims[2]);
|
||||
const auto& value_dims = value->Shape().GetDims();
|
||||
if (value_dims.size() != 3) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
|
||||
value_dims.size());
|
||||
} else if (query_dims[0] != value_dims[0]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'query' and 'value' shall have same dim 0 (batch size)");
|
||||
} else if (query_dims[1] != value_dims[1]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'query' and 'value' shall have same dim 1 (sequence length)");
|
||||
} else if (value_dims[2] != kv_hidden_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
|
||||
}
|
||||
} else {
|
||||
// Check packed qkv
|
||||
head_size = static_cast<int>(q_hidden_size) / (num_heads + 2 * kv_num_heads);
|
||||
if (head_size % 8 != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"head_size must be a multiple of 8. Got head_size % 8 == ",
|
||||
head_size % 8);
|
||||
}
|
||||
if (value != nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
|
||||
}
|
||||
q_hidden_size = head_size * num_heads;
|
||||
kv_hidden_size = head_size * kv_num_heads;
|
||||
}
|
||||
|
||||
// Check past-present KV
|
||||
int32_t past_sequence_length = 0;
|
||||
if (past_key != nullptr && past_value != nullptr) {
|
||||
const auto& past_key_dims = past_key->Shape().GetDims();
|
||||
|
|
@ -130,41 +190,6 @@ Status CheckInputs(const Tensor* query,
|
|||
"Input 'past_key' and 'past_value' shall be both present or both absent.");
|
||||
}
|
||||
|
||||
if (key_dims.size() != 3) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
|
||||
key_dims.size());
|
||||
}
|
||||
if (query_dims[0] != key_dims[0]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'query' and 'key' shall have same dim 0 (batch size)");
|
||||
}
|
||||
|
||||
if (num_heads % kv_num_heads != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
|
||||
num_heads % kv_num_heads);
|
||||
}
|
||||
|
||||
const auto& value_dims = value->Shape().GetDims();
|
||||
if (value_dims.size() != 3) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
|
||||
value_dims.size());
|
||||
}
|
||||
|
||||
if (query_dims[0] != value_dims[0]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'query' and 'value' shall have same dim 0 (batch_size)");
|
||||
}
|
||||
|
||||
if (static_cast<int64_t>(sequence_length) != value_dims[1]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'query,' 'key,' and 'value' shall have the same dim 1 (sequence_length)");
|
||||
}
|
||||
|
||||
if (value_dims[2] != kv_hidden_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
|
||||
}
|
||||
|
||||
// Check seqlens_k tensor (holding past seqlen for token gen)
|
||||
const auto& seqlens_dim = seqlens_k->Shape().GetDims();
|
||||
if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) {
|
||||
|
|
@ -180,6 +205,36 @@ Status CheckInputs(const Tensor* query,
|
|||
int total_sequence_length = *((*total_seqlen).template Data<int32_t>());
|
||||
int present_sequence_length = std::max(total_sequence_length, past_sequence_length);
|
||||
|
||||
if (cos_cache != nullptr && sin_cache != nullptr) {
|
||||
const auto& cos_dims = cos_cache->Shape().GetDims();
|
||||
const auto& sin_dims = sin_cache->Shape().GetDims();
|
||||
|
||||
if (head_size % 16 != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"head_size shall be a multiple of 16. Got head_size % 16 == ",
|
||||
head_size % 16);
|
||||
}
|
||||
if (cos_dims[0] != present_sequence_length) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"cos_cache dimension 0 must be of present_sequence_length.");
|
||||
}
|
||||
if (sin_dims[0] != present_sequence_length) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"sin_cache dimension 0 must be of present_sequence_length.");
|
||||
}
|
||||
if (cos_dims[1] != (head_size / 16) * 8) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
|
||||
}
|
||||
if (sin_dims[1] != (head_size / 16) * 8) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
|
||||
}
|
||||
} else if (cos_cache != nullptr || sin_cache != nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'cos_cache' and 'sin_cache' shall be both present or both absent.");
|
||||
}
|
||||
|
||||
bool is_prompt = sequence_length != 1;
|
||||
|
||||
if (parameters != nullptr) {
|
||||
|
|
@ -190,9 +245,10 @@ Status CheckInputs(const Tensor* query,
|
|||
output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors
|
||||
output_parameters->hidden_size = q_hidden_size;
|
||||
output_parameters->num_heads = num_heads;
|
||||
output_parameters->head_size = q_hidden_size / num_heads;
|
||||
output_parameters->head_size = head_size;
|
||||
output_parameters->kv_hidden_size = kv_hidden_size;
|
||||
output_parameters->kv_num_heads = kv_num_heads;
|
||||
output_parameters->is_packed_qkv = is_packed_qkv;
|
||||
output_parameters->is_unidirectional = true;
|
||||
output_parameters->is_prompt = is_prompt;
|
||||
output_parameters->scale = scale;
|
||||
|
|
@ -208,6 +264,8 @@ Status CheckInputs(const Tensor* query,
|
|||
const Tensor* value,
|
||||
const Tensor* past_key,
|
||||
const Tensor* past_value,
|
||||
const Tensor* cos_cache,
|
||||
const Tensor* sin_cache,
|
||||
void* parameters,
|
||||
int num_heads,
|
||||
int kv_num_heads,
|
||||
|
|
@ -220,7 +278,7 @@ Status CheckInputs(const Tensor* query,
|
|||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block);
|
||||
}
|
||||
|
||||
return CheckInputs(query, key, value, past_key, past_value, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale);
|
||||
return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale);
|
||||
}
|
||||
|
||||
} // namespace group_query_attention_helper
|
||||
|
|
|
|||
|
|
@ -151,9 +151,10 @@ template <typename T>
|
|||
Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameters,
|
||||
GroupQueryAttentionData<T>& data,
|
||||
cudaStream_t stream,
|
||||
const int max_threads_per_block) {
|
||||
const int max_threads_per_block,
|
||||
const bool past_only = false) {
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int kv_sequence_length = parameters.sequence_length;
|
||||
const int kv_sequence_length = past_only ? 0 : parameters.sequence_length;
|
||||
const int past_sequence_length = parameters.seqlen_past_kv_cache;
|
||||
const int present_sequence_length = parameters.seqlen_present_kv_cache;
|
||||
const int kv_num_heads = parameters.kv_num_heads;
|
||||
|
|
@ -441,7 +442,6 @@ Status LaunchUngroup(contrib::GroupQueryAttentionParameters& parameters,
|
|||
return CUDA_CALL(cudaGetLastError());
|
||||
}
|
||||
|
||||
|
||||
__global__ void PastToTotalSeqlen(int32_t* seqlens_k,
|
||||
int32_t* seqlens_k_buff,
|
||||
const int add_seqlen) {
|
||||
|
|
@ -451,7 +451,7 @@ __global__ void PastToTotalSeqlen(int32_t* seqlens_k,
|
|||
// Convert Past to Total sequence length tensor
|
||||
Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k,
|
||||
int32_t* seqlens_k_buff, bool is_total, cudaStream_t stream,
|
||||
const int threads_per_block) {
|
||||
const int threads_per_block) {
|
||||
if (parameters.is_prompt) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -482,91 +482,63 @@ Status FlashAttention(
|
|||
const int batch_size = parameters.batch_size;
|
||||
const int sequence_length = parameters.sequence_length;
|
||||
const int kv_sequence_length = parameters.sequence_length;
|
||||
const int present_sequence_length = parameters.seqlen_present_kv_cache;
|
||||
const int num_heads = parameters.num_heads;
|
||||
const int kv_num_heads = parameters.kv_num_heads;
|
||||
const int head_size = parameters.head_size;
|
||||
AttentionQkvFormat past_kv_format = parameters.past_kv_format;
|
||||
|
||||
void* query = reinterpret_cast<void*>(const_cast<T*>(data.query));
|
||||
void* key = reinterpret_cast<void*>(const_cast<T*>(data.key));
|
||||
void* value = reinterpret_cast<void*>(const_cast<T*>(data.value));
|
||||
|
||||
bool is_causal = true;
|
||||
|
||||
bool is_bf16 = std::is_same<T, BFloat16>::value;
|
||||
|
||||
// Note: seqlens_k is past sequence length for flash
|
||||
if (parameters.is_prompt) {
|
||||
// Launch kernel to copy seqlen
|
||||
constexpr int thr_per_blk = 256;
|
||||
int blk_in_grid = (batch_size + thr_per_blk -1) / thr_per_blk;
|
||||
repeat_seqlen<<<blk_in_grid, thr_per_blk, 0, stream>>>(data.seqlens_k_total, parameters.sequence_length, batch_size);
|
||||
void* query = reinterpret_cast<void*>(const_cast<T*>(data.query));
|
||||
void* key;
|
||||
void* value;
|
||||
|
||||
if (!parameters.is_packed_qkv) {
|
||||
key = reinterpret_cast<void*>(const_cast<T*>(data.key));
|
||||
value = reinterpret_cast<void*>(const_cast<T*>(data.value));
|
||||
} else {
|
||||
const size_t key_offset = static_cast<size_t>(num_heads * head_size);
|
||||
const size_t value_offset = static_cast<size_t>(kv_num_heads * head_size);
|
||||
key = reinterpret_cast<T*>(query) + key_offset;
|
||||
value = reinterpret_cast<T*>(key) + value_offset;
|
||||
}
|
||||
|
||||
void* seqlens_k = reinterpret_cast<void*>(data.seqlens_k);
|
||||
|
||||
if (parameters.kv_share_buffer) {
|
||||
// Share buffer case
|
||||
if (data.past_key == nullptr || data.past_key != data.present_key) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Past and present kv shall share the same tensor when kv_share_buffer is on.");
|
||||
if (parameters.is_prompt) {
|
||||
// set seqlens_k to zeros... flash api uses seqlens_k to indicate where to append key and value
|
||||
// user should use seqlens_k to index into output to get new tokens
|
||||
if (batch_size <= parameters.zeros_count) {
|
||||
seqlens_k = parameters.zero_ptr;
|
||||
} else {
|
||||
// Launch kernel to create larger seqlen tensor when batch_size > 256
|
||||
constexpr int thr_per_blk = 256;
|
||||
int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk;
|
||||
repeat_seqlen<<<blk_in_grid, thr_per_blk, 0, stream>>>(data.seqlens_k_total, 0, batch_size);
|
||||
seqlens_k = data.seqlens_k_total;
|
||||
}
|
||||
|
||||
if (parameters.is_prompt) {
|
||||
ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block));
|
||||
key = nullptr;
|
||||
value = nullptr;
|
||||
seqlens_k = reinterpret_cast<void*>(data.seqlens_k_total);
|
||||
}
|
||||
|
||||
void* present_key = reinterpret_cast<void*>(const_cast<T*>(data.present_key));
|
||||
void* present_value = reinterpret_cast<void*>(const_cast<T*>(data.present_value));
|
||||
|
||||
DUMP_TENSOR_INIT();
|
||||
DUMP_TENSOR("seqlens_k", reinterpret_cast<int*>(seqlens_k), batch_size, 1);
|
||||
|
||||
bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
|
||||
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
|
||||
device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast<void*>(data.softmax_lse),
|
||||
seqlens_k, batch_size, num_heads, kv_num_heads,
|
||||
head_size, sequence_length, present_sequence_length, kv_sequence_length,
|
||||
scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum),
|
||||
reinterpret_cast<void*>(data.out_accum), parameters.local_window_size));
|
||||
} else {
|
||||
// Not share buffer case
|
||||
// Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient
|
||||
if (data.past_key != nullptr && data.past_key == data.present_key) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Past and present kv share the same tensor but kv_share_buffer is not on.");
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block));
|
||||
|
||||
if (!parameters.is_prompt) {
|
||||
ORT_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256));
|
||||
}
|
||||
|
||||
seqlens_k = reinterpret_cast<void*>(data.seqlens_k_total);
|
||||
|
||||
void* present_key = reinterpret_cast<void*>(const_cast<T*>(data.present_key));
|
||||
void* present_value = reinterpret_cast<void*>(const_cast<T*>(data.present_value));
|
||||
|
||||
DUMP_TENSOR_INIT();
|
||||
DUMP_TENSOR("seqlens_k", reinterpret_cast<int*>(seqlens_k), batch_size, 1);
|
||||
DUMP_TENSOR("Q", data.query, batch_size, sequence_length, num_heads, head_size);
|
||||
DUMP_TENSOR("K", data.present_key, batch_size, kv_num_heads, present_sequence_length, head_size);
|
||||
DUMP_TENSOR("V", data.present_value, batch_size, kv_num_heads, present_sequence_length, head_size);
|
||||
|
||||
bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
|
||||
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
|
||||
device_prop, stream, query, present_key, present_value, nullptr, nullptr, data.output, reinterpret_cast<void*>(data.softmax_lse),
|
||||
seqlens_k, batch_size, num_heads, kv_num_heads,
|
||||
head_size, sequence_length, present_sequence_length, 0,
|
||||
scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum),
|
||||
reinterpret_cast<void*>(data.out_accum), parameters.local_window_size));
|
||||
} else if (!parameters.kv_share_buffer) { // copy past kv to present kv
|
||||
ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block, true));
|
||||
}
|
||||
|
||||
void* present_key = reinterpret_cast<void*>(const_cast<T*>(data.present_key));
|
||||
void* present_value = reinterpret_cast<void*>(const_cast<T*>(data.present_value));
|
||||
void* cos_cache = reinterpret_cast<void*>(const_cast<T*>(data.cos_cache));
|
||||
void* sin_cache = reinterpret_cast<void*>(const_cast<T*>(data.sin_cache));
|
||||
|
||||
bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
|
||||
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
|
||||
device_prop, stream, query, present_key, present_value, key, value, data.output,
|
||||
reinterpret_cast<void*>(data.softmax_lse), seqlens_k, cos_cache, sin_cache,
|
||||
batch_size, num_heads, kv_num_heads, head_size, sequence_length,
|
||||
parameters.seqlen_present_kv_cache, kv_sequence_length,
|
||||
scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum),
|
||||
reinterpret_cast<void*>(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved,
|
||||
parameters.is_packed_qkv));
|
||||
|
||||
// if (parameters.left_padding && parameters.is_prompt) {
|
||||
// ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock));
|
||||
// }
|
||||
|
||||
DUMP_TENSOR_INIT();
|
||||
DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size);
|
||||
|
||||
|
|
@ -672,7 +644,6 @@ Status EfficientAttention(
|
|||
p.has_custom_right_padding = true;
|
||||
run_memory_efficient_attention(p);
|
||||
|
||||
DUMP_TENSOR_INIT();
|
||||
DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size);
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -21,6 +21,8 @@ struct GroupQueryAttentionData {
|
|||
const T* past_key = nullptr;
|
||||
const T* past_value = nullptr;
|
||||
int* seqlens_k = nullptr;
|
||||
const T* cos_cache = nullptr;
|
||||
const T* sin_cache = nullptr;
|
||||
// Flash buffers
|
||||
T* softmax_lse = nullptr;
|
||||
T* softmax_lse_accum = nullptr;
|
||||
|
|
|
|||
|
|
@ -259,13 +259,13 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext&
|
|||
*output_shape.add_dim() = query_dims[1];
|
||||
*output_shape.add_dim() = query_dims[2];
|
||||
updateOutputShape(ctx, 0, output_shape);
|
||||
} else {
|
||||
fail_shape_inference("Missing input 2 (value)");
|
||||
}
|
||||
}
|
||||
|
||||
if (ctx.getNumOutputs() > 1) { // has present output
|
||||
if (hasInputShape(ctx, past_key_index)) {
|
||||
// auto& query_shape = getInputShape(ctx, 0);
|
||||
// auto& query_dims = query_shape.dim();
|
||||
auto& past_shape = getInputShape(ctx, past_key_index);
|
||||
auto& past_dims = past_shape.dim();
|
||||
if (past_dims.size() != 4) {
|
||||
|
|
@ -273,8 +273,7 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext&
|
|||
}
|
||||
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, past_key_index, 1);
|
||||
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, static_cast<size_t>(past_key_index) + 1, 2);
|
||||
ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, past_key_index, 1);
|
||||
ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, static_cast<size_t>(past_key_index) + 1, 2);
|
||||
// TODO(aciddelgado): propagate output shapes depending if kv-share buffer is on or not
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1015,18 +1014,29 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
"left_window_size for local attention (like Mistral). Default value is -1 meaning unused.",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(-1))
|
||||
.Attr("do_rotary",
|
||||
"Whether to use rotary position embedding. Default value is 0.",
|
||||
AttributeProto::INT,
|
||||
OPTIONAL_VALUE)
|
||||
.Attr("rotary_interleaved",
|
||||
"Rotate using interleaved pattern. Default value is 0 (False).",
|
||||
AttributeProto::INT,
|
||||
OPTIONAL_VALUE)
|
||||
.Input(0,
|
||||
"query",
|
||||
"Query with shape (batch_size, sequence_length, hidden_size)",
|
||||
"Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape"
|
||||
"(batch_size, sequence_length, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).",
|
||||
"T")
|
||||
.Input(1,
|
||||
"key",
|
||||
"Key with shape (batch_size, kv_sequence_length, kv_hidden_size) ",
|
||||
"T")
|
||||
"T",
|
||||
OpSchema::Optional)
|
||||
.Input(2,
|
||||
"value",
|
||||
"Value with shape (batch_size, kv_sequence_length, kv_hidden_size)",
|
||||
"T")
|
||||
"T",
|
||||
OpSchema::Optional)
|
||||
.Input(3,
|
||||
"past_key",
|
||||
"past state key with support for format BNSH. When past_key uses same tensor as present_key"
|
||||
|
|
@ -1047,6 +1057,16 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
"total_sequence_length",
|
||||
"Scalar tensor of total sequence length (past + new).",
|
||||
"M")
|
||||
.Input(7,
|
||||
"cos_cache",
|
||||
"2D tensor with shape (max_sequence_length, head_size / 2).",
|
||||
"T",
|
||||
OpSchema::Optional)
|
||||
.Input(8,
|
||||
"sin_cache",
|
||||
"2D tensor with shape (max_sequence_length, head_size / 2).",
|
||||
"T",
|
||||
OpSchema::Optional)
|
||||
.Output(0,
|
||||
"output",
|
||||
"3D output tensor with shape (batch_size, sequence_length, hidden_size)",
|
||||
|
|
|
|||
693
onnxruntime/test/python/transformers/rotary_flash.py
Normal file
693
onnxruntime/test/python/transformers/rotary_flash.py
Normal file
|
|
@ -0,0 +1,693 @@
|
|||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from einops import rearrange, repeat
|
||||
|
||||
##### TRITON KERNEL FOR ROTARY #####
|
||||
|
||||
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({"block_m": 2}),
|
||||
# triton.Config({"block_m": 4}),
|
||||
# triton.Config({"block_m": 8}),
|
||||
# triton.Config({"block_m": 16}),
|
||||
# ],
|
||||
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"],
|
||||
# )
|
||||
@triton.jit
|
||||
def rotary_kernel(
|
||||
out_, # Pointers to matrices
|
||||
x_,
|
||||
cos_,
|
||||
sin_,
|
||||
CU_SEQLENS,
|
||||
SEQLEN_OFFSETS, # this could be int or a pointer
|
||||
# Matrix dimensions
|
||||
seqlen,
|
||||
nheads,
|
||||
rotary_dim,
|
||||
seqlen_ro,
|
||||
CACHE_KEY_SEQLEN,
|
||||
# strides
|
||||
stride_out_batch,
|
||||
stride_out_seqlen,
|
||||
stride_out_nheads,
|
||||
stride_out_headdim,
|
||||
stride_x_batch,
|
||||
stride_x_seqlen,
|
||||
stride_x_nheads,
|
||||
stride_x_headdim,
|
||||
# Meta-parameters
|
||||
block_k: tl.constexpr,
|
||||
IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
INTERLEAVED: tl.constexpr,
|
||||
CONJUGATE: tl.constexpr,
|
||||
block_m: tl.constexpr,
|
||||
):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_batch = tl.program_id(axis=1)
|
||||
pid_head = tl.program_id(axis=2)
|
||||
rotary_dim_half = rotary_dim // 2
|
||||
|
||||
if not IS_VARLEN:
|
||||
x_ = x_ + pid_batch * stride_x_batch + pid_head * stride_x_nheads
|
||||
out_ = out_ + pid_batch * stride_out_batch + pid_head * stride_out_nheads
|
||||
else:
|
||||
start_idx = tl.load(CU_SEQLENS + pid_batch)
|
||||
seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
|
||||
x_ = x_ + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
|
||||
out_ = out_ + start_idx * stride_out_seqlen + pid_head * stride_out_nheads
|
||||
|
||||
if pid_m * block_m >= seqlen:
|
||||
return
|
||||
rm = pid_m * block_m + tl.arange(0, block_m)
|
||||
if not IS_SEQLEN_OFFSETS_TENSOR:
|
||||
rm_cs = rm + SEQLEN_OFFSETS
|
||||
else:
|
||||
rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
|
||||
rk = tl.arange(0, block_k)
|
||||
rk_half = tl.arange(0, block_k // 2)
|
||||
|
||||
if not INTERLEAVED:
|
||||
# Load the 1st and 2nd halves of x_, do calculation, then store to 1st and 2nd halves of out_
|
||||
x_ = x_ + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)
|
||||
cos_ = cos_ + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
|
||||
sin_ = sin_ + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
|
||||
cos = tl.load(cos_, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0).to(
|
||||
tl.float32
|
||||
)
|
||||
sin = tl.load(sin_, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
x0 = tl.load(x_, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32)
|
||||
x1 = tl.load(
|
||||
x_ + rotary_dim_half * stride_x_headdim,
|
||||
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
if CONJUGATE:
|
||||
sin = -sin
|
||||
o0 = x0 * cos - x1 * sin
|
||||
o1 = x0 * sin + x1 * cos
|
||||
# write back result
|
||||
out_ = out_ + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
|
||||
tl.store(out_, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
|
||||
tl.store(
|
||||
out_ + rotary_dim_half * stride_out_headdim,
|
||||
o1,
|
||||
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
|
||||
)
|
||||
else:
|
||||
# We don't want to load x_[0, 2, 4, ...] and x_[1, 3, 5, ...] separately since both are slow.
|
||||
# Instead, we load x0 = x_[0, 1, 2, 3, ...] and x1 = x_[1, 0, 3, 2, ...].
|
||||
# Loading x0 will be fast but x1 will be slow.
|
||||
# Then we load cos = cos_[0, 0, 1, 1, ...] and sin = sin_[0, 0, 1, 1, ...].
|
||||
# Then we do the calculation and use tl.where to pick put the right outputs for the even
|
||||
# and for the odd indices.
|
||||
rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
|
||||
rk_repeat = tl.arange(0, block_k) // 2
|
||||
x0_ = x_ + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)
|
||||
x1_ = x_ + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)
|
||||
cos_ = cos_ + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
|
||||
sin_ = sin_ + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
|
||||
cos = tl.load(
|
||||
cos_,
|
||||
mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
|
||||
other=1.0,
|
||||
).to(tl.float32)
|
||||
sin = tl.load(
|
||||
sin_,
|
||||
mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
x0 = tl.load(x0_, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(tl.float32)
|
||||
x1 = tl.load(x1_, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0).to(tl.float32)
|
||||
if CONJUGATE:
|
||||
sin = -sin
|
||||
x0_cos = x0 * cos
|
||||
x1_sin = x1 * sin
|
||||
out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
|
||||
out_ = out_ + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
|
||||
tl.store(out_, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
|
||||
|
||||
|
||||
def apply_rotary(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
interleaved=False,
|
||||
inplace=False,
|
||||
conjugate=False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Arguments:
|
||||
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
|
||||
else (total_seqlen, nheads, headdim).
|
||||
cos: (seqlen_ro, rotary_dim / 2)
|
||||
sin: (seqlen_ro, rotary_dim / 2)
|
||||
seqlen_offsets: integer or integer tensor of size (batch,)
|
||||
cu_seqlens: (batch + 1,) or None
|
||||
max_seqlen: int
|
||||
Returns:
|
||||
y: (batch, seqlen, nheads, headdim)
|
||||
"""
|
||||
is_varlen = cu_seqlens is not None
|
||||
if not is_varlen:
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
else:
|
||||
assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed"
|
||||
total_seqlen, nheads, headdim = x.shape
|
||||
batch_p_1 = cu_seqlens.shape[0]
|
||||
batch = batch_p_1 - 1
|
||||
seqlen = max_seqlen
|
||||
seqlen_ro, rotary_dim = cos.shape
|
||||
assert sin.shape == cos.shape
|
||||
rotary_dim *= 2
|
||||
assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
|
||||
assert headdim <= 256, "Only support headdim <= 256"
|
||||
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
|
||||
|
||||
assert cos.dtype == sin.dtype, f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
|
||||
assert x.dtype == cos.dtype, f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
|
||||
|
||||
cos, sin = cos.contiguous(), sin.contiguous()
|
||||
if isinstance(seqlen_offsets, torch.Tensor):
|
||||
assert seqlen_offsets.shape == (batch,)
|
||||
assert seqlen_offsets.dtype in [torch.int32, torch.int64]
|
||||
seqlen_offsets = seqlen_offsets.contiguous()
|
||||
else:
|
||||
assert seqlen_offsets + seqlen <= seqlen_ro
|
||||
|
||||
output = torch.empty_like(x) if not inplace else x
|
||||
if rotary_dim < headdim and not inplace:
|
||||
output[..., rotary_dim:].copy_(x[..., rotary_dim:])
|
||||
|
||||
block_k = 32 if rotary_dim <= 32 else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
|
||||
grid = lambda META: (triton.cdiv(seqlen, META["block_m"]), batch, nheads) # noqa
|
||||
block_m = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
|
||||
|
||||
# Need this, otherwise Triton tries to launch from cuda:0 and we get
|
||||
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
||||
with torch.cuda.device(x.device.index):
|
||||
rotary_kernel[grid](
|
||||
output, # data ptrs
|
||||
x,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
seqlen_offsets,
|
||||
seqlen, # shapes
|
||||
nheads,
|
||||
rotary_dim,
|
||||
seqlen_ro,
|
||||
seqlen // 128, # key for triton cache (limit number of compilations)
|
||||
output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
|
||||
output.stride(-3), # seqlen_stride or total_seqlen_stride
|
||||
output.stride(-2), # nheads_stride
|
||||
output.stride(-1), # headdim_stride
|
||||
x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
|
||||
x.stride(-3), # seqlen stride or total_seqlen_stride
|
||||
x.stride(-2), # nheads stride
|
||||
x.stride(-1), # headdim stride
|
||||
block_k,
|
||||
isinstance(seqlen_offsets, torch.Tensor),
|
||||
is_varlen,
|
||||
interleaved,
|
||||
conjugate,
|
||||
block_m,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
##### ROTARY API #####
|
||||
|
||||
|
||||
def rotate_half(x, interleaved=False):
|
||||
if not interleaved:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
else:
|
||||
x1, x2 = x[..., ::2], x[..., 1::2]
|
||||
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
|
||||
|
||||
|
||||
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
||||
"""
|
||||
x: (batch_size, seqlen, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
||||
"""
|
||||
ro_dim = cos.shape[-1] * 2
|
||||
assert ro_dim <= x.shape[-1]
|
||||
cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
||||
sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
||||
return torch.cat(
|
||||
[x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
|
||||
class ApplyRotaryEmb(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
x,
|
||||
cos,
|
||||
sin,
|
||||
interleaved=False,
|
||||
inplace=False,
|
||||
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
):
|
||||
out = apply_rotary(
|
||||
x,
|
||||
cos,
|
||||
sin,
|
||||
seqlen_offsets=seqlen_offsets,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
interleaved=interleaved,
|
||||
inplace=inplace,
|
||||
)
|
||||
if isinstance(seqlen_offsets, int):
|
||||
ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
|
||||
ctx.seqlen_offsets = seqlen_offsets
|
||||
else:
|
||||
ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
|
||||
ctx.seqlen_offsets = None
|
||||
ctx.interleaved = interleaved
|
||||
ctx.inplace = inplace
|
||||
ctx.max_seqlen = max_seqlen
|
||||
return out if not inplace else x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
seqlen_offsets = ctx.seqlen_offsets
|
||||
if seqlen_offsets is None:
|
||||
cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
|
||||
else:
|
||||
cos, sin, cu_seqlens = ctx.saved_tensors
|
||||
# TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
|
||||
# "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
|
||||
if not ctx.interleaved and not ctx.inplace:
|
||||
do = do.clone()
|
||||
dx = apply_rotary(
|
||||
do,
|
||||
cos,
|
||||
sin,
|
||||
seqlen_offsets=seqlen_offsets,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=ctx.max_seqlen,
|
||||
interleaved=ctx.interleaved,
|
||||
inplace=ctx.inplace,
|
||||
conjugate=True,
|
||||
)
|
||||
return dx, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
x,
|
||||
cos,
|
||||
sin,
|
||||
interleaved=False,
|
||||
inplace=False,
|
||||
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
|
||||
else (total_seqlen, nheads, headdim)
|
||||
cos, sin: (seqlen_rotary, rotary_dim / 2)
|
||||
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
||||
of 1st half and 2nd half (GPT-NeoX style).
|
||||
inplace: if True, apply rotary embedding in-place.
|
||||
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
|
||||
Most commonly used in inference when we have KV cache.
|
||||
cu_seqlens: (batch + 1,) or None
|
||||
max_seqlen: int
|
||||
Return:
|
||||
out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
|
||||
else (total_seqlen, nheads, headdim)
|
||||
rotary_dim must be <= headdim
|
||||
Apply rotary embedding to the first rotary_dim of x.
|
||||
"""
|
||||
return ApplyRotaryEmb.apply(x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen)
|
||||
|
||||
|
||||
# For backward compatibility
|
||||
apply_rotary_emb_func = apply_rotary_emb
|
||||
|
||||
|
||||
class ApplyRotaryEmbQKV(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
qkv,
|
||||
cos,
|
||||
sin,
|
||||
cos_k=None,
|
||||
sin_k=None,
|
||||
interleaved=False,
|
||||
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
||||
):
|
||||
batch, seqlen, three, nheads, headdim = qkv.shape
|
||||
assert three == 3
|
||||
if cos_k is None and sin_k is None and qkv.is_contiguous():
|
||||
# Call 1 kernel instead of 2 kernels
|
||||
# We need qkv to be contiguous so that when we reshape to combine (3, nheads)
|
||||
# dimensions, we get the same tensor
|
||||
# qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
|
||||
qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
|
||||
apply_rotary(qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True)
|
||||
else:
|
||||
cos_k = cos if cos_k is None else cos_k
|
||||
sin_k = sin if sin_k is None else sin_k
|
||||
q, k = qkv[:, :, 0], qkv[:, :, 1]
|
||||
apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True)
|
||||
apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True)
|
||||
ctx.save_for_backward(cos, sin, cos_k, sin_k)
|
||||
if isinstance(seqlen_offsets, int):
|
||||
ctx.save_for_backward(cos, sin, cos_k, sin_k)
|
||||
ctx.seqlen_offsets = seqlen_offsets
|
||||
else:
|
||||
ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets)
|
||||
ctx.seqlen_offsets = None
|
||||
ctx.interleaved = interleaved
|
||||
return qkv
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dqkv):
|
||||
seqlen_offsets = ctx.seqlen_offsets
|
||||
if seqlen_offsets is None:
|
||||
cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors
|
||||
else:
|
||||
cos, sin, cos_k, sin_k = ctx.saved_tensors
|
||||
if cos_k is None and sin_k is None and dqkv.is_contiguous():
|
||||
# Call 1 kernel instead of 2 kernels
|
||||
# We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
|
||||
# dimensions, we get the same tensor
|
||||
dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d")
|
||||
apply_rotary(
|
||||
dqk,
|
||||
cos,
|
||||
sin,
|
||||
seqlen_offsets=seqlen_offsets,
|
||||
interleaved=ctx.interleaved,
|
||||
inplace=True,
|
||||
conjugate=True,
|
||||
)
|
||||
else:
|
||||
cos_k = cos if cos_k is None else cos_k
|
||||
sin_k = sin if sin_k is None else sin_k
|
||||
dq, dk = dqkv[:, :, 0], dqkv[:, :, 1]
|
||||
apply_rotary(dq, cos, sin, seqlen_offsets, interleaved=ctx.interleaved, inplace=True, conjugate=True)
|
||||
apply_rotary(
|
||||
dk,
|
||||
cos_k,
|
||||
sin_k,
|
||||
seqlen_offsets,
|
||||
interleaved=ctx.interleaved,
|
||||
inplace=True,
|
||||
conjugate=True,
|
||||
)
|
||||
return dqkv, None, None, None, None, None, None
|
||||
|
||||
|
||||
def apply_rotary_emb_qkv_(
|
||||
qkv,
|
||||
cos,
|
||||
sin,
|
||||
cos_k=None,
|
||||
sin_k=None,
|
||||
interleaved=False,
|
||||
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
qkv: (batch_size, seqlen, 3, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2)
|
||||
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
|
||||
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
|
||||
1st half and 2nd half (GPT-NeoX style).
|
||||
seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
|
||||
Most commonly used in inference when we have KV cache.
|
||||
Return:
|
||||
qkv: (batch_size, seqlen, 3, nheads, headdim)
|
||||
rotary_dim must be <= headdim
|
||||
Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
|
||||
"""
|
||||
return ApplyRotaryEmbQKV.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets)
|
||||
|
||||
|
||||
class ApplyRotaryEmbKV(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0):
|
||||
batch, seqlen, two, nheads, headdim = kv.shape
|
||||
assert two == 2
|
||||
k = kv[:, :, 0]
|
||||
apply_rotary(k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True)
|
||||
if isinstance(seqlen_offsets, int):
|
||||
ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward
|
||||
ctx.seqlen_offsets = seqlen_offsets
|
||||
else:
|
||||
ctx.save_for_backward(cos, sin, seqlen_offsets)
|
||||
ctx.seqlen_offsets = None
|
||||
ctx.interleaved = interleaved
|
||||
return kv
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dkv):
|
||||
seqlen_offsets = ctx.seqlen_offsets
|
||||
if seqlen_offsets is None:
|
||||
cos, sin, seqlen_offsets = ctx.saved_tensors
|
||||
else:
|
||||
cos, sin = ctx.saved_tensors
|
||||
apply_rotary(
|
||||
dkv[:, :, 0],
|
||||
cos,
|
||||
sin,
|
||||
seqlen_offsets=seqlen_offsets,
|
||||
interleaved=ctx.interleaved,
|
||||
inplace=True,
|
||||
conjugate=True,
|
||||
)
|
||||
return dkv, None, None, None, None
|
||||
|
||||
|
||||
apply_rotary_emb_kv_ = ApplyRotaryEmbKV.apply
|
||||
|
||||
|
||||
def apply_rotary_emb_kv_(
|
||||
kv,
|
||||
cos,
|
||||
sin,
|
||||
interleaved=False,
|
||||
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
kv: (batch_size, seqlen, 2, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2)
|
||||
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
|
||||
1st half and 2nd half (GPT-NeoX style).
|
||||
seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
|
||||
Most commonly used in inference when we have KV cache.
|
||||
Return:
|
||||
kv: (batch_size, seqlen, 2, nheads, headdim)
|
||||
rotary_dim must be <= headdim
|
||||
Apply rotary embedding *inplace* to the first rotary_dim of K.
|
||||
"""
|
||||
return ApplyRotaryEmbKV.apply(kv, cos, sin, interleaved, seqlen_offsets)
|
||||
|
||||
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
"""
|
||||
The rotary position embeddings from RoFormer_ (Su et. al).
|
||||
A crucial insight from the method is that the query and keys are
|
||||
transformed by rotation matrices which depend on the relative positions.
|
||||
|
||||
Other implementations are available in the Rotary Transformer repo_ and in
|
||||
GPT-NeoX_, GPT-NeoX was an inspiration
|
||||
|
||||
.. _RoFormer: https://arxiv.org/abs/2104.09864
|
||||
.. _repo: https://github.com/ZhuiyiTechnology/roformer
|
||||
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
|
||||
|
||||
If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
|
||||
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
|
||||
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
base=10000.0,
|
||||
interleaved=False,
|
||||
scale_base=None,
|
||||
pos_idx_in_fp32=True,
|
||||
device=None,
|
||||
):
|
||||
"""
|
||||
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
||||
of 1st half and 2nd half (GPT-NeoX style).
|
||||
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
|
||||
otherwise they might be in lower precision.
|
||||
This option was added because previously (before 2023-07-02), when we construct
|
||||
the position indices, we use the dtype of self.inv_freq. In most cases this would
|
||||
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
|
||||
self.inv_freq would be bf16, and the position indices are also in bf16.
|
||||
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
|
||||
embeddings for some positions will coincide.
|
||||
To maintain compatibility with models previously trained in pure bf16,
|
||||
we add this option.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.base = float(base)
|
||||
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
||||
# Generate and save the inverse frequency buffer (non trainable)
|
||||
inv_freq = self._compute_inv_freq(device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.interleaved = interleaved
|
||||
self.scale_base = scale_base
|
||||
scale = (
|
||||
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
||||
if scale_base is not None
|
||||
else None
|
||||
)
|
||||
self.register_buffer("scale", scale, persistent=False)
|
||||
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
|
||||
def _compute_inv_freq(self, device=None):
|
||||
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
|
||||
|
||||
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# if we're on a new device (possibly due to tracing for instance),
|
||||
# or if we're switching from inference mode to training
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached is None
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
or (self.training and self._cos_cached.is_inference())
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
|
||||
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
|
||||
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
|
||||
if self.pos_idx_in_fp32:
|
||||
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
||||
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
|
||||
# will be large. Having it in bf16 will lose a lot of precision and cause the
|
||||
# cos & sin output to change significantly.
|
||||
# We want to recompute self.inv_freq if it was not loaded in fp32
|
||||
if self.inv_freq.dtype != torch.float32:
|
||||
inv_freq = self._compute_inv_freq(device=device)
|
||||
else:
|
||||
inv_freq = self.inv_freq
|
||||
else:
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
inv_freq = self.inv_freq
|
||||
# Don't do einsum, it converts fp32 to fp16 under AMP
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
if self.scale is None:
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
else:
|
||||
power = (
|
||||
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
|
||||
) / self.scale_base
|
||||
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
||||
# We want the multiplication by scale to happen in fp32
|
||||
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
||||
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
||||
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
||||
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
qkv: torch.Tensor,
|
||||
kv: Optional[torch.Tensor] = None,
|
||||
seqlen_offset: Union[int, torch.Tensor] = 0,
|
||||
max_seqlen: Optional[int] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
|
||||
else it's just q of shape (batch, seqlen, nheads, headdim)
|
||||
kv: (batch, seqlen, 2, nheads, headdim)
|
||||
seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
|
||||
Most commonly used in inference when we have KV cache.
|
||||
If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
|
||||
should pass in max_seqlen, which will update the cos / sin cache up to that length.
|
||||
Apply rotary embedding *inplace* to qkv and / or kv.
|
||||
"""
|
||||
seqlen = qkv.shape[1]
|
||||
if max_seqlen is not None:
|
||||
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
||||
elif isinstance(seqlen_offset, int):
|
||||
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
||||
if kv is None:
|
||||
if self.scale is None:
|
||||
return apply_rotary_emb_qkv_(
|
||||
qkv,
|
||||
self._cos_cached,
|
||||
self._sin_cached,
|
||||
interleaved=self.interleaved,
|
||||
seqlen_offsets=seqlen_offset,
|
||||
)
|
||||
else:
|
||||
return apply_rotary_emb_qkv_(
|
||||
qkv,
|
||||
self._cos_cached,
|
||||
self._sin_cached,
|
||||
self._cos_k_cached,
|
||||
self._sin_k_cached,
|
||||
interleaved=self.interleaved,
|
||||
seqlen_offsets=seqlen_offset,
|
||||
)
|
||||
else:
|
||||
q = qkv
|
||||
q = apply_rotary_emb_func(
|
||||
q,
|
||||
self._cos_cached,
|
||||
self._sin_cached,
|
||||
interleaved=self.interleaved,
|
||||
inplace=True,
|
||||
seqlen_offsets=seqlen_offset,
|
||||
)
|
||||
if self.scale is None:
|
||||
kv = apply_rotary_emb_kv_(
|
||||
kv,
|
||||
self._cos_cached,
|
||||
self._sin_cached,
|
||||
interleaved=self.interleaved,
|
||||
seqlen_offsets=seqlen_offset,
|
||||
)
|
||||
else:
|
||||
kv = apply_rotary_emb_kv_(
|
||||
kv,
|
||||
self._cos_k_cached,
|
||||
self._sin_k_cached,
|
||||
interleaved=self.interleaved,
|
||||
seqlen_offsets=seqlen_offset,
|
||||
)
|
||||
return q, kv
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -2046,7 +2046,8 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs):
|
|||
numpy_init_version = numpy.__version__
|
||||
pb_init_version = google.protobuf.__version__
|
||||
run_subprocess(
|
||||
[sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], cwd=SCRIPT_DIR
|
||||
[sys.executable, "-m", "pip", "install", "-r", "requirements-transformers-test.txt"],
|
||||
cwd=SCRIPT_DIR,
|
||||
)
|
||||
run_subprocess([sys.executable, "-m", "pytest", "transformers"], cwd=cwd)
|
||||
# Restore initial numpy/protobuf version in case other tests use it
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ packaging
|
|||
protobuf==3.20.2
|
||||
numpy==1.24.0 ; python_version < '3.12'
|
||||
numpy==1.26.0 ; python_version >= '3.12'
|
||||
torch
|
||||
coloredlogs==15.0
|
||||
transformers==4.36.0
|
||||
psutil
|
||||
einops
|
||||
einops
|
||||
Loading…
Reference in a new issue