Add Gemma Rotary Embedding (#20267)

### Description
Add GemmaRotaryEmbedding kernel which includes sin and cos in
GemmaRotaryEmbedding forward and apply_rotary_pos_emb. See
gemma_rotary_emb_impl.cu for subgraph details

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
jingyanwangms 2024-04-16 15:31:56 -07:00 committed by GitHub
parent 7354f3cdd8
commit c11941289b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 470 additions and 0 deletions

View file

@ -41,6 +41,7 @@ Do not modify directly.*
* <a href="#com.microsoft.Gelu">com.microsoft.Gelu</a>
* <a href="#com.microsoft.GemmFastGelu">com.microsoft.GemmFastGelu</a>
* <a href="#com.microsoft.GemmFloat8">com.microsoft.GemmFloat8</a>
* <a href="#com.microsoft.GemmaRotaryEmbedding">com.microsoft.GemmaRotaryEmbedding</a>
* <a href="#com.microsoft.GreedySearch">com.microsoft.GreedySearch</a>
* <a href="#com.microsoft.GridSample">com.microsoft.GridSample</a>
* <a href="#com.microsoft.GroupNorm">com.microsoft.GroupNorm</a>
@ -2210,6 +2211,69 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>
### <a name="com.microsoft.GemmaRotaryEmbedding"></a><a name="com.microsoft.gemmarotaryembedding">**com.microsoft.GemmaRotaryEmbedding**</a>
GemmaRotaryEmbedding is the implementation of below part of rotary positional embeddings (RoPE). It implements below from modeling_gemma.py.
Here's onnxscript that was tested
from onnxscript import FLOAT, FLOAT16, script
from onnxscript import opset18 as op
@script()
def gemma_rotary_embedding(emb: FLOAT["bs", "seq_len", "dim"], q: FLOAT16["bs", "num_heads", "seq_len", "dim"], q_rot: FLOAT16["bs", "num_heads", "seq_len", "dim"], k: FLOAT16["bs", "num_heads", "seq_len", "dim"], k_rot: FLOAT16["bs", "num_heads", "seq_len", "dim"]):
sin_val = op.Sin(emb)
casted_sin = op.Cast(sin_val, to=10) # for fp16 mix-precision training. Other types are not supported.
cos_val = op.Cos(emb)
casted_cos = op.Cast(cos_val, to=10)
unsqueezed_sin = op.Unsqueeze(casted_sin, [1])
unsqueezed_cos = op.Unsqueeze(casted_cos, [1])
q_embed = (q * casted_cos) + (q_rot * casted_sin)
k_embed = (k * casted_cos) + (k_rot * casted_sin)
return q_embed, k_embed
onnx_model = gemma_rotary_embedding.to_model_proto()
#### Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
#### Inputs
<dl>
<dt><tt>emb</tt> : U</dt>
<dd>embeddding - 3D tensor with shape (batch_size, seq_len, dim)</dd>
<dt><tt>q</tt> : T</dt>
<dd>q state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)</dd>
<dt><tt>q_rot</tt> : T</dt>
<dd>half rotated q state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)</dd>
<dt><tt>k</tt> : T</dt>
<dd>k state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)</dd>
<dt><tt>k_rot</tt> : T</dt>
<dd>k state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)</dd>
</dl>
#### Outputs
<dl>
<dt><tt>output1</tt> : T</dt>
<dd>4D tensor with shape (batch_size, num_heads, seq_len, dim)</dd>
<dt><tt>output2</tt> : T</dt>
<dd>4D tensor with shape (batch_size, num_heads, seq_len, dim)</dd>
</dl>
#### Type Constraints
<dl>
<dt><tt>T</tt> : tensor(float16)</dt>
<dd>Constrain input and output types to float16 tensors.</dd>
<dt><tt>U</tt> : tensor(float)</dt>
<dd>Constrain input 0 type to float tensors</dd>
</dl>
### <a name="com.microsoft.GreedySearch"></a><a name="com.microsoft.greedysearch">**com.microsoft.GreedySearch**</a>
Greedy Search for text generation.

View file

@ -868,6 +868,7 @@ Do not modify directly.*
|GatedRelativePositionBias|*in* query_layer:**T**<br> *in* query_bias:**T**<br> *in* rel_pos:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* eco_a:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|GemmFloat8|*in* A:**TA**<br> *in* B:**TB**<br> *in* C:**TC**<br> *in* scaleA:**TS**<br> *in* scaleB:**TS**<br> *in* scaleY:**TS**<br> *out* Y:**TR**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TR** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TS** = tensor(float)|
|GemmaRotaryEmbedding|*in* emb:**U**<br> *in* q:**T**<br> *in* q_rot:**T**<br> *in* k:**T**<br> *in* k_rot:**T**<br> *out* output1:**T**<br> *out* output2:**T**|1+|**T** = tensor(float16)<br/> **U** = tensor(float)|
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|

