mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
[ARM CPU] Add rotary embedding fp16 kernel (#23013)
### Description Add fp16 kernel to rotary embedding to boost performance. ### Motivation and Context Part of performance optimization work for group query attention
This commit is contained in:
parent
401d16c671
commit
bd5a759d0c
13 changed files with 526 additions and 69 deletions
|
|
@ -41,6 +41,8 @@ onnxruntime_add_static_library(onnxruntime_mlas
|
|||
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
|
||||
${MLAS_SRC_DIR}/flashattn.cpp
|
||||
${MLAS_SRC_DIR}/cast.cpp
|
||||
${MLAS_SRC_DIR}/rotary_embedding.h
|
||||
${MLAS_SRC_DIR}/rotary_embedding.cpp
|
||||
)
|
||||
|
||||
target_sources(onnxruntime_mlas PRIVATE
|
||||
|
|
@ -88,8 +90,11 @@ function(setup_mlas_source_for_windows)
|
|||
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
|
||||
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
|
||||
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
|
||||
${MLAS_SRC_DIR}/fp16_neon_common.cpp
|
||||
${MLAS_SRC_DIR}/cast_kernel_neon.cpp
|
||||
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
|
||||
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
|
||||
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
|
||||
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
|
||||
)
|
||||
|
||||
set(mlas_platform_preprocess_srcs
|
||||
|
|
@ -367,6 +372,8 @@ else()
|
|||
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
|
||||
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
|
||||
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
|
||||
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
|
||||
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
|
||||
)
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
|
||||
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod")
|
||||
|
|
@ -384,8 +391,9 @@ else()
|
|||
${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp
|
||||
${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp
|
||||
${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp
|
||||
${MLAS_SRC_DIR}/fp16_neon_common.cpp
|
||||
${MLAS_SRC_DIR}/cast_kernel_neon.cpp
|
||||
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
|
||||
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
|
||||
)
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
|
||||
|
|
@ -395,8 +403,9 @@ else()
|
|||
set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/fp16_neon_common.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
|
||||
endif()
|
||||
|
||||
if(ONNXRUNTIME_MLAS_MULTI_ARCH)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include "contrib_ops/cpu/bert/rotary_embedding.h"
|
||||
#include "contrib_ops/cpu/bert/rotary_embedding_helper.h"
|
||||
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
#include "core/platform/threadpool.h"
|
||||
|
||||
using onnxruntime::concurrency::ThreadPool;
|
||||
|
|
@ -78,31 +79,12 @@ Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters paramete
|
|||
const T* cos_data = cos_cache + cache_offset;
|
||||
const T* sin_data = sin_cache + cache_offset;
|
||||
|
||||
int cache_idx = 0;
|
||||
bool sign = false;
|
||||
int j = 0;
|
||||
for (int i = 0; i < rotary_emb_dim; i++) {
|
||||
if (interleaved) {
|
||||
cache_idx = (i / 2) % half_rotary_emb_dim;
|
||||
sign = i & 1;
|
||||
j = sign ? i - 1 : i + 1; // i - sign
|
||||
} else {
|
||||
cache_idx = i % half_rotary_emb_dim;
|
||||
sign = (i >= half_rotary_emb_dim);
|
||||
j = (i + half_rotary_emb_dim) % rotary_emb_dim;
|
||||
}
|
||||
float output_data_i = static_cast<float>(input_data[i]) * static_cast<float>(cos_data[cache_idx]);
|
||||
float input_data_j = static_cast<float>(input_data[j]);
|
||||
float sin_data_cache_idx = static_cast<float>(sin_data[cache_idx]);
|
||||
if (sign) {
|
||||
output_data_i += input_data_j * sin_data_cache_idx;
|
||||
} else {
|
||||
output_data_i -= input_data_j * sin_data_cache_idx;
|
||||
}
|
||||
output_data[i] = static_cast<T>(output_data_i);
|
||||
}
|
||||
for (int i = rotary_emb_dim; i < head_size; i++) {
|
||||
output_data[i] = input_data[i];
|
||||
MlasRotaryEmbedOneRow<T>(input_data, sin_data, cos_data, rotary_emb_dim, interleaved, output_data);
|
||||
|
||||
if (rotary_emb_dim < head_size) {
|
||||
std::memcpy(output_data + rotary_emb_dim,
|
||||
input_data + rotary_emb_dim,
|
||||
(head_size - rotary_emb_dim) * sizeof(T));
|
||||
}
|
||||
}
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1435,6 +1435,29 @@ MLAS_FP16* Destination,
|
|||
size_t Count
|
||||
);
|
||||
|
||||
/**
|
||||
* @brief rotary embedding for one hidden state vector
|
||||
*
|
||||
* @tparam T: data type of input, sin, cos and output. Currently only float32/16 are supported.
|
||||
* @param input: input tensor, of shape [dim]
|
||||
* @param sin: sin tensor, of shape [dim/2]
|
||||
* @param cos: cos tensor, of shape [dim/2]
|
||||
* @param dim: dimension of rotary embedding
|
||||
* @param interleaved: whether the real part and imaginary parts are interleaved
|
||||
* @param output: output tensor, of shape [dim]
|
||||
*/
|
||||
template <typename T>
|
||||
void
|
||||
MLASCALL
|
||||
MlasRotaryEmbedOneRow(
|
||||
const T* input,
|
||||
const T* sin,
|
||||
const T* cos,
|
||||
size_t dim,
|
||||
bool interleaved,
|
||||
T* output
|
||||
);
|
||||
|
||||
/**
|
||||
* @brief Whether current CPU supports FP16 acceleration.
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ Licensed under the MIT License.
|
|||
|
||||
Module Name:
|
||||
|
||||
fp16_neon_common.cpp
|
||||
cast_kernel_neon.cpp
|
||||
|
||||
Abstract:
|
||||
|
||||
|
|
@ -1049,6 +1049,13 @@ extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512;
|
|||
|
||||
extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni;
|
||||
|
||||
//
|
||||
// Rotary embedding dispatch structure.
|
||||
//
|
||||
struct MLAS_ROPE_DISPATCH;
|
||||
extern const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon;
|
||||
|
||||
|
||||
//
|
||||
// Quantized depthwise convolution kernels.
|
||||
//
|
||||
|
|
@ -1208,6 +1215,8 @@ struct MLAS_PLATFORM {
|
|||
|
||||
MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel;
|
||||
MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel;
|
||||
|
||||
const MLAS_ROPE_DISPATCH* RopeDispatch{nullptr};
|
||||
};
|
||||
|
||||
inline
|
||||
|
|
|
|||
|
|
@ -543,6 +543,7 @@ Return Value:
|
|||
this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon;
|
||||
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon;
|
||||
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon;
|
||||
this->RopeDispatch = &MlasRopeDispatchNeon;
|
||||
|
||||
//
|
||||
// Check if the processor supports ASIMD dot product instructions.
|
||||
|
|
|
|||
101
onnxruntime/core/mlas/lib/rotary_embedding.cpp
Normal file
101
onnxruntime/core/mlas/lib/rotary_embedding.cpp
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Intel Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
rotary_embedding.cpp
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements rotary embedding kernels for fp32/16.
|
||||
|
||||
--*/
|
||||
|
||||
#include "rotary_embedding.h"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
MLASCALL
|
||||
MlasRotaryEmbedOneRow_FallBack(
|
||||
const T* input_data,
|
||||
const T* sin_data,
|
||||
const T* cos_data,
|
||||
size_t rotary_emb_dim,
|
||||
bool interleaved,
|
||||
T* output_data
|
||||
) {
|
||||
const size_t half_rotary_emb_dim = rotary_emb_dim / 2;
|
||||
size_t cache_idx = 0;
|
||||
bool sign = false;
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < rotary_emb_dim; i++) {
|
||||
if (interleaved) {
|
||||
cache_idx = (i / 2) % half_rotary_emb_dim;
|
||||
sign = i & 1;
|
||||
j = sign ? i - 1 : i + 1; // i - sign
|
||||
} else {
|
||||
cache_idx = i % half_rotary_emb_dim;
|
||||
sign = (i >= half_rotary_emb_dim);
|
||||
j = (i + half_rotary_emb_dim) % rotary_emb_dim;
|
||||
}
|
||||
float output_data_i = static_cast<float>(input_data[i]) * static_cast<float>(cos_data[cache_idx]);
|
||||
float input_data_j = static_cast<float>(input_data[j]);
|
||||
float sin_data_cache_idx = static_cast<float>(sin_data[cache_idx]);
|
||||
if (sign) {
|
||||
output_data_i += input_data_j * sin_data_cache_idx;
|
||||
} else {
|
||||
output_data_i -= input_data_j * sin_data_cache_idx;
|
||||
}
|
||||
output_data[i] = static_cast<T>(output_data_i);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
|
||||
template <>
|
||||
void
|
||||
MLASCALL
|
||||
MlasRotaryEmbedOneRow<float>(
|
||||
const float* input,
|
||||
const float* sin,
|
||||
const float* cos,
|
||||
size_t dim,
|
||||
bool interleaved,
|
||||
float* output
|
||||
) {
|
||||
const auto* dispatch = GetMlasPlatform().RopeDispatch;
|
||||
|
||||
if (dispatch == nullptr || dispatch->SRope == nullptr) {
|
||||
MlasRotaryEmbedOneRow_FallBack<float>(input, sin, cos, dim, interleaved, output);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch->SRope(input, sin, cos, dim, interleaved, output);
|
||||
}
|
||||
|
||||
template <>
|
||||
void
|
||||
MLASCALL
|
||||
MlasRotaryEmbedOneRow<MLAS_FP16>(
|
||||
const MLAS_FP16* input,
|
||||
const MLAS_FP16* sin,
|
||||
const MLAS_FP16* cos,
|
||||
size_t dim,
|
||||
bool interleaved,
|
||||
MLAS_FP16* output
|
||||
) {
|
||||
const auto* dispatch = GetMlasPlatform().RopeDispatch;
|
||||
|
||||
if (dispatch == nullptr || dispatch->HRope == nullptr) {
|
||||
MlasRotaryEmbedOneRow_FallBack<MLAS_FP16>(input, sin, cos, dim, interleaved, output);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch->HRope(input, sin, cos, dim, interleaved, output);
|
||||
}
|
||||
46
onnxruntime/core/mlas/lib/rotary_embedding.h
Normal file
46
onnxruntime/core/mlas/lib/rotary_embedding.h
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
rotary_embedding.h
|
||||
|
||||
Abstract:
|
||||
|
||||
This module includes kernel function prototypes and helper functions for
|
||||
implementing rotary embedding.
|
||||
|
||||
--*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlasi.h"
|
||||
|
||||
struct MLAS_ROPE_DISPATCH {
|
||||
// rotary embedding kernel for fp32
|
||||
typedef void(SRope_Fn)(
|
||||
const float* input,
|
||||
const float* sin,
|
||||
const float* cos,
|
||||
size_t dim,
|
||||
bool interleaved,
|
||||
float* output
|
||||
);
|
||||
|
||||
SRope_Fn* SRope = nullptr;
|
||||
|
||||
// rotary embedding kernel for fp16
|
||||
typedef void(HRope_Fn)(
|
||||
const MLAS_FP16* input,
|
||||
const MLAS_FP16* sin,
|
||||
const MLAS_FP16* cos,
|
||||
size_t dim,
|
||||
bool interleaved,
|
||||
MLAS_FP16* output
|
||||
);
|
||||
|
||||
HRope_Fn* HRope = nullptr;
|
||||
};
|
||||
32
onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.cpp
Normal file
32
onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.cpp
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
rotary_embedding_kernel_neon.cpp
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements the rotary embedding kernels for ARM NEON.
|
||||
|
||||
--*/
|
||||
|
||||
#include "rotary_embedding.h"
|
||||
#include "rotary_embedding_kernel_neon.h"
|
||||
|
||||
//
|
||||
// Kernel dispatch structure definition.
|
||||
//
|
||||
const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon = []() {
|
||||
MLAS_ROPE_DISPATCH d;
|
||||
|
||||
#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)
|
||||
if (MlasFp16AccelerationSupported()) {
|
||||
d.HRope = rope_neon::RopeKernel_Fp16;
|
||||
}
|
||||
#endif
|
||||
return d;
|
||||
}();
|
||||
37
onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.h
Normal file
37
onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.h
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
rotary_embedding_kernel_neon.h
|
||||
|
||||
Abstract:
|
||||
|
||||
This module includes function declarations and common helper functions for
|
||||
rotary embedding on ARM cpu.
|
||||
|
||||
--*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <arm_neon.h>
|
||||
|
||||
#include "mlasi.h"
|
||||
|
||||
namespace rope_neon {
|
||||
|
||||
// Rotary embedding kernel for fp16. Embed one hidden state vector.
|
||||
void
|
||||
RopeKernel_Fp16(
|
||||
const MLAS_FP16* input,
|
||||
const MLAS_FP16* sin,
|
||||
const MLAS_FP16* cos,
|
||||
size_t dim,
|
||||
bool interleaved,
|
||||
MLAS_FP16* output
|
||||
);
|
||||
|
||||
} // namespace rope_neon
|
||||
253
onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp
Normal file
253
onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
rotary_embedding_kernel_neon_fp16.cpp
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements the fp16 rotary embedding kernels for ARM NEON.
|
||||
|
||||
--*/
|
||||
|
||||
#include <arm_neon.h>
|
||||
#include <cassert>
|
||||
|
||||
#include "fp16_common.h"
|
||||
#include "rotary_embedding.h"
|
||||
#include "rotary_embedding_kernel_neon.h"
|
||||
|
||||
namespace rope_neon {
|
||||
|
||||
namespace {
|
||||
|
||||
template <bool interleaved>
|
||||
void
|
||||
RopeKernel_Fp16_Impl(
|
||||
const _mlas_fp16_* input,
|
||||
const _mlas_fp16_* sin,
|
||||
const _mlas_fp16_* cos,
|
||||
size_t dim,
|
||||
_mlas_fp16_* output
|
||||
);
|
||||
|
||||
template <>
|
||||
void
|
||||
RopeKernel_Fp16_Impl<false>(
|
||||
const _mlas_fp16_* input,
|
||||
const _mlas_fp16_* sin,
|
||||
const _mlas_fp16_* cos,
|
||||
size_t dim,
|
||||
_mlas_fp16_* output
|
||||
) {
|
||||
const size_t half_dim = dim >> 1;
|
||||
size_t i = 0, j = half_dim;
|
||||
for (; i + 7 < half_dim; i += 8, j += 8) {
|
||||
float16x8_t real = MlasLoadFloat16x8(input + i);
|
||||
float16x8_t imag = MlasLoadFloat16x8(input + j);
|
||||
float16x8_t sin_val = MlasLoadFloat16x8(sin + i);
|
||||
float16x8_t cos_val = MlasLoadFloat16x8(cos + i);
|
||||
float16x8_t real_out = vfmsq_f16(vmulq_f16(real, cos_val), imag, sin_val);
|
||||
float16x8_t imag_out = vfmaq_f16(vmulq_f16(real, sin_val), imag, cos_val);
|
||||
MlasStoreFloat16x8(output + i, real_out);
|
||||
MlasStoreFloat16x8(output + j, imag_out);
|
||||
}
|
||||
for (; i + 3 < half_dim; i += 4, j += 4) {
|
||||
float16x4_t real = MlasLoadFloat16x4(input + i);
|
||||
float16x4_t imag = MlasLoadFloat16x4(input + j);
|
||||
float16x4_t sin_val = MlasLoadFloat16x4(sin + i);
|
||||
float16x4_t cos_val = MlasLoadFloat16x4(cos + i);
|
||||
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
|
||||
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
|
||||
MlasStoreFloat16x4(output + i, real_out);
|
||||
MlasStoreFloat16x4(output + j, imag_out);
|
||||
}
|
||||
if (half_dim - i == 3) {
|
||||
float16x4_t real = MlasZeroFloat16x4();
|
||||
float16x4_t imag = MlasZeroFloat16x4();
|
||||
float16x4_t sin_val = MlasZeroFloat16x4();
|
||||
float16x4_t cos_val = MlasZeroFloat16x4();
|
||||
real = MlasLoadLaneFloat16x4<0>(input + i, real);
|
||||
real = MlasLoadLaneFloat16x4<1>(input + i + 1, real);
|
||||
real = MlasLoadLaneFloat16x4<2>(input + i + 2, real);
|
||||
imag = MlasLoadLaneFloat16x4<0>(input + j, imag);
|
||||
imag = MlasLoadLaneFloat16x4<1>(input + j + 1, imag);
|
||||
imag = MlasLoadLaneFloat16x4<2>(input + j + 2, imag);
|
||||
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
|
||||
sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val);
|
||||
sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val);
|
||||
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
|
||||
cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val);
|
||||
cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val);
|
||||
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
|
||||
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
|
||||
MlasStoreLaneFloat16x4<0>(output + i, real_out);
|
||||
MlasStoreLaneFloat16x4<1>(output + i + 1, real_out);
|
||||
MlasStoreLaneFloat16x4<2>(output + i + 2, real_out);
|
||||
MlasStoreLaneFloat16x4<0>(output + j, imag_out);
|
||||
MlasStoreLaneFloat16x4<1>(output + j + 1, imag_out);
|
||||
MlasStoreLaneFloat16x4<2>(output + j + 2, imag_out);
|
||||
} else if (half_dim - i == 2) {
|
||||
float16x4_t real = MlasZeroFloat16x4();
|
||||
float16x4_t imag = MlasZeroFloat16x4();
|
||||
float16x4_t sin_val = MlasZeroFloat16x4();
|
||||
float16x4_t cos_val = MlasZeroFloat16x4();
|
||||
real = MlasLoadLaneFloat16x4<0>(input + i, real);
|
||||
real = MlasLoadLaneFloat16x4<1>(input + i + 1, real);
|
||||
imag = MlasLoadLaneFloat16x4<0>(input + j, imag);
|
||||
imag = MlasLoadLaneFloat16x4<1>(input + j + 1, imag);
|
||||
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
|
||||
sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val);
|
||||
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
|
||||
cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val);
|
||||
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
|
||||
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
|
||||
MlasStoreLaneFloat16x4<0>(output + i, real_out);
|
||||
MlasStoreLaneFloat16x4<1>(output + i + 1, real_out);
|
||||
MlasStoreLaneFloat16x4<0>(output + j, imag_out);
|
||||
MlasStoreLaneFloat16x4<1>(output + j + 1, imag_out);
|
||||
} else if (half_dim - i == 1) {
|
||||
float16x4_t real = MlasZeroFloat16x4();
|
||||
float16x4_t imag = MlasZeroFloat16x4();
|
||||
float16x4_t sin_val = MlasZeroFloat16x4();
|
||||
float16x4_t cos_val = MlasZeroFloat16x4();
|
||||
real = MlasLoadLaneFloat16x4<0>(input + i, real);
|
||||
imag = MlasLoadLaneFloat16x4<0>(input + j, imag);
|
||||
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
|
||||
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
|
||||
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
|
||||
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
|
||||
MlasStoreLaneFloat16x4<0>(output + i, real_out);
|
||||
MlasStoreLaneFloat16x4<0>(output + j, imag_out);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void
|
||||
RopeKernel_Fp16_Impl<true>(
|
||||
const _mlas_fp16_* input,
|
||||
const _mlas_fp16_* sin,
|
||||
const _mlas_fp16_* cos,
|
||||
size_t dim,
|
||||
_mlas_fp16_* output
|
||||
) {
|
||||
size_t i = 0;
|
||||
for (; i + 15 < dim; i += 16) {
|
||||
float16x8_t x0 = MlasLoadFloat16x8(input + i);
|
||||
float16x8_t x1 = MlasLoadFloat16x8(input + i + 8);
|
||||
float16x8_t real = vuzp1q_f16(x0, x1);
|
||||
float16x8_t imag = vuzp2q_f16(x0, x1);
|
||||
float16x8_t sin_val = MlasLoadFloat16x8(sin + i);
|
||||
float16x8_t cos_val = MlasLoadFloat16x8(cos + i);
|
||||
float16x8_t real_out = vfmsq_f16(vmulq_f16(real, cos_val), imag, sin_val);
|
||||
float16x8_t imag_out = vfmaq_f16(vmulq_f16(real, sin_val), imag, cos_val);
|
||||
float16x8_t y0 = vzip1q_f16(real_out, imag_out);
|
||||
float16x8_t y1 = vzip2q_f16(real_out, imag_out);
|
||||
MlasStoreFloat16x8(output + i, y0);
|
||||
MlasStoreFloat16x8(output + i + 8, y1);
|
||||
}
|
||||
for (; i + 7 < dim; i += 8) {
|
||||
float16x4_t x0 = MlasLoadFloat16x4(input + i);
|
||||
float16x4_t x1 = MlasLoadFloat16x4(input + i + 4);
|
||||
float16x4_t real = vuzp1_f16(x0, x1);
|
||||
float16x4_t imag = vuzp2_f16(x0, x1);
|
||||
float16x4_t sin_val = MlasLoadFloat16x4(sin + i);
|
||||
float16x4_t cos_val = MlasLoadFloat16x4(cos + i);
|
||||
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
|
||||
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
|
||||
float16x4_t y0 = vzip1_f16(real_out, imag_out);
|
||||
float16x4_t y1 = vzip2_f16(real_out, imag_out);
|
||||
MlasStoreFloat16x4(output + i, y0);
|
||||
MlasStoreFloat16x4(output + i + 4, y1);
|
||||
}
|
||||
if (dim - i == 6) {
|
||||
float16x4_t real = MlasZeroFloat16x4();
|
||||
float16x4_t imag = MlasZeroFloat16x4();
|
||||
float16x4_t sin_val = MlasZeroFloat16x4();
|
||||
float16x4_t cos_val = MlasZeroFloat16x4();
|
||||
real = MlasLoadLaneFloat16x4<0>(input + i, real);
|
||||
imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag);
|
||||
real = MlasLoadLaneFloat16x4<1>(input + i + 2, real);
|
||||
imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag);
|
||||
real = MlasLoadLaneFloat16x4<2>(input + i + 4, real);
|
||||
imag = MlasLoadLaneFloat16x4<2>(input + i + 5, imag);
|
||||
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
|
||||
sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val);
|
||||
sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val);
|
||||
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
|
||||
cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val);
|
||||
cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val);
|
||||
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
|
||||
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
|
||||
MlasStoreLaneFloat16x4<0>(output + i, real_out);
|
||||
MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out);
|
||||
MlasStoreLaneFloat16x4<1>(output + i + 2, real_out);
|
||||
MlasStoreLaneFloat16x4<1>(output + i + 3, imag_out);
|
||||
MlasStoreLaneFloat16x4<2>(output + i + 4, real_out);
|
||||
MlasStoreLaneFloat16x4<2>(output + i + 5, imag_out);
|
||||
} else if (dim - i == 4) {
|
||||
float16x4_t real = MlasZeroFloat16x4();
|
||||
float16x4_t imag = MlasZeroFloat16x4();
|
||||
float16x4_t sin_val = MlasZeroFloat16x4();
|
||||
float16x4_t cos_val = MlasZeroFloat16x4();
|
||||
real = MlasLoadLaneFloat16x4<0>(input + i, real);
|
||||
imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag);
|
||||
real = MlasLoadLaneFloat16x4<1>(input + i + 2, real);
|
||||
imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag);
|
||||
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
|
||||
sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val);
|
||||
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
|
||||
cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val);
|
||||
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
|
||||
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
|
||||
MlasStoreLaneFloat16x4<0>(output + i, real_out);
|
||||
MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out);
|
||||
MlasStoreLaneFloat16x4<1>(output + i + 2, real_out);
|
||||
MlasStoreLaneFloat16x4<1>(output + i + 3, imag_out);
|
||||
} else if (dim - i == 2) {
|
||||
float16x4_t real = MlasZeroFloat16x4();
|
||||
float16x4_t imag = MlasZeroFloat16x4();
|
||||
float16x4_t sin_val = MlasZeroFloat16x4();
|
||||
float16x4_t cos_val = MlasZeroFloat16x4();
|
||||
real = MlasLoadLaneFloat16x4<0>(input + i, real);
|
||||
imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag);
|
||||
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
|
||||
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
|
||||
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
|
||||
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
|
||||
MlasStoreLaneFloat16x4<0>(output + i, real_out);
|
||||
MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void
|
||||
RopeKernel_Fp16(
|
||||
const MLAS_FP16* input,
|
||||
const MLAS_FP16* sin,
|
||||
const MLAS_FP16* cos,
|
||||
size_t dim,
|
||||
bool interleaved,
|
||||
MLAS_FP16* output
|
||||
) {
|
||||
// real part and imaginary part must be paired
|
||||
assert(dim % 2 == 0);
|
||||
|
||||
const auto* input_impl = reinterpret_cast<const _mlas_fp16_*>(input);
|
||||
const auto* sin_impl = reinterpret_cast<const _mlas_fp16_*>(sin);
|
||||
const auto* cos_impl = reinterpret_cast<const _mlas_fp16_*>(cos);
|
||||
auto* output_impl = reinterpret_cast<_mlas_fp16_*>(output);
|
||||
|
||||
if (interleaved) {
|
||||
RopeKernel_Fp16_Impl<true>(input_impl, sin_impl, cos_impl, dim, output_impl);
|
||||
} else {
|
||||
RopeKernel_Fp16_Impl<false>(input_impl, sin_impl, cos_impl, dim, output_impl);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace rope_neon
|
||||
|
|
@ -75,7 +75,7 @@ static void RunTest(
|
|||
if (enable_dml && !disable_dml) {
|
||||
execution_providers.push_back(DefaultDmlExecutionProvider());
|
||||
}
|
||||
if (tensor_type == TensorType::kFloat && !disable_cpu) {
|
||||
if ((tensor_type == TensorType::kFloat || tensor_type == TensorType::kFloat16) && !disable_cpu) {
|
||||
execution_providers.push_back(DefaultCpuExecutionProvider());
|
||||
}
|
||||
if (enable_webgpu) {
|
||||
|
|
@ -140,26 +140,7 @@ static void RunTests(const std::vector<float>& input_data,
|
|||
int64_t interleaved = 0,
|
||||
int64_t is_packed_batching = 0,
|
||||
bool use_float16 = true) {
|
||||
// FP32 test for CPU
|
||||
RunTest(input_data,
|
||||
position_ids,
|
||||
cos_cache,
|
||||
sin_cache,
|
||||
output_data,
|
||||
batch_size,
|
||||
sequence_length,
|
||||
head_size,
|
||||
rotary_embedding_dim,
|
||||
num_heads,
|
||||
max_sequence_length,
|
||||
interleaved,
|
||||
is_packed_batching,
|
||||
TensorType::kFloat,
|
||||
false, /* disable_cpu */
|
||||
true, /* disable_cuda */
|
||||
true /* disable_dml */);
|
||||
|
||||
// FP32 test for CUDA and DML
|
||||
// FP32 test for CPU, CUDA and DML
|
||||
RunTest(input_data,
|
||||
position_ids,
|
||||
cos_cache,
|
||||
|
|
@ -178,7 +159,7 @@ static void RunTests(const std::vector<float>& input_data,
|
|||
false, /* disable_cuda */
|
||||
false /* disable_dml */);
|
||||
|
||||
// FP16 test for CUDA and DML
|
||||
// FP16 test for CPU, CUDA and DML
|
||||
if (use_float16) {
|
||||
RunTest(input_data,
|
||||
position_ids,
|
||||
|
|
@ -194,26 +175,9 @@ static void RunTests(const std::vector<float>& input_data,
|
|||
interleaved,
|
||||
is_packed_batching,
|
||||
TensorType::kFloat16,
|
||||
true, /* disable_cpu */
|
||||
false, /* disable_cpu */
|
||||
false, /* disable_cuda*/
|
||||
false /* disable_dml */);
|
||||
|
||||
// RunTest(input_data,
|
||||
// position_ids,
|
||||
// cos_cache,
|
||||
// sin_cache,
|
||||
// output_data,
|
||||
// batch_size,
|
||||
// sequence_length,
|
||||
// head_size,
|
||||
// rotary_embedding_dim,
|
||||
// num_heads,
|
||||
// max_sequence_length,
|
||||
// interleaved,
|
||||
// TensorType::kBFloat16,
|
||||
// true, /* disable_cpu */
|
||||
// false, /* disable_cuda*/
|
||||
// false /* disable_dml */);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue