[ARM CPU] Add rotary embedding fp16 kernel (#23013)

### Description
Add fp16 kernel to rotary embedding to boost performance.


### Motivation and Context
Part of performance optimization work for group query attention
This commit is contained in:
Jing Fang 2024-12-06 21:25:48 +00:00 committed by GitHub
parent 401d16c671
commit bd5a759d0c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 526 additions and 69 deletions

View file

@ -41,6 +41,8 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
${MLAS_SRC_DIR}/flashattn.cpp
${MLAS_SRC_DIR}/cast.cpp
${MLAS_SRC_DIR}/rotary_embedding.h
${MLAS_SRC_DIR}/rotary_embedding.cpp
)
target_sources(onnxruntime_mlas PRIVATE
@ -88,8 +90,11 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
${MLAS_SRC_DIR}/fp16_neon_common.cpp
${MLAS_SRC_DIR}/cast_kernel_neon.cpp
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
)
set(mlas_platform_preprocess_srcs
@ -367,6 +372,8 @@ else()
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod")
@ -384,8 +391,9 @@ else()
${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp
${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp
${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/fp16_neon_common.cpp
${MLAS_SRC_DIR}/cast_kernel_neon.cpp
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
@ -395,8 +403,9 @@ else()
set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/fp16_neon_common.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
endif()
if(ONNXRUNTIME_MLAS_MULTI_ARCH)

View file

@ -4,6 +4,7 @@
#include "contrib_ops/cpu/bert/rotary_embedding.h"
#include "contrib_ops/cpu/bert/rotary_embedding_helper.h"
#include "core/mlas/inc/mlas.h"
#include "core/platform/threadpool.h"
using onnxruntime::concurrency::ThreadPool;
@ -78,31 +79,12 @@ Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters paramete
const T* cos_data = cos_cache + cache_offset;
const T* sin_data = sin_cache + cache_offset;
int cache_idx = 0;
bool sign = false;
int j = 0;
for (int i = 0; i < rotary_emb_dim; i++) {
if (interleaved) {
cache_idx = (i / 2) % half_rotary_emb_dim;
sign = i & 1;
j = sign ? i - 1 : i + 1; // i - sign
} else {
cache_idx = i % half_rotary_emb_dim;
sign = (i >= half_rotary_emb_dim);
j = (i + half_rotary_emb_dim) % rotary_emb_dim;
}
float output_data_i = static_cast<float>(input_data[i]) * static_cast<float>(cos_data[cache_idx]);
float input_data_j = static_cast<float>(input_data[j]);
float sin_data_cache_idx = static_cast<float>(sin_data[cache_idx]);
if (sign) {
output_data_i += input_data_j * sin_data_cache_idx;
} else {
output_data_i -= input_data_j * sin_data_cache_idx;
}
output_data[i] = static_cast<T>(output_data_i);
}
for (int i = rotary_emb_dim; i < head_size; i++) {
output_data[i] = input_data[i];
MlasRotaryEmbedOneRow<T>(input_data, sin_data, cos_data, rotary_emb_dim, interleaved, output_data);
if (rotary_emb_dim < head_size) {
std::memcpy(output_data + rotary_emb_dim,
input_data + rotary_emb_dim,
(head_size - rotary_emb_dim) * sizeof(T));
}
}
});

View file

@ -1435,6 +1435,29 @@ MLAS_FP16* Destination,
size_t Count
);
/**
* @brief rotary embedding for one hidden state vector
*
* @tparam T: data type of input, sin, cos and output. Currently only float32/16 are supported.
* @param input: input tensor, of shape [dim]
* @param sin: sin tensor, of shape [dim/2]
* @param cos: cos tensor, of shape [dim/2]
* @param dim: dimension of rotary embedding
* @param interleaved: whether the real part and imaginary parts are interleaved
* @param output: output tensor, of shape [dim]
*/
template <typename T>
void
MLASCALL
MlasRotaryEmbedOneRow(
const T* input,
const T* sin,
const T* cos,
size_t dim,
bool interleaved,
T* output
);
/**
* @brief Whether current CPU supports FP16 acceleration.
*/

View file

@ -6,7 +6,7 @@ Licensed under the MIT License.
Module Name:
fp16_neon_common.cpp
cast_kernel_neon.cpp
Abstract:

View file

@ -1049,6 +1049,13 @@ extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512;
extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni;
//
// Rotary embedding dispatch structure.
//
struct MLAS_ROPE_DISPATCH;
extern const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon;
//
// Quantized depthwise convolution kernels.
//
@ -1208,6 +1215,8 @@ struct MLAS_PLATFORM {
MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel;
MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel;
const MLAS_ROPE_DISPATCH* RopeDispatch{nullptr};
};
inline

View file

@ -543,6 +543,7 @@ Return Value:
this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon;
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon;
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon;
this->RopeDispatch = &MlasRopeDispatchNeon;
//
// Check if the processor supports ASIMD dot product instructions.

View file

@ -0,0 +1,101 @@
/*++
Copyright (c) Intel Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
rotary_embedding.cpp
Abstract:
This module implements rotary embedding kernels for fp32/16.
--*/
#include "rotary_embedding.h"
namespace {
template <typename T>
void
MLASCALL
MlasRotaryEmbedOneRow_FallBack(
const T* input_data,
const T* sin_data,
const T* cos_data,
size_t rotary_emb_dim,
bool interleaved,
T* output_data
) {
const size_t half_rotary_emb_dim = rotary_emb_dim / 2;
size_t cache_idx = 0;
bool sign = false;
size_t j = 0;
for (size_t i = 0; i < rotary_emb_dim; i++) {
if (interleaved) {
cache_idx = (i / 2) % half_rotary_emb_dim;
sign = i & 1;
j = sign ? i - 1 : i + 1; // i - sign
} else {
cache_idx = i % half_rotary_emb_dim;
sign = (i >= half_rotary_emb_dim);
j = (i + half_rotary_emb_dim) % rotary_emb_dim;
}
float output_data_i = static_cast<float>(input_data[i]) * static_cast<float>(cos_data[cache_idx]);
float input_data_j = static_cast<float>(input_data[j]);
float sin_data_cache_idx = static_cast<float>(sin_data[cache_idx]);
if (sign) {
output_data_i += input_data_j * sin_data_cache_idx;
} else {
output_data_i -= input_data_j * sin_data_cache_idx;
}
output_data[i] = static_cast<T>(output_data_i);
}
}
} // namespace
template <>
void
MLASCALL
MlasRotaryEmbedOneRow<float>(
const float* input,
const float* sin,
const float* cos,
size_t dim,
bool interleaved,
float* output
) {
const auto* dispatch = GetMlasPlatform().RopeDispatch;
if (dispatch == nullptr || dispatch->SRope == nullptr) {
MlasRotaryEmbedOneRow_FallBack<float>(input, sin, cos, dim, interleaved, output);
return;
}
dispatch->SRope(input, sin, cos, dim, interleaved, output);
}
template <>
void
MLASCALL
MlasRotaryEmbedOneRow<MLAS_FP16>(
const MLAS_FP16* input,
const MLAS_FP16* sin,
const MLAS_FP16* cos,
size_t dim,
bool interleaved,
MLAS_FP16* output
) {
const auto* dispatch = GetMlasPlatform().RopeDispatch;
if (dispatch == nullptr || dispatch->HRope == nullptr) {
MlasRotaryEmbedOneRow_FallBack<MLAS_FP16>(input, sin, cos, dim, interleaved, output);
return;
}
dispatch->HRope(input, sin, cos, dim, interleaved, output);
}

View file

@ -0,0 +1,46 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
rotary_embedding.h
Abstract:
This module includes kernel function prototypes and helper functions for
implementing rotary embedding.
--*/
#pragma once
#include "mlasi.h"
struct MLAS_ROPE_DISPATCH {
// rotary embedding kernel for fp32
typedef void(SRope_Fn)(
const float* input,
const float* sin,
const float* cos,
size_t dim,
bool interleaved,
float* output
);
SRope_Fn* SRope = nullptr;
// rotary embedding kernel for fp16
typedef void(HRope_Fn)(
const MLAS_FP16* input,
const MLAS_FP16* sin,
const MLAS_FP16* cos,
size_t dim,
bool interleaved,
MLAS_FP16* output
);
HRope_Fn* HRope = nullptr;
};

View file

@ -0,0 +1,32 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
rotary_embedding_kernel_neon.cpp
Abstract:
This module implements the rotary embedding kernels for ARM NEON.
--*/
#include "rotary_embedding.h"
#include "rotary_embedding_kernel_neon.h"
//
// Kernel dispatch structure definition.
//
const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon = []() {
MLAS_ROPE_DISPATCH d;
#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)
if (MlasFp16AccelerationSupported()) {
d.HRope = rope_neon::RopeKernel_Fp16;
}
#endif
return d;
}();

View file

@ -0,0 +1,37 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
rotary_embedding_kernel_neon.h
Abstract:
This module includes function declarations and common helper functions for
rotary embedding on ARM cpu.
--*/
#pragma once
#include <arm_neon.h>
#include "mlasi.h"
namespace rope_neon {
// Rotary embedding kernel for fp16. Embed one hidden state vector.
void
RopeKernel_Fp16(
const MLAS_FP16* input,
const MLAS_FP16* sin,
const MLAS_FP16* cos,
size_t dim,
bool interleaved,
MLAS_FP16* output
);
} // namespace rope_neon

View file

@ -0,0 +1,253 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
rotary_embedding_kernel_neon_fp16.cpp
Abstract:
This module implements the fp16 rotary embedding kernels for ARM NEON.
--*/
#include <arm_neon.h>
#include <cassert>
#include "fp16_common.h"
#include "rotary_embedding.h"
#include "rotary_embedding_kernel_neon.h"
namespace rope_neon {
namespace {
template <bool interleaved>
void
RopeKernel_Fp16_Impl(
const _mlas_fp16_* input,
const _mlas_fp16_* sin,
const _mlas_fp16_* cos,
size_t dim,
_mlas_fp16_* output
);
template <>
void
RopeKernel_Fp16_Impl<false>(
const _mlas_fp16_* input,
const _mlas_fp16_* sin,
const _mlas_fp16_* cos,
size_t dim,
_mlas_fp16_* output
) {
const size_t half_dim = dim >> 1;
size_t i = 0, j = half_dim;
for (; i + 7 < half_dim; i += 8, j += 8) {
float16x8_t real = MlasLoadFloat16x8(input + i);
float16x8_t imag = MlasLoadFloat16x8(input + j);
float16x8_t sin_val = MlasLoadFloat16x8(sin + i);
float16x8_t cos_val = MlasLoadFloat16x8(cos + i);
float16x8_t real_out = vfmsq_f16(vmulq_f16(real, cos_val), imag, sin_val);
float16x8_t imag_out = vfmaq_f16(vmulq_f16(real, sin_val), imag, cos_val);
MlasStoreFloat16x8(output + i, real_out);
MlasStoreFloat16x8(output + j, imag_out);
}
for (; i + 3 < half_dim; i += 4, j += 4) {
float16x4_t real = MlasLoadFloat16x4(input + i);
float16x4_t imag = MlasLoadFloat16x4(input + j);
float16x4_t sin_val = MlasLoadFloat16x4(sin + i);
float16x4_t cos_val = MlasLoadFloat16x4(cos + i);
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
MlasStoreFloat16x4(output + i, real_out);
MlasStoreFloat16x4(output + j, imag_out);
}
if (half_dim - i == 3) {
float16x4_t real = MlasZeroFloat16x4();
float16x4_t imag = MlasZeroFloat16x4();
float16x4_t sin_val = MlasZeroFloat16x4();
float16x4_t cos_val = MlasZeroFloat16x4();
real = MlasLoadLaneFloat16x4<0>(input + i, real);
real = MlasLoadLaneFloat16x4<1>(input + i + 1, real);
real = MlasLoadLaneFloat16x4<2>(input + i + 2, real);
imag = MlasLoadLaneFloat16x4<0>(input + j, imag);
imag = MlasLoadLaneFloat16x4<1>(input + j + 1, imag);
imag = MlasLoadLaneFloat16x4<2>(input + j + 2, imag);
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val);
sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val);
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val);
cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val);
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
MlasStoreLaneFloat16x4<0>(output + i, real_out);
MlasStoreLaneFloat16x4<1>(output + i + 1, real_out);
MlasStoreLaneFloat16x4<2>(output + i + 2, real_out);
MlasStoreLaneFloat16x4<0>(output + j, imag_out);
MlasStoreLaneFloat16x4<1>(output + j + 1, imag_out);
MlasStoreLaneFloat16x4<2>(output + j + 2, imag_out);
} else if (half_dim - i == 2) {
float16x4_t real = MlasZeroFloat16x4();
float16x4_t imag = MlasZeroFloat16x4();
float16x4_t sin_val = MlasZeroFloat16x4();
float16x4_t cos_val = MlasZeroFloat16x4();
real = MlasLoadLaneFloat16x4<0>(input + i, real);
real = MlasLoadLaneFloat16x4<1>(input + i + 1, real);
imag = MlasLoadLaneFloat16x4<0>(input + j, imag);
imag = MlasLoadLaneFloat16x4<1>(input + j + 1, imag);
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val);
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val);
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
MlasStoreLaneFloat16x4<0>(output + i, real_out);
MlasStoreLaneFloat16x4<1>(output + i + 1, real_out);
MlasStoreLaneFloat16x4<0>(output + j, imag_out);
MlasStoreLaneFloat16x4<1>(output + j + 1, imag_out);
} else if (half_dim - i == 1) {
float16x4_t real = MlasZeroFloat16x4();
float16x4_t imag = MlasZeroFloat16x4();
float16x4_t sin_val = MlasZeroFloat16x4();
float16x4_t cos_val = MlasZeroFloat16x4();
real = MlasLoadLaneFloat16x4<0>(input + i, real);
imag = MlasLoadLaneFloat16x4<0>(input + j, imag);
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
MlasStoreLaneFloat16x4<0>(output + i, real_out);
MlasStoreLaneFloat16x4<0>(output + j, imag_out);
}
}
template <>
void
RopeKernel_Fp16_Impl<true>(
const _mlas_fp16_* input,
const _mlas_fp16_* sin,
const _mlas_fp16_* cos,
size_t dim,
_mlas_fp16_* output
) {
size_t i = 0;
for (; i + 15 < dim; i += 16) {
float16x8_t x0 = MlasLoadFloat16x8(input + i);
float16x8_t x1 = MlasLoadFloat16x8(input + i + 8);
float16x8_t real = vuzp1q_f16(x0, x1);
float16x8_t imag = vuzp2q_f16(x0, x1);
float16x8_t sin_val = MlasLoadFloat16x8(sin + i);
float16x8_t cos_val = MlasLoadFloat16x8(cos + i);
float16x8_t real_out = vfmsq_f16(vmulq_f16(real, cos_val), imag, sin_val);
float16x8_t imag_out = vfmaq_f16(vmulq_f16(real, sin_val), imag, cos_val);
float16x8_t y0 = vzip1q_f16(real_out, imag_out);
float16x8_t y1 = vzip2q_f16(real_out, imag_out);
MlasStoreFloat16x8(output + i, y0);
MlasStoreFloat16x8(output + i + 8, y1);
}
for (; i + 7 < dim; i += 8) {
float16x4_t x0 = MlasLoadFloat16x4(input + i);
float16x4_t x1 = MlasLoadFloat16x4(input + i + 4);
float16x4_t real = vuzp1_f16(x0, x1);
float16x4_t imag = vuzp2_f16(x0, x1);
float16x4_t sin_val = MlasLoadFloat16x4(sin + i);
float16x4_t cos_val = MlasLoadFloat16x4(cos + i);
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
float16x4_t y0 = vzip1_f16(real_out, imag_out);
float16x4_t y1 = vzip2_f16(real_out, imag_out);
MlasStoreFloat16x4(output + i, y0);
MlasStoreFloat16x4(output + i + 4, y1);
}
if (dim - i == 6) {
float16x4_t real = MlasZeroFloat16x4();
float16x4_t imag = MlasZeroFloat16x4();
float16x4_t sin_val = MlasZeroFloat16x4();
float16x4_t cos_val = MlasZeroFloat16x4();
real = MlasLoadLaneFloat16x4<0>(input + i, real);
imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag);
real = MlasLoadLaneFloat16x4<1>(input + i + 2, real);
imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag);
real = MlasLoadLaneFloat16x4<2>(input + i + 4, real);
imag = MlasLoadLaneFloat16x4<2>(input + i + 5, imag);
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val);
sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val);
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val);
cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val);
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
MlasStoreLaneFloat16x4<0>(output + i, real_out);
MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out);
MlasStoreLaneFloat16x4<1>(output + i + 2, real_out);
MlasStoreLaneFloat16x4<1>(output + i + 3, imag_out);
MlasStoreLaneFloat16x4<2>(output + i + 4, real_out);
MlasStoreLaneFloat16x4<2>(output + i + 5, imag_out);
} else if (dim - i == 4) {
float16x4_t real = MlasZeroFloat16x4();
float16x4_t imag = MlasZeroFloat16x4();
float16x4_t sin_val = MlasZeroFloat16x4();
float16x4_t cos_val = MlasZeroFloat16x4();
real = MlasLoadLaneFloat16x4<0>(input + i, real);
imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag);
real = MlasLoadLaneFloat16x4<1>(input + i + 2, real);
imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag);
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val);
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val);
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
MlasStoreLaneFloat16x4<0>(output + i, real_out);
MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out);
MlasStoreLaneFloat16x4<1>(output + i + 2, real_out);
MlasStoreLaneFloat16x4<1>(output + i + 3, imag_out);
} else if (dim - i == 2) {
float16x4_t real = MlasZeroFloat16x4();
float16x4_t imag = MlasZeroFloat16x4();
float16x4_t sin_val = MlasZeroFloat16x4();
float16x4_t cos_val = MlasZeroFloat16x4();
real = MlasLoadLaneFloat16x4<0>(input + i, real);
imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag);
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
MlasStoreLaneFloat16x4<0>(output + i, real_out);
MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out);
}
}
} // namespace
void
RopeKernel_Fp16(
const MLAS_FP16* input,
const MLAS_FP16* sin,
const MLAS_FP16* cos,
size_t dim,
bool interleaved,
MLAS_FP16* output
) {
// real part and imaginary part must be paired
assert(dim % 2 == 0);
const auto* input_impl = reinterpret_cast<const _mlas_fp16_*>(input);
const auto* sin_impl = reinterpret_cast<const _mlas_fp16_*>(sin);
const auto* cos_impl = reinterpret_cast<const _mlas_fp16_*>(cos);
auto* output_impl = reinterpret_cast<_mlas_fp16_*>(output);
if (interleaved) {
RopeKernel_Fp16_Impl<true>(input_impl, sin_impl, cos_impl, dim, output_impl);
} else {
RopeKernel_Fp16_Impl<false>(input_impl, sin_impl, cos_impl, dim, output_impl);
}
}
} // namespace rope_neon

