fix windows build

This commit is contained in:
fajin-corp 2025-02-07 22:52:42 +00:00
parent 56b85ed252
commit 94e80536ce
6 changed files with 225 additions and 215 deletions

View file

@ -51,12 +51,12 @@ struct MLAS_HALF_ACTIVATION_FUNCTION<MlasReluActivation>
MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value)
{
return MlasMaximum(ZeroVec, Value);
return MlasMaximumFloat16(ZeroVec, Value);
}
MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value)
{
return MlasMaximum(MlasToLowHalfFloat16x4(ZeroVec), Value);
return MlasMaximumFloat16(MlasToLowHalfFloat16x4(ZeroVec), Value);
}
};
@ -75,7 +75,7 @@ struct MLAS_HALF_ACTIVATION_FUNCTION<MlasLeakyReluActivation>
MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value)
{
MLAS_FLOAT16X8 ValueTimesAlpha = MlasMultiply(Value, AlphaBroadcast);
MLAS_FLOAT16X8 ValueTimesAlpha = MlasMultiplyFloat16(Value, AlphaBroadcast);
return MlasBitwiseSelectFloat16x8(MlasCmpLessEqualFloat16x8(Value, ZeroVec),
ValueTimesAlpha, Value);
}
@ -83,7 +83,7 @@ struct MLAS_HALF_ACTIVATION_FUNCTION<MlasLeakyReluActivation>
MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value)
{
MLAS_FLOAT16X4 ValueTimesAlpha =
MlasMultiply(Value, MlasToLowHalfFloat16x4(AlphaBroadcast));
MlasMultiplyFloat16(Value, MlasToLowHalfFloat16x4(AlphaBroadcast));
return MlasBitwiseSelectFloat16x4(
MlasCmpLessEqualFloat16x4(Value, MlasToLowHalfFloat16x4(ZeroVec)), ValueTimesAlpha,
Value);
@ -539,16 +539,16 @@ struct MLAS_HALF_ACTIVATION_FUNCTION<MlasClipActivation> {
MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value)
{
Value = MlasMaximum(MinimumBroadcast, Value);
Value = MlasMinimum(MaximumBroadcast, Value);
Value = MlasMaximumFloat16(MinimumBroadcast, Value);
Value = MlasMinimumFloat16(MaximumBroadcast, Value);
return Value;
}
MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value)
{
Value = MlasMaximum(MlasToLowHalfFloat16x4(MinimumBroadcast), Value);
Value = MlasMinimum(MlasToLowHalfFloat16x4(MaximumBroadcast), Value);
Value = MlasMaximumFloat16(MlasToLowHalfFloat16x4(MinimumBroadcast), Value);
Value = MlasMinimumFloat16(MlasToLowHalfFloat16x4(MaximumBroadcast), Value);
return Value;
}
};
@ -573,19 +573,19 @@ struct MLAS_HALF_ACTIVATION_FUNCTION<MlasHardSigmoidActivation>
MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value)
{
Value = MlasMultiplyAdd(Value, AlphaBroadcast, BetaBroadcast);
Value = MlasMinimum(MaximumBroadcast, Value);
Value = MlasMaximum(MinimumBroadcast, Value);
Value = MlasMultiplyAddFloat16(Value, AlphaBroadcast, BetaBroadcast);
Value = MlasMinimumFloat16(MaximumBroadcast, Value);
Value = MlasMaximumFloat16(MinimumBroadcast, Value);
return Value;
}
MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value)
{
Value = MlasMultiplyAdd(Value, MlasToLowHalfFloat16x4(AlphaBroadcast),
Value = MlasMultiplyAddFloat16(Value, MlasToLowHalfFloat16x4(AlphaBroadcast),
MlasToLowHalfFloat16x4(BetaBroadcast));
Value = MlasMinimum(MlasToLowHalfFloat16x4(MaximumBroadcast), Value);
Value = MlasMaximum(MlasToLowHalfFloat16x4(MinimumBroadcast), Value);
Value = MlasMinimumFloat16(MlasToLowHalfFloat16x4(MaximumBroadcast), Value);
Value = MlasMaximumFloat16(MlasToLowHalfFloat16x4(MinimumBroadcast), Value);
return Value;
}
@ -692,7 +692,7 @@ MlasActivationKernel(
MLAS_FLOAT16X8 AVec = MlasLoadFloat16x8(addsrc);
MLAS_FLOAT16X8 Vector = MlasLoadFloat16x8(buffer);
addsrc += 8;
Vector = MlasAdd(Vector, AVec);
Vector = MlasAddFloat16(Vector, AVec);
Vector = ActivationFunction.Activate(Vector);
MlasStoreFloat16x8(buffer, Vector);
buffer += 8;
@ -703,7 +703,7 @@ MlasActivationKernel(
MLAS_FLOAT16X4 AVec = MlasLoadFloat16x4(addsrc);
MLAS_FLOAT16X4 Vector = MlasLoadFloat16x4(buffer);
addsrc += 4;
Vector = MlasAdd(Vector, AVec);
Vector = MlasAddFloat16(Vector, AVec);
Vector = ActivationFunction.Activate(Vector);
MlasStoreFloat16x4(buffer, Vector);
buffer += 4;
@ -715,7 +715,7 @@ MlasActivationKernel(
MLAS_FLOAT16X4 buf;
std::memcpy(&addbuf, addsrc, n * sizeof(_mlas_fp16_));
std::memcpy(&buf, buffer, n * sizeof(_mlas_fp16_));
buf = MlasAdd(buf, addbuf);
buf = MlasAddFloat16(buf, addbuf);
buf = ActivationFunction.Activate(buf);
MlasStorePartialFloat16x4(buffer, buf, n);
}

View file

@ -43,7 +43,7 @@ MlasConvDepthwiseKernel(
MLAS_FLOAT16X8 InputVector = MlasLoadFloat16x8(&Input[k][ChannelOffset]);
MLAS_FLOAT16X8 FilterVector = MlasLoadFloat16x8(&Filter[ChannelKernelOffset]);
Accumulator = MlasMultiplyAdd(InputVector, FilterVector, Accumulator);
Accumulator = MlasMultiplyAddFloat16(InputVector, FilterVector, Accumulator);
ChannelKernelOffset += Channels;
}
MlasStoreFloat16x8(Output, Accumulator);
@ -61,7 +61,7 @@ MlasConvDepthwiseKernel(
MLAS_FLOAT16X4 InputVector = MlasLoadFloat16x4(&Input[k][ChannelOffset]);
MLAS_FLOAT16X4 FilterVector = MlasLoadFloat16x4(&Filter[ChannelKernelOffset]);
Accumulator = MlasMultiplyAdd(InputVector, FilterVector, Accumulator);
Accumulator = MlasMultiplyAddFloat16(InputVector, FilterVector, Accumulator);
ChannelKernelOffset += Channels;
}
MlasStoreFloat16x4(Output, Accumulator);
@ -80,7 +80,7 @@ MlasConvDepthwiseKernel(
MLAS_FLOAT16X4 InputValue = MlasLoadFloat16x4(&Input[k][ChannelOffset]);
MLAS_FLOAT16X4 FilterValue = MlasLoadFloat16x4(&Filter[ChannelKernelOffset]);
Accumulator = MlasMultiplyAdd(InputValue, FilterValue, Accumulator);
Accumulator = MlasMultiplyAddFloat16(InputValue, FilterValue, Accumulator);
ChannelKernelOffset += Channels;
}
MlasStorePartialFloat16x4(Output, Accumulator, c);

View file

@ -32,23 +32,23 @@ typedef int16x4_t MLAS_INT16X4;
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasReinterpretAsFloat16(MLAS_INT32X4 Vector) { return vreinterpretq_f16_s32(Vector); }
MlasReinterpretInt32AsFloat16(MLAS_INT32X4 Vector) { return vreinterpretq_f16_s32(Vector); }
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasReinterpretAsFloat16(MLAS_INT16X8 Vector) { return vreinterpretq_f16_s16(Vector); }
MlasReinterpretInt16AsFloat16(MLAS_INT16X8 Vector) { return vreinterpretq_f16_s16(Vector); }
MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasReinterpretAsFloat16(MLAS_INT16X4 Vector) { return vreinterpret_f16_s16(Vector); }
MlasReinterpretInt16AsFloat16(MLAS_INT16X4 Vector) { return vreinterpret_f16_s16(Vector); }
MLAS_FORCEINLINE
MLAS_INT16X8
MlasReinterpretAsInt16(MLAS_FLOAT16X8 Vector) { return vreinterpretq_s16_f16(Vector); }
MlasReinterpretFloat16AsInt16(MLAS_FLOAT16X8 Vector) { return vreinterpretq_s16_f16(Vector); }
MLAS_FORCEINLINE
MLAS_INT16X4
MlasReinterpretAsInt16(MLAS_FLOAT16X4 Vector) { return vreinterpret_s16_f16(Vector); }
MlasReinterpretFloat16AsInt16(MLAS_FLOAT16X4 Vector) { return vreinterpret_s16_f16(Vector); }
MLAS_FORCEINLINE
MLAS_FLOAT16X8
@ -160,98 +160,98 @@ MlasToLowHalfFloat16x4(MLAS_FLOAT16X8 V)
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasAdd(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
MlasAddFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
{
return vaddq_f16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasAdd(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2)
MlasAddFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2)
{
return vadd_f16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_INT16X8
MlasAdd(MLAS_INT16X8 Vector1, MLAS_INT16X8 Vector2)
MlasAddInt16(MLAS_INT16X8 Vector1, MLAS_INT16X8 Vector2)
{
return vaddq_s16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_INT16X4
MlasAdd(MLAS_INT16X4 Vector1, MLAS_INT16X4 Vector2)
MlasAddInt16(MLAS_INT16X4 Vector1, MLAS_INT16X4 Vector2)
{
return vadd_s16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasSubtract(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
MlasSubtractFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
{
return vsubq_f16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasSubtract(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2)
MlasSubtractFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2)
{
return vsub_f16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_INT16X8
MlasSubtract(MLAS_INT16X8 Vector1, MLAS_INT16X8 Vector2)
MlasSubtractInt16(MLAS_INT16X8 Vector1, MLAS_INT16X8 Vector2)
{
return vsubq_s16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_INT16X4
MlasSubtract(MLAS_INT16X4 Vector1, MLAS_INT16X4 Vector2)
MlasSubtractInt16(MLAS_INT16X4 Vector1, MLAS_INT16X4 Vector2)
{
return vsub_s16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasMultiply(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
MlasMultiplyFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
{
return vmulq_f16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasMultiply(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2)
MlasMultiplyFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2)
{
return vmul_f16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasDivide(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
MlasDivideFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
{
return vdivq_f16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasDivide(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2)
MlasDivideFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2)
{
return vdiv_f16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasMultiplyAdd(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2, MLAS_FLOAT16X8 Dest)
MlasMultiplyAddFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2, MLAS_FLOAT16X8 Dest)
{
return vfmaq_f16(Dest, Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasMultiplyAdd(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2, MLAS_FLOAT16X4 Dest)
MlasMultiplyAddFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2, MLAS_FLOAT16X4 Dest)
{
return vfma_f16(Dest, Vector1, Vector2);
}
@ -260,14 +260,14 @@ MLAS_FORCEINLINE
void
MlasMultiplyAddFloat16x8(MLAS_FLOAT16X8 Vector1, _mlas_fp16_ Scalar2, MLAS_FLOAT16X8 Vector3)
{
MlasMultiplyAdd(Vector1, MlasBroadcastFloat16x8(Scalar2), Vector3);
MlasMultiplyAddFloat16(Vector1, MlasBroadcastFloat16x8(Scalar2), Vector3);
}
MLAS_FORCEINLINE
void
MlasMultiplyAddFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2, _mlas_fp16_ Scalar3)
{
MlasMultiplyAdd(Vector1, Vector2, MlasBroadcastFloat16x8(Scalar3));
MlasMultiplyAddFloat16(Vector1, Vector2, MlasBroadcastFloat16x8(Scalar3));
}
MLAS_FORCEINLINE
@ -315,56 +315,56 @@ MlasBlendFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2, MLAS_FLOAT16X
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasMaximum(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
MlasMaximumFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
{
return vmaxq_f16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasMaximum(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2)
MlasMaximumFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2)
{
return vmax_f16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_INT16X8
MlasMaximum(MLAS_INT16X8 Vector1, MLAS_INT16X8 Vector2)
MlasMaximumInt16(MLAS_INT16X8 Vector1, MLAS_INT16X8 Vector2)
{
return vmaxq_s16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_INT16X4
MlasMaximum(MLAS_INT16X4 Vector1, MLAS_INT16X4 Vector2)
MlasMaximumInt16(MLAS_INT16X4 Vector1, MLAS_INT16X4 Vector2)
{
return vmax_s16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasMinimum(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
MlasMinimumFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
{
return vminq_f16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasMinimum(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2)
MlasMinimumFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2)
{
return vmin_f16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_INT16X8
MlasMinimum(MLAS_INT16X8 Vector1, MLAS_INT16X8 Vector2)
MlasMinimumInt16(MLAS_INT16X8 Vector1, MLAS_INT16X8 Vector2)
{
return vminq_s16(Vector1, Vector2);
}
MLAS_FORCEINLINE
MLAS_INT16X4
MlasMinimum(MLAS_INT16X4 Vector1, MLAS_INT16X4 Vector2)
MlasMinimumInt16(MLAS_INT16X4 Vector1, MLAS_INT16X4 Vector2)
{
return vmin_s16(Vector1, Vector2);
}
@ -373,24 +373,34 @@ MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasClampFloat16x8(MLAS_FLOAT16X8 Value, _mlas_fp16_ LowerRange, _mlas_fp16_ UpperRange)
{
Value = MlasMaximum(MlasBroadcastFloat16x8(LowerRange), Value);
Value = MlasMinimum(MlasBroadcastFloat16x8(UpperRange), Value);
Value = MlasMaximumFloat16(MlasBroadcastFloat16x8(LowerRange), Value);
Value = MlasMaximumFloat16(MlasBroadcastFloat16x8(UpperRange), Value);
return Value;
}
template <typename T>
MLAS_FORCEINLINE
T
MlasClamp(T Value, T LowerRange, T UpperRange)
MlasClampFloat16(T Value, T LowerRange, T UpperRange)
{
Value = MlasMaximum(LowerRange, Value);
Value = MlasMinimum(UpperRange, Value);
Value = MlasMaximumFloat16(LowerRange, Value);
Value = MlasMinimumFloat16(UpperRange, Value);
return Value;
}
template <typename T>
MLAS_FORCEINLINE
T
MlasClampInt16(T Value, T LowerRange, T UpperRange)
{
Value = MlasMaximumInt16(LowerRange, Value);
Value = MlasMinimumInt16(UpperRange, Value);
return Value;
}
MLAS_FORCEINLINE
_mlas_fp16_
MlasReduceAdd(MLAS_FLOAT16X8 Vector)
MlasReduceAddFloat16(MLAS_FLOAT16X8 Vector)
{
Vector = vpaddq_f16(Vector, Vector);
Vector = vpaddq_f16(Vector, Vector);
@ -400,7 +410,7 @@ MlasReduceAdd(MLAS_FLOAT16X8 Vector)
MLAS_FORCEINLINE
_mlas_fp16_
MlasReduceAdd(MLAS_FLOAT16X4 Vector)
MlasReduceAddFloat16(MLAS_FLOAT16X4 Vector)
{
Vector = vpadd_f16(Vector, Vector);
Vector = vpadd_f16(Vector, Vector);
@ -409,7 +419,7 @@ MlasReduceAdd(MLAS_FLOAT16X4 Vector)
MLAS_FORCEINLINE
_mlas_fp16_
MlasReduceMaximum(MLAS_FLOAT16X8 Vector)
MlasReduceMaximumFloat16(MLAS_FLOAT16X8 Vector)
{
Vector = vpmaxq_f16(Vector, Vector);
Vector = vpmaxq_f16(Vector, Vector);
@ -419,7 +429,7 @@ MlasReduceMaximum(MLAS_FLOAT16X8 Vector)
MLAS_FORCEINLINE
_mlas_fp16_
MlasReduceMaximum(MLAS_FLOAT16X4 Vector)
MlasReduceMaximumFloat16(MLAS_FLOAT16X4 Vector)
{
Vector = vpmax_f16(Vector, Vector);
Vector = vpmax_f16(Vector, Vector);
@ -556,7 +566,7 @@ Transpose4x4(MLAS_FLOAT16X4& v0, MLAS_FLOAT16X4& v1, MLAS_FLOAT16X4& v2, MLAS_FL
template<unsigned ShiftCount>
MLAS_FORCEINLINE
MLAS_INT16X8
MlasShiftLeft(MLAS_INT16X8 Vector)
MlasShiftLeftInt16(MLAS_INT16X8 Vector)
{
return vshlq_n_s16(Vector, ShiftCount);
}
@ -564,7 +574,7 @@ MlasShiftLeft(MLAS_INT16X8 Vector)
template<unsigned ShiftCount>
MLAS_FORCEINLINE
MLAS_INT16X4
MlasShiftLeft(MLAS_INT16X4 Vector)
MlasShiftLeftInt16(MLAS_INT16X4 Vector)
{
return vshl_n_s16(Vector, ShiftCount);
}

View file

@ -84,7 +84,7 @@ MLAS_FORCEINLINE
MLAS_FLOAT16X8
PoolAggregate16x8<MaxPoolAggregation>(MLAS_FLOAT16X8 agg, MLAS_FLOAT16X8 element)
{
return MlasMaximum(agg, element);
return MlasMaximumFloat16(agg, element);
}
template<>
@ -92,7 +92,7 @@ MLAS_FORCEINLINE
MLAS_FLOAT16X4
PoolAggregate16x4<MaxPoolAggregation>(MLAS_FLOAT16X4 agg, MLAS_FLOAT16X4 element)
{
return MlasMaximum(agg, element);
return MlasMaximumFloat16(agg, element);
}
template<>
@ -144,28 +144,28 @@ template <>
MLAS_FORCEINLINE MLAS_FLOAT16X8
PoolAggregate16x8<AveragePoolAggregation>(MLAS_FLOAT16X8 agg, MLAS_FLOAT16X8 element)
{
return MlasAdd(agg, element);
return MlasAddFloat16(agg, element);
}
template <>
MLAS_FORCEINLINE MLAS_FLOAT16X4
PoolAggregate16x4<AveragePoolAggregation>(MLAS_FLOAT16X4 agg, MLAS_FLOAT16X4 element)
{
return MlasAdd(agg, element);
return MlasAddFloat16(agg, element);
}
template <>
MLAS_FORCEINLINE MLAS_FLOAT16X8
PoolSummary16x8<AveragePoolAggregation>(MLAS_FLOAT16X8 agg, MLAS_FLOAT16X8 context)
{
return MlasDivide(agg, context);
return MlasDivideFloat16(agg, context);
}
template <>
MLAS_FORCEINLINE MLAS_FLOAT16X4
PoolSummary16x4<AveragePoolAggregation>(MLAS_FLOAT16X4 agg, MLAS_FLOAT16X8 context)
{
return MlasDivide(agg, MlasToLowHalfFloat16x4(context));
return MlasDivideFloat16(agg, MlasToLowHalfFloat16x4(context));
}

View file

@ -124,40 +124,40 @@ template<typename T>
MLAS_FORCEINLINE
T Exp_Vector_Fp16(T x) {
const auto constants = Get_Exp_Constants<T>();
auto clamped_x = MlasClamp(x, constants.LowerRange, constants.UpperRange);
auto clamped_x = MlasClampFloat16(x, constants.LowerRange, constants.UpperRange);
// integral
auto biased = MlasMultiplyAdd(clamped_x, constants.Log2Reciprocal, constants.RoundingBias);
auto m = MlasSubtract(biased, constants.RoundingBias);
auto biased = MlasMultiplyAddFloat16(clamped_x, constants.Log2Reciprocal, constants.RoundingBias);
auto m = MlasSubtractFloat16(biased, constants.RoundingBias);
// residual
auto r = MlasMultiplyAdd(m, constants.Log2High, clamped_x);
r = MlasMultiplyAdd(m, constants.Log2Mid, r);
r = MlasMultiplyAdd(m, constants.Log2Low, r);
auto r = MlasMultiplyAddFloat16(m, constants.Log2High, clamped_x);
r = MlasMultiplyAddFloat16(m, constants.Log2Mid, r);
r = MlasMultiplyAddFloat16(m, constants.Log2Low, r);
// handle overflow
auto overflow = MlasShiftLeft<10>(MlasReinterpretAsInt16(biased));
auto overflow = MlasShiftLeftInt16<10>(MlasReinterpretFloat16AsInt16(biased));
auto normal = overflow;
auto minimum_exponent = MlasReinterpretAsInt16(constants.MinimumExponent);
auto maximum_exponent = MlasReinterpretAsInt16(constants.MaximumExponent);
normal = MlasClamp(normal, minimum_exponent, maximum_exponent);
auto minimum_exponent = MlasReinterpretFloat16AsInt16(constants.MinimumExponent);
auto maximum_exponent = MlasReinterpretFloat16AsInt16(constants.MaximumExponent);
normal = MlasClampInt16(normal, minimum_exponent, maximum_exponent);
overflow = MlasSubtract(overflow, normal);
overflow = MlasAdd(overflow, maximum_exponent);
normal = MlasAdd(normal, maximum_exponent);
overflow = MlasSubtractInt16(overflow, normal);
overflow = MlasAddInt16(overflow, maximum_exponent);
normal = MlasAddInt16(normal, maximum_exponent);
// polynomial approximation
auto p = MlasMultiplyAdd(constants.poly_0, r, constants.poly_1);
p = MlasMultiplyAdd(p, r, constants.poly_2);
p = MlasMultiplyAdd(p, r, constants.poly_3);
p = MlasMultiplyAdd(p, r, constants.poly_4);
p = MlasMultiplyAdd(p, r, constants.poly_56);
auto p = MlasMultiplyAddFloat16(constants.poly_0, r, constants.poly_1);
p = MlasMultiplyAddFloat16(p, r, constants.poly_2);
p = MlasMultiplyAddFloat16(p, r, constants.poly_3);
p = MlasMultiplyAddFloat16(p, r, constants.poly_4);
p = MlasMultiplyAddFloat16(p, r, constants.poly_56);
auto overflow_f = MlasReinterpretAsFloat16(overflow);
r = MlasMultiply(r, overflow_f);
p = MlasMultiplyAdd(p, r, overflow_f);
p = MlasMultiply(p, MlasReinterpretAsFloat16(normal));
auto overflow_f = MlasReinterpretInt16AsFloat16(overflow);
r = MlasMultiplyFloat16(r, overflow_f);
p = MlasMultiplyAddFloat16(p, r, overflow_f);
p = MlasMultiplyFloat16(p, MlasReinterpretInt16AsFloat16(normal));
return p;
}
@ -242,44 +242,44 @@ template<typename T>
MLAS_FORCEINLINE
T SumExp_Vector_Fp16(T x, T negative_maximum) {
const auto constants = Get_Exp_Constants<T>();
auto clamped_x = MlasMaximum(MlasAdd(x, negative_maximum), constants.LowerRangeSumExp);
auto clamped_x = MlasMaximumFloat16(MlasAddFloat16(x, negative_maximum), constants.LowerRangeSumExp);
// integral
auto biased = MlasMultiplyAdd(clamped_x, constants.Log2Reciprocal, constants.RoundingBias);
auto m = MlasSubtract(biased, constants.RoundingBias);
auto biased = MlasMultiplyAddFloat16(clamped_x, constants.Log2Reciprocal, constants.RoundingBias);
auto m = MlasSubtractFloat16(biased, constants.RoundingBias);
// residual
auto r = MlasMultiplyAdd(m, constants.Log2High, clamped_x);
r = MlasMultiplyAdd(m, constants.Log2Mid, r);
r = MlasMultiplyAdd(m, constants.Log2Low, r);
auto r = MlasMultiplyAddFloat16(m, constants.Log2High, clamped_x);
r = MlasMultiplyAddFloat16(m, constants.Log2Mid, r);
r = MlasMultiplyAddFloat16(m, constants.Log2Low, r);
// 2^m
auto normal = MlasShiftLeft<10>(MlasReinterpretAsInt16(biased));
normal = MlasAdd(normal, MlasReinterpretAsInt16(constants.MaximumExponent));
auto normal = MlasShiftLeftInt16<10>(MlasReinterpretFloat16AsInt16(biased));
normal = MlasAddInt16(normal, MlasReinterpretFloat16AsInt16(constants.MaximumExponent));
// polynomial approximation
auto p = MlasMultiplyAdd(constants.poly_0, r, constants.poly_1);
p = MlasMultiplyAdd(p, r, constants.poly_2);
p = MlasMultiplyAdd(p, r, constants.poly_3);
p = MlasMultiplyAdd(p, r, constants.poly_4);
p = MlasMultiplyAdd(p, r, constants.poly_56);
p = MlasMultiplyAdd(p, r, constants.poly_56);
auto p = MlasMultiplyAddFloat16(constants.poly_0, r, constants.poly_1);
p = MlasMultiplyAddFloat16(p, r, constants.poly_2);
p = MlasMultiplyAddFloat16(p, r, constants.poly_3);
p = MlasMultiplyAddFloat16(p, r, constants.poly_4);
p = MlasMultiplyAddFloat16(p, r, constants.poly_56);
p = MlasMultiplyAddFloat16(p, r, constants.poly_56);
p = MlasMultiply(p, MlasReinterpretAsFloat16(normal));
p = MlasMultiplyFloat16(p, MlasReinterpretInt16AsFloat16(normal));
return p;
}
MLAS_FORCEINLINE
float16x8_t AddUp(float16x8_t v0, float16x8_t v1, float16x8_t v2, float16x8_t v3, float16x8_t v4) {
v0 = MlasAdd(v0, v1);
v2 = MlasAdd(v2, v3);
return MlasAdd(MlasAdd(v0, v2), v4);
v0 = MlasAddFloat16(v0, v1);
v2 = MlasAddFloat16(v2, v3);
return MlasAddFloat16(MlasAddFloat16(v0, v2), v4);
}
MLAS_FORCEINLINE
float16x8_t AddUp(float16x8_t v0, float16x8_t v1, float16x8_t v2) {
return MlasAdd(MlasAdd(v0, v1), v2);
return MlasAddFloat16(MlasAddFloat16(v0, v1), v2);
}
MLAS_FP16 SumExp_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, const MLAS_FP16 NegativeMaximum) {
@ -338,7 +338,7 @@ MLAS_FP16 SumExp_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N
if (N & 8) {
auto v0 = MlasLoadFloat16x8(input);
auto r0 = SumExp_Vector_Fp16(v0, negative_maximum8);
accumulator8 = MlasAdd(r0, accumulator8);
accumulator8 = MlasAddFloat16(r0, accumulator8);
if (store_output) {
MlasStoreFloat16x8(output, r0);
@ -352,7 +352,7 @@ MLAS_FP16 SumExp_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N
if (N & 4) {
auto v0 = MlasLoadFloat16x4(input);
auto r0 = SumExp_Vector_Fp16(v0, negative_maximum4);
accumulator4 = MlasAdd(r0, accumulator4);
accumulator4 = MlasAddFloat16(r0, accumulator4);
if (store_output) {
MlasStoreFloat16x4(output, r0);
@ -371,8 +371,8 @@ MLAS_FP16 SumExp_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N
MlasStorePartialFloat16x4(output, r0, 3);
}
r0 = MlasReinterpretAsFloat16(vset_lane_s16((int16_t)0, MlasReinterpretAsInt16(r0), 3));
accumulator4 = MlasAdd(r0, accumulator4);
r0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast<int16_t>(0), MlasReinterpretFloat16AsInt16(r0), 3));
accumulator4 = MlasAddFloat16(r0, accumulator4);
} else if (N == 2) {
auto v0 = MlasLoadPartialFloat16x4(input, 2);
auto r0 = SumExp_Vector_Fp16(v0, negative_maximum4);
@ -381,9 +381,9 @@ MLAS_FP16 SumExp_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N
MlasStorePartialFloat16x4(output, r0, 2);
}
r0 = MlasReinterpretAsFloat16(vset_lane_s16((int16_t)0, MlasReinterpretAsInt16(r0), 3));
r0 = MlasReinterpretAsFloat16(vset_lane_s16((int16_t)0, MlasReinterpretAsInt16(r0), 2));
accumulator4 = MlasAdd(r0, accumulator4);
r0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast<int16_t>(0), MlasReinterpretFloat16AsInt16(r0), 3));
r0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast<int16_t>(0), MlasReinterpretFloat16AsInt16(r0), 2));
accumulator4 = MlasAddFloat16(r0, accumulator4);
} else if (N == 1) {
auto v0 = MlasLoadPartialFloat16x4(input, 1);
auto r0 = SumExp_Vector_Fp16(v0, negative_maximum4);
@ -392,15 +392,15 @@ MLAS_FP16 SumExp_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N
MlasStorePartialFloat16x4(output, r0, 1);
}
r0 = MlasReinterpretAsFloat16(vset_lane_s16((int16_t)0, MlasReinterpretAsInt16(r0), 3));
r0 = MlasReinterpretAsFloat16(vset_lane_s16((int16_t)0, MlasReinterpretAsInt16(r0), 2));
r0 = MlasReinterpretAsFloat16(vset_lane_s16((int16_t)0, MlasReinterpretAsInt16(r0), 1));
accumulator4 = MlasAdd(r0, accumulator4);
r0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast<int16_t>(0), MlasReinterpretFloat16AsInt16(r0), 3));
r0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast<int16_t>(0), MlasReinterpretFloat16AsInt16(r0), 2));
r0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast<int16_t>(0), MlasReinterpretFloat16AsInt16(r0), 1));
accumulator4 = MlasAddFloat16(r0, accumulator4);
}
auto t = MlasAdd(vget_low_f16(accumulator8), vget_high_f16(accumulator8));
t = MlasAdd(t, accumulator4);
_mlas_fp16_ result = MlasReduceAdd(t);
auto t = MlasAddFloat16(vget_low_f16(accumulator8), vget_high_f16(accumulator8));
t = MlasAddFloat16(t, accumulator4);
_mlas_fp16_ result = MlasReduceAddFloat16(t);
return MLAS_FP16::FromBits(result);
}
@ -478,20 +478,20 @@ template <typename T>
MLAS_FORCEINLINE
T Tanh_Vector_Fp16(T x) {
const auto constants = Get_Tanh_Constants<T>();
x = MlasClamp(x, constants.LowerRange, constants.UpperRange);
x = MlasClampFloat16(x, constants.LowerRange, constants.UpperRange);
T x_2 = MlasMultiply(x, x);
T x_2 = MlasMultiplyFloat16(x, x);
T p = MlasMultiplyAdd(constants.alpha_7, x_2, constants.alpha_5);
p = MlasMultiplyAdd(p, x_2, constants.alpha_3);
p = MlasMultiplyAdd(p, x_2, constants.alpha_1);
p = MlasMultiply(p, x);
T p = MlasMultiplyAddFloat16(constants.alpha_7, x_2, constants.alpha_5);
p = MlasMultiplyAddFloat16(p, x_2, constants.alpha_3);
p = MlasMultiplyAddFloat16(p, x_2, constants.alpha_1);
p = MlasMultiplyFloat16(p, x);
T q = MlasMultiplyAdd(constants.beta_6, x_2, constants.beta_4);
q = MlasMultiplyAdd(q, x_2, constants.beta_2);
q = MlasMultiplyAdd(q, x_2, constants.beta_0);
T q = MlasMultiplyAddFloat16(constants.beta_6, x_2, constants.beta_4);
q = MlasMultiplyAddFloat16(q, x_2, constants.beta_2);
q = MlasMultiplyAddFloat16(q, x_2, constants.beta_0);
return MlasDivide(p, q);
return MlasDivideFloat16(p, q);
}
void Tanh_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N) {
@ -576,8 +576,8 @@ void Softcap_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, co
auto softcap4 = MlasBroadcastFloat16x4(Softcap.val);
auto one8 = MlasBroadcastFloat16x8((_mlas_fp16_)0x3c00);
auto one4 = MlasBroadcastFloat16x4((_mlas_fp16_)0x3c00);
auto softcap_reciprocal8 = MlasDivide(one8, softcap8);
auto softcap_reciprocal4 = MlasDivide(one4, softcap4);
auto softcap_reciprocal8 = MlasDivideFloat16(one8, softcap8);
auto softcap_reciprocal4 = MlasDivideFloat16(one4, softcap4);
while (N >= 32) {
auto v0 = MlasLoadFloat16x8(input);
@ -585,20 +585,20 @@ void Softcap_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, co
auto v2 = MlasLoadFloat16x8(input + 16);
auto v3 = MlasLoadFloat16x8(input + 24);
v0 = MlasMultiply(v0, softcap_reciprocal8);
v1 = MlasMultiply(v1, softcap_reciprocal8);
v2 = MlasMultiply(v2, softcap_reciprocal8);
v3 = MlasMultiply(v3, softcap_reciprocal8);
v0 = MlasMultiplyFloat16(v0, softcap_reciprocal8);
v1 = MlasMultiplyFloat16(v1, softcap_reciprocal8);
v2 = MlasMultiplyFloat16(v2, softcap_reciprocal8);
v3 = MlasMultiplyFloat16(v3, softcap_reciprocal8);
v0 = Tanh_Vector_Fp16(v0);
v1 = Tanh_Vector_Fp16(v1);
v2 = Tanh_Vector_Fp16(v2);
v3 = Tanh_Vector_Fp16(v3);
v0 = MlasMultiply(v0, softcap8);
v1 = MlasMultiply(v1, softcap8);
v2 = MlasMultiply(v2, softcap8);
v3 = MlasMultiply(v3, softcap8);
v0 = MlasMultiplyFloat16(v0, softcap8);
v1 = MlasMultiplyFloat16(v1, softcap8);
v2 = MlasMultiplyFloat16(v2, softcap8);
v3 = MlasMultiplyFloat16(v3, softcap8);
MlasStoreFloat16x8(output, v0);
MlasStoreFloat16x8(output + 8, v1);
@ -614,14 +614,14 @@ void Softcap_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, co
auto v0 = MlasLoadFloat16x8(input);
auto v1 = MlasLoadFloat16x8(input + 8);
v0 = MlasMultiply(v0, softcap_reciprocal8);
v1 = MlasMultiply(v1, softcap_reciprocal8);
v0 = MlasMultiplyFloat16(v0, softcap_reciprocal8);
v1 = MlasMultiplyFloat16(v1, softcap_reciprocal8);
v0 = Tanh_Vector_Fp16(v0);
v1 = Tanh_Vector_Fp16(v1);
v0 = MlasMultiply(v0, softcap8);
v1 = MlasMultiply(v1, softcap8);
v0 = MlasMultiplyFloat16(v0, softcap8);
v1 = MlasMultiplyFloat16(v1, softcap8);
MlasStoreFloat16x8(output, v0);
MlasStoreFloat16x8(output + 8, v1);
@ -633,9 +633,9 @@ void Softcap_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, co
if (N & 8) {
auto v0 = MlasLoadFloat16x8(input);
v0 = MlasMultiply(v0, softcap_reciprocal8);
v0 = MlasMultiplyFloat16(v0, softcap_reciprocal8);
v0 = Tanh_Vector_Fp16(v0);
v0 = MlasMultiply(v0, softcap8);
v0 = MlasMultiplyFloat16(v0, softcap8);
MlasStoreFloat16x8(output, v0);
input += 8;
@ -645,9 +645,9 @@ void Softcap_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, co
if (N & 4) {
auto v0 = MlasLoadFloat16x4(input);
v0 = MlasMultiply(v0, softcap_reciprocal4);
v0 = MlasMultiplyFloat16(v0, softcap_reciprocal4);
v0 = Tanh_Vector_Fp16(v0);
v0 = MlasMultiply(v0, softcap4);
v0 = MlasMultiplyFloat16(v0, softcap4);
MlasStoreFloat16x4(output, v0);
input += 4;
@ -657,21 +657,21 @@ void Softcap_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, co
if (N == 3) {
auto v0 = MlasLoadPartialFloat16x4(input, 3);
v0 = MlasMultiply(v0, softcap_reciprocal4);
v0 = MlasMultiplyFloat16(v0, softcap_reciprocal4);
v0 = Tanh_Vector_Fp16(v0);
v0 = MlasMultiply(v0, softcap4);
v0 = MlasMultiplyFloat16(v0, softcap4);
MlasStorePartialFloat16x4(output, v0, 3);
} else if (N == 2) {
auto v0 = MlasLoadPartialFloat16x4(input, 2);
v0 = MlasMultiply(v0, softcap_reciprocal4);
v0 = MlasMultiplyFloat16(v0, softcap_reciprocal4);
v0 = Tanh_Vector_Fp16(v0);
v0 = MlasMultiply(v0, softcap4);
v0 = MlasMultiplyFloat16(v0, softcap4);
MlasStorePartialFloat16x4(output, v0, 2);
} else if (N == 1) {
auto v0 = MlasLoadPartialFloat16x4(input, 1);
v0 = MlasMultiply(v0, softcap_reciprocal4);
v0 = MlasMultiplyFloat16(v0, softcap_reciprocal4);
v0 = Tanh_Vector_Fp16(v0);
v0 = MlasMultiply(v0, softcap4);
v0 = MlasMultiplyFloat16(v0, softcap4);
MlasStorePartialFloat16x4(output, v0, 1);
}
}
@ -687,10 +687,10 @@ MLAS_FP16 ReduceMax_Kernel_Fp16(const MLAS_FP16* Input, size_t N) {
auto v2 = MlasLoadFloat16x8(input + 16);
auto v3 = MlasLoadFloat16x8(input + 24);
v0 = MlasMaximum(v0, v1);
v2 = MlasMaximum(v2, v3);
v0 = MlasMaximum(v0, v2);
max8 = MlasMaximum(max8, v0);
v0 = MlasMaximumFloat16(v0, v1);
v2 = MlasMaximumFloat16(v2, v3);
v0 = MlasMaximumFloat16(v0, v2);
max8 = MlasMaximumFloat16(max8, v0);
input += 32;
N -= 32;
@ -700,8 +700,8 @@ MLAS_FP16 ReduceMax_Kernel_Fp16(const MLAS_FP16* Input, size_t N) {
auto v0 = MlasLoadFloat16x8(input);
auto v1 = MlasLoadFloat16x8(input + 8);
v0 = MlasMaximum(v0, v1);
max8 = MlasMaximum(max8, v0);
v0 = MlasMaximumFloat16(v0, v1);
max8 = MlasMaximumFloat16(max8, v0);
input += 16;
N -= 16;
@ -709,7 +709,7 @@ MLAS_FP16 ReduceMax_Kernel_Fp16(const MLAS_FP16* Input, size_t N) {
if (N & 8) {
auto v0 = MlasLoadFloat16x8(input);
max8 = MlasMaximum(max8, v0);
max8 = MlasMaximumFloat16(max8, v0);
input += 8;
N -= 8;
@ -717,7 +717,7 @@ MLAS_FP16 ReduceMax_Kernel_Fp16(const MLAS_FP16* Input, size_t N) {
if (N & 4) {
auto v0 = MlasLoadFloat16x4(input);
max4 = MlasMaximum(max4, v0);
max4 = MlasMaximumFloat16(max4, v0);
input += 4;
N -= 4;
@ -725,24 +725,24 @@ MLAS_FP16 ReduceMax_Kernel_Fp16(const MLAS_FP16* Input, size_t N) {
if (N == 3) {
auto v0 = MlasLoadPartialFloat16x4(input, 3);
v0 = MlasReinterpretAsFloat16(vset_lane_s16((int16_t)0xfbff, MlasReinterpretAsInt16(v0), 3));
max4 = MlasMaximum(max4, v0);
v0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast<int16_t>(0xfbff), MlasReinterpretFloat16AsInt16(v0), 3));
max4 = MlasMaximumFloat16(max4, v0);
} else if (N == 2) {
auto v0 = MlasLoadPartialFloat16x4(input, 2);
v0 = MlasReinterpretAsFloat16(vset_lane_s16((int16_t)0xfbff, MlasReinterpretAsInt16(v0), 3));
v0 = MlasReinterpretAsFloat16(vset_lane_s16((int16_t)0xfbff, MlasReinterpretAsInt16(v0), 2));
max4 = MlasMaximum(max4, v0);
v0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast<int16_t>(0xfbff), MlasReinterpretFloat16AsInt16(v0), 3));
v0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast<int16_t>(0xfbff), MlasReinterpretFloat16AsInt16(v0), 2));
max4 = MlasMaximumFloat16(max4, v0);
} else if (N == 1) {
auto v0 = MlasLoadPartialFloat16x4(input, 1);
v0 = MlasReinterpretAsFloat16(vset_lane_s16((int16_t)0xfbff, MlasReinterpretAsInt16(v0), 3));
v0 = MlasReinterpretAsFloat16(vset_lane_s16((int16_t)0xfbff, MlasReinterpretAsInt16(v0), 2));
v0 = MlasReinterpretAsFloat16(vset_lane_s16((int16_t)0xfbff, MlasReinterpretAsInt16(v0), 1));
max4 = MlasMaximum(max4, v0);
v0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast<int16_t>(0xfbff), MlasReinterpretFloat16AsInt16(v0), 3));
v0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast<int16_t>(0xfbff), MlasReinterpretFloat16AsInt16(v0), 2));
v0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast<int16_t>(0xfbff), MlasReinterpretFloat16AsInt16(v0), 1));
max4 = MlasMaximumFloat16(max4, v0);
}
auto t = MlasMaximum(vget_low_f16(max8), vget_high_f16(max8));
t = MlasMaximum(t, max4);
_mlas_fp16_ result = MlasReduceMaximum(t);
auto t = MlasMaximumFloat16(vget_low_f16(max8), vget_high_f16(max8));
t = MlasMaximumFloat16(t, max4);
_mlas_fp16_ result = MlasReduceMaximumFloat16(t);
return MLAS_FP16::FromBits(result);
}
@ -752,8 +752,8 @@ void Softmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, co
auto* output = reinterpret_cast<_mlas_fp16_*>(Output);
auto sum8 = MlasBroadcastFloat16x8(Sum.val);
auto sum4 = MlasBroadcastFloat16x4(Sum.val);
auto scale8 = MlasDivide(MlasBroadcastFloat16x8((_mlas_fp16_)0x3c00), sum8);
auto scale4 = MlasDivide(MlasBroadcastFloat16x4((_mlas_fp16_)0x3c00), sum4);
auto scale8 = MlasDivideFloat16(MlasBroadcastFloat16x8((_mlas_fp16_)0x3c00), sum8);
auto scale4 = MlasDivideFloat16(MlasBroadcastFloat16x4((_mlas_fp16_)0x3c00), sum4);
while (N >= 32) {
auto v0 = MlasLoadFloat16x8(input);
@ -761,10 +761,10 @@ void Softmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, co
auto v2 = MlasLoadFloat16x8(input + 16);
auto v3 = MlasLoadFloat16x8(input + 24);
v0 = MlasMultiply(v0, scale8);
v1 = MlasMultiply(v1, scale8);
v2 = MlasMultiply(v2, scale8);
v3 = MlasMultiply(v3, scale8);
v0 = MlasMultiplyFloat16(v0, scale8);
v1 = MlasMultiplyFloat16(v1, scale8);
v2 = MlasMultiplyFloat16(v2, scale8);
v3 = MlasMultiplyFloat16(v3, scale8);
MlasStoreFloat16x8(output, v0);
MlasStoreFloat16x8(output + 8, v1);
@ -780,8 +780,8 @@ void Softmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, co
auto v0 = MlasLoadFloat16x8(input);
auto v1 = MlasLoadFloat16x8(input + 8);
v0 = MlasMultiply(v0, scale8);
v1 = MlasMultiply(v1, scale8);
v0 = MlasMultiplyFloat16(v0, scale8);
v1 = MlasMultiplyFloat16(v1, scale8);
MlasStoreFloat16x8(output, v0);
MlasStoreFloat16x8(output + 8, v1);
@ -793,7 +793,7 @@ void Softmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, co
if (N & 8) {
auto v0 = MlasLoadFloat16x8(input);
v0 = MlasMultiply(v0, scale8);
v0 = MlasMultiplyFloat16(v0, scale8);
MlasStoreFloat16x8(output, v0);
input += 8;
@ -803,7 +803,7 @@ void Softmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, co
if (N & 4) {
auto v0 = MlasLoadFloat16x4(input);
v0 = MlasMultiply(v0, scale4);
v0 = MlasMultiplyFloat16(v0, scale4);
MlasStoreFloat16x4(output, v0);
input += 4;
@ -813,15 +813,15 @@ void Softmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, co
if (N == 3) {
auto v0 = MlasLoadPartialFloat16x4(input, 3);
v0 = MlasMultiply(v0, scale4);
v0 = MlasMultiplyFloat16(v0, scale4);
MlasStorePartialFloat16x4(output, v0, 3);
} else if (N == 2) {
auto v0 = MlasLoadPartialFloat16x4(input, 2);
v0 = MlasMultiply(v0, scale4);
v0 = MlasMultiplyFloat16(v0, scale4);
MlasStorePartialFloat16x4(output, v0, 2);
} else if (N == 1) {
auto v0 = MlasLoadPartialFloat16x4(input, 1);
v0 = MlasMultiply(v0, scale4);
v0 = MlasMultiplyFloat16(v0, scale4);
MlasStorePartialFloat16x4(output, v0, 1);
}
}
@ -840,15 +840,15 @@ void LogSoftmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N,
auto v2 = MlasLoadFloat16x8(input + 16);
auto v3 = MlasLoadFloat16x8(input + 24);
v0 = MlasAdd(v0, negative_maximum8);
v1 = MlasAdd(v1, negative_maximum8);
v2 = MlasAdd(v2, negative_maximum8);
v3 = MlasAdd(v3, negative_maximum8);
v0 = MlasAddFloat16(v0, negative_maximum8);
v1 = MlasAddFloat16(v1, negative_maximum8);
v2 = MlasAddFloat16(v2, negative_maximum8);
v3 = MlasAddFloat16(v3, negative_maximum8);
v0 = MlasSubtract(v0, log_sum8);
v1 = MlasSubtract(v1, log_sum8);
v2 = MlasSubtract(v2, log_sum8);
v3 = MlasSubtract(v3, log_sum8);
v0 = MlasSubtractFloat16(v0, log_sum8);
v1 = MlasSubtractFloat16(v1, log_sum8);
v2 = MlasSubtractFloat16(v2, log_sum8);
v3 = MlasSubtractFloat16(v3, log_sum8);
MlasStoreFloat16x8(output, v0);
MlasStoreFloat16x8(output + 8, v1);
@ -864,11 +864,11 @@ void LogSoftmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N,
auto v0 = MlasLoadFloat16x8(input);
auto v1 = MlasLoadFloat16x8(input + 8);
v0 = MlasAdd(v0, negative_maximum8);
v1 = MlasAdd(v1, negative_maximum8);
v0 = MlasAddFloat16(v0, negative_maximum8);
v1 = MlasAddFloat16(v1, negative_maximum8);
v0 = MlasSubtract(v0, log_sum8);
v1 = MlasSubtract(v1, log_sum8);
v0 = MlasSubtractFloat16(v0, log_sum8);
v1 = MlasSubtractFloat16(v1, log_sum8);
MlasStoreFloat16x8(output, v0);
MlasStoreFloat16x8(output + 8, v1);
@ -880,8 +880,8 @@ void LogSoftmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N,
if (N & 8) {
auto v0 = MlasLoadFloat16x8(input);
v0 = MlasAdd(v0, negative_maximum8);
v0 = MlasSubtract(v0, log_sum8);
v0 = MlasAddFloat16(v0, negative_maximum8);
v0 = MlasSubtractFloat16(v0, log_sum8);
MlasStoreFloat16x8(output, v0);
input += 8;
@ -891,8 +891,8 @@ void LogSoftmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N,
if (N & 4) {
auto v0 = MlasLoadFloat16x4(input);
v0 = MlasAdd(v0, negative_maximum4);
v0 = MlasSubtract(v0, log_sum4);
v0 = MlasAddFloat16(v0, negative_maximum4);
v0 = MlasSubtractFloat16(v0, log_sum4);
MlasStoreFloat16x4(output, v0);
input += 4;
@ -902,18 +902,18 @@ void LogSoftmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N,
if (N == 3) {
auto v0 = MlasLoadPartialFloat16x4(input, 3);
v0 = MlasAdd(v0, negative_maximum4);
v0 = MlasSubtract(v0, log_sum4);
v0 = MlasAddFloat16(v0, negative_maximum4);
v0 = MlasSubtractFloat16(v0, log_sum4);
MlasStorePartialFloat16x4(output, v0, 3);
} else if (N == 2) {
auto v0 = MlasLoadPartialFloat16x4(input, 2);
v0 = MlasAdd(v0, negative_maximum4);
v0 = MlasSubtract(v0, log_sum4);
v0 = MlasAddFloat16(v0, negative_maximum4);
v0 = MlasSubtractFloat16(v0, log_sum4);
MlasStorePartialFloat16x4(output, v0, 2);
} else if (N == 1) {
auto v0 = MlasLoadPartialFloat16x4(input, 1);
v0 = MlasAdd(v0, negative_maximum4);
v0 = MlasSubtract(v0, log_sum4);
v0 = MlasAddFloat16(v0, negative_maximum4);
v0 = MlasSubtractFloat16(v0, log_sum4);
MlasStorePartialFloat16x4(output, v0, 1);
}
}

View file

@ -47,7 +47,7 @@ class MlasComputeExpTest : public MlasTestBase {
MLAS_FP16* Input = BufferInputFp16.GetBuffer(N);
MLAS_FP16* Output = BufferOutputFp16.GetBuffer(N);
std::default_random_engine generator(N);
std::default_random_engine generator(static_cast<unsigned>(N));
std::uniform_real_distribution<float> distribution(MinimumValue, MaximumValue);
for (size_t n = 0; n < N; n++) {
@ -73,13 +73,13 @@ class MlasComputeExpTest : public MlasTestBase {
MLAS_FP16* Input = BufferInputFp16.GetBuffer(N);
MLAS_FP16* Output = BufferOutputFp16.GetBuffer(N);
std::default_random_engine generator(N);
std::default_random_engine generator(static_cast<unsigned>(N));
std::uniform_real_distribution<float> distribution(MinimumValue, MaximumValue);
float max_val = std::numeric_limits<float>::min();
float max_val = std::numeric_limits<float>::lowest();
for (size_t n = 0; n < N; n++) {
Input[n] = MLAS_FP16(distribution(generator));
max_val = std::max(max_val, Input[n].ToFloat());
max_val = std::fmax(max_val, Input[n].ToFloat());
}
const auto* dispatch = GetMlasPlatform().SoftmaxDispatch;
@ -177,7 +177,7 @@ class MlasSoftmaxTest : public MlasTestBase {
for (size_t nd = 0; nd < N; nd++) {
Input[nd] = MLAS_FP16(distribution(generator));
ref = std::max(ref, Input[nd].ToFloat());
ref = std::fmax(ref, Input[nd].ToFloat());
}
const auto* dispatch = GetMlasPlatform().SoftmaxDispatch;