mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Add Gemma Rotary Embedding (#20267)
### Description Add GemmaRotaryEmbedding kernel which includes sin and cos in GemmaRotaryEmbedding forward and apply_rotary_pos_emb. See gemma_rotary_emb_impl.cu for subgraph details ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
7354f3cdd8
commit
c11941289b
10 changed files with 470 additions and 0 deletions
|
|
@ -41,6 +41,7 @@ Do not modify directly.*
|
|||
* <a href="#com.microsoft.Gelu">com.microsoft.Gelu</a>
|
||||
* <a href="#com.microsoft.GemmFastGelu">com.microsoft.GemmFastGelu</a>
|
||||
* <a href="#com.microsoft.GemmFloat8">com.microsoft.GemmFloat8</a>
|
||||
* <a href="#com.microsoft.GemmaRotaryEmbedding">com.microsoft.GemmaRotaryEmbedding</a>
|
||||
* <a href="#com.microsoft.GreedySearch">com.microsoft.GreedySearch</a>
|
||||
* <a href="#com.microsoft.GridSample">com.microsoft.GridSample</a>
|
||||
* <a href="#com.microsoft.GroupNorm">com.microsoft.GroupNorm</a>
|
||||
|
|
@ -2210,6 +2211,69 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
</dl>
|
||||
|
||||
|
||||
### <a name="com.microsoft.GemmaRotaryEmbedding"></a><a name="com.microsoft.gemmarotaryembedding">**com.microsoft.GemmaRotaryEmbedding**</a>
|
||||
|
||||
GemmaRotaryEmbedding is the implementation of below part of rotary positional embeddings (RoPE). It implements below from modeling_gemma.py.
|
||||
|
||||
Here's onnxscript that was tested
|
||||
|
||||
from onnxscript import FLOAT, FLOAT16, script
|
||||
from onnxscript import opset18 as op
|
||||
|
||||
@script()
|
||||
def gemma_rotary_embedding(emb: FLOAT["bs", "seq_len", "dim"], q: FLOAT16["bs", "num_heads", "seq_len", "dim"], q_rot: FLOAT16["bs", "num_heads", "seq_len", "dim"], k: FLOAT16["bs", "num_heads", "seq_len", "dim"], k_rot: FLOAT16["bs", "num_heads", "seq_len", "dim"]):
|
||||
sin_val = op.Sin(emb)
|
||||
casted_sin = op.Cast(sin_val, to=10) # for fp16 mix-precision training. Other types are not supported.
|
||||
cos_val = op.Cos(emb)
|
||||
casted_cos = op.Cast(cos_val, to=10)
|
||||
unsqueezed_sin = op.Unsqueeze(casted_sin, [1])
|
||||
unsqueezed_cos = op.Unsqueeze(casted_cos, [1])
|
||||
q_embed = (q * casted_cos) + (q_rot * casted_sin)
|
||||
k_embed = (k * casted_cos) + (k_rot * casted_sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
onnx_model = gemma_rotary_embedding.to_model_proto()
|
||||
|
||||
|
||||
|
||||
#### Version
|
||||
|
||||
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
|
||||
|
||||
#### Inputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>emb</tt> : U</dt>
|
||||
<dd>embeddding - 3D tensor with shape (batch_size, seq_len, dim)</dd>
|
||||
<dt><tt>q</tt> : T</dt>
|
||||
<dd>q state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)</dd>
|
||||
<dt><tt>q_rot</tt> : T</dt>
|
||||
<dd>half rotated q state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)</dd>
|
||||
<dt><tt>k</tt> : T</dt>
|
||||
<dd>k state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)</dd>
|
||||
<dt><tt>k_rot</tt> : T</dt>
|
||||
<dd>k state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)</dd>
|
||||
</dl>
|
||||
|
||||
#### Outputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>output1</tt> : T</dt>
|
||||
<dd>4D tensor with shape (batch_size, num_heads, seq_len, dim)</dd>
|
||||
<dt><tt>output2</tt> : T</dt>
|
||||
<dd>4D tensor with shape (batch_size, num_heads, seq_len, dim)</dd>
|
||||
</dl>
|
||||
|
||||
#### Type Constraints
|
||||
|
||||
<dl>
|
||||
<dt><tt>T</tt> : tensor(float16)</dt>
|
||||
<dd>Constrain input and output types to float16 tensors.</dd>
|
||||
<dt><tt>U</tt> : tensor(float)</dt>
|
||||
<dd>Constrain input 0 type to float tensors</dd>
|
||||
</dl>
|
||||
|
||||
|
||||
### <a name="com.microsoft.GreedySearch"></a><a name="com.microsoft.greedysearch">**com.microsoft.GreedySearch**</a>
|
||||
|
||||
Greedy Search for text generation.
|
||||
|
|
|
|||
|
|
@ -868,6 +868,7 @@ Do not modify directly.*
|
|||
|GatedRelativePositionBias|*in* query_layer:**T**<br> *in* query_bias:**T**<br> *in* rel_pos:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* eco_a:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|GemmFloat8|*in* A:**TA**<br> *in* B:**TB**<br> *in* C:**TC**<br> *in* scaleA:**TS**<br> *in* scaleB:**TS**<br> *in* scaleY:**TS**<br> *out* Y:**TR**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TR** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TS** = tensor(float)|
|
||||
|GemmaRotaryEmbedding|*in* emb:**U**<br> *in* q:**T**<br> *in* q_rot:**T**<br> *in* k:**T**<br> *in* k_rot:**T**<br> *out* output1:**T**<br> *out* output2:**T**|1+|**T** = tensor(float16)<br/> **U** = tensor(float)|
|
||||
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|
||||
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|
|
|
|||
75
onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.cc
Normal file
75
onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.cc
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "contrib_ops/cuda/bert/gemma_rotary_emb.h"
|
||||
#include "contrib_ops/cuda/bert/gemma_rotary_emb_impl.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T, U) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
GemmaRotaryEmbedding, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("U", DataTypeImpl::GetTensorType<U>()), \
|
||||
GemmaRotaryEmbedding<T, U>);
|
||||
|
||||
REGISTER_KERNEL_TYPED(MLFloat16, float)
|
||||
|
||||
template <typename T, typename U>
|
||||
GemmaRotaryEmbedding<T, U>::GemmaRotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) {
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
Status GemmaRotaryEmbedding<T, U>::ComputeInternal(OpKernelContext* context) const {
|
||||
const Tensor* emb = context->Input<Tensor>(0);
|
||||
const Tensor* q = context->Input<Tensor>(1);
|
||||
const Tensor* q_rot = context->Input<Tensor>(2);
|
||||
const Tensor* k = context->Input<Tensor>(3);
|
||||
const Tensor* k_rot = context->Input<Tensor>(4);
|
||||
|
||||
const auto& emb_dims = emb->Shape().GetDims();
|
||||
const auto& q_dims = q->Shape().GetDims();
|
||||
int batch_size = static_cast<int>(q_dims[0]);
|
||||
int num_heads = static_cast<int>(q_dims[1]);
|
||||
int seq_len = static_cast<int>(q_dims[2]);
|
||||
int dim = static_cast<int>(q_dims[3]);
|
||||
|
||||
// q_dims should be [batch_size, num_heads, seq_len, dim]
|
||||
// emb_dims should be [batch_size, seq, dim]
|
||||
ORT_ENFORCE(emb_dims.size() == 3, "emb_dims should be 3D");
|
||||
ORT_ENFORCE(q_dims.size() == 4, "emb_dims should be 4D");
|
||||
ORT_ENFORCE(emb_dims[0] == batch_size, "emb_dims[0] should match q_dims[0]");
|
||||
ORT_ENFORCE(emb_dims[1] == seq_len, "emb_dims[1] should match q_dims[2]");
|
||||
ORT_ENFORCE(emb_dims[2] == dim, "emb_dims[2] should match q_dims[3]");
|
||||
|
||||
Tensor* output1 = context->Output(0, q_dims);
|
||||
Tensor* output2 = context->Output(1, q_dims);
|
||||
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
typedef typename ToCudaType<U>::MappedType CudaU;
|
||||
return LaunchGemmaRotaryEmbeddingKernel<CudaT>(
|
||||
Stream(context),
|
||||
reinterpret_cast<CudaT*>(output1->template MutableData<T>()),
|
||||
reinterpret_cast<CudaT*>(output2->template MutableData<T>()),
|
||||
reinterpret_cast<const CudaU*>(emb->template Data<U>()),
|
||||
reinterpret_cast<const CudaT*>(q->template Data<T>()),
|
||||
reinterpret_cast<const CudaT*>(q_rot->template Data<T>()),
|
||||
reinterpret_cast<const CudaT*>(k->template Data<T>()),
|
||||
reinterpret_cast<const CudaT*>(k_rot->template Data<T>()),
|
||||
batch_size,
|
||||
num_heads,
|
||||
seq_len,
|
||||
dim);
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
24
onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.h
Normal file
24
onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.h
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/cuda/cuda_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
using onnxruntime::cuda::CudaKernel;
|
||||
using onnxruntime::cuda::ToCudaType;
|
||||
|
||||
template <typename T, typename U>
|
||||
class GemmaRotaryEmbedding final : public CudaKernel {
|
||||
public:
|
||||
GemmaRotaryEmbedding(const OpKernelInfo& info);
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
104
onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.cu
Normal file
104
onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.cu
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
/*
|
||||
Copyright (c) Microsoft Corporation.
|
||||
Licensed under the MIT License.
|
||||
*/
|
||||
/*
|
||||
Kernel implementation for Gamma rotary embeddings.
|
||||
This implementation below subgraph
|
||||
(emb)
|
||||
/ \
|
||||
/ \
|
||||
Sin Cos
|
||||
| |
|
||||
Cast Cast
|
||||
| |
|
||||
Unsqueeze Unsqueeze
|
||||
\/ \/ \/ \/
|
||||
Mul Mul Mul Mul
|
||||
\ / \ /
|
||||
Add Add
|
||||
| |
|
||||
(output1) (output2)
|
||||
*/
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cmath>
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "contrib_ops/cuda/bert/gemma_rotary_emb_impl.h"
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
constexpr int kThreadsPerBlock = GridDim::maxThreadsPerBlock;
|
||||
|
||||
template <typename T, typename U>
|
||||
__global__ void GemmaRotaryEmb(
|
||||
T* output1,
|
||||
T* output2,
|
||||
const U* emb,
|
||||
const T* q,
|
||||
const T* q_rot,
|
||||
const T* k,
|
||||
const T* k_rot,
|
||||
const int batch_size,
|
||||
const int num_heads,
|
||||
const int seq_len,
|
||||
const int dim) {
|
||||
|
||||
const int qk_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// index [i, j, k, l] -> [i, k, l]
|
||||
const int emb_idx = qk_idx / (num_heads * seq_len * dim) * (seq_len * dim) + qk_idx % (seq_len * dim);
|
||||
if (qk_idx < batch_size * num_heads * seq_len * dim) {
|
||||
T sin_val = static_cast<T>(sin(emb[emb_idx]));
|
||||
T cos_val = static_cast<T>(cos(emb[emb_idx]));
|
||||
output1[qk_idx] = q[qk_idx] * cos_val + q_rot[qk_idx] * sin_val;
|
||||
output2[qk_idx] = k[qk_idx] * cos_val + k_rot[qk_idx] * sin_val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
Status LaunchGemmaRotaryEmbeddingKernel(
|
||||
cudaStream_t stream,
|
||||
T* output1,
|
||||
T* output2,
|
||||
const U* emb,
|
||||
const T* q,
|
||||
const T* q_rot,
|
||||
const T* k,
|
||||
const T* k_rot,
|
||||
const int batch_size,
|
||||
const int num_heads,
|
||||
const int seq_len,
|
||||
const int dim
|
||||
) {
|
||||
int blocksPerGrid = static_cast<int>(ceil(float(batch_size * num_heads * seq_len * dim) / kThreadsPerBlock));
|
||||
|
||||
GemmaRotaryEmb<<<blocksPerGrid, kThreadsPerBlock, 0, stream>>>(
|
||||
output1, output2,
|
||||
emb, q, q_rot, k, k_rot,
|
||||
batch_size, num_heads, seq_len, dim
|
||||
);
|
||||
|
||||
return CUDA_CALL(cudaGetLastError());
|
||||
}
|
||||
|
||||
template Status LaunchGemmaRotaryEmbeddingKernel<half, float>(
|
||||
cudaStream_t stream,
|
||||
half* output1,
|
||||
half* output2,
|
||||
const float* emb,
|
||||
const half* q,
|
||||
const half* q_rot,
|
||||
const half* k,
|
||||
const half* k_rot,
|
||||
const int batch_size,
|
||||
const int num_heads,
|
||||
const int seq_len,
|
||||
const int dim);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
29
onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.h
Normal file
29
onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.h
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T, typename U>
|
||||
Status LaunchGemmaRotaryEmbeddingKernel(
|
||||
cudaStream_t stream,
|
||||
T* output1,
|
||||
T* output2,
|
||||
const U* emb,
|
||||
const T* q,
|
||||
const T* q_rot,
|
||||
const T* k,
|
||||
const T* k_rot,
|
||||
const int batch_size,
|
||||
const int num_heads,
|
||||
const int seq_len,
|
||||
const int dim);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -98,6 +98,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, RotaryEmbedding);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GemmaRotaryEmbedding);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh);
|
||||
|
|
@ -302,6 +303,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, RotaryEmbedding)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GemmaRotaryEmbedding)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh)>,
|
||||
|
|
|
|||
|
|
@ -1215,6 +1215,71 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
propagateShapeFromInputToOutput(ctx, 0, 0);
|
||||
}));
|
||||
|
||||
constexpr const char* GemmaRotaryEmbedding_ver1_doc = R"DOC(
|
||||
GemmaRotaryEmbedding is the implementation of below part of rotary positional embeddings (RoPE). It implements below from modeling_gemma.py.
|
||||
|
||||
Here's onnxscript that was tested
|
||||
|
||||
from onnxscript import FLOAT, FLOAT16, script
|
||||
from onnxscript import opset18 as op
|
||||
|
||||
@script()
|
||||
def gemma_rotary_embedding(emb: FLOAT["bs", "seq_len", "dim"], q: FLOAT16["bs", "num_heads", "seq_len", "dim"], q_rot: FLOAT16["bs", "num_heads", "seq_len", "dim"], k: FLOAT16["bs", "num_heads", "seq_len", "dim"], k_rot: FLOAT16["bs", "num_heads", "seq_len", "dim"]):
|
||||
sin_val = op.Sin(emb)
|
||||
casted_sin = op.Cast(sin_val, to=10) # for fp16 mix-precision training. Other types are not supported.
|
||||
cos_val = op.Cos(emb)
|
||||
casted_cos = op.Cast(cos_val, to=10)
|
||||
unsqueezed_sin = op.Unsqueeze(casted_sin, [1])
|
||||
unsqueezed_cos = op.Unsqueeze(casted_cos, [1])
|
||||
q_embed = (q * casted_cos) + (q_rot * casted_sin)
|
||||
k_embed = (k * casted_cos) + (k_rot * casted_sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
onnx_model = gemma_rotary_embedding.to_model_proto()
|
||||
|
||||
|
||||
)DOC";
|
||||
ONNX_MS_OPERATOR_SET_SCHEMA(
|
||||
GemmaRotaryEmbedding, 1,
|
||||
OpSchema()
|
||||
.SetDoc(GemmaRotaryEmbedding_ver1_doc)
|
||||
.Input(0,
|
||||
"emb",
|
||||
"embeddding - 3D tensor with shape (batch_size, seq_len, dim)",
|
||||
"U")
|
||||
.Input(1,
|
||||
"q",
|
||||
"q state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)",
|
||||
"T")
|
||||
.Input(2,
|
||||
"q_rot",
|
||||
"half rotated q state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)",
|
||||
"T")
|
||||
.Input(3,
|
||||
"k",
|
||||
"k state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)",
|
||||
"T")
|
||||
.Input(4,
|
||||
"k_rot",
|
||||
"k state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)",
|
||||
"T")
|
||||
.Output(0,
|
||||
"output1",
|
||||
"4D tensor with shape (batch_size, num_heads, seq_len, dim)",
|
||||
"T")
|
||||
.Output(1,
|
||||
"output2",
|
||||
"4D tensor with shape (batch_size, num_heads, seq_len, dim)",
|
||||
"T")
|
||||
.TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output types to float16 tensors.")
|
||||
.TypeConstraint("U", {"tensor(float)"}, "Constrain input 0 type to float tensors")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 1, 0);
|
||||
propagateElemTypeFromInputToOutput(ctx, 1, 1);
|
||||
propagateShapeFromInputToOutput(ctx, 1, 0);
|
||||
propagateShapeFromInputToOutput(ctx, 1, 1);
|
||||
}));
|
||||
|
||||
constexpr const char* EmbedLayerNormalization_ver1_doc = R"DOC(
|
||||
EmbedLayerNormalization is the fusion of embedding layer in BERT model, with optional mask processing.
|
||||
The embedding layer takes input_ids (word IDs) and segment_ids (sentence IDs) to look up word_embedding, position_embedding,
|
||||
|
|
|
|||
|
|
@ -98,6 +98,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding);
|
|||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RotaryEmbedding);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GemmaRotaryEmbedding);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipGroupNorm);
|
||||
|
|
@ -208,6 +209,7 @@ class OpSet_Microsoft_ver1 {
|
|||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RotaryEmbedding)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GemmaRotaryEmbedding)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipGroupNorm)>());
|
||||
|
|
|
|||
104
onnxruntime/test/contrib_ops/gemma_rotary_emb_test.cc
Normal file
104
onnxruntime/test/contrib_ops/gemma_rotary_emb_test.cc
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <cassert>
|
||||
#include "gtest/gtest.h"
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
#include "test/common/tensor_op_test_utils.h"
|
||||
#include "test/common/cuda_op_test_utils.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include <cstdlib> // For rand() and srand()
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
constexpr auto k_random_data_min = -1.0f;
|
||||
constexpr auto k_random_data_max = 1.0f;
|
||||
|
||||
namespace {
|
||||
enum class TensorType {
|
||||
kFloat,
|
||||
kFloat16,
|
||||
kBFloat16
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
static void calculateExpectedOutput(const std::vector<float>& emb_data,
|
||||
const std::vector<MLFloat16>& q_data,
|
||||
const std::vector<MLFloat16>& q_rot_data,
|
||||
const std::vector<MLFloat16>& k_data,
|
||||
const std::vector<MLFloat16>& k_rot_data,
|
||||
const std::vector<int64_t>& mul_dim,
|
||||
std::vector<MLFloat16>& output1,
|
||||
std::vector<MLFloat16>& output2) {
|
||||
for (long int i = 0; i < mul_dim[0]; ++i) {
|
||||
for (long int j = 0; j < mul_dim[1]; ++j) {
|
||||
for (long int k = 0; k < mul_dim[2]; ++k) {
|
||||
for (long int l = 0; l < mul_dim[3]; ++l) {
|
||||
long int embIdx = i * mul_dim[1] * mul_dim[3] + k * mul_dim[3] + l;
|
||||
long int mulIdx = i * mul_dim[1] * mul_dim[2] * mul_dim[3] + j * mul_dim[2] * mul_dim[3] + k * mul_dim[3] + l;
|
||||
|
||||
MLFloat16 sin_val = static_cast<MLFloat16>(sin(emb_data[embIdx]));
|
||||
MLFloat16 cos_val = static_cast<MLFloat16>(cos(emb_data[embIdx]));
|
||||
MLFloat16 q_val = static_cast<MLFloat16>(q_data[mulIdx]);
|
||||
MLFloat16 q_rot_val = static_cast<MLFloat16>(q_rot_data[mulIdx]);
|
||||
MLFloat16 k_val = static_cast<MLFloat16>(k_data[mulIdx]);
|
||||
MLFloat16 k_rot_val = static_cast<MLFloat16>(k_rot_data[mulIdx]);
|
||||
output1.push_back(static_cast<MLFloat16>(q_val * cos_val + q_rot_val * sin_val));
|
||||
output2.push_back(static_cast<MLFloat16>(k_val * cos_val + k_rot_val * sin_val));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void RunTest() {
|
||||
std::string op_type = "GemmaRotaryEmbedding";
|
||||
std::vector<int64_t> emb_dim = {1, 2, 2};
|
||||
std::vector<int64_t> mul_dim = {1, 3, 2, 2};
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
|
||||
int min_cuda_architecture = 530;
|
||||
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
|
||||
|
||||
if (enable_cuda) {
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
}
|
||||
|
||||
if (execution_providers.size() == 0) {
|
||||
// Return early if CI pipeline does not support EP (e.g. CUDA EP for CPU CI pipeline)
|
||||
return;
|
||||
}
|
||||
|
||||
OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain);
|
||||
|
||||
// create rand inputs
|
||||
RandomValueGenerator random{};
|
||||
const std::vector<float> emb_data = random.Uniform<float>(emb_dim, k_random_data_min, k_random_data_max);
|
||||
const std::vector<MLFloat16> q = random.Uniform<MLFloat16>(mul_dim, k_random_data_min, k_random_data_max);
|
||||
const std::vector<MLFloat16> q_rot = random.Uniform<MLFloat16>(mul_dim, k_random_data_min, k_random_data_max);
|
||||
const std::vector<MLFloat16> k = random.Uniform<MLFloat16>(mul_dim, k_random_data_min, k_random_data_max);
|
||||
const std::vector<MLFloat16> k_rot = random.Uniform<MLFloat16>(mul_dim, k_random_data_min, k_random_data_max);
|
||||
|
||||
std::vector<MLFloat16> output1;
|
||||
std::vector<MLFloat16> output2;
|
||||
|
||||
calculateExpectedOutput(emb_data, q, q_rot, k, k_rot, mul_dim, output1, output2);
|
||||
|
||||
test.AddInput<float>("emb", emb_dim, emb_data);
|
||||
test.AddInput<MLFloat16>("q_data", mul_dim, q);
|
||||
test.AddInput<MLFloat16>("q_rot_data", mul_dim, q_rot);
|
||||
test.AddInput<MLFloat16>("k_data", mul_dim, k);
|
||||
test.AddInput<MLFloat16>("k_rot_data", mul_dim, k_rot);
|
||||
test.AddOutput<MLFloat16>("output1", mul_dim, output1);
|
||||
test.AddOutput<MLFloat16>("output2", mul_dim, output2);
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
|
||||
TEST(GemmaRotaryEmbeddingTest, GemmaRotaryEmbedding_Small) {
|
||||
RunTest();
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue