Fp16 Activations (#14722)

### Description

NEON fp16 SIMD implementation of Activation functions


### Motivation and Context
Step 2 of fp16 SIMD support.

---------

Co-authored-by: Chen Fu <fuchen@microsoft.com>
This commit is contained in:
Chen Fu 2023-02-28 17:20:40 -08:00 committed by GitHub
parent 69c5edb11b
commit acc2ac627f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 930 additions and 86 deletions

View file

@ -29,6 +29,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/reorder.cpp
${MLAS_SRC_DIR}/snchwc.cpp
${MLAS_SRC_DIR}/activate.cpp
${MLAS_SRC_DIR}/activate_fp16.cpp
${MLAS_SRC_DIR}/logistic.cpp
${MLAS_SRC_DIR}/tanh.cpp
${MLAS_SRC_DIR}/erf.cpp
@ -324,6 +325,7 @@ else()
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
if(ONNXRUNTIME_MLAS_MULTI_ARCH)
onnxruntime_add_static_library(onnxruntime_mlas_arm64 ${mlas_platform_srcs})

View file

@ -77,6 +77,16 @@ Abstract:
#define MLAS_SUPPORTS_GEMM_DOUBLE
#endif
#if (!defined(_MSC_VER)) || (_MSC_VER >= 1930)
// Visual Studio older than 2022 does not support fp16 intrinsic
#if defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_ARM64EC)
#define MLAS_F16VEC_INTRINSICS_SUPPORTED
#endif // ARM64
#endif // Visual Studio 16 or earlier does not support fp16 intrinsic
//
// Basic Linear Algebra Subprograms (BLAS) types.
//
@ -553,7 +563,7 @@ private:
/**
* @brief Supply matrices shape and data type information to quantized gemm functions
*
** NOTE: AIsSigned == true is not supported on non-ARM devices for now.
** NOTE: AIsSigned == true is not supported on non-ARM devices for now.
** AIsSigned == true is supported on ARM devices when BIsSigned is also true.
*
*/
@ -641,10 +651,10 @@ struct MLAS_SYMM_QGEMM_DATA_PARAMS {
* @param [IN] Shape A single shape descriptor for all multiplicatons.
Currently A and B must be signed, and accumulation
mode not supported
* @param [IN] DataParams Array of data descriptors, one for each mutliplication
* @param [IN] DataParams Array of data descriptors, one for each multiplication
* B must be prepacked
* @param [IN] BatchN Number of multiplications
* @param [IN] ThreadPool
* @param [IN] ThreadPool
*/
void
MLASCALL
@ -701,8 +711,8 @@ MlasGemmPackB(
/**
* @brief For symmetric quantized GEMM, returns size of the
* packing buffer needed for right hand side
* @param N Number of columns
* packing buffer needed for right hand side
* @param N Number of columns
* @param K Number of rows
* @param AIsSigned Whether left hand size is signed int8_t
* @return size of the packing buffer,
@ -712,7 +722,7 @@ size_t
MLASCALL
MlasSymmQgemmPackBSize(
size_t N,
size_t K,
size_t K,
bool AIsSigned
);
@ -866,17 +876,17 @@ MlasConvSymDepthwiseGetKernelOutputCnt(
/**
* @brief Returns the stride M of depthwise conv kernel
*
* Most optimized path is Symmetric conv. See
*
* Most optimized path is Symmetric conv. See
* MlasConvSymDepthwiseGetKernelOutputCnt(bool)
*
*
* These kernels are implemented in qdwconv.cpp using
* intrincic, all of them with stride val 1. We use
* a slightly bigger value to improve cache reuse.
*
* This needs to be changed if we optimize depthwise
* kernels.
*
*
* @return
*/
inline
@ -1379,16 +1389,18 @@ using MLAS_FP16 = onnxruntime::MLFloat16;
constexpr size_t FP16_SIZE = sizeof(uint16_t);
/**
* @brief Whether current CPU supports FP16 acceleration.
*/
bool MLASCALL
MlasFp16AccelerationSupported();
/**
* @brief Interface for half gemm post processors.
*
*
* Example implementation of this interface includes activations,
* conversion from half precision to single precision, etc.
*
*
* Half GEMM is computed tile by tile. When a tile of result matrix
* is produced, the method Process() is called to process this tile.
* Parameters of this method describe the location and shape of the
@ -1411,16 +1423,14 @@ public:
};
/**
* @brief Convert half gemm result matrix to single precision float matrix
* @brief Half precision activation functions
*/
class MLAS_HALF_GEMM_2FLOAT_PROCESSOR : public MLAS_HALF_GEMM_POSTPROCESSOR {
class MLAS_HALF_GEMM_ACTIVATION_PROCESSOR : public MLAS_HALF_GEMM_POSTPROCESSOR {
public:
MLAS_HALF_GEMM_2FLOAT_PROCESSOR(
float* Output, /**< address of the output matrix, row major */
size_t RowStride /**< row stride of the output matrix */
) :
Output_(Output),
RowStride_(RowStride)
MLAS_HALF_GEMM_ACTIVATION_PROCESSOR(
const MLAS_ACTIVATION& Activation
) :
Activation_(Activation)
{}
void
@ -1434,8 +1444,52 @@ public:
) const override;
private:
const MLAS_ACTIVATION& Activation_;
};
inline
void
MlasFp16Activation(
const MLAS_ACTIVATION* Activation,
MLAS_FP16* Buffer,
size_t M,
size_t N,
size_t ldc
)
{
MLAS_HALF_GEMM_ACTIVATION_PROCESSOR proc(*Activation);
proc.Process(Buffer, 0, 0, M, N, ldc);
}
/**
* @brief Convert half gemm result matrix to single precision float matrix
*/
class MLAS_HALF_GEMM_2FLOAT_PROCESSOR : public MLAS_HALF_GEMM_POSTPROCESSOR {
public:
MLAS_HALF_GEMM_2FLOAT_PROCESSOR(
const MLAS_ACTIVATION& Activation,
float* Output, /**< address of the output matrix, row major */
size_t RowStride /**< row stride of the output matrix */
) : Activation_(Activation),
Output_(Output),
RowStride_(RowStride)
{}
void
Process(
MLAS_FP16* C,
size_t StartM,
size_t StartN,
size_t CountM,
size_t CountN,
size_t ldc
) const override;
private:
const MLAS_ACTIVATION& Activation_;
float* Output_;
size_t RowStride_;
const size_t RowStride_;
};
@ -1459,17 +1513,17 @@ struct MLAS_HALF_GEMM_DATA_PARAMS {
/**
* @brief Half precision Batched GEMM: C = A * B + Bias
* Either A or B can be fp32 or fp16
*
*
* Note: We only support uniform batching, so shapes and types of the
* input must be same across all parameter blocks.
*
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BatchN number of batches
* @param[inout] DataParams An array (size BatchN) of parameter blocks
* @param[in] ThreadPool
* @return
* @param[in] ThreadPool
* @return
*/
void
MLASCALL
@ -1485,7 +1539,7 @@ MlasHalfGemmBatch(
/**
* @brief For half precision GEMM, returns size of the
* packing buffer needed for right hand side
* @param[in] N Number of columns
* @param[in] N Number of columns
* @param[in] K Number of rows
* @param[in] float2half Whether the input is float that
* needs to be converted to half precision
@ -1503,11 +1557,11 @@ MlasHalfGemmPackBSize(
/**
* @brief For half precision GEMM, pack the right hand
* side matrix B
*
*
* @param[in] N Number of columns
* @param[in] K Number of rows
* @param[in] B Address of matrix B
* @param[in] ldb leading dimension of input matrix B
* @param[in] B Address of matrix B
* @param[in] ldb leading dimension of input matrix B
* @param[out] PackedB Address of the packed matrix
*/
void
@ -1523,11 +1577,11 @@ MlasHalfGemmPackB(
/**
* @brief For half precision GEMM, convert the float matrix B
* to half precision and pack it into a packing buffer
*
*
* @param[in] N Number of columns
* @param[in] K Number of rows
* @param[in] B Address of matrix B
* @param[in] ldb leading dimension of input matrix B
* @param[in] B Address of matrix B
* @param[in] ldb leading dimension of input matrix B
* @param[out] PackedB Address of the packed matrix
*/
void

View file

@ -0,0 +1,337 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
activate_fp16.cpp
Abstract:
This module implements the activation routines for fp16 data types
--*/
#include "fp16_common.h"
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
//
// Templates for activation functions.
//
template<MLAS_ACTIVATION_KIND ActivationKind>
struct MLAS_HALF_ACTIVATION_FUNCTION;
template<>
struct MLAS_HALF_ACTIVATION_FUNCTION<MlasReluActivation>
{
const MLAS_FLOAT16X8 ZeroVec = MlasZeroFloat16x8();
MLAS_HALF_ACTIVATION_FUNCTION(const MLAS_ACTIVATION& Activation)
{
MLAS_UNREFERENCED_PARAMETER(Activation);
}
MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value)
{
return MlasMaximumFloat16x8(ZeroVec, Value);
}
MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value)
{
return MlasMaximumFloat16x4(MlasToLowHalfFloat16x4(ZeroVec), Value);
}
};
template<>
struct MLAS_HALF_ACTIVATION_FUNCTION<MlasLeakyReluActivation>
{
const MLAS_FLOAT16X8 ZeroVec = MlasZeroFloat16x8();
MLAS_FLOAT16X8 AlphaBroadcast;
MLAS_HALF_ACTIVATION_FUNCTION(const MLAS_ACTIVATION& Activation)
{
const _mlas_fp16_ alpha = MLAS_Float2Half(Activation.Parameters.LeakyRelu.alpha);
AlphaBroadcast = MlasBroadcastFloat16x8(alpha);
}
MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value)
{
MLAS_FLOAT16X8 ValueTimesAlpha = MlasMultiplyFloat16x8(Value, AlphaBroadcast);
return MlasBitwiseSelectFloat16x8(MlasCmpLessEqualFloat16x8(Value, ZeroVec),
ValueTimesAlpha, Value);
}
MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value)
{
MLAS_FLOAT16X4 ValueTimesAlpha =
MlasMultiplyFloat16x4(Value, MlasToLowHalfFloat16x4(AlphaBroadcast));
return MlasBitwiseSelectFloat16x4(
MlasCmpLessEqualFloat16x4(Value, MlasToLowHalfFloat16x4(ZeroVec)), ValueTimesAlpha,
Value);
}
};
template<>
struct MLAS_HALF_ACTIVATION_FUNCTION<MlasClipActivation>
{
MLAS_FLOAT16X8 MinimumBroadcast;
MLAS_FLOAT16X8 MaximumBroadcast;
MLAS_HALF_ACTIVATION_FUNCTION(const MLAS_ACTIVATION& Activation)
{
const _mlas_fp16_ min = MLAS_Float2Half(Activation.Parameters.Clip.minimum);
MinimumBroadcast = MlasBroadcastFloat16x8(min);
const _mlas_fp16_ max = MLAS_Float2Half(Activation.Parameters.Clip.maximum);
MaximumBroadcast = MlasBroadcastFloat16x8(max);
}
MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value)
{
Value = MlasMaximumFloat16x8(MinimumBroadcast, Value);
Value = MlasMinimumFloat16x8(MaximumBroadcast, Value);
return Value;
}
MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value)
{
Value = MlasMaximumFloat16x4(MlasToLowHalfFloat16x4(MinimumBroadcast), Value);
Value = MlasMinimumFloat16x4(MlasToLowHalfFloat16x4(MaximumBroadcast), Value);
return Value;
}
};
template<>
struct MLAS_HALF_ACTIVATION_FUNCTION<MlasHardSigmoidActivation>
{
MLAS_FLOAT16X8 AlphaBroadcast;
MLAS_FLOAT16X8 BetaBroadcast;
MLAS_FLOAT16X8 MinimumBroadcast;
MLAS_FLOAT16X8 MaximumBroadcast;
MLAS_HALF_ACTIVATION_FUNCTION(const MLAS_ACTIVATION& Activation)
{
const _mlas_fp16_ alpha = MLAS_Float2Half(Activation.Parameters.HardSigmoid.alpha);
AlphaBroadcast = MlasBroadcastFloat16x8(alpha);
const _mlas_fp16_ beta = MLAS_Float2Half(Activation.Parameters.HardSigmoid.beta);
BetaBroadcast = MlasBroadcastFloat16x8(beta);
MinimumBroadcast = MlasZeroFloat16x8();
MaximumBroadcast = MlasBroadcastFloat16x8(MLAS_Float2Half(1.0f));
}
MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value)
{
Value = MlasMultiplyAddFloat16x8(Value, AlphaBroadcast, BetaBroadcast);
Value = MlasMinimumFloat16x8(MaximumBroadcast, Value);
Value = MlasMaximumFloat16x8(MinimumBroadcast, Value);
return Value;
}
MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value)
{
Value = MlasMultiplyAddFloat16x4(Value, MlasToLowHalfFloat16x4(AlphaBroadcast),
MlasToLowHalfFloat16x4(BetaBroadcast));
Value = MlasMinimumFloat16x4(MlasToLowHalfFloat16x4(MaximumBroadcast), Value);
Value = MlasMaximumFloat16x4(MlasToLowHalfFloat16x4(MinimumBroadcast), Value);
return Value;
}
};
template<MLAS_ACTIVATION_KIND ActivationKind>
inline
void
MlasActivationKernel(
const MLAS_ACTIVATION& Activation,
MLAS_FP16* Buffer,
size_t StartM,
size_t StartN,
size_t CountM,
size_t CountN,
size_t ldc
)
{
MLAS_HALF_ACTIVATION_FUNCTION<ActivationKind> ActivationFunction(Activation);
auto* CRow = reinterpret_cast<_mlas_fp16_*>(Buffer);
CRow += StartM * ldc + StartN;
while (CountM-- > 0) {
_mlas_fp16_* buffer = CRow;
size_t n = CountN;
while (n >= 8) {
MLAS_FLOAT16X8 Vector = MlasLoadFloat16x8(buffer);
MlasStoreFloat16x8(buffer, ActivationFunction.Activate(Vector));
buffer += 8;
n -= 8;
}
if (n >= 4) {
MLAS_FLOAT16X4 Vector = MlasLoadFloat16x4(buffer);
MlasStoreFloat16x4(buffer, ActivationFunction.Activate(Vector));
buffer += 4;
n -= 4;
}
if (n > 0) {
MLAS_FLOAT16X4 buf;
std::memcpy(&buf, buffer, n * sizeof(_mlas_fp16_));
MLAS_FLOAT16X4 res = ActivationFunction.Activate(buf);
MlasStorePartialFloat16x4(buffer, res, n);
}
CRow += ldc;
}
}
template<>
inline
void
MlasActivationKernel<MlasIdentityActivation>(
const MLAS_ACTIVATION& Activation,
MLAS_FP16* Buffer,
size_t StartM,
size_t StartN,
size_t CountM,
size_t CountN,
size_t ldc
)
{
//
// No operation.
//
MLAS_UNREFERENCED_PARAMETER(Activation);
MLAS_UNREFERENCED_PARAMETER(Buffer);
MLAS_UNREFERENCED_PARAMETER(StartM);
MLAS_UNREFERENCED_PARAMETER(StartN);
MLAS_UNREFERENCED_PARAMETER(CountM);
MLAS_UNREFERENCED_PARAMETER(CountN);
MLAS_UNREFERENCED_PARAMETER(ldc);
}
void
MLAS_HALF_GEMM_ACTIVATION_PROCESSOR::Process(
MLAS_FP16* C,
size_t StartM,
size_t StartN,
size_t CountM,
size_t CountN,
size_t ldc
) const
{
switch (Activation_.ActivationKind) {
case MlasIdentityActivation: {
MlasActivationKernel<MlasIdentityActivation>(Activation_, C, StartM, StartN, CountM, CountN, ldc);
break;
}
case MlasReluActivation: {
MlasActivationKernel<MlasReluActivation>(Activation_, C, StartM, StartN, CountM, CountN,
ldc);
break;
}
case MlasLeakyReluActivation: {
MlasActivationKernel<MlasLeakyReluActivation>(Activation_, C, StartM, StartN, CountM,
CountN, ldc);
break;
}
case MlasClipActivation: {
MlasActivationKernel<MlasClipActivation>(Activation_, C, StartM, StartN, CountM, CountN,
ldc);
;
break;
}
case MlasHardSigmoidActivation: {
MlasActivationKernel<MlasHardSigmoidActivation>(Activation_, C, StartM, StartN, CountM,
CountN, ldc);
break;
}
/* case MlasTanhActivation : {
if (N == ldc) {
MlasComputeTanh(Buffer, Buffer, M * N);
} else {
while (M-- > 0) {
MlasComputeTanh(Buffer, Buffer, N);
Buffer += ldc;
}
}
break;
}
case MlasLogisticActivation: {
if (N == ldc) {
MlasComputeLogistic(Buffer, Buffer, M * N);
} else {
while (M-- > 0) {
MlasComputeLogistic(Buffer, Buffer, N);
Buffer += ldc;
}
}
break;
}
*/
default:
// Tanh and Logistic activation not supported.
return;
}
}
#else
// Really dumb implementation when fp16 acceleration is not supported
#include <vector>
MLAS_FORCEINLINE
void
CvtFloat2Half(
_mlas_fp16_* dest,
const float* src,
size_t len
)
{
for (size_t i = 0; i < len; i++) {
*dest++ = MLAS_Float2Half(*src++);
}
}
void
MLAS_HALF_GEMM_ACTIVATION_PROCESSOR::Process(
MLAS_FP16* C,
size_t StartM,
size_t StartN,
size_t CountM,
size_t CountN,
size_t ldc
) const
{
std::vector<float> buffer(CountM*CountN);
MLAS_HALF_GEMM_2FLOAT_PROCESSOR proc(this->Activation_, buffer.data(), CountN);
proc.Process(C, StartM, StartN, CountM, CountN, ldc);
_mlas_fp16_* Output = reinterpret_cast<_mlas_fp16_*>(C);
const auto* CRow = buffer.data();
Output += StartM * ldc + StartN;
while (CountM-- > 0) {
CvtFloat2Half(Output, CRow, CountN);
CRow += CountN;
Output += ldc;
}
}
#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED

