diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 80a65c6787..ab8e2d75d2 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -29,6 +29,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/reorder.cpp ${MLAS_SRC_DIR}/snchwc.cpp ${MLAS_SRC_DIR}/activate.cpp + ${MLAS_SRC_DIR}/activate_fp16.cpp ${MLAS_SRC_DIR}/logistic.cpp ${MLAS_SRC_DIR}/tanh.cpp ${MLAS_SRC_DIR}/erf.cpp @@ -324,6 +325,7 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") if(ONNXRUNTIME_MLAS_MULTI_ARCH) onnxruntime_add_static_library(onnxruntime_mlas_arm64 ${mlas_platform_srcs}) diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 07757917de..b87cd0a77b 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -77,6 +77,16 @@ Abstract: #define MLAS_SUPPORTS_GEMM_DOUBLE #endif +#if (!defined(_MSC_VER)) || (_MSC_VER >= 1930) +// Visual Studio older than 2022 does not support fp16 intrinsic + +#if defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_ARM64EC) + +#define MLAS_F16VEC_INTRINSICS_SUPPORTED + +#endif // ARM64 +#endif // Visual Studio 16 or earlier does not support fp16 intrinsic + // // Basic Linear Algebra Subprograms (BLAS) types. // @@ -553,7 +563,7 @@ private: /** * @brief Supply matrices shape and data type information to quantized gemm functions * - ** NOTE: AIsSigned == true is not supported on non-ARM devices for now. + ** NOTE: AIsSigned == true is not supported on non-ARM devices for now. ** AIsSigned == true is supported on ARM devices when BIsSigned is also true. * */ @@ -641,10 +651,10 @@ struct MLAS_SYMM_QGEMM_DATA_PARAMS { * @param [IN] Shape A single shape descriptor for all multiplicatons. Currently A and B must be signed, and accumulation mode not supported - * @param [IN] DataParams Array of data descriptors, one for each mutliplication + * @param [IN] DataParams Array of data descriptors, one for each multiplication * B must be prepacked * @param [IN] BatchN Number of multiplications - * @param [IN] ThreadPool + * @param [IN] ThreadPool */ void MLASCALL @@ -701,8 +711,8 @@ MlasGemmPackB( /** * @brief For symmetric quantized GEMM, returns size of the - * packing buffer needed for right hand side - * @param N Number of columns + * packing buffer needed for right hand side + * @param N Number of columns * @param K Number of rows * @param AIsSigned Whether left hand size is signed int8_t * @return size of the packing buffer, @@ -712,7 +722,7 @@ size_t MLASCALL MlasSymmQgemmPackBSize( size_t N, - size_t K, + size_t K, bool AIsSigned ); @@ -866,17 +876,17 @@ MlasConvSymDepthwiseGetKernelOutputCnt( /** * @brief Returns the stride M of depthwise conv kernel - * - * Most optimized path is Symmetric conv. See + * + * Most optimized path is Symmetric conv. See * MlasConvSymDepthwiseGetKernelOutputCnt(bool) - * + * * These kernels are implemented in qdwconv.cpp using * intrincic, all of them with stride val 1. We use * a slightly bigger value to improve cache reuse. * * This needs to be changed if we optimize depthwise * kernels. - * + * * @return */ inline @@ -1379,16 +1389,18 @@ using MLAS_FP16 = onnxruntime::MLFloat16; constexpr size_t FP16_SIZE = sizeof(uint16_t); - +/** + * @brief Whether current CPU supports FP16 acceleration. +*/ bool MLASCALL MlasFp16AccelerationSupported(); /** * @brief Interface for half gemm post processors. - * + * * Example implementation of this interface includes activations, * conversion from half precision to single precision, etc. - * + * * Half GEMM is computed tile by tile. When a tile of result matrix * is produced, the method Process() is called to process this tile. * Parameters of this method describe the location and shape of the @@ -1411,16 +1423,14 @@ public: }; /** - * @brief Convert half gemm result matrix to single precision float matrix + * @brief Half precision activation functions */ -class MLAS_HALF_GEMM_2FLOAT_PROCESSOR : public MLAS_HALF_GEMM_POSTPROCESSOR { +class MLAS_HALF_GEMM_ACTIVATION_PROCESSOR : public MLAS_HALF_GEMM_POSTPROCESSOR { public: - MLAS_HALF_GEMM_2FLOAT_PROCESSOR( - float* Output, /**< address of the output matrix, row major */ - size_t RowStride /**< row stride of the output matrix */ - ) : - Output_(Output), - RowStride_(RowStride) + MLAS_HALF_GEMM_ACTIVATION_PROCESSOR( + const MLAS_ACTIVATION& Activation + ) : + Activation_(Activation) {} void @@ -1434,8 +1444,52 @@ public: ) const override; private: + const MLAS_ACTIVATION& Activation_; +}; + +inline +void +MlasFp16Activation( + const MLAS_ACTIVATION* Activation, + MLAS_FP16* Buffer, + size_t M, + size_t N, + size_t ldc + ) +{ + MLAS_HALF_GEMM_ACTIVATION_PROCESSOR proc(*Activation); + proc.Process(Buffer, 0, 0, M, N, ldc); +} + + +/** + * @brief Convert half gemm result matrix to single precision float matrix +*/ +class MLAS_HALF_GEMM_2FLOAT_PROCESSOR : public MLAS_HALF_GEMM_POSTPROCESSOR { +public: + MLAS_HALF_GEMM_2FLOAT_PROCESSOR( + const MLAS_ACTIVATION& Activation, + float* Output, /**< address of the output matrix, row major */ + size_t RowStride /**< row stride of the output matrix */ + ) : Activation_(Activation), + Output_(Output), + RowStride_(RowStride) + {} + + void + Process( + MLAS_FP16* C, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN, + size_t ldc + ) const override; + +private: + const MLAS_ACTIVATION& Activation_; float* Output_; - size_t RowStride_; + const size_t RowStride_; }; @@ -1459,17 +1513,17 @@ struct MLAS_HALF_GEMM_DATA_PARAMS { /** * @brief Half precision Batched GEMM: C = A * B + Bias * Either A or B can be fp32 or fp16 - * + * * Note: We only support uniform batching, so shapes and types of the * input must be same across all parameter blocks. - * + * * @param[in] M row size of matrix A and C * @param[in] N column size of matrix B and C * @param[in] K column size of matrix A and row size of matrix B * @param[in] BatchN number of batches * @param[inout] DataParams An array (size BatchN) of parameter blocks - * @param[in] ThreadPool - * @return + * @param[in] ThreadPool + * @return */ void MLASCALL @@ -1485,7 +1539,7 @@ MlasHalfGemmBatch( /** * @brief For half precision GEMM, returns size of the * packing buffer needed for right hand side - * @param[in] N Number of columns + * @param[in] N Number of columns * @param[in] K Number of rows * @param[in] float2half Whether the input is float that * needs to be converted to half precision @@ -1503,11 +1557,11 @@ MlasHalfGemmPackBSize( /** * @brief For half precision GEMM, pack the right hand * side matrix B - * + * * @param[in] N Number of columns * @param[in] K Number of rows - * @param[in] B Address of matrix B - * @param[in] ldb leading dimension of input matrix B + * @param[in] B Address of matrix B + * @param[in] ldb leading dimension of input matrix B * @param[out] PackedB Address of the packed matrix */ void @@ -1523,11 +1577,11 @@ MlasHalfGemmPackB( /** * @brief For half precision GEMM, convert the float matrix B * to half precision and pack it into a packing buffer - * + * * @param[in] N Number of columns * @param[in] K Number of rows - * @param[in] B Address of matrix B - * @param[in] ldb leading dimension of input matrix B + * @param[in] B Address of matrix B + * @param[in] ldb leading dimension of input matrix B * @param[out] PackedB Address of the packed matrix */ void diff --git a/onnxruntime/core/mlas/lib/activate_fp16.cpp b/onnxruntime/core/mlas/lib/activate_fp16.cpp new file mode 100644 index 0000000000..55cbbe141e --- /dev/null +++ b/onnxruntime/core/mlas/lib/activate_fp16.cpp @@ -0,0 +1,337 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + activate_fp16.cpp + +Abstract: + + This module implements the activation routines for fp16 data types + +--*/ + +#include "fp16_common.h" + +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED + +// +// Templates for activation functions. +// + +template +struct MLAS_HALF_ACTIVATION_FUNCTION; + +template<> +struct MLAS_HALF_ACTIVATION_FUNCTION +{ + const MLAS_FLOAT16X8 ZeroVec = MlasZeroFloat16x8(); + + MLAS_HALF_ACTIVATION_FUNCTION(const MLAS_ACTIVATION& Activation) + { + MLAS_UNREFERENCED_PARAMETER(Activation); + } + + MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value) + { + return MlasMaximumFloat16x8(ZeroVec, Value); + } + + MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value) + { + return MlasMaximumFloat16x4(MlasToLowHalfFloat16x4(ZeroVec), Value); + } +}; + +template<> +struct MLAS_HALF_ACTIVATION_FUNCTION +{ + const MLAS_FLOAT16X8 ZeroVec = MlasZeroFloat16x8(); + + MLAS_FLOAT16X8 AlphaBroadcast; + + MLAS_HALF_ACTIVATION_FUNCTION(const MLAS_ACTIVATION& Activation) + { + const _mlas_fp16_ alpha = MLAS_Float2Half(Activation.Parameters.LeakyRelu.alpha); + AlphaBroadcast = MlasBroadcastFloat16x8(alpha); + } + + MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value) + { + MLAS_FLOAT16X8 ValueTimesAlpha = MlasMultiplyFloat16x8(Value, AlphaBroadcast); + return MlasBitwiseSelectFloat16x8(MlasCmpLessEqualFloat16x8(Value, ZeroVec), + ValueTimesAlpha, Value); + } + + MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value) + { + MLAS_FLOAT16X4 ValueTimesAlpha = + MlasMultiplyFloat16x4(Value, MlasToLowHalfFloat16x4(AlphaBroadcast)); + return MlasBitwiseSelectFloat16x4( + MlasCmpLessEqualFloat16x4(Value, MlasToLowHalfFloat16x4(ZeroVec)), ValueTimesAlpha, + Value); + } +}; + +template<> +struct MLAS_HALF_ACTIVATION_FUNCTION +{ + MLAS_FLOAT16X8 MinimumBroadcast; + MLAS_FLOAT16X8 MaximumBroadcast; + + MLAS_HALF_ACTIVATION_FUNCTION(const MLAS_ACTIVATION& Activation) + { + const _mlas_fp16_ min = MLAS_Float2Half(Activation.Parameters.Clip.minimum); + MinimumBroadcast = MlasBroadcastFloat16x8(min); + const _mlas_fp16_ max = MLAS_Float2Half(Activation.Parameters.Clip.maximum); + MaximumBroadcast = MlasBroadcastFloat16x8(max); + } + + MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value) + { + Value = MlasMaximumFloat16x8(MinimumBroadcast, Value); + Value = MlasMinimumFloat16x8(MaximumBroadcast, Value); + + return Value; + } + + MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value) + { + Value = MlasMaximumFloat16x4(MlasToLowHalfFloat16x4(MinimumBroadcast), Value); + Value = MlasMinimumFloat16x4(MlasToLowHalfFloat16x4(MaximumBroadcast), Value); + return Value; + } +}; + +template<> +struct MLAS_HALF_ACTIVATION_FUNCTION +{ + MLAS_FLOAT16X8 AlphaBroadcast; + MLAS_FLOAT16X8 BetaBroadcast; + MLAS_FLOAT16X8 MinimumBroadcast; + MLAS_FLOAT16X8 MaximumBroadcast; + + MLAS_HALF_ACTIVATION_FUNCTION(const MLAS_ACTIVATION& Activation) + { + const _mlas_fp16_ alpha = MLAS_Float2Half(Activation.Parameters.HardSigmoid.alpha); + AlphaBroadcast = MlasBroadcastFloat16x8(alpha); + const _mlas_fp16_ beta = MLAS_Float2Half(Activation.Parameters.HardSigmoid.beta); + BetaBroadcast = MlasBroadcastFloat16x8(beta); + MinimumBroadcast = MlasZeroFloat16x8(); + MaximumBroadcast = MlasBroadcastFloat16x8(MLAS_Float2Half(1.0f)); + } + + MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value) + { + Value = MlasMultiplyAddFloat16x8(Value, AlphaBroadcast, BetaBroadcast); + Value = MlasMinimumFloat16x8(MaximumBroadcast, Value); + Value = MlasMaximumFloat16x8(MinimumBroadcast, Value); + + return Value; + } + + MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value) + { + Value = MlasMultiplyAddFloat16x4(Value, MlasToLowHalfFloat16x4(AlphaBroadcast), + MlasToLowHalfFloat16x4(BetaBroadcast)); + Value = MlasMinimumFloat16x4(MlasToLowHalfFloat16x4(MaximumBroadcast), Value); + Value = MlasMaximumFloat16x4(MlasToLowHalfFloat16x4(MinimumBroadcast), Value); + + return Value; + } +}; + +template +inline +void +MlasActivationKernel( + const MLAS_ACTIVATION& Activation, + MLAS_FP16* Buffer, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN, + size_t ldc + ) +{ + MLAS_HALF_ACTIVATION_FUNCTION ActivationFunction(Activation); + + auto* CRow = reinterpret_cast<_mlas_fp16_*>(Buffer); + CRow += StartM * ldc + StartN; + + while (CountM-- > 0) { + _mlas_fp16_* buffer = CRow; + size_t n = CountN; + + while (n >= 8) { + MLAS_FLOAT16X8 Vector = MlasLoadFloat16x8(buffer); + MlasStoreFloat16x8(buffer, ActivationFunction.Activate(Vector)); + buffer += 8; + n -= 8; + } + + if (n >= 4) { + MLAS_FLOAT16X4 Vector = MlasLoadFloat16x4(buffer); + MlasStoreFloat16x4(buffer, ActivationFunction.Activate(Vector)); + buffer += 4; + n -= 4; + } + + if (n > 0) { + MLAS_FLOAT16X4 buf; + std::memcpy(&buf, buffer, n * sizeof(_mlas_fp16_)); + MLAS_FLOAT16X4 res = ActivationFunction.Activate(buf); + MlasStorePartialFloat16x4(buffer, res, n); + } + + CRow += ldc; + } +} + +template<> +inline +void +MlasActivationKernel( + const MLAS_ACTIVATION& Activation, + MLAS_FP16* Buffer, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN, + size_t ldc + ) +{ + // + // No operation. + // + + MLAS_UNREFERENCED_PARAMETER(Activation); + MLAS_UNREFERENCED_PARAMETER(Buffer); + MLAS_UNREFERENCED_PARAMETER(StartM); + MLAS_UNREFERENCED_PARAMETER(StartN); + MLAS_UNREFERENCED_PARAMETER(CountM); + MLAS_UNREFERENCED_PARAMETER(CountN); + MLAS_UNREFERENCED_PARAMETER(ldc); +} + + +void +MLAS_HALF_GEMM_ACTIVATION_PROCESSOR::Process( + MLAS_FP16* C, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN, + size_t ldc + ) const +{ + switch (Activation_.ActivationKind) { + case MlasIdentityActivation: { + MlasActivationKernel(Activation_, C, StartM, StartN, CountM, CountN, ldc); + break; + } + + case MlasReluActivation: { + MlasActivationKernel(Activation_, C, StartM, StartN, CountM, CountN, + ldc); + break; + } + + case MlasLeakyReluActivation: { + MlasActivationKernel(Activation_, C, StartM, StartN, CountM, + CountN, ldc); + break; + } + + case MlasClipActivation: { + MlasActivationKernel(Activation_, C, StartM, StartN, CountM, CountN, + ldc); + ; + break; + } + + case MlasHardSigmoidActivation: { + MlasActivationKernel(Activation_, C, StartM, StartN, CountM, + CountN, ldc); + break; + } + +/* case MlasTanhActivation : { + if (N == ldc) { + MlasComputeTanh(Buffer, Buffer, M * N); + } else { + while (M-- > 0) { + MlasComputeTanh(Buffer, Buffer, N); + Buffer += ldc; + } + } + + break; + } + + case MlasLogisticActivation: { + if (N == ldc) { + MlasComputeLogistic(Buffer, Buffer, M * N); + } else { + while (M-- > 0) { + MlasComputeLogistic(Buffer, Buffer, N); + Buffer += ldc; + } + } + + break; + } +*/ + default: + // Tanh and Logistic activation not supported. + return; + } +} + +#else +// Really dumb implementation when fp16 acceleration is not supported + +#include + +MLAS_FORCEINLINE +void +CvtFloat2Half( + _mlas_fp16_* dest, + const float* src, + size_t len +) +{ + for (size_t i = 0; i < len; i++) { + *dest++ = MLAS_Float2Half(*src++); + } +} + +void +MLAS_HALF_GEMM_ACTIVATION_PROCESSOR::Process( + MLAS_FP16* C, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN, + size_t ldc + ) const +{ + std::vector buffer(CountM*CountN); + MLAS_HALF_GEMM_2FLOAT_PROCESSOR proc(this->Activation_, buffer.data(), CountN); + proc.Process(C, StartM, StartN, CountM, CountN, ldc); + + _mlas_fp16_* Output = reinterpret_cast<_mlas_fp16_*>(C); + const auto* CRow = buffer.data(); + Output += StartM * ldc + StartN; + + while (CountM-- > 0) { + CvtFloat2Half(Output, CRow, CountN); + CRow += CountN; + Output += ldc; + } +} + +#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED diff --git a/onnxruntime/core/mlas/lib/fp16_common.h b/onnxruntime/core/mlas/lib/fp16_common.h new file mode 100644 index 0000000000..e952a5667c --- /dev/null +++ b/onnxruntime/core/mlas/lib/fp16_common.h @@ -0,0 +1,300 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + fp16_common.h + +Abstract: + + Intrinsic and inline functions for fp16 processing. + +--*/ + +#pragma once + +#include "mlas_float16.h" +#include "mlasi.h" + +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED + +// TODO!! Add intel fp16 implementations + +typedef float16x8_t MLAS_FLOAT16X8; +typedef float16x4_t MLAS_FLOAT16X4; +typedef uint16x8_t MLAS_UINT16X8; +typedef uint16x4_t MLAS_UINT16X4; + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasReinterpretAsFloat16x8(MLAS_INT32X4 Vector) { return vreinterpretq_f16_s32(Vector); } + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasBroadcastFloat16x8(_mlas_fp16_ Value) { return vreinterpretq_f16_p16(vdupq_n_p16(Value)); } + +MLAS_FORCEINLINE +MLAS_FLOAT16X4 +MlasBroadcastFloat16x4(_mlas_fp16_ Value) { return vreinterpret_f16_p16(vdup_n_p16(Value)); } + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasBroadcastFloat16x8(const _mlas_fp16_* Value) { return vreinterpretq_f16_u16(vld1q_dup_u16(Value)); } + +MLAS_FORCEINLINE +MLAS_FLOAT16X4 +MlasBroadcastFloat16x4(const _mlas_fp16_* Value) { return vreinterpret_f16_u16(vld1_dup_u16(Value)); } + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasZeroFloat16x8(void) { return vreinterpretq_f16_f32(vdupq_n_f32(0.0f)); } + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasLoadFloat16x8(const _mlas_fp16_* Buffer) { return vreinterpretq_f16_u16(vld1q_u16(Buffer)); } + +MLAS_FORCEINLINE +MLAS_FLOAT16X4 +MlasLoadFloat16x4(const _mlas_fp16_* Buffer) { return vreinterpret_f16_u16(vld1_u16(Buffer)); } + +MLAS_FORCEINLINE +void +MlasStoreFloat16x8(_mlas_fp16_* Buffer, MLAS_FLOAT16X8 Vector) +{ + vst1q_u16(Buffer, vreinterpretq_u16_f16(Vector)); +} + +MLAS_FORCEINLINE +void +MlasStoreFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector) +{ + vst1_u16(Buffer, vreinterpret_u16_f16(Vector)); +} + +MLAS_FORCEINLINE +void +MlasStorePartialFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector, size_t len) +{ + if ((len & 2) != 0) { + vst1_lane_f32(reinterpret_cast(Buffer), vreinterpret_f32_f16(Vector), 0); + Vector = vreinterpret_f16_f32(vdup_lane_f32(vreinterpret_f32_f16(Vector), 1)); + Buffer += 2; + } + if ((len & 1) != 0) { + vst1_lane_u16(Buffer, vreinterpret_u16_f16(Vector), 0); + } +} + +template +MLAS_FORCEINLINE void +MlasStoreLaneFloat16x8(_mlas_fp16_* Buffer, MLAS_FLOAT16X8 Vector) +{ + vst1q_lane_u16(Buffer, vreinterpretq_u16_f16(Vector), Lane); +} + +MLAS_FORCEINLINE MLAS_FLOAT16X4 +MlasToLowHalfFloat16x4(MLAS_FLOAT16X8 V) +{ + // vget_low should be compiled to nothing + return vget_low_f16(V); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasAddFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +{ + return vaddq_f16(Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X4 +MlasAddFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) +{ + return vadd_f16(Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasSubtractFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +{ + return vsubq_f16(Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X4 +MlasSubtractFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) +{ + return vsub_f16(Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasMultiplyFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +{ + return vmulq_f16(Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X4 +MlasMultiplyFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) +{ + return vmul_f16(Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasMultiplyAddFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2, MLAS_FLOAT16X8 Vector3) +{ + return vfmaq_f16(Vector3, Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X4 +MlasMultiplyAddFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2, MLAS_FLOAT16X4 Vector3) +{ + return vfma_f16(Vector3, Vector1, Vector2); +} + + +MLAS_FORCEINLINE +void +MlasMultiplyAddFloat16x8(MLAS_FLOAT16X8 Vector1, _mlas_fp16_ Scalar2, MLAS_FLOAT16X8 Vector3) +{ + MlasMultiplyAddFloat16x8(Vector1, MlasBroadcastFloat16x8(Scalar2), Vector3); +} + +MLAS_FORCEINLINE +void +MlasMultiplyAddFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2, _mlas_fp16_ Scalar3) +{ + MlasMultiplyAddFloat16x8(Vector1, Vector2, MlasBroadcastFloat16x8(Scalar3)); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasDivideFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +{ + return vdivq_f16(Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasGreaterThanFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +{ + return vreinterpretq_f16_u16(vcgtq_f16(Vector1, Vector2)); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasAndFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +{ + return vreinterpretq_f16_s64(vandq_s64(vreinterpretq_s64_f16(Vector1), vreinterpretq_s64_f16(Vector2))); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasOrFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +{ + return vreinterpretq_f16_s64(vorrq_s64(vreinterpretq_s64_f16(Vector1), vreinterpretq_s64_f16(Vector2))); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasAndNotFloat16x8(MLAS_FLOAT16X8 VectorNot, MLAS_FLOAT16X8 Vector) +{ + return vreinterpretq_f16_s32(vandq_s32(vmvnq_s32(vreinterpretq_s32_f16(VectorNot)), vreinterpretq_s32_f16(Vector))); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasXorFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +{ + return vreinterpretq_f16_s32(veorq_s32(vreinterpretq_s32_f16(Vector1), vreinterpretq_s32_f16(Vector2))); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasBlendFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2, MLAS_FLOAT16X8 Selection) +{ + return MlasOrFloat16x8(MlasAndFloat16x8(Vector2, Selection), + MlasAndNotFloat16x8(Selection, Vector1)); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasMaximumFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +{ + return vmaxq_f16(Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X4 +MlasMaximumFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) +{ + return vmax_f16(Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasMinimumFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +{ + return vminq_f16(Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X4 +MlasMinimumFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) +{ + return vmin_f16(Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasClampFloat16x8(MLAS_FLOAT16X8 Value, _mlas_fp16_ LowerRange, _mlas_fp16_ UpperRange) +{ + Value = MlasMaximumFloat16x8(MlasBroadcastFloat16x8(LowerRange), Value); + Value = MlasMinimumFloat16x8(MlasBroadcastFloat16x8(UpperRange), Value); + return Value; +} + +MLAS_FORCEINLINE +_mlas_fp16_ +MlasReduceAddFloat16x8(MLAS_FLOAT16X8 Vector) +{ + Vector = vpaddq_f16(Vector, Vector); + Vector = vpaddq_f16(Vector, Vector); + return vgetq_lane_u16(vreinterpretq_u16_f16(Vector), 0); +} + +MLAS_FORCEINLINE +MLAS_UINT16X8 +MlasCmpLessEqualFloat16x8(MLAS_FLOAT16X8 left, MLAS_FLOAT16X8 right) +{ + return vcleq_f16(left, right); +} + +MLAS_FORCEINLINE +MLAS_UINT16X4 +MlasCmpLessEqualFloat16x4(MLAS_FLOAT16X4 left, MLAS_FLOAT16X4 right) +{ + return vcle_f16(left, right); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasBitwiseSelectFloat16x8(MLAS_UINT16X8 select, MLAS_FLOAT16X8 ones, MLAS_FLOAT16X8 zeros) +{ + return vbslq_f16(select, ones, zeros); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X4 +MlasBitwiseSelectFloat16x4(MLAS_UINT16X4 select, MLAS_FLOAT16X4 ones, MLAS_FLOAT16X4 zeros) +{ + return vbsl_f16(select, ones, zeros); +} + +#endif // fp16 vector intrinsic supported diff --git a/onnxruntime/core/mlas/lib/halfgemm.cpp b/onnxruntime/core/mlas/lib/halfgemm.cpp index 778db2003d..49387d2fc9 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm.cpp @@ -25,7 +25,11 @@ Abstract: bool MLASCALL MlasFp16AccelerationSupported() { +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED return MLAS_CPUIDINFO::GetCPUIDInfo().HasFp16VectorAcceleration(); +#else + return false; +#endif } @@ -217,9 +221,6 @@ MLAS_HALF_GEMM_2FLOAT_PROCESSOR::Process( size_t ldc ) const { - // - // TODO!! use templates to add activations in this impl - // float* Output = Output_; const auto* CRow = reinterpret_cast(C); CRow += StartM * ldc + StartN; @@ -227,7 +228,7 @@ MLAS_HALF_GEMM_2FLOAT_PROCESSOR::Process( while (CountM-- > 0) { CvtHalf2Float(Output, CRow, CountN); - + MlasActivation(&Activation_, Output, nullptr, 1, CountN, ldc); CRow += ldc; Output += RowStride_; } diff --git a/onnxruntime/test/mlas/unittest/test_fp16.h b/onnxruntime/test/mlas/unittest/test_fp16.h new file mode 100644 index 0000000000..8000b0bb71 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_fp16.h @@ -0,0 +1,67 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_fp16.h + +Abstract: + + Define fp16 type before it is available in all compilers + +--*/ + +#pragma once + +#include "test_util.h" +#include "mlas_float16.h" + + +// +// Define our own fp16 type to avoid dragging in big dependencies +// +struct MLFp16 { + uint16_t val{0}; + + MLFp16() = default; + explicit constexpr MLFp16(uint16_t x) : val(x) {} + explicit constexpr MLFp16(int32_t x) : val((uint16_t)x) {} + explicit MLFp16(float ff) : val(MLAS_Float2Half(ff)) {} + + float ToFloat() const { + return MLAS_Half2Float(val); + } + + operator float() const { return ToFloat(); } + + MLFp16& operator=(float ff) { + val = MLAS_Float2Half(ff); + return *this; + } +}; + +inline bool +operator==(const MLFp16& left, const MLFp16& right) { + return left.val == right.val; +} + +inline bool +operator!=(const MLFp16& left, const MLFp16& right) { + return left.val != right.val; +} + + +template +void SmallFloatFill(T* start, size_t size) { + constexpr float MinimumFillValue = -11.0f; + auto* FillAddress = start; + size_t offset = size % 23; + + for (size_t i = 0; i < size; i++) { + offset = (offset + 21) % 23; + *FillAddress++ = T((MinimumFillValue + offset) / 16.0f); + } +} diff --git a/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp b/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp new file mode 100644 index 0000000000..f055e0d865 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp @@ -0,0 +1,128 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test_fp16.h" + +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED + +class MlasFp16ActivationTest : public MlasTestBase { + public: + static const char* GetTestSuiteName() { + static const std::string suite_name("Fp16Activation"); + return suite_name.c_str(); + } + + void ExecuteShort(void) override { + union AliasedValue { + unsigned u; + float f; + }; + + // N.B. The test data includes values at the edge of Tanh/Logistic boundaries. + // Identity, Relu, LeakyRelu, Tanh, Logistic, Clip, + static const AliasedValue TestData[] = { + {0x00000001}, // positive denormal + {0x80000001}, // negative denormal + {0x7fc00000}, // positive NaN + {0xffc00000}, // negative NaN + {0x00000000}, // 0.0f + {0x80000000}, // -0.0f + {0x3e800000}, // 0.25f + {0xbe800000}, // -0.25f + {0x40800000}, // 4.0f + {0xc0800000}, // -4.0f + {0x41200000}, // 10.0f + {0xc1200000}, // -10.0f + {0xc18866eb}, // -17.0502529144f + {0xc18869bb}, // -17.0516262054f + {0xc18852a8}, // -17.0403594971f + {0xc18844aa}, // -17.0335273743f + {0x418866eb}, // +17.0502529144f + {0x418869bb}, // +17.0516262054f + {0x418852a8}, // +17.0403594971f + {0x418844aa} // +17.0335273743f + }; + + constexpr size_t M = 5; + constexpr size_t N = 23; + + MatrixGuardBuffer HalfBuffer1; + auto* testData1 = HalfBuffer1.GetBuffer(M * N, true); + MatrixGuardBuffer HalfBuffer2; + auto* testData2 = HalfBuffer2.GetBuffer(M * N, true); + MatrixGuardBuffer FloatBuffer; + auto* fpBuffer = FloatBuffer.GetBuffer(M * N, true); + + MLAS_ACTIVATION_KIND acts[] = { + MlasIdentityActivation, + MlasReluActivation, + MlasLeakyReluActivation, + MlasClipActivation, + MlasHardSigmoidActivation}; + + MLAS_ACTIVATION Activation; + MLAS_HALF_GEMM_ACTIVATION_PROCESSOR proc(Activation); + MLAS_HALF_GEMM_2FLOAT_PROCESSOR converter(Activation, fpBuffer, N); + for (auto kind : acts) { + Activation.ActivationKind = MLAS_ACTIVATION_KIND(kind); + + if (Activation.ActivationKind == MlasLeakyReluActivation) { + Activation.Parameters.LeakyRelu.alpha = 0.2f; + } else if (Activation.ActivationKind == MlasClipActivation) { + Activation.Parameters.Clip.minimum = 0.0f; + Activation.Parameters.Clip.maximum = 6.0f; + } else if (Activation.ActivationKind == MlasHardSigmoidActivation){ + Activation.Parameters.HardSigmoid.alpha = 0.2f; + Activation.Parameters.HardSigmoid.beta = 0.12f; + } + + // + // Test the vectorized activations. + // + + for (size_t i = 0; i < _countof(TestData); i++) { + testData1[i] = TestData[i].f; + testData2[i] = TestData[i].f; + } + constexpr float MinimumFillValue = -11.0f; + size_t offset = 7; + for (size_t i = _countof(TestData); i < M * N; i++) { + offset = (offset + 19) % 23; + testData1[i] = (MinimumFillValue + offset) / 16.0f; + testData2[i] = testData1[i]; + } + + proc.Process(reinterpret_cast(testData1), 0, 0, M, N, N); + converter.Process(reinterpret_cast(testData2), 0, 0, M, N, N); + + for (size_t i = 0; i < M*N; i++) { + float actual = testData1[i].ToFloat(); + if (std::isnan(actual)) { + EXPECT_TRUE(std::isnan(fpBuffer[i])) + << ", Vector Activation Kind:" << (int)kind << ", i=" << i << ", value:" + << std::setw(8) << std::setfill('0') << std::hex << actual << ", expecting:" + << std::setw(8) << std::setfill('0') << std::hex << fpBuffer[i]; + + } else { + float diff = std::abs(actual - fpBuffer[i]); + float top = std::max(std::abs(actual), std::abs(fpBuffer[i])); + float ratio = 0; + if (top > 0.0001) { + ratio = diff / top; + } + EXPECT_TRUE(ratio < 0.005) + << ", Vector Activation Kind:" << (int)kind << ", i=" << i << ", value:" + << actual << ", expecting:" << fpBuffer[i]; + } + } + } + } +}; + +template<> MlasFp16ActivationTest* MlasTestFixture::mlas_tester(nullptr); + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + return is_short_execute ? MlasDirectShortExecuteTests::RegisterShortExecute() : 0; +}); + +#endif // fp16 vector intrinsic supported \ No newline at end of file diff --git a/onnxruntime/test/mlas/unittest/test_halfgemm.h b/onnxruntime/test/mlas/unittest/test_halfgemm.h index 91d22341a6..ba9ce43d64 100644 --- a/onnxruntime/test/mlas/unittest/test_halfgemm.h +++ b/onnxruntime/test/mlas/unittest/test_halfgemm.h @@ -16,54 +16,7 @@ Abstract: #pragma once -#include "test_util.h" -#include "mlas_float16.h" - - -// -// Define our own fp16 type to avoid dragging in big dependencies -// -struct MLFp16 { - uint16_t val{0}; - - MLFp16() = default; - explicit constexpr MLFp16(uint16_t x) : val(x) {} - explicit constexpr MLFp16(int32_t x) : val((uint16_t)x) {} - explicit MLFp16(float ff) : val(MLAS_Float2Half(ff)) {} - - float ToFloat() const { - return MLAS_Half2Float(val); - } - - operator float() const { return ToFloat(); } - - MLFp16& operator=(float ff) { - val = MLAS_Float2Half(ff); - return *this; - } -}; - -inline bool -operator==(const MLFp16& left, const MLFp16& right) { - return left.val == right.val; -} - -inline bool -operator!=(const MLFp16& left, const MLFp16& right) { - return left.val != right.val; -} - -template -void SmallFloatFill(T* start, size_t size) { - constexpr float MinimumFillValue = -11.0f; - auto* FillAddress = start; - size_t offset = size % 23; - - for (size_t i = 0; i < size; i++) { - offset = (offset + 21) % 23; - *FillAddress++ = T((MinimumFillValue + offset) / 16.0f); - } -} +#include "test_fp16.h" inline bool CloseEnough(float actual, float expected){ @@ -123,6 +76,8 @@ private: MLFp16* C, size_t ldc, float* Cfloat) { + MLAS_ACTIVATION act; + act.ActivationKind = MlasIdentityActivation; std::vector Converters; Converters.reserve(BatchSize); @@ -150,7 +105,7 @@ private: } params.AIsfp32 = std::is_same::value; params.BIsfp32 = std::is_same::value; - Converters.emplace_back(Cfloat + (M * N * i), N); + Converters.emplace_back(act, Cfloat + (M * N * i), N); params.OutputProcessor = &(Converters[i]); }