View file

@ -75,7 +75,7 @@ static void RunTest(
if (enable_dml && !disable_dml) {
execution_providers.push_back(DefaultDmlExecutionProvider());
}
if (tensor_type == TensorType::kFloat && !disable_cpu) {
if ((tensor_type == TensorType::kFloat || tensor_type == TensorType::kFloat16) && !disable_cpu) {
execution_providers.push_back(DefaultCpuExecutionProvider());
}
if (enable_webgpu) {
@ -140,26 +140,7 @@ static void RunTests(const std::vector<float>& input_data,
int64_t interleaved = 0,
int64_t is_packed_batching = 0,
bool use_float16 = true) {
// FP32 test for CPU
RunTest(input_data,
position_ids,
cos_cache,
sin_cache,
output_data,
batch_size,
sequence_length,
head_size,
rotary_embedding_dim,
num_heads,
max_sequence_length,
interleaved,
is_packed_batching,
TensorType::kFloat,
false, /* disable_cpu */
true, /* disable_cuda */
true /* disable_dml */);
// FP32 test for CUDA and DML
// FP32 test for CPU, CUDA and DML
RunTest(input_data,
position_ids,
cos_cache,
@ -178,7 +159,7 @@ static void RunTests(const std::vector<float>& input_data,
false, /* disable_cuda */
false /* disable_dml */);
// FP16 test for CUDA and DML
// FP16 test for CPU, CUDA and DML
if (use_float16) {
RunTest(input_data,
position_ids,
@ -194,26 +175,9 @@ static void RunTests(const std::vector<float>& input_data,
interleaved,
is_packed_batching,
TensorType::kFloat16,
true, /* disable_cpu */
false, /* disable_cpu */
false, /* disable_cuda*/
false /* disable_dml */);
// RunTest(input_data,
// position_ids,
// cos_cache,
// sin_cache,
// output_data,
// batch_size,
// sequence_length,
// head_size,
// rotary_embedding_dim,
// num_heads,
// max_sequence_length,
// interleaved,
// TensorType::kBFloat16,
// true, /* disable_cpu */
// false, /* disable_cuda*/
// false /* disable_dml */);
}
}