View file

@ -0,0 +1,300 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
fp16_common.h
Abstract:
Intrinsic and inline functions for fp16 processing.
--*/
#pragma once
#include "mlas_float16.h"
#include "mlasi.h"
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
// TODO!! Add intel fp16 implementations
typedef float16x8_t MLAS_FLOAT16X8;
typedef float16x4_t MLAS_FLOAT16X4;
typedef uint16x8_t MLAS_UINT16X8;
typedef uint16x4_t MLAS_UINT16X4;
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasReinterpretAsFloat16x8(MLAS_INT32X4 Vector) { return vreinterpretq_f16_s32(Vector); }
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasBroadcastFloat16x8(_mlas_fp16_ Value) { return vreinterpretq_f16_p16(vdupq_n_p16(Value)); }
MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasBroadcastFloat16x4(_mlas_fp16_ Value) { return vreinterpret_f16_p16(vdup_n_p16(Value)); }
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasBroadcastFloat16x8(const _mlas_fp16_* Value) { return vreinterpretq_f16_u16(vld1q_dup_u16(Value)); }
MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasBroadcastFloat16x4(const _mlas_fp16_* Value) { return vreinterpret_f16_u16(vld1_dup_u16(Value)); }
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasZeroFloat16x8(void) { return vreinterpretq_f16_f32(vdupq_n_f32(0.0f)); }
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasLoadFloat16x8(const _mlas_fp16_* Buffer) { return vreinterpretq_f16_u16(vld1q_u16(Buffer)); }
MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasLoadFloat16x4(const _mlas_fp16_* Buffer) { return vreinterpret_f16_u16(vld1_u16(Buffer)); }
MLAS_FORCEINLINE
void
MlasStoreFloat16x8(_mlas_fp16_* Buffer, MLAS_FLOAT16X8 Vector)
{
vst1q_u16(Buffer, vreinterpretq_u16_f16(Vector));
}
MLAS_FORCEINLINE
void
MlasStoreFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector)
{
vst1_u16(Buffer, vreinterpret_u16_f16(Vector));
}
MLAS_FORCEINLINE
void
MlasStorePartialFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector, size_t len)
{
if ((len & 2) != 0) {
vst1_lane_f32(reinterpret_cast<float*>(Buffer), vreinterpret_f32_f16(Vector), 0);
Vector = vreinterpret_f16_f32(vdup_lane_f32(vreinterpret_f32_f16(Vector), 1));
Buffer += 2;
}
if ((len & 1) != 0) {
vst1_lane_u16(Buffer, vreinterpret_u16_f16(Vector), 0);
}
}
template <unsigned Lane>
MLAS_FORCEINLINE void
MlasStoreLaneFloat16x8(_mlas_fp16_* Buffer, MLAS_FLOAT16X8 Vector)
{
vst1q_lane_u16(Buffer, vreinterpretq_u16_f16(Vector), Lane);
}
MLAS_FORCEINLINE MLAS_FLOAT16X4
MlasToLowHalfFloat16x4(MLAS_FLOAT16X8 V)
{
// vget_low should be compiled to nothing
return vget_low_f16(V);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasAddFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
{
return vaddq_f16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasAddFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2)
{
return vadd_f16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasSubtractFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
{
return vsubq_f16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasSubtractFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2)
{
return vsub_f16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasMultiplyFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
{
return vmulq_f16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasMultiplyFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2)
{
return vmul_f16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasMultiplyAddFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2, MLAS_FLOAT16X8 Vector3)
{
return vfmaq_f16(Vector3, Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasMultiplyAddFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2, MLAS_FLOAT16X4 Vector3)
{
return vfma_f16(Vector3, Vector1, Vector2);
}
MLAS_FORCEINLINE
void
MlasMultiplyAddFloat16x8(MLAS_FLOAT16X8 Vector1, _mlas_fp16_ Scalar2, MLAS_FLOAT16X8 Vector3)
{
MlasMultiplyAddFloat16x8(Vector1, MlasBroadcastFloat16x8(Scalar2), Vector3);
}
MLAS_FORCEINLINE
void
MlasMultiplyAddFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2, _mlas_fp16_ Scalar3)
{
MlasMultiplyAddFloat16x8(Vector1, Vector2, MlasBroadcastFloat16x8(Scalar3));
}
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasDivideFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
{
return vdivq_f16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasGreaterThanFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
{
return vreinterpretq_f16_u16(vcgtq_f16(Vector1, Vector2));
}
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasAndFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
{
return vreinterpretq_f16_s64(vandq_s64(vreinterpretq_s64_f16(Vector1), vreinterpretq_s64_f16(Vector2)));
}
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasOrFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
{
return vreinterpretq_f16_s64(vorrq_s64(vreinterpretq_s64_f16(Vector1), vreinterpretq_s64_f16(Vector2)));
}
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasAndNotFloat16x8(MLAS_FLOAT16X8 VectorNot, MLAS_FLOAT16X8 Vector)
{
return vreinterpretq_f16_s32(vandq_s32(vmvnq_s32(vreinterpretq_s32_f16(VectorNot)), vreinterpretq_s32_f16(Vector)));
}
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasXorFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
{
return vreinterpretq_f16_s32(veorq_s32(vreinterpretq_s32_f16(Vector1), vreinterpretq_s32_f16(Vector2)));
}
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasBlendFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2, MLAS_FLOAT16X8 Selection)
{
return MlasOrFloat16x8(MlasAndFloat16x8(Vector2, Selection),
MlasAndNotFloat16x8(Selection, Vector1));
}
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasMaximumFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
{
return vmaxq_f16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasMaximumFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2)
{
return vmax_f16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasMinimumFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
{
return vminq_f16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasMinimumFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2)
{
return vmin_f16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasClampFloat16x8(MLAS_FLOAT16X8 Value, _mlas_fp16_ LowerRange, _mlas_fp16_ UpperRange)
{
Value = MlasMaximumFloat16x8(MlasBroadcastFloat16x8(LowerRange), Value);
Value = MlasMinimumFloat16x8(MlasBroadcastFloat16x8(UpperRange), Value);
return Value;
}
MLAS_FORCEINLINE
_mlas_fp16_
MlasReduceAddFloat16x8(MLAS_FLOAT16X8 Vector)
{
Vector = vpaddq_f16(Vector, Vector);
Vector = vpaddq_f16(Vector, Vector);
return vgetq_lane_u16(vreinterpretq_u16_f16(Vector), 0);
}
MLAS_FORCEINLINE
MLAS_UINT16X8
MlasCmpLessEqualFloat16x8(MLAS_FLOAT16X8 left, MLAS_FLOAT16X8 right)
{
return vcleq_f16(left, right);
}
MLAS_FORCEINLINE
MLAS_UINT16X4
MlasCmpLessEqualFloat16x4(MLAS_FLOAT16X4 left, MLAS_FLOAT16X4 right)
{
return vcle_f16(left, right);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasBitwiseSelectFloat16x8(MLAS_UINT16X8 select, MLAS_FLOAT16X8 ones, MLAS_FLOAT16X8 zeros)
{
return vbslq_f16(select, ones, zeros);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasBitwiseSelectFloat16x4(MLAS_UINT16X4 select, MLAS_FLOAT16X4 ones, MLAS_FLOAT16X4 zeros)
{
return vbsl_f16(select, ones, zeros);
}
#endif // fp16 vector intrinsic supported

View file

@ -25,7 +25,11 @@ Abstract:
bool MLASCALL
MlasFp16AccelerationSupported()
{
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
return MLAS_CPUIDINFO::GetCPUIDInfo().HasFp16VectorAcceleration();
#else
return false;
#endif
}
@ -217,9 +221,6 @@ MLAS_HALF_GEMM_2FLOAT_PROCESSOR::Process(
size_t ldc
) const
{
//
// TODO!! use templates to add activations in this impl
//
float* Output = Output_;
const auto* CRow = reinterpret_cast<const _mlas_fp16_*>(C);
CRow += StartM * ldc + StartN;
@ -227,7 +228,7 @@ MLAS_HALF_GEMM_2FLOAT_PROCESSOR::Process(
while (CountM-- > 0) {
CvtHalf2Float(Output, CRow, CountN);
MlasActivation(&Activation_, Output, nullptr, 1, CountN, ldc);
CRow += ldc;
Output += RowStride_;
}

View file

@ -0,0 +1,67 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
test_fp16.h
Abstract:
Define fp16 type before it is available in all compilers
--*/
#pragma once
#include "test_util.h"
#include "mlas_float16.h"
//
// Define our own fp16 type to avoid dragging in big dependencies
//
struct MLFp16 {
uint16_t val{0};
MLFp16() = default;
explicit constexpr MLFp16(uint16_t x) : val(x) {}
explicit constexpr MLFp16(int32_t x) : val((uint16_t)x) {}
explicit MLFp16(float ff) : val(MLAS_Float2Half(ff)) {}
float ToFloat() const {
return MLAS_Half2Float(val);
}
operator float() const { return ToFloat(); }
MLFp16& operator=(float ff) {
val = MLAS_Float2Half(ff);
return *this;
}
};
inline bool
operator==(const MLFp16& left, const MLFp16& right) {
return left.val == right.val;
}
inline bool
operator!=(const MLFp16& left, const MLFp16& right) {
return left.val != right.val;
}
template<typename T>
void SmallFloatFill(T* start, size_t size) {
constexpr float MinimumFillValue = -11.0f;
auto* FillAddress = start;
size_t offset = size % 23;
for (size_t i = 0; i < size; i++) {
offset = (offset + 21) % 23;
*FillAddress++ = T((MinimumFillValue + offset) / 16.0f);
}
}

View file

@ -0,0 +1,128 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "test_fp16.h"
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
class MlasFp16ActivationTest : public MlasTestBase {
public:
static const char* GetTestSuiteName() {
static const std::string suite_name("Fp16Activation");
return suite_name.c_str();
}
void ExecuteShort(void) override {
union AliasedValue {
unsigned u;
float f;
};
// N.B. The test data includes values at the edge of Tanh/Logistic boundaries.
// Identity, Relu, LeakyRelu, Tanh, Logistic, Clip,
static const AliasedValue TestData[] = {
{0x00000001}, // positive denormal
{0x80000001}, // negative denormal
{0x7fc00000}, // positive NaN
{0xffc00000}, // negative NaN
{0x00000000}, // 0.0f
{0x80000000}, // -0.0f
{0x3e800000}, // 0.25f
{0xbe800000}, // -0.25f
{0x40800000}, // 4.0f
{0xc0800000}, // -4.0f
{0x41200000}, // 10.0f
{0xc1200000}, // -10.0f
{0xc18866eb}, // -17.0502529144f
{0xc18869bb}, // -17.0516262054f
{0xc18852a8}, // -17.0403594971f
{0xc18844aa}, // -17.0335273743f
{0x418866eb}, // +17.0502529144f
{0x418869bb}, // +17.0516262054f
{0x418852a8}, // +17.0403594971f
{0x418844aa} // +17.0335273743f
};
constexpr size_t M = 5;
constexpr size_t N = 23;
MatrixGuardBuffer<MLFp16> HalfBuffer1;
auto* testData1 = HalfBuffer1.GetBuffer(M * N, true);
MatrixGuardBuffer<MLFp16> HalfBuffer2;
auto* testData2 = HalfBuffer2.GetBuffer(M * N, true);
MatrixGuardBuffer<float> FloatBuffer;
auto* fpBuffer = FloatBuffer.GetBuffer(M * N, true);
MLAS_ACTIVATION_KIND acts[] = {
MlasIdentityActivation,
MlasReluActivation,
MlasLeakyReluActivation,
MlasClipActivation,
MlasHardSigmoidActivation};
MLAS_ACTIVATION Activation;
MLAS_HALF_GEMM_ACTIVATION_PROCESSOR proc(Activation);
MLAS_HALF_GEMM_2FLOAT_PROCESSOR converter(Activation, fpBuffer, N);
for (auto kind : acts) {
Activation.ActivationKind = MLAS_ACTIVATION_KIND(kind);
if (Activation.ActivationKind == MlasLeakyReluActivation) {
Activation.Parameters.LeakyRelu.alpha = 0.2f;
} else if (Activation.ActivationKind == MlasClipActivation) {
Activation.Parameters.Clip.minimum = 0.0f;
Activation.Parameters.Clip.maximum = 6.0f;
} else if (Activation.ActivationKind == MlasHardSigmoidActivation){
Activation.Parameters.HardSigmoid.alpha = 0.2f;
Activation.Parameters.HardSigmoid.beta = 0.12f;
}
//
// Test the vectorized activations.
//
for (size_t i = 0; i < _countof(TestData); i++) {
testData1[i] = TestData[i].f;
testData2[i] = TestData[i].f;
}
constexpr float MinimumFillValue = -11.0f;
size_t offset = 7;
for (size_t i = _countof(TestData); i < M * N; i++) {
offset = (offset + 19) % 23;
testData1[i] = (MinimumFillValue + offset) / 16.0f;
testData2[i] = testData1[i];
}
proc.Process(reinterpret_cast<MLAS_FP16*>(testData1), 0, 0, M, N, N);
converter.Process(reinterpret_cast<MLAS_FP16*>(testData2), 0, 0, M, N, N);
for (size_t i = 0; i < M*N; i++) {
float actual = testData1[i].ToFloat();
if (std::isnan(actual)) {
EXPECT_TRUE(std::isnan(fpBuffer[i]))
<< ", Vector Activation Kind:" << (int)kind << ", i=" << i << ", value:"
<< std::setw(8) << std::setfill('0') << std::hex << actual << ", expecting:"
<< std::setw(8) << std::setfill('0') << std::hex << fpBuffer[i];
} else {
float diff = std::abs(actual - fpBuffer[i]);
float top = std::max(std::abs(actual), std::abs(fpBuffer[i]));
float ratio = 0;
if (top > 0.0001) {
ratio = diff / top;
}
EXPECT_TRUE(ratio < 0.005)
<< ", Vector Activation Kind:" << (int)kind << ", i=" << i << ", value:"
<< actual << ", expecting:" << fpBuffer[i];
}
}
}
}
};
template<> MlasFp16ActivationTest* MlasTestFixture<MlasFp16ActivationTest>::mlas_tester(nullptr);
static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) {
return is_short_execute ? MlasDirectShortExecuteTests<MlasFp16ActivationTest>::RegisterShortExecute() : 0;
});
#endif // fp16 vector intrinsic supported

View file

@ -16,54 +16,7 @@ Abstract:
#pragma once
#include "test_util.h"
#include "mlas_float16.h"
//
// Define our own fp16 type to avoid dragging in big dependencies
//
struct MLFp16 {
uint16_t val{0};
MLFp16() = default;
explicit constexpr MLFp16(uint16_t x) : val(x) {}
explicit constexpr MLFp16(int32_t x) : val((uint16_t)x) {}
explicit MLFp16(float ff) : val(MLAS_Float2Half(ff)) {}
float ToFloat() const {
return MLAS_Half2Float(val);
}
operator float() const { return ToFloat(); }
MLFp16& operator=(float ff) {
val = MLAS_Float2Half(ff);
return *this;
}
};
inline bool
operator==(const MLFp16& left, const MLFp16& right) {
return left.val == right.val;
}
inline bool
operator!=(const MLFp16& left, const MLFp16& right) {
return left.val != right.val;
}
template<typename T>
void SmallFloatFill(T* start, size_t size) {
constexpr float MinimumFillValue = -11.0f;
auto* FillAddress = start;
size_t offset = size % 23;
for (size_t i = 0; i < size; i++) {
offset = (offset + 21) % 23;
*FillAddress++ = T((MinimumFillValue + offset) / 16.0f);
}
}
#include "test_fp16.h"
inline bool
CloseEnough(float actual, float expected){
@ -123,6 +76,8 @@ private:
MLFp16* C,
size_t ldc,
float* Cfloat) {
MLAS_ACTIVATION act;
act.ActivationKind = MlasIdentityActivation;
std::vector<MLAS_HALF_GEMM_2FLOAT_PROCESSOR> Converters;
Converters.reserve(BatchSize);
@ -150,7 +105,7 @@ private:
}
params.AIsfp32 = std::is_same<AType, float>::value;
params.BIsfp32 = std::is_same<BType, float>::value;
Converters.emplace_back(Cfloat + (M * N * i), N);
Converters.emplace_back(act, Cfloat + (M * N * i), N);
params.OutputProcessor = &(Converters[i]);
}