From bd5a759d0cdbed6e7f611c990d4eb5457a9ecf60 Mon Sep 17 00:00:00 2001 From: Jing Fang <126209182+fajin-corp@users.noreply.github.com> Date: Fri, 6 Dec 2024 21:25:48 +0000 Subject: [PATCH] [ARM CPU] Add rotary embedding fp16 kernel (#23013) ### Description Add fp16 kernel to rotary embedding to boost performance. ### Motivation and Context Part of performance optimization work for group query attention --- cmake/onnxruntime_mlas.cmake | 15 +- .../contrib_ops/cpu/bert/rotary_embedding.cc | 32 +-- onnxruntime/core/mlas/inc/mlas.h | 23 ++ ...6_neon_common.cpp => cast_kernel_neon.cpp} | 2 +- onnxruntime/core/mlas/lib/mlasi.h | 9 + onnxruntime/core/mlas/lib/platform.cpp | 1 + .../core/mlas/lib/rotary_embedding.cpp | 101 +++++++ onnxruntime/core/mlas/lib/rotary_embedding.h | 46 ++++ .../mlas/lib/rotary_embedding_kernel_neon.cpp | 32 +++ .../mlas/lib/rotary_embedding_kernel_neon.h | 37 +++ .../lib/rotary_embedding_kernel_neon_fp16.cpp | 253 ++++++++++++++++++ .../contrib_ops/rotary_embedding_op_test.cc | 44 +-- ...ch_fp16_neon_common.cpp => bench_cast.cpp} | 0 13 files changed, 526 insertions(+), 69 deletions(-) rename onnxruntime/core/mlas/lib/{fp16_neon_common.cpp => cast_kernel_neon.cpp} (99%) create mode 100644 onnxruntime/core/mlas/lib/rotary_embedding.cpp create mode 100644 onnxruntime/core/mlas/lib/rotary_embedding.h create mode 100644 onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.cpp create mode 100644 onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.h create mode 100644 onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp rename onnxruntime/test/mlas/bench/{bench_fp16_neon_common.cpp => bench_cast.cpp} (100%) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 10c307b3b9..5124262ec0 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -41,6 +41,8 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h ${MLAS_SRC_DIR}/flashattn.cpp ${MLAS_SRC_DIR}/cast.cpp + ${MLAS_SRC_DIR}/rotary_embedding.h + ${MLAS_SRC_DIR}/rotary_embedding.cpp ) target_sources(onnxruntime_mlas PRIVATE @@ -88,8 +90,11 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp - ${MLAS_SRC_DIR}/fp16_neon_common.cpp + ${MLAS_SRC_DIR}/cast_kernel_neon.cpp ${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp ) set(mlas_platform_preprocess_srcs @@ -367,6 +372,8 @@ else() ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") @@ -384,8 +391,9 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/fp16_neon_common.cpp + ${MLAS_SRC_DIR}/cast_kernel_neon.cpp ${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") @@ -395,8 +403,9 @@ else() set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - set_source_files_properties(${MLAS_SRC_DIR}/fp16_neon_common.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc index cbfd2f0949..9a6c2af022 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -4,6 +4,7 @@ #include "contrib_ops/cpu/bert/rotary_embedding.h" #include "contrib_ops/cpu/bert/rotary_embedding_helper.h" +#include "core/mlas/inc/mlas.h" #include "core/platform/threadpool.h" using onnxruntime::concurrency::ThreadPool; @@ -78,31 +79,12 @@ Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters paramete const T* cos_data = cos_cache + cache_offset; const T* sin_data = sin_cache + cache_offset; - int cache_idx = 0; - bool sign = false; - int j = 0; - for (int i = 0; i < rotary_emb_dim; i++) { - if (interleaved) { - cache_idx = (i / 2) % half_rotary_emb_dim; - sign = i & 1; - j = sign ? i - 1 : i + 1; // i - sign - } else { - cache_idx = i % half_rotary_emb_dim; - sign = (i >= half_rotary_emb_dim); - j = (i + half_rotary_emb_dim) % rotary_emb_dim; - } - float output_data_i = static_cast(input_data[i]) * static_cast(cos_data[cache_idx]); - float input_data_j = static_cast(input_data[j]); - float sin_data_cache_idx = static_cast(sin_data[cache_idx]); - if (sign) { - output_data_i += input_data_j * sin_data_cache_idx; - } else { - output_data_i -= input_data_j * sin_data_cache_idx; - } - output_data[i] = static_cast(output_data_i); - } - for (int i = rotary_emb_dim; i < head_size; i++) { - output_data[i] = input_data[i]; + MlasRotaryEmbedOneRow(input_data, sin_data, cos_data, rotary_emb_dim, interleaved, output_data); + + if (rotary_emb_dim < head_size) { + std::memcpy(output_data + rotary_emb_dim, + input_data + rotary_emb_dim, + (head_size - rotary_emb_dim) * sizeof(T)); } } }); diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 28ae64c4d5..207c058d89 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1435,6 +1435,29 @@ MLAS_FP16* Destination, size_t Count ); +/** + * @brief rotary embedding for one hidden state vector + * + * @tparam T: data type of input, sin, cos and output. Currently only float32/16 are supported. + * @param input: input tensor, of shape [dim] + * @param sin: sin tensor, of shape [dim/2] + * @param cos: cos tensor, of shape [dim/2] + * @param dim: dimension of rotary embedding + * @param interleaved: whether the real part and imaginary parts are interleaved + * @param output: output tensor, of shape [dim] + */ +template +void +MLASCALL +MlasRotaryEmbedOneRow( + const T* input, + const T* sin, + const T* cos, + size_t dim, + bool interleaved, + T* output +); + /** * @brief Whether current CPU supports FP16 acceleration. */ diff --git a/onnxruntime/core/mlas/lib/fp16_neon_common.cpp b/onnxruntime/core/mlas/lib/cast_kernel_neon.cpp similarity index 99% rename from onnxruntime/core/mlas/lib/fp16_neon_common.cpp rename to onnxruntime/core/mlas/lib/cast_kernel_neon.cpp index 29734c2277..8a385c9c61 100644 --- a/onnxruntime/core/mlas/lib/fp16_neon_common.cpp +++ b/onnxruntime/core/mlas/lib/cast_kernel_neon.cpp @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - fp16_neon_common.cpp + cast_kernel_neon.cpp Abstract: diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 0533a5e49b..100d7d4775 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1049,6 +1049,13 @@ extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512; extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; +// +// Rotary embedding dispatch structure. +// +struct MLAS_ROPE_DISPATCH; +extern const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon; + + // // Quantized depthwise convolution kernels. // @@ -1208,6 +1215,8 @@ struct MLAS_PLATFORM { MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel; MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel; + + const MLAS_ROPE_DISPATCH* RopeDispatch{nullptr}; }; inline diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index b3c9461293..ec572a4150 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -543,6 +543,7 @@ Return Value: this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon; this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; + this->RopeDispatch = &MlasRopeDispatchNeon; // // Check if the processor supports ASIMD dot product instructions. diff --git a/onnxruntime/core/mlas/lib/rotary_embedding.cpp b/onnxruntime/core/mlas/lib/rotary_embedding.cpp new file mode 100644 index 0000000000..1f8f7b2406 --- /dev/null +++ b/onnxruntime/core/mlas/lib/rotary_embedding.cpp @@ -0,0 +1,101 @@ +/*++ + +Copyright (c) Intel Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding.cpp + +Abstract: + + This module implements rotary embedding kernels for fp32/16. + +--*/ + +#include "rotary_embedding.h" + +namespace { + +template +void +MLASCALL +MlasRotaryEmbedOneRow_FallBack( + const T* input_data, + const T* sin_data, + const T* cos_data, + size_t rotary_emb_dim, + bool interleaved, + T* output_data +) { + const size_t half_rotary_emb_dim = rotary_emb_dim / 2; + size_t cache_idx = 0; + bool sign = false; + size_t j = 0; + for (size_t i = 0; i < rotary_emb_dim; i++) { + if (interleaved) { + cache_idx = (i / 2) % half_rotary_emb_dim; + sign = i & 1; + j = sign ? i - 1 : i + 1; // i - sign + } else { + cache_idx = i % half_rotary_emb_dim; + sign = (i >= half_rotary_emb_dim); + j = (i + half_rotary_emb_dim) % rotary_emb_dim; + } + float output_data_i = static_cast(input_data[i]) * static_cast(cos_data[cache_idx]); + float input_data_j = static_cast(input_data[j]); + float sin_data_cache_idx = static_cast(sin_data[cache_idx]); + if (sign) { + output_data_i += input_data_j * sin_data_cache_idx; + } else { + output_data_i -= input_data_j * sin_data_cache_idx; + } + output_data[i] = static_cast(output_data_i); + } +} + +} // namespace + + +template <> +void +MLASCALL +MlasRotaryEmbedOneRow( + const float* input, + const float* sin, + const float* cos, + size_t dim, + bool interleaved, + float* output +) { + const auto* dispatch = GetMlasPlatform().RopeDispatch; + + if (dispatch == nullptr || dispatch->SRope == nullptr) { + MlasRotaryEmbedOneRow_FallBack(input, sin, cos, dim, interleaved, output); + return; + } + + dispatch->SRope(input, sin, cos, dim, interleaved, output); +} + +template <> +void +MLASCALL +MlasRotaryEmbedOneRow( + const MLAS_FP16* input, + const MLAS_FP16* sin, + const MLAS_FP16* cos, + size_t dim, + bool interleaved, + MLAS_FP16* output +) { + const auto* dispatch = GetMlasPlatform().RopeDispatch; + + if (dispatch == nullptr || dispatch->HRope == nullptr) { + MlasRotaryEmbedOneRow_FallBack(input, sin, cos, dim, interleaved, output); + return; + } + + dispatch->HRope(input, sin, cos, dim, interleaved, output); +} diff --git a/onnxruntime/core/mlas/lib/rotary_embedding.h b/onnxruntime/core/mlas/lib/rotary_embedding.h new file mode 100644 index 0000000000..352dddccf1 --- /dev/null +++ b/onnxruntime/core/mlas/lib/rotary_embedding.h @@ -0,0 +1,46 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding.h + +Abstract: + + This module includes kernel function prototypes and helper functions for + implementing rotary embedding. + +--*/ + +#pragma once + +#include "mlasi.h" + +struct MLAS_ROPE_DISPATCH { + // rotary embedding kernel for fp32 + typedef void(SRope_Fn)( + const float* input, + const float* sin, + const float* cos, + size_t dim, + bool interleaved, + float* output + ); + + SRope_Fn* SRope = nullptr; + + // rotary embedding kernel for fp16 + typedef void(HRope_Fn)( + const MLAS_FP16* input, + const MLAS_FP16* sin, + const MLAS_FP16* cos, + size_t dim, + bool interleaved, + MLAS_FP16* output + ); + + HRope_Fn* HRope = nullptr; +}; diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.cpp b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.cpp new file mode 100644 index 0000000000..e59a95cd9e --- /dev/null +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.cpp @@ -0,0 +1,32 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_neon.cpp + +Abstract: + + This module implements the rotary embedding kernels for ARM NEON. + +--*/ + +#include "rotary_embedding.h" +#include "rotary_embedding_kernel_neon.h" + +// +// Kernel dispatch structure definition. +// +const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon = []() { + MLAS_ROPE_DISPATCH d; + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + if (MlasFp16AccelerationSupported()) { + d.HRope = rope_neon::RopeKernel_Fp16; + } +#endif + return d; +}(); diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.h b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.h new file mode 100644 index 0000000000..8153f65650 --- /dev/null +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.h @@ -0,0 +1,37 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_neon.h + +Abstract: + + This module includes function declarations and common helper functions for + rotary embedding on ARM cpu. + +--*/ + +#pragma once + +#include + +#include "mlasi.h" + +namespace rope_neon { + +// Rotary embedding kernel for fp16. Embed one hidden state vector. +void +RopeKernel_Fp16( + const MLAS_FP16* input, + const MLAS_FP16* sin, + const MLAS_FP16* cos, + size_t dim, + bool interleaved, + MLAS_FP16* output +); + +} // namespace rope_neon diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp new file mode 100644 index 0000000000..3e2eb8fee0 --- /dev/null +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp @@ -0,0 +1,253 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_neon_fp16.cpp + +Abstract: + + This module implements the fp16 rotary embedding kernels for ARM NEON. + +--*/ + +#include +#include + +#include "fp16_common.h" +#include "rotary_embedding.h" +#include "rotary_embedding_kernel_neon.h" + +namespace rope_neon { + +namespace { + +template +void +RopeKernel_Fp16_Impl( + const _mlas_fp16_* input, + const _mlas_fp16_* sin, + const _mlas_fp16_* cos, + size_t dim, + _mlas_fp16_* output +); + +template <> +void +RopeKernel_Fp16_Impl( + const _mlas_fp16_* input, + const _mlas_fp16_* sin, + const _mlas_fp16_* cos, + size_t dim, + _mlas_fp16_* output +) { + const size_t half_dim = dim >> 1; + size_t i = 0, j = half_dim; + for (; i + 7 < half_dim; i += 8, j += 8) { + float16x8_t real = MlasLoadFloat16x8(input + i); + float16x8_t imag = MlasLoadFloat16x8(input + j); + float16x8_t sin_val = MlasLoadFloat16x8(sin + i); + float16x8_t cos_val = MlasLoadFloat16x8(cos + i); + float16x8_t real_out = vfmsq_f16(vmulq_f16(real, cos_val), imag, sin_val); + float16x8_t imag_out = vfmaq_f16(vmulq_f16(real, sin_val), imag, cos_val); + MlasStoreFloat16x8(output + i, real_out); + MlasStoreFloat16x8(output + j, imag_out); + } + for (; i + 3 < half_dim; i += 4, j += 4) { + float16x4_t real = MlasLoadFloat16x4(input + i); + float16x4_t imag = MlasLoadFloat16x4(input + j); + float16x4_t sin_val = MlasLoadFloat16x4(sin + i); + float16x4_t cos_val = MlasLoadFloat16x4(cos + i); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreFloat16x4(output + i, real_out); + MlasStoreFloat16x4(output + j, imag_out); + } + if (half_dim - i == 3) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + real = MlasLoadLaneFloat16x4<1>(input + i + 1, real); + real = MlasLoadLaneFloat16x4<2>(input + i + 2, real); + imag = MlasLoadLaneFloat16x4<0>(input + j, imag); + imag = MlasLoadLaneFloat16x4<1>(input + j + 1, imag); + imag = MlasLoadLaneFloat16x4<2>(input + j + 2, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); + sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); + cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<1>(output + i + 1, real_out); + MlasStoreLaneFloat16x4<2>(output + i + 2, real_out); + MlasStoreLaneFloat16x4<0>(output + j, imag_out); + MlasStoreLaneFloat16x4<1>(output + j + 1, imag_out); + MlasStoreLaneFloat16x4<2>(output + j + 2, imag_out); + } else if (half_dim - i == 2) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + real = MlasLoadLaneFloat16x4<1>(input + i + 1, real); + imag = MlasLoadLaneFloat16x4<0>(input + j, imag); + imag = MlasLoadLaneFloat16x4<1>(input + j + 1, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<1>(output + i + 1, real_out); + MlasStoreLaneFloat16x4<0>(output + j, imag_out); + MlasStoreLaneFloat16x4<1>(output + j + 1, imag_out); + } else if (half_dim - i == 1) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + imag = MlasLoadLaneFloat16x4<0>(input + j, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<0>(output + j, imag_out); + } +} + +template <> +void +RopeKernel_Fp16_Impl( + const _mlas_fp16_* input, + const _mlas_fp16_* sin, + const _mlas_fp16_* cos, + size_t dim, + _mlas_fp16_* output +) { + size_t i = 0; + for (; i + 15 < dim; i += 16) { + float16x8_t x0 = MlasLoadFloat16x8(input + i); + float16x8_t x1 = MlasLoadFloat16x8(input + i + 8); + float16x8_t real = vuzp1q_f16(x0, x1); + float16x8_t imag = vuzp2q_f16(x0, x1); + float16x8_t sin_val = MlasLoadFloat16x8(sin + i); + float16x8_t cos_val = MlasLoadFloat16x8(cos + i); + float16x8_t real_out = vfmsq_f16(vmulq_f16(real, cos_val), imag, sin_val); + float16x8_t imag_out = vfmaq_f16(vmulq_f16(real, sin_val), imag, cos_val); + float16x8_t y0 = vzip1q_f16(real_out, imag_out); + float16x8_t y1 = vzip2q_f16(real_out, imag_out); + MlasStoreFloat16x8(output + i, y0); + MlasStoreFloat16x8(output + i + 8, y1); + } + for (; i + 7 < dim; i += 8) { + float16x4_t x0 = MlasLoadFloat16x4(input + i); + float16x4_t x1 = MlasLoadFloat16x4(input + i + 4); + float16x4_t real = vuzp1_f16(x0, x1); + float16x4_t imag = vuzp2_f16(x0, x1); + float16x4_t sin_val = MlasLoadFloat16x4(sin + i); + float16x4_t cos_val = MlasLoadFloat16x4(cos + i); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + float16x4_t y0 = vzip1_f16(real_out, imag_out); + float16x4_t y1 = vzip2_f16(real_out, imag_out); + MlasStoreFloat16x4(output + i, y0); + MlasStoreFloat16x4(output + i + 4, y1); + } + if (dim - i == 6) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag); + real = MlasLoadLaneFloat16x4<1>(input + i + 2, real); + imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag); + real = MlasLoadLaneFloat16x4<2>(input + i + 4, real); + imag = MlasLoadLaneFloat16x4<2>(input + i + 5, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); + sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); + cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out); + MlasStoreLaneFloat16x4<1>(output + i + 2, real_out); + MlasStoreLaneFloat16x4<1>(output + i + 3, imag_out); + MlasStoreLaneFloat16x4<2>(output + i + 4, real_out); + MlasStoreLaneFloat16x4<2>(output + i + 5, imag_out); + } else if (dim - i == 4) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag); + real = MlasLoadLaneFloat16x4<1>(input + i + 2, real); + imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out); + MlasStoreLaneFloat16x4<1>(output + i + 2, real_out); + MlasStoreLaneFloat16x4<1>(output + i + 3, imag_out); + } else if (dim - i == 2) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out); + } +} + +} // namespace + +void +RopeKernel_Fp16( + const MLAS_FP16* input, + const MLAS_FP16* sin, + const MLAS_FP16* cos, + size_t dim, + bool interleaved, + MLAS_FP16* output +) { + // real part and imaginary part must be paired + assert(dim % 2 == 0); + + const auto* input_impl = reinterpret_cast(input); + const auto* sin_impl = reinterpret_cast(sin); + const auto* cos_impl = reinterpret_cast(cos); + auto* output_impl = reinterpret_cast<_mlas_fp16_*>(output); + + if (interleaved) { + RopeKernel_Fp16_Impl(input_impl, sin_impl, cos_impl, dim, output_impl); + } else { + RopeKernel_Fp16_Impl(input_impl, sin_impl, cos_impl, dim, output_impl); + } +} + +} // namespace rope_neon diff --git a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc index 7d5a701487..0e964cf64f 100644 --- a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc +++ b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc @@ -75,7 +75,7 @@ static void RunTest( if (enable_dml && !disable_dml) { execution_providers.push_back(DefaultDmlExecutionProvider()); } - if (tensor_type == TensorType::kFloat && !disable_cpu) { + if ((tensor_type == TensorType::kFloat || tensor_type == TensorType::kFloat16) && !disable_cpu) { execution_providers.push_back(DefaultCpuExecutionProvider()); } if (enable_webgpu) { @@ -140,26 +140,7 @@ static void RunTests(const std::vector& input_data, int64_t interleaved = 0, int64_t is_packed_batching = 0, bool use_float16 = true) { - // FP32 test for CPU - RunTest(input_data, - position_ids, - cos_cache, - sin_cache, - output_data, - batch_size, - sequence_length, - head_size, - rotary_embedding_dim, - num_heads, - max_sequence_length, - interleaved, - is_packed_batching, - TensorType::kFloat, - false, /* disable_cpu */ - true, /* disable_cuda */ - true /* disable_dml */); - - // FP32 test for CUDA and DML + // FP32 test for CPU, CUDA and DML RunTest(input_data, position_ids, cos_cache, @@ -178,7 +159,7 @@ static void RunTests(const std::vector& input_data, false, /* disable_cuda */ false /* disable_dml */); - // FP16 test for CUDA and DML + // FP16 test for CPU, CUDA and DML if (use_float16) { RunTest(input_data, position_ids, @@ -194,26 +175,9 @@ static void RunTests(const std::vector& input_data, interleaved, is_packed_batching, TensorType::kFloat16, - true, /* disable_cpu */ + false, /* disable_cpu */ false, /* disable_cuda*/ false /* disable_dml */); - - // RunTest(input_data, - // position_ids, - // cos_cache, - // sin_cache, - // output_data, - // batch_size, - // sequence_length, - // head_size, - // rotary_embedding_dim, - // num_heads, - // max_sequence_length, - // interleaved, - // TensorType::kBFloat16, - // true, /* disable_cpu */ - // false, /* disable_cuda*/ - // false /* disable_dml */); } } diff --git a/onnxruntime/test/mlas/bench/bench_fp16_neon_common.cpp b/onnxruntime/test/mlas/bench/bench_cast.cpp similarity index 100% rename from onnxruntime/test/mlas/bench/bench_fp16_neon_common.cpp rename to onnxruntime/test/mlas/bench/bench_cast.cpp