diff --git a/onnxruntime/core/mlas/lib/activate_fp16.cpp b/onnxruntime/core/mlas/lib/activate_fp16.cpp index 0b1397db77..564c6a4fb6 100644 --- a/onnxruntime/core/mlas/lib/activate_fp16.cpp +++ b/onnxruntime/core/mlas/lib/activate_fp16.cpp @@ -51,12 +51,12 @@ struct MLAS_HALF_ACTIVATION_FUNCTION MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value) { - return MlasMaximum(ZeroVec, Value); + return MlasMaximumFloat16(ZeroVec, Value); } MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value) { - return MlasMaximum(MlasToLowHalfFloat16x4(ZeroVec), Value); + return MlasMaximumFloat16(MlasToLowHalfFloat16x4(ZeroVec), Value); } }; @@ -75,7 +75,7 @@ struct MLAS_HALF_ACTIVATION_FUNCTION MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value) { - MLAS_FLOAT16X8 ValueTimesAlpha = MlasMultiply(Value, AlphaBroadcast); + MLAS_FLOAT16X8 ValueTimesAlpha = MlasMultiplyFloat16(Value, AlphaBroadcast); return MlasBitwiseSelectFloat16x8(MlasCmpLessEqualFloat16x8(Value, ZeroVec), ValueTimesAlpha, Value); } @@ -83,7 +83,7 @@ struct MLAS_HALF_ACTIVATION_FUNCTION MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value) { MLAS_FLOAT16X4 ValueTimesAlpha = - MlasMultiply(Value, MlasToLowHalfFloat16x4(AlphaBroadcast)); + MlasMultiplyFloat16(Value, MlasToLowHalfFloat16x4(AlphaBroadcast)); return MlasBitwiseSelectFloat16x4( MlasCmpLessEqualFloat16x4(Value, MlasToLowHalfFloat16x4(ZeroVec)), ValueTimesAlpha, Value); @@ -539,16 +539,16 @@ struct MLAS_HALF_ACTIVATION_FUNCTION { MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value) { - Value = MlasMaximum(MinimumBroadcast, Value); - Value = MlasMinimum(MaximumBroadcast, Value); + Value = MlasMaximumFloat16(MinimumBroadcast, Value); + Value = MlasMinimumFloat16(MaximumBroadcast, Value); return Value; } MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value) { - Value = MlasMaximum(MlasToLowHalfFloat16x4(MinimumBroadcast), Value); - Value = MlasMinimum(MlasToLowHalfFloat16x4(MaximumBroadcast), Value); + Value = MlasMaximumFloat16(MlasToLowHalfFloat16x4(MinimumBroadcast), Value); + Value = MlasMinimumFloat16(MlasToLowHalfFloat16x4(MaximumBroadcast), Value); return Value; } }; @@ -573,19 +573,19 @@ struct MLAS_HALF_ACTIVATION_FUNCTION MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value) { - Value = MlasMultiplyAdd(Value, AlphaBroadcast, BetaBroadcast); - Value = MlasMinimum(MaximumBroadcast, Value); - Value = MlasMaximum(MinimumBroadcast, Value); + Value = MlasMultiplyAddFloat16(Value, AlphaBroadcast, BetaBroadcast); + Value = MlasMinimumFloat16(MaximumBroadcast, Value); + Value = MlasMaximumFloat16(MinimumBroadcast, Value); return Value; } MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value) { - Value = MlasMultiplyAdd(Value, MlasToLowHalfFloat16x4(AlphaBroadcast), + Value = MlasMultiplyAddFloat16(Value, MlasToLowHalfFloat16x4(AlphaBroadcast), MlasToLowHalfFloat16x4(BetaBroadcast)); - Value = MlasMinimum(MlasToLowHalfFloat16x4(MaximumBroadcast), Value); - Value = MlasMaximum(MlasToLowHalfFloat16x4(MinimumBroadcast), Value); + Value = MlasMinimumFloat16(MlasToLowHalfFloat16x4(MaximumBroadcast), Value); + Value = MlasMaximumFloat16(MlasToLowHalfFloat16x4(MinimumBroadcast), Value); return Value; } @@ -692,7 +692,7 @@ MlasActivationKernel( MLAS_FLOAT16X8 AVec = MlasLoadFloat16x8(addsrc); MLAS_FLOAT16X8 Vector = MlasLoadFloat16x8(buffer); addsrc += 8; - Vector = MlasAdd(Vector, AVec); + Vector = MlasAddFloat16(Vector, AVec); Vector = ActivationFunction.Activate(Vector); MlasStoreFloat16x8(buffer, Vector); buffer += 8; @@ -703,7 +703,7 @@ MlasActivationKernel( MLAS_FLOAT16X4 AVec = MlasLoadFloat16x4(addsrc); MLAS_FLOAT16X4 Vector = MlasLoadFloat16x4(buffer); addsrc += 4; - Vector = MlasAdd(Vector, AVec); + Vector = MlasAddFloat16(Vector, AVec); Vector = ActivationFunction.Activate(Vector); MlasStoreFloat16x4(buffer, Vector); buffer += 4; @@ -715,7 +715,7 @@ MlasActivationKernel( MLAS_FLOAT16X4 buf; std::memcpy(&addbuf, addsrc, n * sizeof(_mlas_fp16_)); std::memcpy(&buf, buffer, n * sizeof(_mlas_fp16_)); - buf = MlasAdd(buf, addbuf); + buf = MlasAddFloat16(buf, addbuf); buf = ActivationFunction.Activate(buf); MlasStorePartialFloat16x4(buffer, buf, n); } diff --git a/onnxruntime/core/mlas/lib/dwconv.cpp b/onnxruntime/core/mlas/lib/dwconv.cpp index febc2b15b2..0fff937010 100644 --- a/onnxruntime/core/mlas/lib/dwconv.cpp +++ b/onnxruntime/core/mlas/lib/dwconv.cpp @@ -43,7 +43,7 @@ MlasConvDepthwiseKernel( MLAS_FLOAT16X8 InputVector = MlasLoadFloat16x8(&Input[k][ChannelOffset]); MLAS_FLOAT16X8 FilterVector = MlasLoadFloat16x8(&Filter[ChannelKernelOffset]); - Accumulator = MlasMultiplyAdd(InputVector, FilterVector, Accumulator); + Accumulator = MlasMultiplyAddFloat16(InputVector, FilterVector, Accumulator); ChannelKernelOffset += Channels; } MlasStoreFloat16x8(Output, Accumulator); @@ -61,7 +61,7 @@ MlasConvDepthwiseKernel( MLAS_FLOAT16X4 InputVector = MlasLoadFloat16x4(&Input[k][ChannelOffset]); MLAS_FLOAT16X4 FilterVector = MlasLoadFloat16x4(&Filter[ChannelKernelOffset]); - Accumulator = MlasMultiplyAdd(InputVector, FilterVector, Accumulator); + Accumulator = MlasMultiplyAddFloat16(InputVector, FilterVector, Accumulator); ChannelKernelOffset += Channels; } MlasStoreFloat16x4(Output, Accumulator); @@ -80,7 +80,7 @@ MlasConvDepthwiseKernel( MLAS_FLOAT16X4 InputValue = MlasLoadFloat16x4(&Input[k][ChannelOffset]); MLAS_FLOAT16X4 FilterValue = MlasLoadFloat16x4(&Filter[ChannelKernelOffset]); - Accumulator = MlasMultiplyAdd(InputValue, FilterValue, Accumulator); + Accumulator = MlasMultiplyAddFloat16(InputValue, FilterValue, Accumulator); ChannelKernelOffset += Channels; } MlasStorePartialFloat16x4(Output, Accumulator, c); diff --git a/onnxruntime/core/mlas/lib/fp16_common.h b/onnxruntime/core/mlas/lib/fp16_common.h index 60ec0ccdc4..d4713cce5a 100644 --- a/onnxruntime/core/mlas/lib/fp16_common.h +++ b/onnxruntime/core/mlas/lib/fp16_common.h @@ -32,23 +32,23 @@ typedef int16x4_t MLAS_INT16X4; MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasReinterpretAsFloat16(MLAS_INT32X4 Vector) { return vreinterpretq_f16_s32(Vector); } +MlasReinterpretInt32AsFloat16(MLAS_INT32X4 Vector) { return vreinterpretq_f16_s32(Vector); } MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasReinterpretAsFloat16(MLAS_INT16X8 Vector) { return vreinterpretq_f16_s16(Vector); } +MlasReinterpretInt16AsFloat16(MLAS_INT16X8 Vector) { return vreinterpretq_f16_s16(Vector); } MLAS_FORCEINLINE MLAS_FLOAT16X4 -MlasReinterpretAsFloat16(MLAS_INT16X4 Vector) { return vreinterpret_f16_s16(Vector); } +MlasReinterpretInt16AsFloat16(MLAS_INT16X4 Vector) { return vreinterpret_f16_s16(Vector); } MLAS_FORCEINLINE MLAS_INT16X8 -MlasReinterpretAsInt16(MLAS_FLOAT16X8 Vector) { return vreinterpretq_s16_f16(Vector); } +MlasReinterpretFloat16AsInt16(MLAS_FLOAT16X8 Vector) { return vreinterpretq_s16_f16(Vector); } MLAS_FORCEINLINE MLAS_INT16X4 -MlasReinterpretAsInt16(MLAS_FLOAT16X4 Vector) { return vreinterpret_s16_f16(Vector); } +MlasReinterpretFloat16AsInt16(MLAS_FLOAT16X4 Vector) { return vreinterpret_s16_f16(Vector); } MLAS_FORCEINLINE MLAS_FLOAT16X8 @@ -160,98 +160,98 @@ MlasToLowHalfFloat16x4(MLAS_FLOAT16X8 V) MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasAdd(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +MlasAddFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) { return vaddq_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X4 -MlasAdd(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) +MlasAddFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) { return vadd_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_INT16X8 -MlasAdd(MLAS_INT16X8 Vector1, MLAS_INT16X8 Vector2) +MlasAddInt16(MLAS_INT16X8 Vector1, MLAS_INT16X8 Vector2) { return vaddq_s16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_INT16X4 -MlasAdd(MLAS_INT16X4 Vector1, MLAS_INT16X4 Vector2) +MlasAddInt16(MLAS_INT16X4 Vector1, MLAS_INT16X4 Vector2) { return vadd_s16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasSubtract(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +MlasSubtractFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) { return vsubq_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X4 -MlasSubtract(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) +MlasSubtractFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) { return vsub_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_INT16X8 -MlasSubtract(MLAS_INT16X8 Vector1, MLAS_INT16X8 Vector2) +MlasSubtractInt16(MLAS_INT16X8 Vector1, MLAS_INT16X8 Vector2) { return vsubq_s16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_INT16X4 -MlasSubtract(MLAS_INT16X4 Vector1, MLAS_INT16X4 Vector2) +MlasSubtractInt16(MLAS_INT16X4 Vector1, MLAS_INT16X4 Vector2) { return vsub_s16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasMultiply(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +MlasMultiplyFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) { return vmulq_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X4 -MlasMultiply(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) +MlasMultiplyFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) { return vmul_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasDivide(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +MlasDivideFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) { return vdivq_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X4 -MlasDivide(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) +MlasDivideFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) { return vdiv_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasMultiplyAdd(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2, MLAS_FLOAT16X8 Dest) +MlasMultiplyAddFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2, MLAS_FLOAT16X8 Dest) { return vfmaq_f16(Dest, Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X4 -MlasMultiplyAdd(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2, MLAS_FLOAT16X4 Dest) +MlasMultiplyAddFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2, MLAS_FLOAT16X4 Dest) { return vfma_f16(Dest, Vector1, Vector2); } @@ -260,14 +260,14 @@ MLAS_FORCEINLINE void MlasMultiplyAddFloat16x8(MLAS_FLOAT16X8 Vector1, _mlas_fp16_ Scalar2, MLAS_FLOAT16X8 Vector3) { - MlasMultiplyAdd(Vector1, MlasBroadcastFloat16x8(Scalar2), Vector3); + MlasMultiplyAddFloat16(Vector1, MlasBroadcastFloat16x8(Scalar2), Vector3); } MLAS_FORCEINLINE void MlasMultiplyAddFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2, _mlas_fp16_ Scalar3) { - MlasMultiplyAdd(Vector1, Vector2, MlasBroadcastFloat16x8(Scalar3)); + MlasMultiplyAddFloat16(Vector1, Vector2, MlasBroadcastFloat16x8(Scalar3)); } MLAS_FORCEINLINE @@ -315,56 +315,56 @@ MlasBlendFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2, MLAS_FLOAT16X MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasMaximum(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +MlasMaximumFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) { return vmaxq_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X4 -MlasMaximum(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) +MlasMaximumFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) { return vmax_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_INT16X8 -MlasMaximum(MLAS_INT16X8 Vector1, MLAS_INT16X8 Vector2) +MlasMaximumInt16(MLAS_INT16X8 Vector1, MLAS_INT16X8 Vector2) { return vmaxq_s16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_INT16X4 -MlasMaximum(MLAS_INT16X4 Vector1, MLAS_INT16X4 Vector2) +MlasMaximumInt16(MLAS_INT16X4 Vector1, MLAS_INT16X4 Vector2) { return vmax_s16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasMinimum(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +MlasMinimumFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) { return vminq_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X4 -MlasMinimum(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) +MlasMinimumFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) { return vmin_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_INT16X8 -MlasMinimum(MLAS_INT16X8 Vector1, MLAS_INT16X8 Vector2) +MlasMinimumInt16(MLAS_INT16X8 Vector1, MLAS_INT16X8 Vector2) { return vminq_s16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_INT16X4 -MlasMinimum(MLAS_INT16X4 Vector1, MLAS_INT16X4 Vector2) +MlasMinimumInt16(MLAS_INT16X4 Vector1, MLAS_INT16X4 Vector2) { return vmin_s16(Vector1, Vector2); } @@ -373,24 +373,34 @@ MLAS_FORCEINLINE MLAS_FLOAT16X8 MlasClampFloat16x8(MLAS_FLOAT16X8 Value, _mlas_fp16_ LowerRange, _mlas_fp16_ UpperRange) { - Value = MlasMaximum(MlasBroadcastFloat16x8(LowerRange), Value); - Value = MlasMinimum(MlasBroadcastFloat16x8(UpperRange), Value); + Value = MlasMaximumFloat16(MlasBroadcastFloat16x8(LowerRange), Value); + Value = MlasMaximumFloat16(MlasBroadcastFloat16x8(UpperRange), Value); return Value; } template MLAS_FORCEINLINE T -MlasClamp(T Value, T LowerRange, T UpperRange) +MlasClampFloat16(T Value, T LowerRange, T UpperRange) { - Value = MlasMaximum(LowerRange, Value); - Value = MlasMinimum(UpperRange, Value); + Value = MlasMaximumFloat16(LowerRange, Value); + Value = MlasMinimumFloat16(UpperRange, Value); + return Value; +} + +template +MLAS_FORCEINLINE +T +MlasClampInt16(T Value, T LowerRange, T UpperRange) +{ + Value = MlasMaximumInt16(LowerRange, Value); + Value = MlasMinimumInt16(UpperRange, Value); return Value; } MLAS_FORCEINLINE _mlas_fp16_ -MlasReduceAdd(MLAS_FLOAT16X8 Vector) +MlasReduceAddFloat16(MLAS_FLOAT16X8 Vector) { Vector = vpaddq_f16(Vector, Vector); Vector = vpaddq_f16(Vector, Vector); @@ -400,7 +410,7 @@ MlasReduceAdd(MLAS_FLOAT16X8 Vector) MLAS_FORCEINLINE _mlas_fp16_ -MlasReduceAdd(MLAS_FLOAT16X4 Vector) +MlasReduceAddFloat16(MLAS_FLOAT16X4 Vector) { Vector = vpadd_f16(Vector, Vector); Vector = vpadd_f16(Vector, Vector); @@ -409,7 +419,7 @@ MlasReduceAdd(MLAS_FLOAT16X4 Vector) MLAS_FORCEINLINE _mlas_fp16_ -MlasReduceMaximum(MLAS_FLOAT16X8 Vector) +MlasReduceMaximumFloat16(MLAS_FLOAT16X8 Vector) { Vector = vpmaxq_f16(Vector, Vector); Vector = vpmaxq_f16(Vector, Vector); @@ -419,7 +429,7 @@ MlasReduceMaximum(MLAS_FLOAT16X8 Vector) MLAS_FORCEINLINE _mlas_fp16_ -MlasReduceMaximum(MLAS_FLOAT16X4 Vector) +MlasReduceMaximumFloat16(MLAS_FLOAT16X4 Vector) { Vector = vpmax_f16(Vector, Vector); Vector = vpmax_f16(Vector, Vector); @@ -556,7 +566,7 @@ Transpose4x4(MLAS_FLOAT16X4& v0, MLAS_FLOAT16X4& v1, MLAS_FLOAT16X4& v2, MLAS_FL template MLAS_FORCEINLINE MLAS_INT16X8 -MlasShiftLeft(MLAS_INT16X8 Vector) +MlasShiftLeftInt16(MLAS_INT16X8 Vector) { return vshlq_n_s16(Vector, ShiftCount); } @@ -564,7 +574,7 @@ MlasShiftLeft(MLAS_INT16X8 Vector) template MLAS_FORCEINLINE MLAS_INT16X4 -MlasShiftLeft(MLAS_INT16X4 Vector) +MlasShiftLeftInt16(MLAS_INT16X4 Vector) { return vshl_n_s16(Vector, ShiftCount); } diff --git a/onnxruntime/core/mlas/lib/pooling_fp16.cpp b/onnxruntime/core/mlas/lib/pooling_fp16.cpp index 62670a6986..9765192d25 100644 --- a/onnxruntime/core/mlas/lib/pooling_fp16.cpp +++ b/onnxruntime/core/mlas/lib/pooling_fp16.cpp @@ -84,7 +84,7 @@ MLAS_FORCEINLINE MLAS_FLOAT16X8 PoolAggregate16x8(MLAS_FLOAT16X8 agg, MLAS_FLOAT16X8 element) { - return MlasMaximum(agg, element); + return MlasMaximumFloat16(agg, element); } template<> @@ -92,7 +92,7 @@ MLAS_FORCEINLINE MLAS_FLOAT16X4 PoolAggregate16x4(MLAS_FLOAT16X4 agg, MLAS_FLOAT16X4 element) { - return MlasMaximum(agg, element); + return MlasMaximumFloat16(agg, element); } template<> @@ -144,28 +144,28 @@ template <> MLAS_FORCEINLINE MLAS_FLOAT16X8 PoolAggregate16x8(MLAS_FLOAT16X8 agg, MLAS_FLOAT16X8 element) { - return MlasAdd(agg, element); + return MlasAddFloat16(agg, element); } template <> MLAS_FORCEINLINE MLAS_FLOAT16X4 PoolAggregate16x4(MLAS_FLOAT16X4 agg, MLAS_FLOAT16X4 element) { - return MlasAdd(agg, element); + return MlasAddFloat16(agg, element); } template <> MLAS_FORCEINLINE MLAS_FLOAT16X8 PoolSummary16x8(MLAS_FLOAT16X8 agg, MLAS_FLOAT16X8 context) { - return MlasDivide(agg, context); + return MlasDivideFloat16(agg, context); } template <> MLAS_FORCEINLINE MLAS_FLOAT16X4 PoolSummary16x4(MLAS_FLOAT16X4 agg, MLAS_FLOAT16X8 context) { - return MlasDivide(agg, MlasToLowHalfFloat16x4(context)); + return MlasDivideFloat16(agg, MlasToLowHalfFloat16x4(context)); } diff --git a/onnxruntime/core/mlas/lib/softmax_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/softmax_kernel_neon_fp16.cpp index 819b7fe941..2678d3d6c6 100644 --- a/onnxruntime/core/mlas/lib/softmax_kernel_neon_fp16.cpp +++ b/onnxruntime/core/mlas/lib/softmax_kernel_neon_fp16.cpp @@ -124,40 +124,40 @@ template MLAS_FORCEINLINE T Exp_Vector_Fp16(T x) { const auto constants = Get_Exp_Constants(); - auto clamped_x = MlasClamp(x, constants.LowerRange, constants.UpperRange); + auto clamped_x = MlasClampFloat16(x, constants.LowerRange, constants.UpperRange); // integral - auto biased = MlasMultiplyAdd(clamped_x, constants.Log2Reciprocal, constants.RoundingBias); - auto m = MlasSubtract(biased, constants.RoundingBias); + auto biased = MlasMultiplyAddFloat16(clamped_x, constants.Log2Reciprocal, constants.RoundingBias); + auto m = MlasSubtractFloat16(biased, constants.RoundingBias); // residual - auto r = MlasMultiplyAdd(m, constants.Log2High, clamped_x); - r = MlasMultiplyAdd(m, constants.Log2Mid, r); - r = MlasMultiplyAdd(m, constants.Log2Low, r); + auto r = MlasMultiplyAddFloat16(m, constants.Log2High, clamped_x); + r = MlasMultiplyAddFloat16(m, constants.Log2Mid, r); + r = MlasMultiplyAddFloat16(m, constants.Log2Low, r); // handle overflow - auto overflow = MlasShiftLeft<10>(MlasReinterpretAsInt16(biased)); + auto overflow = MlasShiftLeftInt16<10>(MlasReinterpretFloat16AsInt16(biased)); auto normal = overflow; - auto minimum_exponent = MlasReinterpretAsInt16(constants.MinimumExponent); - auto maximum_exponent = MlasReinterpretAsInt16(constants.MaximumExponent); - normal = MlasClamp(normal, minimum_exponent, maximum_exponent); + auto minimum_exponent = MlasReinterpretFloat16AsInt16(constants.MinimumExponent); + auto maximum_exponent = MlasReinterpretFloat16AsInt16(constants.MaximumExponent); + normal = MlasClampInt16(normal, minimum_exponent, maximum_exponent); - overflow = MlasSubtract(overflow, normal); - overflow = MlasAdd(overflow, maximum_exponent); - normal = MlasAdd(normal, maximum_exponent); + overflow = MlasSubtractInt16(overflow, normal); + overflow = MlasAddInt16(overflow, maximum_exponent); + normal = MlasAddInt16(normal, maximum_exponent); // polynomial approximation - auto p = MlasMultiplyAdd(constants.poly_0, r, constants.poly_1); - p = MlasMultiplyAdd(p, r, constants.poly_2); - p = MlasMultiplyAdd(p, r, constants.poly_3); - p = MlasMultiplyAdd(p, r, constants.poly_4); - p = MlasMultiplyAdd(p, r, constants.poly_56); + auto p = MlasMultiplyAddFloat16(constants.poly_0, r, constants.poly_1); + p = MlasMultiplyAddFloat16(p, r, constants.poly_2); + p = MlasMultiplyAddFloat16(p, r, constants.poly_3); + p = MlasMultiplyAddFloat16(p, r, constants.poly_4); + p = MlasMultiplyAddFloat16(p, r, constants.poly_56); - auto overflow_f = MlasReinterpretAsFloat16(overflow); - r = MlasMultiply(r, overflow_f); - p = MlasMultiplyAdd(p, r, overflow_f); - p = MlasMultiply(p, MlasReinterpretAsFloat16(normal)); + auto overflow_f = MlasReinterpretInt16AsFloat16(overflow); + r = MlasMultiplyFloat16(r, overflow_f); + p = MlasMultiplyAddFloat16(p, r, overflow_f); + p = MlasMultiplyFloat16(p, MlasReinterpretInt16AsFloat16(normal)); return p; } @@ -242,44 +242,44 @@ template MLAS_FORCEINLINE T SumExp_Vector_Fp16(T x, T negative_maximum) { const auto constants = Get_Exp_Constants(); - auto clamped_x = MlasMaximum(MlasAdd(x, negative_maximum), constants.LowerRangeSumExp); + auto clamped_x = MlasMaximumFloat16(MlasAddFloat16(x, negative_maximum), constants.LowerRangeSumExp); // integral - auto biased = MlasMultiplyAdd(clamped_x, constants.Log2Reciprocal, constants.RoundingBias); - auto m = MlasSubtract(biased, constants.RoundingBias); + auto biased = MlasMultiplyAddFloat16(clamped_x, constants.Log2Reciprocal, constants.RoundingBias); + auto m = MlasSubtractFloat16(biased, constants.RoundingBias); // residual - auto r = MlasMultiplyAdd(m, constants.Log2High, clamped_x); - r = MlasMultiplyAdd(m, constants.Log2Mid, r); - r = MlasMultiplyAdd(m, constants.Log2Low, r); + auto r = MlasMultiplyAddFloat16(m, constants.Log2High, clamped_x); + r = MlasMultiplyAddFloat16(m, constants.Log2Mid, r); + r = MlasMultiplyAddFloat16(m, constants.Log2Low, r); // 2^m - auto normal = MlasShiftLeft<10>(MlasReinterpretAsInt16(biased)); - normal = MlasAdd(normal, MlasReinterpretAsInt16(constants.MaximumExponent)); + auto normal = MlasShiftLeftInt16<10>(MlasReinterpretFloat16AsInt16(biased)); + normal = MlasAddInt16(normal, MlasReinterpretFloat16AsInt16(constants.MaximumExponent)); // polynomial approximation - auto p = MlasMultiplyAdd(constants.poly_0, r, constants.poly_1); - p = MlasMultiplyAdd(p, r, constants.poly_2); - p = MlasMultiplyAdd(p, r, constants.poly_3); - p = MlasMultiplyAdd(p, r, constants.poly_4); - p = MlasMultiplyAdd(p, r, constants.poly_56); - p = MlasMultiplyAdd(p, r, constants.poly_56); + auto p = MlasMultiplyAddFloat16(constants.poly_0, r, constants.poly_1); + p = MlasMultiplyAddFloat16(p, r, constants.poly_2); + p = MlasMultiplyAddFloat16(p, r, constants.poly_3); + p = MlasMultiplyAddFloat16(p, r, constants.poly_4); + p = MlasMultiplyAddFloat16(p, r, constants.poly_56); + p = MlasMultiplyAddFloat16(p, r, constants.poly_56); - p = MlasMultiply(p, MlasReinterpretAsFloat16(normal)); + p = MlasMultiplyFloat16(p, MlasReinterpretInt16AsFloat16(normal)); return p; } MLAS_FORCEINLINE float16x8_t AddUp(float16x8_t v0, float16x8_t v1, float16x8_t v2, float16x8_t v3, float16x8_t v4) { - v0 = MlasAdd(v0, v1); - v2 = MlasAdd(v2, v3); - return MlasAdd(MlasAdd(v0, v2), v4); + v0 = MlasAddFloat16(v0, v1); + v2 = MlasAddFloat16(v2, v3); + return MlasAddFloat16(MlasAddFloat16(v0, v2), v4); } MLAS_FORCEINLINE float16x8_t AddUp(float16x8_t v0, float16x8_t v1, float16x8_t v2) { - return MlasAdd(MlasAdd(v0, v1), v2); + return MlasAddFloat16(MlasAddFloat16(v0, v1), v2); } MLAS_FP16 SumExp_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, const MLAS_FP16 NegativeMaximum) { @@ -338,7 +338,7 @@ MLAS_FP16 SumExp_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N if (N & 8) { auto v0 = MlasLoadFloat16x8(input); auto r0 = SumExp_Vector_Fp16(v0, negative_maximum8); - accumulator8 = MlasAdd(r0, accumulator8); + accumulator8 = MlasAddFloat16(r0, accumulator8); if (store_output) { MlasStoreFloat16x8(output, r0); @@ -352,7 +352,7 @@ MLAS_FP16 SumExp_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N if (N & 4) { auto v0 = MlasLoadFloat16x4(input); auto r0 = SumExp_Vector_Fp16(v0, negative_maximum4); - accumulator4 = MlasAdd(r0, accumulator4); + accumulator4 = MlasAddFloat16(r0, accumulator4); if (store_output) { MlasStoreFloat16x4(output, r0); @@ -371,8 +371,8 @@ MLAS_FP16 SumExp_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N MlasStorePartialFloat16x4(output, r0, 3); } - r0 = MlasReinterpretAsFloat16(vset_lane_s16((int16_t)0, MlasReinterpretAsInt16(r0), 3)); - accumulator4 = MlasAdd(r0, accumulator4); + r0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0), MlasReinterpretFloat16AsInt16(r0), 3)); + accumulator4 = MlasAddFloat16(r0, accumulator4); } else if (N == 2) { auto v0 = MlasLoadPartialFloat16x4(input, 2); auto r0 = SumExp_Vector_Fp16(v0, negative_maximum4); @@ -381,9 +381,9 @@ MLAS_FP16 SumExp_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N MlasStorePartialFloat16x4(output, r0, 2); } - r0 = MlasReinterpretAsFloat16(vset_lane_s16((int16_t)0, MlasReinterpretAsInt16(r0), 3)); - r0 = MlasReinterpretAsFloat16(vset_lane_s16((int16_t)0, MlasReinterpretAsInt16(r0), 2)); - accumulator4 = MlasAdd(r0, accumulator4); + r0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0), MlasReinterpretFloat16AsInt16(r0), 3)); + r0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0), MlasReinterpretFloat16AsInt16(r0), 2)); + accumulator4 = MlasAddFloat16(r0, accumulator4); } else if (N == 1) { auto v0 = MlasLoadPartialFloat16x4(input, 1); auto r0 = SumExp_Vector_Fp16(v0, negative_maximum4); @@ -392,15 +392,15 @@ MLAS_FP16 SumExp_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N MlasStorePartialFloat16x4(output, r0, 1); } - r0 = MlasReinterpretAsFloat16(vset_lane_s16((int16_t)0, MlasReinterpretAsInt16(r0), 3)); - r0 = MlasReinterpretAsFloat16(vset_lane_s16((int16_t)0, MlasReinterpretAsInt16(r0), 2)); - r0 = MlasReinterpretAsFloat16(vset_lane_s16((int16_t)0, MlasReinterpretAsInt16(r0), 1)); - accumulator4 = MlasAdd(r0, accumulator4); + r0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0), MlasReinterpretFloat16AsInt16(r0), 3)); + r0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0), MlasReinterpretFloat16AsInt16(r0), 2)); + r0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0), MlasReinterpretFloat16AsInt16(r0), 1)); + accumulator4 = MlasAddFloat16(r0, accumulator4); } - auto t = MlasAdd(vget_low_f16(accumulator8), vget_high_f16(accumulator8)); - t = MlasAdd(t, accumulator4); - _mlas_fp16_ result = MlasReduceAdd(t); + auto t = MlasAddFloat16(vget_low_f16(accumulator8), vget_high_f16(accumulator8)); + t = MlasAddFloat16(t, accumulator4); + _mlas_fp16_ result = MlasReduceAddFloat16(t); return MLAS_FP16::FromBits(result); } @@ -478,20 +478,20 @@ template MLAS_FORCEINLINE T Tanh_Vector_Fp16(T x) { const auto constants = Get_Tanh_Constants(); - x = MlasClamp(x, constants.LowerRange, constants.UpperRange); + x = MlasClampFloat16(x, constants.LowerRange, constants.UpperRange); - T x_2 = MlasMultiply(x, x); + T x_2 = MlasMultiplyFloat16(x, x); - T p = MlasMultiplyAdd(constants.alpha_7, x_2, constants.alpha_5); - p = MlasMultiplyAdd(p, x_2, constants.alpha_3); - p = MlasMultiplyAdd(p, x_2, constants.alpha_1); - p = MlasMultiply(p, x); + T p = MlasMultiplyAddFloat16(constants.alpha_7, x_2, constants.alpha_5); + p = MlasMultiplyAddFloat16(p, x_2, constants.alpha_3); + p = MlasMultiplyAddFloat16(p, x_2, constants.alpha_1); + p = MlasMultiplyFloat16(p, x); - T q = MlasMultiplyAdd(constants.beta_6, x_2, constants.beta_4); - q = MlasMultiplyAdd(q, x_2, constants.beta_2); - q = MlasMultiplyAdd(q, x_2, constants.beta_0); + T q = MlasMultiplyAddFloat16(constants.beta_6, x_2, constants.beta_4); + q = MlasMultiplyAddFloat16(q, x_2, constants.beta_2); + q = MlasMultiplyAddFloat16(q, x_2, constants.beta_0); - return MlasDivide(p, q); + return MlasDivideFloat16(p, q); } void Tanh_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N) { @@ -576,8 +576,8 @@ void Softcap_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, co auto softcap4 = MlasBroadcastFloat16x4(Softcap.val); auto one8 = MlasBroadcastFloat16x8((_mlas_fp16_)0x3c00); auto one4 = MlasBroadcastFloat16x4((_mlas_fp16_)0x3c00); - auto softcap_reciprocal8 = MlasDivide(one8, softcap8); - auto softcap_reciprocal4 = MlasDivide(one4, softcap4); + auto softcap_reciprocal8 = MlasDivideFloat16(one8, softcap8); + auto softcap_reciprocal4 = MlasDivideFloat16(one4, softcap4); while (N >= 32) { auto v0 = MlasLoadFloat16x8(input); @@ -585,20 +585,20 @@ void Softcap_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, co auto v2 = MlasLoadFloat16x8(input + 16); auto v3 = MlasLoadFloat16x8(input + 24); - v0 = MlasMultiply(v0, softcap_reciprocal8); - v1 = MlasMultiply(v1, softcap_reciprocal8); - v2 = MlasMultiply(v2, softcap_reciprocal8); - v3 = MlasMultiply(v3, softcap_reciprocal8); + v0 = MlasMultiplyFloat16(v0, softcap_reciprocal8); + v1 = MlasMultiplyFloat16(v1, softcap_reciprocal8); + v2 = MlasMultiplyFloat16(v2, softcap_reciprocal8); + v3 = MlasMultiplyFloat16(v3, softcap_reciprocal8); v0 = Tanh_Vector_Fp16(v0); v1 = Tanh_Vector_Fp16(v1); v2 = Tanh_Vector_Fp16(v2); v3 = Tanh_Vector_Fp16(v3); - v0 = MlasMultiply(v0, softcap8); - v1 = MlasMultiply(v1, softcap8); - v2 = MlasMultiply(v2, softcap8); - v3 = MlasMultiply(v3, softcap8); + v0 = MlasMultiplyFloat16(v0, softcap8); + v1 = MlasMultiplyFloat16(v1, softcap8); + v2 = MlasMultiplyFloat16(v2, softcap8); + v3 = MlasMultiplyFloat16(v3, softcap8); MlasStoreFloat16x8(output, v0); MlasStoreFloat16x8(output + 8, v1); @@ -614,14 +614,14 @@ void Softcap_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, co auto v0 = MlasLoadFloat16x8(input); auto v1 = MlasLoadFloat16x8(input + 8); - v0 = MlasMultiply(v0, softcap_reciprocal8); - v1 = MlasMultiply(v1, softcap_reciprocal8); + v0 = MlasMultiplyFloat16(v0, softcap_reciprocal8); + v1 = MlasMultiplyFloat16(v1, softcap_reciprocal8); v0 = Tanh_Vector_Fp16(v0); v1 = Tanh_Vector_Fp16(v1); - v0 = MlasMultiply(v0, softcap8); - v1 = MlasMultiply(v1, softcap8); + v0 = MlasMultiplyFloat16(v0, softcap8); + v1 = MlasMultiplyFloat16(v1, softcap8); MlasStoreFloat16x8(output, v0); MlasStoreFloat16x8(output + 8, v1); @@ -633,9 +633,9 @@ void Softcap_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, co if (N & 8) { auto v0 = MlasLoadFloat16x8(input); - v0 = MlasMultiply(v0, softcap_reciprocal8); + v0 = MlasMultiplyFloat16(v0, softcap_reciprocal8); v0 = Tanh_Vector_Fp16(v0); - v0 = MlasMultiply(v0, softcap8); + v0 = MlasMultiplyFloat16(v0, softcap8); MlasStoreFloat16x8(output, v0); input += 8; @@ -645,9 +645,9 @@ void Softcap_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, co if (N & 4) { auto v0 = MlasLoadFloat16x4(input); - v0 = MlasMultiply(v0, softcap_reciprocal4); + v0 = MlasMultiplyFloat16(v0, softcap_reciprocal4); v0 = Tanh_Vector_Fp16(v0); - v0 = MlasMultiply(v0, softcap4); + v0 = MlasMultiplyFloat16(v0, softcap4); MlasStoreFloat16x4(output, v0); input += 4; @@ -657,21 +657,21 @@ void Softcap_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, co if (N == 3) { auto v0 = MlasLoadPartialFloat16x4(input, 3); - v0 = MlasMultiply(v0, softcap_reciprocal4); + v0 = MlasMultiplyFloat16(v0, softcap_reciprocal4); v0 = Tanh_Vector_Fp16(v0); - v0 = MlasMultiply(v0, softcap4); + v0 = MlasMultiplyFloat16(v0, softcap4); MlasStorePartialFloat16x4(output, v0, 3); } else if (N == 2) { auto v0 = MlasLoadPartialFloat16x4(input, 2); - v0 = MlasMultiply(v0, softcap_reciprocal4); + v0 = MlasMultiplyFloat16(v0, softcap_reciprocal4); v0 = Tanh_Vector_Fp16(v0); - v0 = MlasMultiply(v0, softcap4); + v0 = MlasMultiplyFloat16(v0, softcap4); MlasStorePartialFloat16x4(output, v0, 2); } else if (N == 1) { auto v0 = MlasLoadPartialFloat16x4(input, 1); - v0 = MlasMultiply(v0, softcap_reciprocal4); + v0 = MlasMultiplyFloat16(v0, softcap_reciprocal4); v0 = Tanh_Vector_Fp16(v0); - v0 = MlasMultiply(v0, softcap4); + v0 = MlasMultiplyFloat16(v0, softcap4); MlasStorePartialFloat16x4(output, v0, 1); } } @@ -687,10 +687,10 @@ MLAS_FP16 ReduceMax_Kernel_Fp16(const MLAS_FP16* Input, size_t N) { auto v2 = MlasLoadFloat16x8(input + 16); auto v3 = MlasLoadFloat16x8(input + 24); - v0 = MlasMaximum(v0, v1); - v2 = MlasMaximum(v2, v3); - v0 = MlasMaximum(v0, v2); - max8 = MlasMaximum(max8, v0); + v0 = MlasMaximumFloat16(v0, v1); + v2 = MlasMaximumFloat16(v2, v3); + v0 = MlasMaximumFloat16(v0, v2); + max8 = MlasMaximumFloat16(max8, v0); input += 32; N -= 32; @@ -700,8 +700,8 @@ MLAS_FP16 ReduceMax_Kernel_Fp16(const MLAS_FP16* Input, size_t N) { auto v0 = MlasLoadFloat16x8(input); auto v1 = MlasLoadFloat16x8(input + 8); - v0 = MlasMaximum(v0, v1); - max8 = MlasMaximum(max8, v0); + v0 = MlasMaximumFloat16(v0, v1); + max8 = MlasMaximumFloat16(max8, v0); input += 16; N -= 16; @@ -709,7 +709,7 @@ MLAS_FP16 ReduceMax_Kernel_Fp16(const MLAS_FP16* Input, size_t N) { if (N & 8) { auto v0 = MlasLoadFloat16x8(input); - max8 = MlasMaximum(max8, v0); + max8 = MlasMaximumFloat16(max8, v0); input += 8; N -= 8; @@ -717,7 +717,7 @@ MLAS_FP16 ReduceMax_Kernel_Fp16(const MLAS_FP16* Input, size_t N) { if (N & 4) { auto v0 = MlasLoadFloat16x4(input); - max4 = MlasMaximum(max4, v0); + max4 = MlasMaximumFloat16(max4, v0); input += 4; N -= 4; @@ -725,24 +725,24 @@ MLAS_FP16 ReduceMax_Kernel_Fp16(const MLAS_FP16* Input, size_t N) { if (N == 3) { auto v0 = MlasLoadPartialFloat16x4(input, 3); - v0 = MlasReinterpretAsFloat16(vset_lane_s16((int16_t)0xfbff, MlasReinterpretAsInt16(v0), 3)); - max4 = MlasMaximum(max4, v0); + v0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0xfbff), MlasReinterpretFloat16AsInt16(v0), 3)); + max4 = MlasMaximumFloat16(max4, v0); } else if (N == 2) { auto v0 = MlasLoadPartialFloat16x4(input, 2); - v0 = MlasReinterpretAsFloat16(vset_lane_s16((int16_t)0xfbff, MlasReinterpretAsInt16(v0), 3)); - v0 = MlasReinterpretAsFloat16(vset_lane_s16((int16_t)0xfbff, MlasReinterpretAsInt16(v0), 2)); - max4 = MlasMaximum(max4, v0); + v0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0xfbff), MlasReinterpretFloat16AsInt16(v0), 3)); + v0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0xfbff), MlasReinterpretFloat16AsInt16(v0), 2)); + max4 = MlasMaximumFloat16(max4, v0); } else if (N == 1) { auto v0 = MlasLoadPartialFloat16x4(input, 1); - v0 = MlasReinterpretAsFloat16(vset_lane_s16((int16_t)0xfbff, MlasReinterpretAsInt16(v0), 3)); - v0 = MlasReinterpretAsFloat16(vset_lane_s16((int16_t)0xfbff, MlasReinterpretAsInt16(v0), 2)); - v0 = MlasReinterpretAsFloat16(vset_lane_s16((int16_t)0xfbff, MlasReinterpretAsInt16(v0), 1)); - max4 = MlasMaximum(max4, v0); + v0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0xfbff), MlasReinterpretFloat16AsInt16(v0), 3)); + v0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0xfbff), MlasReinterpretFloat16AsInt16(v0), 2)); + v0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0xfbff), MlasReinterpretFloat16AsInt16(v0), 1)); + max4 = MlasMaximumFloat16(max4, v0); } - auto t = MlasMaximum(vget_low_f16(max8), vget_high_f16(max8)); - t = MlasMaximum(t, max4); - _mlas_fp16_ result = MlasReduceMaximum(t); + auto t = MlasMaximumFloat16(vget_low_f16(max8), vget_high_f16(max8)); + t = MlasMaximumFloat16(t, max4); + _mlas_fp16_ result = MlasReduceMaximumFloat16(t); return MLAS_FP16::FromBits(result); } @@ -752,8 +752,8 @@ void Softmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, co auto* output = reinterpret_cast<_mlas_fp16_*>(Output); auto sum8 = MlasBroadcastFloat16x8(Sum.val); auto sum4 = MlasBroadcastFloat16x4(Sum.val); - auto scale8 = MlasDivide(MlasBroadcastFloat16x8((_mlas_fp16_)0x3c00), sum8); - auto scale4 = MlasDivide(MlasBroadcastFloat16x4((_mlas_fp16_)0x3c00), sum4); + auto scale8 = MlasDivideFloat16(MlasBroadcastFloat16x8((_mlas_fp16_)0x3c00), sum8); + auto scale4 = MlasDivideFloat16(MlasBroadcastFloat16x4((_mlas_fp16_)0x3c00), sum4); while (N >= 32) { auto v0 = MlasLoadFloat16x8(input); @@ -761,10 +761,10 @@ void Softmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, co auto v2 = MlasLoadFloat16x8(input + 16); auto v3 = MlasLoadFloat16x8(input + 24); - v0 = MlasMultiply(v0, scale8); - v1 = MlasMultiply(v1, scale8); - v2 = MlasMultiply(v2, scale8); - v3 = MlasMultiply(v3, scale8); + v0 = MlasMultiplyFloat16(v0, scale8); + v1 = MlasMultiplyFloat16(v1, scale8); + v2 = MlasMultiplyFloat16(v2, scale8); + v3 = MlasMultiplyFloat16(v3, scale8); MlasStoreFloat16x8(output, v0); MlasStoreFloat16x8(output + 8, v1); @@ -780,8 +780,8 @@ void Softmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, co auto v0 = MlasLoadFloat16x8(input); auto v1 = MlasLoadFloat16x8(input + 8); - v0 = MlasMultiply(v0, scale8); - v1 = MlasMultiply(v1, scale8); + v0 = MlasMultiplyFloat16(v0, scale8); + v1 = MlasMultiplyFloat16(v1, scale8); MlasStoreFloat16x8(output, v0); MlasStoreFloat16x8(output + 8, v1); @@ -793,7 +793,7 @@ void Softmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, co if (N & 8) { auto v0 = MlasLoadFloat16x8(input); - v0 = MlasMultiply(v0, scale8); + v0 = MlasMultiplyFloat16(v0, scale8); MlasStoreFloat16x8(output, v0); input += 8; @@ -803,7 +803,7 @@ void Softmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, co if (N & 4) { auto v0 = MlasLoadFloat16x4(input); - v0 = MlasMultiply(v0, scale4); + v0 = MlasMultiplyFloat16(v0, scale4); MlasStoreFloat16x4(output, v0); input += 4; @@ -813,15 +813,15 @@ void Softmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, co if (N == 3) { auto v0 = MlasLoadPartialFloat16x4(input, 3); - v0 = MlasMultiply(v0, scale4); + v0 = MlasMultiplyFloat16(v0, scale4); MlasStorePartialFloat16x4(output, v0, 3); } else if (N == 2) { auto v0 = MlasLoadPartialFloat16x4(input, 2); - v0 = MlasMultiply(v0, scale4); + v0 = MlasMultiplyFloat16(v0, scale4); MlasStorePartialFloat16x4(output, v0, 2); } else if (N == 1) { auto v0 = MlasLoadPartialFloat16x4(input, 1); - v0 = MlasMultiply(v0, scale4); + v0 = MlasMultiplyFloat16(v0, scale4); MlasStorePartialFloat16x4(output, v0, 1); } } @@ -840,15 +840,15 @@ void LogSoftmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, auto v2 = MlasLoadFloat16x8(input + 16); auto v3 = MlasLoadFloat16x8(input + 24); - v0 = MlasAdd(v0, negative_maximum8); - v1 = MlasAdd(v1, negative_maximum8); - v2 = MlasAdd(v2, negative_maximum8); - v3 = MlasAdd(v3, negative_maximum8); + v0 = MlasAddFloat16(v0, negative_maximum8); + v1 = MlasAddFloat16(v1, negative_maximum8); + v2 = MlasAddFloat16(v2, negative_maximum8); + v3 = MlasAddFloat16(v3, negative_maximum8); - v0 = MlasSubtract(v0, log_sum8); - v1 = MlasSubtract(v1, log_sum8); - v2 = MlasSubtract(v2, log_sum8); - v3 = MlasSubtract(v3, log_sum8); + v0 = MlasSubtractFloat16(v0, log_sum8); + v1 = MlasSubtractFloat16(v1, log_sum8); + v2 = MlasSubtractFloat16(v2, log_sum8); + v3 = MlasSubtractFloat16(v3, log_sum8); MlasStoreFloat16x8(output, v0); MlasStoreFloat16x8(output + 8, v1); @@ -864,11 +864,11 @@ void LogSoftmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, auto v0 = MlasLoadFloat16x8(input); auto v1 = MlasLoadFloat16x8(input + 8); - v0 = MlasAdd(v0, negative_maximum8); - v1 = MlasAdd(v1, negative_maximum8); + v0 = MlasAddFloat16(v0, negative_maximum8); + v1 = MlasAddFloat16(v1, negative_maximum8); - v0 = MlasSubtract(v0, log_sum8); - v1 = MlasSubtract(v1, log_sum8); + v0 = MlasSubtractFloat16(v0, log_sum8); + v1 = MlasSubtractFloat16(v1, log_sum8); MlasStoreFloat16x8(output, v0); MlasStoreFloat16x8(output + 8, v1); @@ -880,8 +880,8 @@ void LogSoftmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, if (N & 8) { auto v0 = MlasLoadFloat16x8(input); - v0 = MlasAdd(v0, negative_maximum8); - v0 = MlasSubtract(v0, log_sum8); + v0 = MlasAddFloat16(v0, negative_maximum8); + v0 = MlasSubtractFloat16(v0, log_sum8); MlasStoreFloat16x8(output, v0); input += 8; @@ -891,8 +891,8 @@ void LogSoftmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, if (N & 4) { auto v0 = MlasLoadFloat16x4(input); - v0 = MlasAdd(v0, negative_maximum4); - v0 = MlasSubtract(v0, log_sum4); + v0 = MlasAddFloat16(v0, negative_maximum4); + v0 = MlasSubtractFloat16(v0, log_sum4); MlasStoreFloat16x4(output, v0); input += 4; @@ -902,18 +902,18 @@ void LogSoftmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, if (N == 3) { auto v0 = MlasLoadPartialFloat16x4(input, 3); - v0 = MlasAdd(v0, negative_maximum4); - v0 = MlasSubtract(v0, log_sum4); + v0 = MlasAddFloat16(v0, negative_maximum4); + v0 = MlasSubtractFloat16(v0, log_sum4); MlasStorePartialFloat16x4(output, v0, 3); } else if (N == 2) { auto v0 = MlasLoadPartialFloat16x4(input, 2); - v0 = MlasAdd(v0, negative_maximum4); - v0 = MlasSubtract(v0, log_sum4); + v0 = MlasAddFloat16(v0, negative_maximum4); + v0 = MlasSubtractFloat16(v0, log_sum4); MlasStorePartialFloat16x4(output, v0, 2); } else if (N == 1) { auto v0 = MlasLoadPartialFloat16x4(input, 1); - v0 = MlasAdd(v0, negative_maximum4); - v0 = MlasSubtract(v0, log_sum4); + v0 = MlasAddFloat16(v0, negative_maximum4); + v0 = MlasSubtractFloat16(v0, log_sum4); MlasStorePartialFloat16x4(output, v0, 1); } } diff --git a/onnxruntime/test/mlas/unittest/test_softmax.cpp b/onnxruntime/test/mlas/unittest/test_softmax.cpp index 763b292439..041b6c61cd 100644 --- a/onnxruntime/test/mlas/unittest/test_softmax.cpp +++ b/onnxruntime/test/mlas/unittest/test_softmax.cpp @@ -47,7 +47,7 @@ class MlasComputeExpTest : public MlasTestBase { MLAS_FP16* Input = BufferInputFp16.GetBuffer(N); MLAS_FP16* Output = BufferOutputFp16.GetBuffer(N); - std::default_random_engine generator(N); + std::default_random_engine generator(static_cast(N)); std::uniform_real_distribution distribution(MinimumValue, MaximumValue); for (size_t n = 0; n < N; n++) { @@ -73,13 +73,13 @@ class MlasComputeExpTest : public MlasTestBase { MLAS_FP16* Input = BufferInputFp16.GetBuffer(N); MLAS_FP16* Output = BufferOutputFp16.GetBuffer(N); - std::default_random_engine generator(N); + std::default_random_engine generator(static_cast(N)); std::uniform_real_distribution distribution(MinimumValue, MaximumValue); - float max_val = std::numeric_limits::min(); + float max_val = std::numeric_limits::lowest(); for (size_t n = 0; n < N; n++) { Input[n] = MLAS_FP16(distribution(generator)); - max_val = std::max(max_val, Input[n].ToFloat()); + max_val = std::fmax(max_val, Input[n].ToFloat()); } const auto* dispatch = GetMlasPlatform().SoftmaxDispatch; @@ -177,7 +177,7 @@ class MlasSoftmaxTest : public MlasTestBase { for (size_t nd = 0; nd < N; nd++) { Input[nd] = MLAS_FP16(distribution(generator)); - ref = std::max(ref, Input[nd].ToFloat()); + ref = std::fmax(ref, Input[nd].ToFloat()); } const auto* dispatch = GetMlasPlatform().SoftmaxDispatch;