mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
pass rotary embedding to attention op (#18846)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
df740d7d15
commit
1c2dca95d8
9 changed files with 43 additions and 23 deletions
|
|
@ -155,6 +155,8 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
<dd>Corresponding past and present are same tensor, its size is (2, batch_size, num_heads, max_sequence_length, head_size)</dd>
|
||||
<dt><tt>qkv_hidden_sizes</tt> : list of ints</dt>
|
||||
<dd>Hidden dimension of Q, K, V: hidden_size, hidden_size and v_hidden_size</dd>
|
||||
<dt><tt>rotary_embedding_dim</tt> : int</dt>
|
||||
<dd>Dimension of rotary embedding. Limited to 32, 64 or 128. Default value is head_size</dd>
|
||||
<dt><tt>scale</tt> : float</dt>
|
||||
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
|
||||
<dt><tt>unidirectional</tt> : int</dt>
|
||||
|
|
|
|||
|
|
@ -253,6 +253,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
|
|||
output_parameters->is_unidirectional = is_unidirectional_;
|
||||
output_parameters->past_present_share_buffer = (past_present_share_buffer_ != 0 && past != nullptr);
|
||||
output_parameters->do_rotary = do_rotary_;
|
||||
output_parameters->rotary_embedding = rotary_embedding_ == 0 ? (int)(output_parameters->head_size) : rotary_embedding_;
|
||||
output_parameters->mask_filter_value = mask_filter_value_;
|
||||
output_parameters->scale = scale_;
|
||||
output_parameters->mask_type = mask_type;
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ class AttentionBase {
|
|||
|
||||
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
|
||||
do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
|
||||
rotary_embedding_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0));
|
||||
mask_filter_value_ = info.GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
|
||||
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
|
||||
|
||||
|
|
@ -72,6 +73,7 @@ class AttentionBase {
|
|||
bool require_same_hidden_size_; // whether the implementation supports different hidden sizes of Q/K/V.
|
||||
bool past_present_share_buffer_; // whether or not the past (if used) and present tensor share the same buffer
|
||||
bool do_rotary_; // whether or not to use rotary embeddings
|
||||
int rotary_embedding_; // rotary embedding dimension
|
||||
float mask_filter_value_; // the value to be used for filtered out positions
|
||||
float scale_; // the scale to be used for softmax
|
||||
};
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ struct AttentionParameters {
|
|||
int v_head_size; // hidden size per head of V
|
||||
int num_heads;
|
||||
int num_splits;
|
||||
int rotary_embedding;
|
||||
bool is_unidirectional;
|
||||
bool past_present_share_buffer;
|
||||
bool do_rotary;
|
||||
|
|
|
|||
|
|
@ -640,7 +640,7 @@ void InvokeAddBiasTranspose(
|
|||
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
|
||||
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
|
||||
const T* input, const T* biases, T* output, T* qkv_add_bias, const int v_head_size, int total_matrix_count,
|
||||
bool do_rotary = false, int past_sequence_length = 0) {
|
||||
bool do_rotary = false, int rotary_embedding = 0, int past_sequence_length = 0) {
|
||||
assert(num_heads <= max_threads_per_block);
|
||||
|
||||
if (do_rotary) {
|
||||
|
|
@ -650,20 +650,20 @@ void InvokeAddBiasTranspose(
|
|||
if (format != 1 && format != 2 && format != 3) {
|
||||
ORT_THROW("format must be 1, 2 or 3 for rotary attention");
|
||||
}
|
||||
if (qk_head_size != 64 && qk_head_size != 128) {
|
||||
ORT_THROW("qk_head_size must be 64 or 128 for rotary attention");
|
||||
if (rotary_embedding != 32 && rotary_embedding != 64 && rotary_embedding != 128) {
|
||||
ORT_THROW("rotary_embedding must be 32, 64 or 128 for rotary attention");
|
||||
}
|
||||
if (v_head_size != -1 && qk_head_size != v_head_size) {
|
||||
ORT_THROW("qk_head_size must be equal to v_head_size for rotary attention");
|
||||
}
|
||||
|
||||
const int step = past_sequence_length == 0 ? sequence_length : past_sequence_length;
|
||||
size_t smem_size = 2 * qk_head_size * sizeof(T);
|
||||
size_t smem_size = 2 * rotary_embedding * sizeof(T);
|
||||
|
||||
const dim3 grid(sequence_length, num_heads, batch_size);
|
||||
const dim3 block((qk_head_size / 2 + 31) / 32 * 32, 1, 1);
|
||||
AddBiasTransposeQKV<T><<<grid, block, smem_size, stream>>>(total_matrix_count, input, biases, output,
|
||||
qkv_add_bias, qk_head_size, qk_head_size,
|
||||
qkv_add_bias, rotary_embedding, qk_head_size,
|
||||
step, format);
|
||||
#else
|
||||
ORT_THROW("Rotary Attention is supported on sm >= 530. Current sm is", __CUDA_ARCH__);
|
||||
|
|
@ -727,7 +727,7 @@ void LaunchAddBiasTranspose(
|
|||
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
|
||||
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
|
||||
const half* input, const half* biases, half* output, bool enable_half4, const int v_head_size,
|
||||
half* qkv_add_bias, int total_matrix_count, bool do_rotary, int past_sequence_length) {
|
||||
half* qkv_add_bias, int total_matrix_count, bool do_rotary, int rotary_embedding, int past_sequence_length) {
|
||||
total_matrix_count = std::max(num_matrices, total_matrix_count);
|
||||
if (enable_half4 && 0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4)) && !do_rotary) {
|
||||
const int H = qk_head_size / 4;
|
||||
|
|
@ -753,7 +753,7 @@ void LaunchAddBiasTranspose(
|
|||
InvokeAddBiasTranspose<half>(
|
||||
stream, num_matrices, format, max_threads_per_block,
|
||||
batch_size, sequence_length, num_heads, qk_head_size, input, biases, output,
|
||||
qkv_add_bias, v_head_size, total_matrix_count, do_rotary, past_sequence_length);
|
||||
qkv_add_bias, v_head_size, total_matrix_count, do_rotary, rotary_embedding, past_sequence_length);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -763,7 +763,7 @@ void LaunchAddBiasTranspose(
|
|||
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
|
||||
const float* input, const float* biases, float* output, bool /*enable_half4*/,
|
||||
const int v_head_size, float* qkv_add_bias, int total_matrix_count, bool do_rotary,
|
||||
int past_sequence_length) {
|
||||
int rotary_embedding, int past_sequence_length) {
|
||||
total_matrix_count = std::max(num_matrices, total_matrix_count);
|
||||
if (0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4)) && !do_rotary) {
|
||||
const int H = qk_head_size / 4;
|
||||
|
|
@ -789,7 +789,8 @@ void LaunchAddBiasTranspose(
|
|||
InvokeAddBiasTranspose<float>(
|
||||
stream, num_matrices, format, max_threads_per_block,
|
||||
batch_size, sequence_length, num_heads, qk_head_size, input, biases, output,
|
||||
qkv_add_bias, v_head_size, total_matrix_count, do_rotary, past_sequence_length);
|
||||
qkv_add_bias, v_head_size, total_matrix_count, do_rotary, rotary_embedding,
|
||||
past_sequence_length);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ void LaunchAddBiasTranspose(
|
|||
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
|
||||
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
|
||||
const T* input, const T* biases, T* output, bool enable_half4, const int v_head_size, T* qkv_add_bias = nullptr,
|
||||
int total_matrix_count = -1, bool do_rotary = false, int past_sequence_length = 0);
|
||||
int total_matrix_count = -1, bool do_rotary = false, int rotary_embedding = 0, int past_sequence_length = 0);
|
||||
|
||||
// Add (bias) and Transpose for separated inputs of Q, K and V, and output Trt format.
|
||||
// For self attention:
|
||||
|
|
|
|||
|
|
@ -65,7 +65,8 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters,
|
|||
LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block,
|
||||
batch_size, sequence_length, num_heads, qk_head_size,
|
||||
data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias,
|
||||
3, parameters.do_rotary, parameters.past_sequence_length);
|
||||
3, parameters.do_rotary, parameters.rotary_embedding,
|
||||
parameters.past_sequence_length);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -333,6 +333,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
"Whether to use rotary position embedding. Default value is 0.",
|
||||
AttributeProto::INT,
|
||||
OPTIONAL_VALUE)
|
||||
.Attr("rotary_embedding_dim",
|
||||
"Dimension of rotary embedding. Limited to 32, 64 or 128. Default value is head_size",
|
||||
AttributeProto::INT,
|
||||
OPTIONAL_VALUE)
|
||||
.Attr("mask_filter_value",
|
||||
"The value to be filled in the attention mask. Default value is -10000.0f",
|
||||
AttributeProto::FLOAT,
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ def create_neox_attention_graph(
|
|||
qkv_weight,
|
||||
qkv_bias,
|
||||
num_heads,
|
||||
rotary_embedding,
|
||||
):
|
||||
nodes = [
|
||||
helper.make_node(
|
||||
|
|
@ -43,6 +44,7 @@ def create_neox_attention_graph(
|
|||
num_heads=num_heads,
|
||||
unidirectional=1,
|
||||
do_rotary=1,
|
||||
rotary_embedding=rotary_embedding,
|
||||
domain="com.microsoft",
|
||||
),
|
||||
]
|
||||
|
|
@ -174,13 +176,13 @@ def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
|
|||
|
||||
|
||||
class GPTNeoXAttention(nn.Module):
|
||||
def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0):
|
||||
def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0, rotary_ndims=64):
|
||||
super().__init__()
|
||||
self.do_rotary = True
|
||||
self.num_attention_heads = num_head
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = self.hidden_size // self.num_attention_heads
|
||||
self.rotary_ndims = int(self.head_size)
|
||||
self.rotary_ndims = rotary_ndims
|
||||
max_positions = 2048
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
|
|
@ -197,6 +199,7 @@ class GPTNeoXAttention(nn.Module):
|
|||
# self.query_key_value.bias.data.copy_(torch.tensor(np.zeros((3 * hidden_size))))
|
||||
|
||||
if past_seq_len > 0:
|
||||
assert self.rotary_ndims == self.head_size
|
||||
self.onnx_graph = create_neox_decoder_masked_self_attention_graph(
|
||||
batch_size,
|
||||
seq_len,
|
||||
|
|
@ -220,6 +223,7 @@ class GPTNeoXAttention(nn.Module):
|
|||
.transpose(0, 1),
|
||||
self.query_key_value.bias.reshape(self.num_attention_heads, 3, -1).transpose(0, 1).reshape(-1),
|
||||
self.num_attention_heads,
|
||||
self.rotary_ndims,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -422,17 +426,21 @@ class TestGPTNeoXAttention(unittest.TestCase):
|
|||
for batch_size in [1, 2, 4, 8]:
|
||||
for seq_len in [32, 128, 512, 1024, 2048]:
|
||||
for num_head in [12]:
|
||||
for hidden_size in [768]:
|
||||
attn = GPTNeoXAttention(batch_size, seq_len, num_head, hidden_size)
|
||||
for rotary_ndims in [32, 64]:
|
||||
for hidden_size in [768, 960]:
|
||||
attn = GPTNeoXAttention(batch_size, seq_len, num_head, hidden_size, 0, rotary_ndims)
|
||||
|
||||
hidden_states = torch.normal(mean=0.5, std=0.1, size=(batch_size, seq_len, hidden_size)).to(
|
||||
torch.float32
|
||||
)
|
||||
hidden_states = torch.normal(mean=0.5, std=0.1, size=(batch_size, seq_len, hidden_size)).to(
|
||||
torch.float32
|
||||
)
|
||||
|
||||
torch_output = attn.torch_forward(hidden_states)
|
||||
ort_output = attn.onnx_forward(hidden_states)
|
||||
if ort_output is not None:
|
||||
assert torch.allclose(torch_output, ort_output, atol=1e-4)
|
||||
torch_output = attn.torch_forward(hidden_states)
|
||||
ort_output = attn.onnx_forward(hidden_states)
|
||||
if ort_output is not None:
|
||||
assert torch.allclose(torch_output, ort_output, atol=1e-3)
|
||||
print(
|
||||
f"Passed: test_gpt_neox_attention: {batch_size}, {seq_len}, {num_head}, {hidden_size}, {rotary_ndims}"
|
||||
)
|
||||
|
||||
def test_gpt_neox_decoder_masked_self_attention(self):
|
||||
for batch_size in [1, 2, 4, 8]:
|
||||
|
|
@ -466,7 +474,7 @@ class TestGPTNeoXAttention(unittest.TestCase):
|
|||
hidden_states, attention_mask=attention_mask, layer_past=layer_past
|
||||
)
|
||||
if ort_output is not None:
|
||||
assert torch.allclose(torch_output, ort_output, atol=1e-4)
|
||||
assert torch.allclose(torch_output, ort_output, atol=1e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in a new issue