From c11941289b2fa516073016ec5474dc426e4c916d Mon Sep 17 00:00:00 2001 From: jingyanwangms <47403504+jingyanwangms@users.noreply.github.com> Date: Tue, 16 Apr 2024 15:31:56 -0700 Subject: [PATCH] Add Gemma Rotary Embedding (#20267) ### Description Add GemmaRotaryEmbedding kernel which includes sin and cos in GemmaRotaryEmbedding forward and apply_rotary_pos_emb. See gemma_rotary_emb_impl.cu for subgraph details ### Motivation and Context --- docs/ContribOperators.md | 64 +++++++++++ docs/OperatorKernels.md | 1 + .../contrib_ops/cuda/bert/gemma_rotary_emb.cc | 75 +++++++++++++ .../contrib_ops/cuda/bert/gemma_rotary_emb.h | 24 ++++ .../cuda/bert/gemma_rotary_emb_impl.cu | 104 ++++++++++++++++++ .../cuda/bert/gemma_rotary_emb_impl.h | 29 +++++ .../contrib_ops/cuda/cuda_contrib_kernels.cc | 2 + .../core/graph/contrib_ops/bert_defs.cc | 65 +++++++++++ onnxruntime/core/graph/contrib_ops/ms_opset.h | 2 + .../test/contrib_ops/gemma_rotary_emb_test.cc | 104 ++++++++++++++++++ 10 files changed, 470 insertions(+) create mode 100644 onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.cc create mode 100644 onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.h create mode 100644 onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.h create mode 100644 onnxruntime/test/contrib_ops/gemma_rotary_emb_test.cc diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 9b45cc0270..3d984a54c0 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -41,6 +41,7 @@ Do not modify directly.* * com.microsoft.Gelu * com.microsoft.GemmFastGelu * com.microsoft.GemmFloat8 + * com.microsoft.GemmaRotaryEmbedding * com.microsoft.GreedySearch * com.microsoft.GridSample * com.microsoft.GroupNorm @@ -2210,6 +2211,69 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.GemmaRotaryEmbedding** + + GemmaRotaryEmbedding is the implementation of below part of rotary positional embeddings (RoPE). It implements below from modeling_gemma.py. + + Here's onnxscript that was tested + + from onnxscript import FLOAT, FLOAT16, script + from onnxscript import opset18 as op + + @script() + def gemma_rotary_embedding(emb: FLOAT["bs", "seq_len", "dim"], q: FLOAT16["bs", "num_heads", "seq_len", "dim"], q_rot: FLOAT16["bs", "num_heads", "seq_len", "dim"], k: FLOAT16["bs", "num_heads", "seq_len", "dim"], k_rot: FLOAT16["bs", "num_heads", "seq_len", "dim"]): + sin_val = op.Sin(emb) + casted_sin = op.Cast(sin_val, to=10) # for fp16 mix-precision training. Other types are not supported. + cos_val = op.Cos(emb) + casted_cos = op.Cast(cos_val, to=10) + unsqueezed_sin = op.Unsqueeze(casted_sin, [1]) + unsqueezed_cos = op.Unsqueeze(casted_cos, [1]) + q_embed = (q * casted_cos) + (q_rot * casted_sin) + k_embed = (k * casted_cos) + (k_rot * casted_sin) + return q_embed, k_embed + + onnx_model = gemma_rotary_embedding.to_model_proto() + + + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Inputs + +
+
emb : U
+
embeddding - 3D tensor with shape (batch_size, seq_len, dim)
+
q : T
+
q state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)
+
q_rot : T
+
half rotated q state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)
+
k : T
+
k state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)
+
k_rot : T
+
k state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)
+
+ +#### Outputs + +
+
output1 : T
+
4D tensor with shape (batch_size, num_heads, seq_len, dim)
+
output2 : T
+
4D tensor with shape (batch_size, num_heads, seq_len, dim)
+
+ +#### Type Constraints + +
+
T : tensor(float16)
+
Constrain input and output types to float16 tensors.
+
U : tensor(float)
+
Constrain input 0 type to float tensors
+
+ + ### **com.microsoft.GreedySearch** Greedy Search for text generation. diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index d184485cb5..5bae5ea626 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -868,6 +868,7 @@ Do not modify directly.* |GatedRelativePositionBias|*in* query_layer:**T**
*in* query_bias:**T**
*in* rel_pos:**T**
*in* weight:**T**
*in* bias:**T**
*in* eco_a:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |GemmFloat8|*in* A:**TA**
*in* B:**TB**
*in* C:**TC**
*in* scaleA:**TS**
*in* scaleB:**TS**
*in* scaleY:**TS**
*out* Y:**TR**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TR** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TS** = tensor(float)| +|GemmaRotaryEmbedding|*in* emb:**U**
*in* q:**T**
*in* q_rot:**T**
*in* k:**T**
*in* k_rot:**T**
*out* output1:**T**
*out* output2:**T**|1+|**T** = tensor(float16)
**U** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.cc b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.cc new file mode 100644 index 0000000000..49bf79188e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.cc @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/bert/gemma_rotary_emb.h" +#include "contrib_ops/cuda/bert/gemma_rotary_emb_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T, U) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GemmaRotaryEmbedding, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("U", DataTypeImpl::GetTensorType()), \ + GemmaRotaryEmbedding); + +REGISTER_KERNEL_TYPED(MLFloat16, float) + +template +GemmaRotaryEmbedding::GemmaRotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) { +} + +template +Status GemmaRotaryEmbedding::ComputeInternal(OpKernelContext* context) const { + const Tensor* emb = context->Input(0); + const Tensor* q = context->Input(1); + const Tensor* q_rot = context->Input(2); + const Tensor* k = context->Input(3); + const Tensor* k_rot = context->Input(4); + + const auto& emb_dims = emb->Shape().GetDims(); + const auto& q_dims = q->Shape().GetDims(); + int batch_size = static_cast(q_dims[0]); + int num_heads = static_cast(q_dims[1]); + int seq_len = static_cast(q_dims[2]); + int dim = static_cast(q_dims[3]); + + // q_dims should be [batch_size, num_heads, seq_len, dim] + // emb_dims should be [batch_size, seq, dim] + ORT_ENFORCE(emb_dims.size() == 3, "emb_dims should be 3D"); + ORT_ENFORCE(q_dims.size() == 4, "emb_dims should be 4D"); + ORT_ENFORCE(emb_dims[0] == batch_size, "emb_dims[0] should match q_dims[0]"); + ORT_ENFORCE(emb_dims[1] == seq_len, "emb_dims[1] should match q_dims[2]"); + ORT_ENFORCE(emb_dims[2] == dim, "emb_dims[2] should match q_dims[3]"); + + Tensor* output1 = context->Output(0, q_dims); + Tensor* output2 = context->Output(1, q_dims); + + typedef typename ToCudaType::MappedType CudaT; + typedef typename ToCudaType::MappedType CudaU; + return LaunchGemmaRotaryEmbeddingKernel( + Stream(context), + reinterpret_cast(output1->template MutableData()), + reinterpret_cast(output2->template MutableData()), + reinterpret_cast(emb->template Data()), + reinterpret_cast(q->template Data()), + reinterpret_cast(q_rot->template Data()), + reinterpret_cast(k->template Data()), + reinterpret_cast(k_rot->template Data()), + batch_size, + num_heads, + seq_len, + dim); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.h b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.h new file mode 100644 index 0000000000..e63236d2ab --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using onnxruntime::cuda::CudaKernel; +using onnxruntime::cuda::ToCudaType; + +template +class GemmaRotaryEmbedding final : public CudaKernel { + public: + GemmaRotaryEmbedding(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.cu b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.cu new file mode 100644 index 0000000000..9e00ca713a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.cu @@ -0,0 +1,104 @@ +/* +Copyright (c) Microsoft Corporation. +Licensed under the MIT License. +*/ +/* +Kernel implementation for Gamma rotary embeddings. +This implementation below subgraph + (emb) + / \ + / \ + Sin Cos + | | + Cast Cast + | | + Unsqueeze Unsqueeze + \/ \/ \/ \/ + Mul Mul Mul Mul + \ / \ / + Add Add + | | + (output1) (output2) +*/ + +#include +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "contrib_ops/cuda/bert/gemma_rotary_emb_impl.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +constexpr int kThreadsPerBlock = GridDim::maxThreadsPerBlock; + +template +__global__ void GemmaRotaryEmb( + T* output1, + T* output2, + const U* emb, + const T* q, + const T* q_rot, + const T* k, + const T* k_rot, + const int batch_size, + const int num_heads, + const int seq_len, + const int dim) { + + const int qk_idx = blockIdx.x * blockDim.x + threadIdx.x; + // index [i, j, k, l] -> [i, k, l] + const int emb_idx = qk_idx / (num_heads * seq_len * dim) * (seq_len * dim) + qk_idx % (seq_len * dim); + if (qk_idx < batch_size * num_heads * seq_len * dim) { + T sin_val = static_cast(sin(emb[emb_idx])); + T cos_val = static_cast(cos(emb[emb_idx])); + output1[qk_idx] = q[qk_idx] * cos_val + q_rot[qk_idx] * sin_val; + output2[qk_idx] = k[qk_idx] * cos_val + k_rot[qk_idx] * sin_val; + } +} + +template +Status LaunchGemmaRotaryEmbeddingKernel( + cudaStream_t stream, + T* output1, + T* output2, + const U* emb, + const T* q, + const T* q_rot, + const T* k, + const T* k_rot, + const int batch_size, + const int num_heads, + const int seq_len, + const int dim + ) { + int blocksPerGrid = static_cast(ceil(float(batch_size * num_heads * seq_len * dim) / kThreadsPerBlock)); + + GemmaRotaryEmb<<>>( + output1, output2, + emb, q, q_rot, k, k_rot, + batch_size, num_heads, seq_len, dim + ); + + return CUDA_CALL(cudaGetLastError()); +} + +template Status LaunchGemmaRotaryEmbeddingKernel( + cudaStream_t stream, + half* output1, + half* output2, + const float* emb, + const half* q, + const half* q_rot, + const half* k, + const half* k_rot, + const int batch_size, + const int num_heads, + const int seq_len, + const int dim); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.h b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.h new file mode 100644 index 0000000000..c57fbe0d7e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.h @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status LaunchGemmaRotaryEmbeddingKernel( + cudaStream_t stream, + T* output1, + T* output2, + const U* emb, + const T* q, + const T* q_rot, + const T* k, + const T* k_rot, + const int batch_size, + const int num_heads, + const int seq_len, + const int dim); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 3621ffc5c6..583e67b2e6 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -98,6 +98,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, RotaryEmbedding); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GemmaRotaryEmbedding); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); @@ -302,6 +303,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index cc10e73be3..adfa1b61e1 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1215,6 +1215,71 @@ ONNX_MS_OPERATOR_SET_SCHEMA( propagateShapeFromInputToOutput(ctx, 0, 0); })); +constexpr const char* GemmaRotaryEmbedding_ver1_doc = R"DOC( +GemmaRotaryEmbedding is the implementation of below part of rotary positional embeddings (RoPE). It implements below from modeling_gemma.py. + +Here's onnxscript that was tested + +from onnxscript import FLOAT, FLOAT16, script +from onnxscript import opset18 as op + +@script() +def gemma_rotary_embedding(emb: FLOAT["bs", "seq_len", "dim"], q: FLOAT16["bs", "num_heads", "seq_len", "dim"], q_rot: FLOAT16["bs", "num_heads", "seq_len", "dim"], k: FLOAT16["bs", "num_heads", "seq_len", "dim"], k_rot: FLOAT16["bs", "num_heads", "seq_len", "dim"]): + sin_val = op.Sin(emb) + casted_sin = op.Cast(sin_val, to=10) # for fp16 mix-precision training. Other types are not supported. + cos_val = op.Cos(emb) + casted_cos = op.Cast(cos_val, to=10) + unsqueezed_sin = op.Unsqueeze(casted_sin, [1]) + unsqueezed_cos = op.Unsqueeze(casted_cos, [1]) + q_embed = (q * casted_cos) + (q_rot * casted_sin) + k_embed = (k * casted_cos) + (k_rot * casted_sin) + return q_embed, k_embed + +onnx_model = gemma_rotary_embedding.to_model_proto() + + +)DOC"; +ONNX_MS_OPERATOR_SET_SCHEMA( + GemmaRotaryEmbedding, 1, + OpSchema() + .SetDoc(GemmaRotaryEmbedding_ver1_doc) + .Input(0, + "emb", + "embeddding - 3D tensor with shape (batch_size, seq_len, dim)", + "U") + .Input(1, + "q", + "q state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)", + "T") + .Input(2, + "q_rot", + "half rotated q state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)", + "T") + .Input(3, + "k", + "k state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)", + "T") + .Input(4, + "k_rot", + "k state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)", + "T") + .Output(0, + "output1", + "4D tensor with shape (batch_size, num_heads, seq_len, dim)", + "T") + .Output(1, + "output2", + "4D tensor with shape (batch_size, num_heads, seq_len, dim)", + "T") + .TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output types to float16 tensors.") + .TypeConstraint("U", {"tensor(float)"}, "Constrain input 0 type to float tensors") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 1, 0); + propagateElemTypeFromInputToOutput(ctx, 1, 1); + propagateShapeFromInputToOutput(ctx, 1, 0); + propagateShapeFromInputToOutput(ctx, 1, 1); + })); + constexpr const char* EmbedLayerNormalization_ver1_doc = R"DOC( EmbedLayerNormalization is the fusion of embedding layer in BERT model, with optional mask processing. The embedding layer takes input_ids (word IDs) and segment_ids (sentence IDs) to look up word_embedding, position_embedding, diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index ef86352080..a23ad4678b 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -98,6 +98,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RotaryEmbedding); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GemmaRotaryEmbedding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipGroupNorm); @@ -208,6 +209,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/test/contrib_ops/gemma_rotary_emb_test.cc b/onnxruntime/test/contrib_ops/gemma_rotary_emb_test.cc new file mode 100644 index 0000000000..80adf04f40 --- /dev/null +++ b/onnxruntime/test/contrib_ops/gemma_rotary_emb_test.cc @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "gtest/gtest.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/providers/provider_test_utils.h" +#include // For rand() and srand() + +namespace onnxruntime { +namespace test { + +constexpr auto k_random_data_min = -1.0f; +constexpr auto k_random_data_max = 1.0f; + +namespace { +enum class TensorType { + kFloat, + kFloat16, + kBFloat16 +}; +} // anonymous namespace + +static void calculateExpectedOutput(const std::vector& emb_data, + const std::vector& q_data, + const std::vector& q_rot_data, + const std::vector& k_data, + const std::vector& k_rot_data, + const std::vector& mul_dim, + std::vector& output1, + std::vector& output2) { + for (long int i = 0; i < mul_dim[0]; ++i) { + for (long int j = 0; j < mul_dim[1]; ++j) { + for (long int k = 0; k < mul_dim[2]; ++k) { + for (long int l = 0; l < mul_dim[3]; ++l) { + long int embIdx = i * mul_dim[1] * mul_dim[3] + k * mul_dim[3] + l; + long int mulIdx = i * mul_dim[1] * mul_dim[2] * mul_dim[3] + j * mul_dim[2] * mul_dim[3] + k * mul_dim[3] + l; + + MLFloat16 sin_val = static_cast(sin(emb_data[embIdx])); + MLFloat16 cos_val = static_cast(cos(emb_data[embIdx])); + MLFloat16 q_val = static_cast(q_data[mulIdx]); + MLFloat16 q_rot_val = static_cast(q_rot_data[mulIdx]); + MLFloat16 k_val = static_cast(k_data[mulIdx]); + MLFloat16 k_rot_val = static_cast(k_rot_data[mulIdx]); + output1.push_back(static_cast(q_val * cos_val + q_rot_val * sin_val)); + output2.push_back(static_cast(k_val * cos_val + k_rot_val * sin_val)); + } + } + } + } +} + +static void RunTest() { + std::string op_type = "GemmaRotaryEmbedding"; + std::vector emb_dim = {1, 2, 2}; + std::vector mul_dim = {1, 3, 2, 2}; + std::vector> execution_providers; + + int min_cuda_architecture = 530; + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + + if (enable_cuda) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + + if (execution_providers.size() == 0) { + // Return early if CI pipeline does not support EP (e.g. CUDA EP for CPU CI pipeline) + return; + } + + OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain); + + // create rand inputs + RandomValueGenerator random{}; + const std::vector emb_data = random.Uniform(emb_dim, k_random_data_min, k_random_data_max); + const std::vector q = random.Uniform(mul_dim, k_random_data_min, k_random_data_max); + const std::vector q_rot = random.Uniform(mul_dim, k_random_data_min, k_random_data_max); + const std::vector k = random.Uniform(mul_dim, k_random_data_min, k_random_data_max); + const std::vector k_rot = random.Uniform(mul_dim, k_random_data_min, k_random_data_max); + + std::vector output1; + std::vector output2; + + calculateExpectedOutput(emb_data, q, q_rot, k, k_rot, mul_dim, output1, output2); + + test.AddInput("emb", emb_dim, emb_data); + test.AddInput("q_data", mul_dim, q); + test.AddInput("q_rot_data", mul_dim, q_rot); + test.AddInput("k_data", mul_dim, k); + test.AddInput("k_rot_data", mul_dim, k_rot); + test.AddOutput("output1", mul_dim, output1); + test.AddOutput("output2", mul_dim, output2); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(GemmaRotaryEmbeddingTest, GemmaRotaryEmbedding_Small) { + RunTest(); +} + +} // namespace test +} // namespace onnxruntime \ No newline at end of file