View file

@ -0,0 +1,75 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cuda/bert/gemma_rotary_emb.h"
#include "contrib_ops/cuda/bert/gemma_rotary_emb_impl.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
#define REGISTER_KERNEL_TYPED(T, U) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
GemmaRotaryEmbedding, \
kMSDomain, \
1, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("U", DataTypeImpl::GetTensorType<U>()), \
GemmaRotaryEmbedding<T, U>);
REGISTER_KERNEL_TYPED(MLFloat16, float)
template <typename T, typename U>
GemmaRotaryEmbedding<T, U>::GemmaRotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) {
}
template <typename T, typename U>
Status GemmaRotaryEmbedding<T, U>::ComputeInternal(OpKernelContext* context) const {
const Tensor* emb = context->Input<Tensor>(0);
const Tensor* q = context->Input<Tensor>(1);
const Tensor* q_rot = context->Input<Tensor>(2);
const Tensor* k = context->Input<Tensor>(3);
const Tensor* k_rot = context->Input<Tensor>(4);
const auto& emb_dims = emb->Shape().GetDims();
const auto& q_dims = q->Shape().GetDims();
int batch_size = static_cast<int>(q_dims[0]);
int num_heads = static_cast<int>(q_dims[1]);
int seq_len = static_cast<int>(q_dims[2]);
int dim = static_cast<int>(q_dims[3]);
// q_dims should be [batch_size, num_heads, seq_len, dim]
// emb_dims should be [batch_size, seq, dim]
ORT_ENFORCE(emb_dims.size() == 3, "emb_dims should be 3D");
ORT_ENFORCE(q_dims.size() == 4, "emb_dims should be 4D");
ORT_ENFORCE(emb_dims[0] == batch_size, "emb_dims[0] should match q_dims[0]");
ORT_ENFORCE(emb_dims[1] == seq_len, "emb_dims[1] should match q_dims[2]");
ORT_ENFORCE(emb_dims[2] == dim, "emb_dims[2] should match q_dims[3]");
Tensor* output1 = context->Output(0, q_dims);
Tensor* output2 = context->Output(1, q_dims);
typedef typename ToCudaType<T>::MappedType CudaT;
typedef typename ToCudaType<U>::MappedType CudaU;
return LaunchGemmaRotaryEmbeddingKernel<CudaT>(
Stream(context),
reinterpret_cast<CudaT*>(output1->template MutableData<T>()),
reinterpret_cast<CudaT*>(output2->template MutableData<T>()),
reinterpret_cast<const CudaU*>(emb->template Data<U>()),
reinterpret_cast<const CudaT*>(q->template Data<T>()),
reinterpret_cast<const CudaT*>(q_rot->template Data<T>()),
reinterpret_cast<const CudaT*>(k->template Data<T>()),
reinterpret_cast<const CudaT*>(k_rot->template Data<T>()),
batch_size,
num_heads,
seq_len,
dim);
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/cuda/cuda_kernel.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
using onnxruntime::cuda::CudaKernel;
using onnxruntime::cuda::ToCudaType;
template <typename T, typename U>
class GemmaRotaryEmbedding final : public CudaKernel {
public:
GemmaRotaryEmbedding(const OpKernelInfo& info);
Status ComputeInternal(OpKernelContext* context) const override;
};
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,104 @@
/*
Copyright (c) Microsoft Corporation.
Licensed under the MIT License.
*/
/*
Kernel implementation for Gamma rotary embeddings.
This implementation below subgraph
(emb)
/ \
/ \
Sin Cos
| |
Cast Cast
| |
Unsqueeze Unsqueeze
\/ \/ \/ \/
Mul Mul Mul Mul
\ / \ /
Add Add
| |
(output1) (output2)
*/
#include <cuda_fp16.h>
#include <cmath>
#include "core/providers/cuda/cu_inc/common.cuh"
#include "contrib_ops/cuda/bert/gemma_rotary_emb_impl.h"
using namespace onnxruntime::cuda;
namespace onnxruntime {
namespace contrib {
namespace cuda {
constexpr int kThreadsPerBlock = GridDim::maxThreadsPerBlock;
template <typename T, typename U>
__global__ void GemmaRotaryEmb(
T* output1,
T* output2,
const U* emb,
const T* q,
const T* q_rot,
const T* k,
const T* k_rot,
const int batch_size,
const int num_heads,
const int seq_len,
const int dim) {
const int qk_idx = blockIdx.x * blockDim.x + threadIdx.x;
// index [i, j, k, l] -> [i, k, l]
const int emb_idx = qk_idx / (num_heads * seq_len * dim) * (seq_len * dim) + qk_idx % (seq_len * dim);
if (qk_idx < batch_size * num_heads * seq_len * dim) {
T sin_val = static_cast<T>(sin(emb[emb_idx]));
T cos_val = static_cast<T>(cos(emb[emb_idx]));
output1[qk_idx] = q[qk_idx] * cos_val + q_rot[qk_idx] * sin_val;
output2[qk_idx] = k[qk_idx] * cos_val + k_rot[qk_idx] * sin_val;
}
}
template <typename T, typename U>
Status LaunchGemmaRotaryEmbeddingKernel(
cudaStream_t stream,
T* output1,
T* output2,
const U* emb,
const T* q,
const T* q_rot,
const T* k,
const T* k_rot,
const int batch_size,
const int num_heads,
const int seq_len,
const int dim
) {
int blocksPerGrid = static_cast<int>(ceil(float(batch_size * num_heads * seq_len * dim) / kThreadsPerBlock));
GemmaRotaryEmb<<<blocksPerGrid, kThreadsPerBlock, 0, stream>>>(
output1, output2,
emb, q, q_rot, k, k_rot,
batch_size, num_heads, seq_len, dim
);
return CUDA_CALL(cudaGetLastError());
}
template Status LaunchGemmaRotaryEmbeddingKernel<half, float>(
cudaStream_t stream,
half* output1,
half* output2,
const float* emb,
const half* q,
const half* q_rot,
const half* k,
const half* k_rot,
const int batch_size,
const int num_heads,
const int seq_len,
const int dim);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,29 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/cuda/shared_inc/cuda_utils.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
template <typename T, typename U>
Status LaunchGemmaRotaryEmbeddingKernel(
cudaStream_t stream,
T* output1,
T* output2,
const U* emb,
const T* q,
const T* q_rot,
const T* k,
const T* k_rot,
const int batch_size,
const int num_heads,
const int seq_len,
const int dim);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -98,6 +98,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GemmaRotaryEmbedding);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh);
@ -302,6 +303,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, RotaryEmbedding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GemmaRotaryEmbedding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh)>,

View file

@ -1215,6 +1215,71 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
propagateShapeFromInputToOutput(ctx, 0, 0);
}));
constexpr const char* GemmaRotaryEmbedding_ver1_doc = R"DOC(
GemmaRotaryEmbedding is the implementation of below part of rotary positional embeddings (RoPE). It implements below from modeling_gemma.py.
Here's onnxscript that was tested
from onnxscript import FLOAT, FLOAT16, script
from onnxscript import opset18 as op
@script()
def gemma_rotary_embedding(emb: FLOAT["bs", "seq_len", "dim"], q: FLOAT16["bs", "num_heads", "seq_len", "dim"], q_rot: FLOAT16["bs", "num_heads", "seq_len", "dim"], k: FLOAT16["bs", "num_heads", "seq_len", "dim"], k_rot: FLOAT16["bs", "num_heads", "seq_len", "dim"]):
sin_val = op.Sin(emb)
casted_sin = op.Cast(sin_val, to=10) # for fp16 mix-precision training. Other types are not supported.
cos_val = op.Cos(emb)
casted_cos = op.Cast(cos_val, to=10)
unsqueezed_sin = op.Unsqueeze(casted_sin, [1])
unsqueezed_cos = op.Unsqueeze(casted_cos, [1])
q_embed = (q * casted_cos) + (q_rot * casted_sin)
k_embed = (k * casted_cos) + (k_rot * casted_sin)
return q_embed, k_embed
onnx_model = gemma_rotary_embedding.to_model_proto()
)DOC";
ONNX_MS_OPERATOR_SET_SCHEMA(
GemmaRotaryEmbedding, 1,
OpSchema()
.SetDoc(GemmaRotaryEmbedding_ver1_doc)
.Input(0,
"emb",
"embeddding - 3D tensor with shape (batch_size, seq_len, dim)",
"U")
.Input(1,
"q",
"q state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)",
"T")
.Input(2,
"q_rot",
"half rotated q state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)",
"T")
.Input(3,
"k",
"k state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)",
"T")
.Input(4,
"k_rot",
"k state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)",
"T")
.Output(0,
"output1",
"4D tensor with shape (batch_size, num_heads, seq_len, dim)",
"T")
.Output(1,
"output2",
"4D tensor with shape (batch_size, num_heads, seq_len, dim)",
"T")
.TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output types to float16 tensors.")
.TypeConstraint("U", {"tensor(float)"}, "Constrain input 0 type to float tensors")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 1, 0);
propagateElemTypeFromInputToOutput(ctx, 1, 1);
propagateShapeFromInputToOutput(ctx, 1, 0);
propagateShapeFromInputToOutput(ctx, 1, 1);
}));
constexpr const char* EmbedLayerNormalization_ver1_doc = R"DOC(
EmbedLayerNormalization is the fusion of embedding layer in BERT model, with optional mask processing.
The embedding layer takes input_ids (word IDs) and segment_ids (sentence IDs) to look up word_embedding, position_embedding,

