From c11941289b2fa516073016ec5474dc426e4c916d Mon Sep 17 00:00:00 2001
From: jingyanwangms <47403504+jingyanwangms@users.noreply.github.com>
Date: Tue, 16 Apr 2024 15:31:56 -0700
Subject: [PATCH] Add Gemma Rotary Embedding (#20267)
### Description
Add GemmaRotaryEmbedding kernel which includes sin and cos in
GemmaRotaryEmbedding forward and apply_rotary_pos_emb. See
gemma_rotary_emb_impl.cu for subgraph details
### Motivation and Context
---
docs/ContribOperators.md | 64 +++++++++++
docs/OperatorKernels.md | 1 +
.../contrib_ops/cuda/bert/gemma_rotary_emb.cc | 75 +++++++++++++
.../contrib_ops/cuda/bert/gemma_rotary_emb.h | 24 ++++
.../cuda/bert/gemma_rotary_emb_impl.cu | 104 ++++++++++++++++++
.../cuda/bert/gemma_rotary_emb_impl.h | 29 +++++
.../contrib_ops/cuda/cuda_contrib_kernels.cc | 2 +
.../core/graph/contrib_ops/bert_defs.cc | 65 +++++++++++
onnxruntime/core/graph/contrib_ops/ms_opset.h | 2 +
.../test/contrib_ops/gemma_rotary_emb_test.cc | 104 ++++++++++++++++++
10 files changed, 470 insertions(+)
create mode 100644 onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.cc
create mode 100644 onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.h
create mode 100644 onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.cu
create mode 100644 onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.h
create mode 100644 onnxruntime/test/contrib_ops/gemma_rotary_emb_test.cc
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 9b45cc0270..3d984a54c0 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -41,6 +41,7 @@ Do not modify directly.*
* com.microsoft.Gelu
* com.microsoft.GemmFastGelu
* com.microsoft.GemmFloat8
+ * com.microsoft.GemmaRotaryEmbedding
* com.microsoft.GreedySearch
* com.microsoft.GridSample
* com.microsoft.GroupNorm
@@ -2210,6 +2211,69 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.GemmaRotaryEmbedding**
+
+ GemmaRotaryEmbedding is the implementation of below part of rotary positional embeddings (RoPE). It implements below from modeling_gemma.py.
+
+ Here's onnxscript that was tested
+
+ from onnxscript import FLOAT, FLOAT16, script
+ from onnxscript import opset18 as op
+
+ @script()
+ def gemma_rotary_embedding(emb: FLOAT["bs", "seq_len", "dim"], q: FLOAT16["bs", "num_heads", "seq_len", "dim"], q_rot: FLOAT16["bs", "num_heads", "seq_len", "dim"], k: FLOAT16["bs", "num_heads", "seq_len", "dim"], k_rot: FLOAT16["bs", "num_heads", "seq_len", "dim"]):
+ sin_val = op.Sin(emb)
+ casted_sin = op.Cast(sin_val, to=10) # for fp16 mix-precision training. Other types are not supported.
+ cos_val = op.Cos(emb)
+ casted_cos = op.Cast(cos_val, to=10)
+ unsqueezed_sin = op.Unsqueeze(casted_sin, [1])
+ unsqueezed_cos = op.Unsqueeze(casted_cos, [1])
+ q_embed = (q * casted_cos) + (q_rot * casted_sin)
+ k_embed = (k * casted_cos) + (k_rot * casted_sin)
+ return q_embed, k_embed
+
+ onnx_model = gemma_rotary_embedding.to_model_proto()
+
+
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Inputs
+
+
+- emb : U
+- embeddding - 3D tensor with shape (batch_size, seq_len, dim)
+- q : T
+- q state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)
+- q_rot : T
+- half rotated q state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)
+- k : T
+- k state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)
+- k_rot : T
+- k state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)
+
+
+#### Outputs
+
+
+- output1 : T
+- 4D tensor with shape (batch_size, num_heads, seq_len, dim)
+- output2 : T
+- 4D tensor with shape (batch_size, num_heads, seq_len, dim)
+
+
+#### Type Constraints
+
+
+- T : tensor(float16)
+- Constrain input and output types to float16 tensors.
+- U : tensor(float)
+- Constrain input 0 type to float tensors
+
+
+
### **com.microsoft.GreedySearch**
Greedy Search for text generation.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index d184485cb5..5bae5ea626 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -868,6 +868,7 @@ Do not modify directly.*
|GatedRelativePositionBias|*in* query_layer:**T**
*in* query_bias:**T**
*in* rel_pos:**T**
*in* weight:**T**
*in* bias:**T**
*in* eco_a:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|GemmFloat8|*in* A:**TA**
*in* B:**TB**
*in* C:**TC**
*in* scaleA:**TS**
*in* scaleB:**TS**
*in* scaleY:**TS**
*out* Y:**TR**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TR** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TS** = tensor(float)|
+|GemmaRotaryEmbedding|*in* emb:**U**
*in* q:**T**
*in* q_rot:**T**
*in* k:**T**
*in* k_rot:**T**
*out* output1:**T**
*out* output2:**T**|1+|**T** = tensor(float16)
**U** = tensor(float)|
|GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)|
|GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
diff --git a/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.cc b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.cc
new file mode 100644
index 0000000000..49bf79188e
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.cc
@@ -0,0 +1,75 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/cuda/cuda_common.h"
+#include "contrib_ops/cuda/bert/gemma_rotary_emb.h"
+#include "contrib_ops/cuda/bert/gemma_rotary_emb_impl.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+#define REGISTER_KERNEL_TYPED(T, U) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ GemmaRotaryEmbedding, \
+ kMSDomain, \
+ 1, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("U", DataTypeImpl::GetTensorType()), \
+ GemmaRotaryEmbedding);
+
+REGISTER_KERNEL_TYPED(MLFloat16, float)
+
+template
+GemmaRotaryEmbedding::GemmaRotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) {
+}
+
+template
+Status GemmaRotaryEmbedding::ComputeInternal(OpKernelContext* context) const {
+ const Tensor* emb = context->Input(0);
+ const Tensor* q = context->Input(1);
+ const Tensor* q_rot = context->Input(2);
+ const Tensor* k = context->Input(3);
+ const Tensor* k_rot = context->Input(4);
+
+ const auto& emb_dims = emb->Shape().GetDims();
+ const auto& q_dims = q->Shape().GetDims();
+ int batch_size = static_cast(q_dims[0]);
+ int num_heads = static_cast(q_dims[1]);
+ int seq_len = static_cast(q_dims[2]);
+ int dim = static_cast(q_dims[3]);
+
+ // q_dims should be [batch_size, num_heads, seq_len, dim]
+ // emb_dims should be [batch_size, seq, dim]
+ ORT_ENFORCE(emb_dims.size() == 3, "emb_dims should be 3D");
+ ORT_ENFORCE(q_dims.size() == 4, "emb_dims should be 4D");
+ ORT_ENFORCE(emb_dims[0] == batch_size, "emb_dims[0] should match q_dims[0]");
+ ORT_ENFORCE(emb_dims[1] == seq_len, "emb_dims[1] should match q_dims[2]");
+ ORT_ENFORCE(emb_dims[2] == dim, "emb_dims[2] should match q_dims[3]");
+
+ Tensor* output1 = context->Output(0, q_dims);
+ Tensor* output2 = context->Output(1, q_dims);
+
+ typedef typename ToCudaType::MappedType CudaT;
+ typedef typename ToCudaType::MappedType CudaU;
+ return LaunchGemmaRotaryEmbeddingKernel(
+ Stream(context),
+ reinterpret_cast(output1->template MutableData()),
+ reinterpret_cast(output2->template MutableData()),
+ reinterpret_cast(emb->template Data()),
+ reinterpret_cast(q->template Data()),
+ reinterpret_cast(q_rot->template Data()),
+ reinterpret_cast(k->template Data()),
+ reinterpret_cast(k_rot->template Data()),
+ batch_size,
+ num_heads,
+ seq_len,
+ dim);
+}
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.h b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.h
new file mode 100644
index 0000000000..e63236d2ab
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.h
@@ -0,0 +1,24 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/providers/cuda/cuda_kernel.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+using onnxruntime::cuda::CudaKernel;
+using onnxruntime::cuda::ToCudaType;
+
+template
+class GemmaRotaryEmbedding final : public CudaKernel {
+ public:
+ GemmaRotaryEmbedding(const OpKernelInfo& info);
+ Status ComputeInternal(OpKernelContext* context) const override;
+};
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.cu b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.cu
new file mode 100644
index 0000000000..9e00ca713a
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.cu
@@ -0,0 +1,104 @@
+/*
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT License.
+*/
+/*
+Kernel implementation for Gamma rotary embeddings.
+This implementation below subgraph
+ (emb)
+ / \
+ / \
+ Sin Cos
+ | |
+ Cast Cast
+ | |
+ Unsqueeze Unsqueeze
+ \/ \/ \/ \/
+ Mul Mul Mul Mul
+ \ / \ /
+ Add Add
+ | |
+ (output1) (output2)
+*/
+
+#include
+#include
+#include "core/providers/cuda/cu_inc/common.cuh"
+#include "contrib_ops/cuda/bert/gemma_rotary_emb_impl.h"
+
+using namespace onnxruntime::cuda;
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+constexpr int kThreadsPerBlock = GridDim::maxThreadsPerBlock;
+
+template
+__global__ void GemmaRotaryEmb(
+ T* output1,
+ T* output2,
+ const U* emb,
+ const T* q,
+ const T* q_rot,
+ const T* k,
+ const T* k_rot,
+ const int batch_size,
+ const int num_heads,
+ const int seq_len,
+ const int dim) {
+
+ const int qk_idx = blockIdx.x * blockDim.x + threadIdx.x;
+ // index [i, j, k, l] -> [i, k, l]
+ const int emb_idx = qk_idx / (num_heads * seq_len * dim) * (seq_len * dim) + qk_idx % (seq_len * dim);
+ if (qk_idx < batch_size * num_heads * seq_len * dim) {
+ T sin_val = static_cast(sin(emb[emb_idx]));
+ T cos_val = static_cast(cos(emb[emb_idx]));
+ output1[qk_idx] = q[qk_idx] * cos_val + q_rot[qk_idx] * sin_val;
+ output2[qk_idx] = k[qk_idx] * cos_val + k_rot[qk_idx] * sin_val;
+ }
+}
+
+template
+Status LaunchGemmaRotaryEmbeddingKernel(
+ cudaStream_t stream,
+ T* output1,
+ T* output2,
+ const U* emb,
+ const T* q,
+ const T* q_rot,
+ const T* k,
+ const T* k_rot,
+ const int batch_size,
+ const int num_heads,
+ const int seq_len,
+ const int dim
+ ) {
+ int blocksPerGrid = static_cast(ceil(float(batch_size * num_heads * seq_len * dim) / kThreadsPerBlock));
+
+ GemmaRotaryEmb<<>>(
+ output1, output2,
+ emb, q, q_rot, k, k_rot,
+ batch_size, num_heads, seq_len, dim
+ );
+
+ return CUDA_CALL(cudaGetLastError());
+}
+
+template Status LaunchGemmaRotaryEmbeddingKernel(
+ cudaStream_t stream,
+ half* output1,
+ half* output2,
+ const float* emb,
+ const half* q,
+ const half* q_rot,
+ const half* k,
+ const half* k_rot,
+ const int batch_size,
+ const int num_heads,
+ const int seq_len,
+ const int dim);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.h b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.h
new file mode 100644
index 0000000000..c57fbe0d7e
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.h
@@ -0,0 +1,29 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/providers/cuda/shared_inc/cuda_utils.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+template
+Status LaunchGemmaRotaryEmbeddingKernel(
+ cudaStream_t stream,
+ T* output1,
+ T* output2,
+ const U* emb,
+ const T* q,
+ const T* q_rot,
+ const T* k,
+ const T* k_rot,
+ const int batch_size,
+ const int num_heads,
+ const int seq_len,
+ const int dim);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
index 3621ffc5c6..583e67b2e6 100644
--- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
@@ -98,6 +98,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, RotaryEmbedding);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GemmaRotaryEmbedding);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh);
@@ -302,6 +303,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index cc10e73be3..adfa1b61e1 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -1215,6 +1215,71 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
propagateShapeFromInputToOutput(ctx, 0, 0);
}));
+constexpr const char* GemmaRotaryEmbedding_ver1_doc = R"DOC(
+GemmaRotaryEmbedding is the implementation of below part of rotary positional embeddings (RoPE). It implements below from modeling_gemma.py.
+
+Here's onnxscript that was tested
+
+from onnxscript import FLOAT, FLOAT16, script
+from onnxscript import opset18 as op
+
+@script()
+def gemma_rotary_embedding(emb: FLOAT["bs", "seq_len", "dim"], q: FLOAT16["bs", "num_heads", "seq_len", "dim"], q_rot: FLOAT16["bs", "num_heads", "seq_len", "dim"], k: FLOAT16["bs", "num_heads", "seq_len", "dim"], k_rot: FLOAT16["bs", "num_heads", "seq_len", "dim"]):
+ sin_val = op.Sin(emb)
+ casted_sin = op.Cast(sin_val, to=10) # for fp16 mix-precision training. Other types are not supported.
+ cos_val = op.Cos(emb)
+ casted_cos = op.Cast(cos_val, to=10)
+ unsqueezed_sin = op.Unsqueeze(casted_sin, [1])
+ unsqueezed_cos = op.Unsqueeze(casted_cos, [1])
+ q_embed = (q * casted_cos) + (q_rot * casted_sin)
+ k_embed = (k * casted_cos) + (k_rot * casted_sin)
+ return q_embed, k_embed
+
+onnx_model = gemma_rotary_embedding.to_model_proto()
+
+
+)DOC";
+ONNX_MS_OPERATOR_SET_SCHEMA(
+ GemmaRotaryEmbedding, 1,
+ OpSchema()
+ .SetDoc(GemmaRotaryEmbedding_ver1_doc)
+ .Input(0,
+ "emb",
+ "embeddding - 3D tensor with shape (batch_size, seq_len, dim)",
+ "U")
+ .Input(1,
+ "q",
+ "q state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)",
+ "T")
+ .Input(2,
+ "q_rot",
+ "half rotated q state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)",
+ "T")
+ .Input(3,
+ "k",
+ "k state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)",
+ "T")
+ .Input(4,
+ "k_rot",
+ "k state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)",
+ "T")
+ .Output(0,
+ "output1",
+ "4D tensor with shape (batch_size, num_heads, seq_len, dim)",
+ "T")
+ .Output(1,
+ "output2",
+ "4D tensor with shape (batch_size, num_heads, seq_len, dim)",
+ "T")
+ .TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output types to float16 tensors.")
+ .TypeConstraint("U", {"tensor(float)"}, "Constrain input 0 type to float tensors")
+ .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
+ propagateElemTypeFromInputToOutput(ctx, 1, 0);
+ propagateElemTypeFromInputToOutput(ctx, 1, 1);
+ propagateShapeFromInputToOutput(ctx, 1, 0);
+ propagateShapeFromInputToOutput(ctx, 1, 1);
+ }));
+
constexpr const char* EmbedLayerNormalization_ver1_doc = R"DOC(
EmbedLayerNormalization is the fusion of embedding layer in BERT model, with optional mask processing.
The embedding layer takes input_ids (word IDs) and segment_ids (sentence IDs) to look up word_embedding, position_embedding,
diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h
index ef86352080..a23ad4678b 100644
--- a/onnxruntime/core/graph/contrib_ops/ms_opset.h
+++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h
@@ -98,6 +98,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RotaryEmbedding);
+class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GemmaRotaryEmbedding);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipGroupNorm);
@@ -208,6 +209,7 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
+ fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
diff --git a/onnxruntime/test/contrib_ops/gemma_rotary_emb_test.cc b/onnxruntime/test/contrib_ops/gemma_rotary_emb_test.cc
new file mode 100644
index 0000000000..80adf04f40
--- /dev/null
+++ b/onnxruntime/test/contrib_ops/gemma_rotary_emb_test.cc
@@ -0,0 +1,104 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include
+#include "gtest/gtest.h"
+#include "core/session/onnxruntime_cxx_api.h"
+#include "test/common/tensor_op_test_utils.h"
+#include "test/common/cuda_op_test_utils.h"
+#include "test/providers/provider_test_utils.h"
+#include // For rand() and srand()
+
+namespace onnxruntime {
+namespace test {
+
+constexpr auto k_random_data_min = -1.0f;
+constexpr auto k_random_data_max = 1.0f;
+
+namespace {
+enum class TensorType {
+ kFloat,
+ kFloat16,
+ kBFloat16
+};
+} // anonymous namespace
+
+static void calculateExpectedOutput(const std::vector& emb_data,
+ const std::vector& q_data,
+ const std::vector& q_rot_data,
+ const std::vector& k_data,
+ const std::vector& k_rot_data,
+ const std::vector& mul_dim,
+ std::vector& output1,
+ std::vector& output2) {
+ for (long int i = 0; i < mul_dim[0]; ++i) {
+ for (long int j = 0; j < mul_dim[1]; ++j) {
+ for (long int k = 0; k < mul_dim[2]; ++k) {
+ for (long int l = 0; l < mul_dim[3]; ++l) {
+ long int embIdx = i * mul_dim[1] * mul_dim[3] + k * mul_dim[3] + l;
+ long int mulIdx = i * mul_dim[1] * mul_dim[2] * mul_dim[3] + j * mul_dim[2] * mul_dim[3] + k * mul_dim[3] + l;
+
+ MLFloat16 sin_val = static_cast(sin(emb_data[embIdx]));
+ MLFloat16 cos_val = static_cast(cos(emb_data[embIdx]));
+ MLFloat16 q_val = static_cast(q_data[mulIdx]);
+ MLFloat16 q_rot_val = static_cast(q_rot_data[mulIdx]);
+ MLFloat16 k_val = static_cast(k_data[mulIdx]);
+ MLFloat16 k_rot_val = static_cast(k_rot_data[mulIdx]);
+ output1.push_back(static_cast(q_val * cos_val + q_rot_val * sin_val));
+ output2.push_back(static_cast(k_val * cos_val + k_rot_val * sin_val));
+ }
+ }
+ }
+ }
+}
+
+static void RunTest() {
+ std::string op_type = "GemmaRotaryEmbedding";
+ std::vector emb_dim = {1, 2, 2};
+ std::vector mul_dim = {1, 3, 2, 2};
+ std::vector> execution_providers;
+
+ int min_cuda_architecture = 530;
+ bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
+
+ if (enable_cuda) {
+ execution_providers.push_back(DefaultCudaExecutionProvider());
+ }
+
+ if (execution_providers.size() == 0) {
+ // Return early if CI pipeline does not support EP (e.g. CUDA EP for CPU CI pipeline)
+ return;
+ }
+
+ OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain);
+
+ // create rand inputs
+ RandomValueGenerator random{};
+ const std::vector emb_data = random.Uniform(emb_dim, k_random_data_min, k_random_data_max);
+ const std::vector q = random.Uniform(mul_dim, k_random_data_min, k_random_data_max);
+ const std::vector q_rot = random.Uniform(mul_dim, k_random_data_min, k_random_data_max);
+ const std::vector k = random.Uniform(mul_dim, k_random_data_min, k_random_data_max);
+ const std::vector k_rot = random.Uniform(mul_dim, k_random_data_min, k_random_data_max);
+
+ std::vector output1;
+ std::vector output2;
+
+ calculateExpectedOutput(emb_data, q, q_rot, k, k_rot, mul_dim, output1, output2);
+
+ test.AddInput("emb", emb_dim, emb_data);
+ test.AddInput("q_data", mul_dim, q);
+ test.AddInput("q_rot_data", mul_dim, q_rot);
+ test.AddInput("k_data", mul_dim, k);
+ test.AddInput("k_rot_data", mul_dim, k_rot);
+ test.AddOutput("output1", mul_dim, output1);
+ test.AddOutput("output2", mul_dim, output2);
+
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+}
+
+TEST(GemmaRotaryEmbeddingTest, GemmaRotaryEmbedding_Small) {
+ RunTest();
+}
+
+} // namespace test
+} // namespace onnxruntime
\ No newline at end of file