View file

@ -98,6 +98,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RotaryEmbedding);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GemmaRotaryEmbedding);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipGroupNorm);
@ -208,6 +209,7 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RotaryEmbedding)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GemmaRotaryEmbedding)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipGroupNorm)>());

View file

@ -0,0 +1,104 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <cassert>
#include "gtest/gtest.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "test/common/tensor_op_test_utils.h"
#include "test/common/cuda_op_test_utils.h"
#include "test/providers/provider_test_utils.h"
#include <cstdlib> // For rand() and srand()
namespace onnxruntime {
namespace test {
constexpr auto k_random_data_min = -1.0f;
constexpr auto k_random_data_max = 1.0f;
namespace {
enum class TensorType {
kFloat,
kFloat16,
kBFloat16
};
} // anonymous namespace
static void calculateExpectedOutput(const std::vector<float>& emb_data,
const std::vector<MLFloat16>& q_data,
const std::vector<MLFloat16>& q_rot_data,
const std::vector<MLFloat16>& k_data,
const std::vector<MLFloat16>& k_rot_data,
const std::vector<int64_t>& mul_dim,
std::vector<MLFloat16>& output1,
std::vector<MLFloat16>& output2) {
for (long int i = 0; i < mul_dim[0]; ++i) {
for (long int j = 0; j < mul_dim[1]; ++j) {
for (long int k = 0; k < mul_dim[2]; ++k) {
for (long int l = 0; l < mul_dim[3]; ++l) {
long int embIdx = i * mul_dim[1] * mul_dim[3] + k * mul_dim[3] + l;
long int mulIdx = i * mul_dim[1] * mul_dim[2] * mul_dim[3] + j * mul_dim[2] * mul_dim[3] + k * mul_dim[3] + l;
MLFloat16 sin_val = static_cast<MLFloat16>(sin(emb_data[embIdx]));
MLFloat16 cos_val = static_cast<MLFloat16>(cos(emb_data[embIdx]));
MLFloat16 q_val = static_cast<MLFloat16>(q_data[mulIdx]);
MLFloat16 q_rot_val = static_cast<MLFloat16>(q_rot_data[mulIdx]);
MLFloat16 k_val = static_cast<MLFloat16>(k_data[mulIdx]);
MLFloat16 k_rot_val = static_cast<MLFloat16>(k_rot_data[mulIdx]);
output1.push_back(static_cast<MLFloat16>(q_val * cos_val + q_rot_val * sin_val));
output2.push_back(static_cast<MLFloat16>(k_val * cos_val + k_rot_val * sin_val));
}
}
}
}
}
static void RunTest() {
std::string op_type = "GemmaRotaryEmbedding";
std::vector<int64_t> emb_dim = {1, 2, 2};
std::vector<int64_t> mul_dim = {1, 3, 2, 2};
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
int min_cuda_architecture = 530;
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
if (enable_cuda) {
execution_providers.push_back(DefaultCudaExecutionProvider());
}
if (execution_providers.size() == 0) {
// Return early if CI pipeline does not support EP (e.g. CUDA EP for CPU CI pipeline)
return;
}
OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain);
// create rand inputs
RandomValueGenerator random{};
const std::vector<float> emb_data = random.Uniform<float>(emb_dim, k_random_data_min, k_random_data_max);
const std::vector<MLFloat16> q = random.Uniform<MLFloat16>(mul_dim, k_random_data_min, k_random_data_max);
const std::vector<MLFloat16> q_rot = random.Uniform<MLFloat16>(mul_dim, k_random_data_min, k_random_data_max);
const std::vector<MLFloat16> k = random.Uniform<MLFloat16>(mul_dim, k_random_data_min, k_random_data_max);
const std::vector<MLFloat16> k_rot = random.Uniform<MLFloat16>(mul_dim, k_random_data_min, k_random_data_max);
std::vector<MLFloat16> output1;
std::vector<MLFloat16> output2;
calculateExpectedOutput(emb_data, q, q_rot, k, k_rot, mul_dim, output1, output2);
test.AddInput<float>("emb", emb_dim, emb_data);
test.AddInput<MLFloat16>("q_data", mul_dim, q);
test.AddInput<MLFloat16>("q_rot_data", mul_dim, q_rot);
test.AddInput<MLFloat16>("k_data", mul_dim, k);
test.AddInput<MLFloat16>("k_rot_data", mul_dim, k_rot);
test.AddOutput<MLFloat16>("output1", mul_dim, output1);
test.AddOutput<MLFloat16>("output2", mul_dim, output2);
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
TEST(GemmaRotaryEmbeddingTest, GemmaRotaryEmbedding_Small) {
RunTest();
}
} // namespace test
} // namespace onnxruntime