mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
ppc64le: Optimizing the MlasQLinearMulKernel() to use VSX instructions (#12051)
This commit is contained in:
parent
040c2f4517
commit
51f8456c4d
1 changed files with 141 additions and 0 deletions
|
|
@ -284,6 +284,147 @@ MlasQLinearMulKernel(
|
|||
}
|
||||
}
|
||||
|
||||
#elif defined(MLAS_VSX_INTRINSICS)
|
||||
|
||||
template<typename DataType, bool IsScalarB>
|
||||
static
|
||||
void
|
||||
MlasQLinearMulKernel(
|
||||
const DataType* InputA,
|
||||
float ScaleA,
|
||||
int32_t ZeroPointA,
|
||||
const DataType* InputB,
|
||||
float ScaleB,
|
||||
int32_t ZeroPointB,
|
||||
float ScaleC,
|
||||
int32_t ZeroPointC,
|
||||
DataType* OutputC,
|
||||
size_t N
|
||||
)
|
||||
{
|
||||
const float MinimumValue = (float)((int)std::numeric_limits<DataType>::min() - ZeroPointC);
|
||||
const float MaximumValue = (float)((int)std::numeric_limits<DataType>::max() - ZeroPointC);
|
||||
|
||||
auto ZeroPointAVector = vec_splats(int8_t(ZeroPointA));
|
||||
auto ZeroPointBVector = vec_splats(int8_t(ZeroPointB));
|
||||
auto ZeroPointCVector = vec_splats(float(ZeroPointC));
|
||||
|
||||
auto ScaleAVector = vec_splats(ScaleA);
|
||||
auto ScaleBVector = vec_splats(ScaleB);
|
||||
auto ScaleCVector = vec_splats(ScaleC);
|
||||
|
||||
auto MinimumVector = vec_splats(MinimumValue);
|
||||
auto MaximumVector = vec_splats(MaximumValue);
|
||||
|
||||
float ValueB;
|
||||
__vector float ValueBVector0;
|
||||
__vector float ValueBVector1;
|
||||
__vector float ValueBVector2;
|
||||
__vector float ValueBVector3;
|
||||
|
||||
if (IsScalarB) {
|
||||
ValueB = ScaleB * (int32_t(InputB[0]) - ZeroPointB);
|
||||
ValueBVector0 = vec_splats(ValueB);
|
||||
ValueBVector1 = vec_splats(ValueB);
|
||||
ValueBVector2 = vec_splats(ValueB);
|
||||
ValueBVector3 = vec_splats(ValueB);
|
||||
}
|
||||
|
||||
while (N >= 16) {
|
||||
auto IntegerVector = vec_sub(vec_xl(0, (int8_t *) InputA), ZeroPointAVector);
|
||||
|
||||
auto ShortVectorL = vec_unpackl(IntegerVector);
|
||||
auto ShortVectorH = vec_unpackh(IntegerVector);
|
||||
auto IntegerVector0 = vec_unpackh(ShortVectorH);
|
||||
auto IntegerVector1 = vec_unpackl(ShortVectorH);
|
||||
auto IntegerVector2 = vec_unpackh(ShortVectorL);
|
||||
auto IntegerVector3 = vec_unpackl(ShortVectorL);
|
||||
|
||||
auto ValueAVector0 = vec_mul(ScaleAVector, vec_ctf(IntegerVector0, 0));
|
||||
auto ValueAVector1 = vec_mul(ScaleAVector, vec_ctf(IntegerVector1, 0));
|
||||
auto ValueAVector2 = vec_mul(ScaleAVector, vec_ctf(IntegerVector2, 0));
|
||||
auto ValueAVector3 = vec_mul(ScaleAVector, vec_ctf(IntegerVector3, 0));
|
||||
|
||||
if (!IsScalarB) {
|
||||
IntegerVector = vec_sub(vec_xl(0, (int8_t *) InputB), ZeroPointBVector);
|
||||
|
||||
auto ShortVectorL = vec_unpackl(IntegerVector);
|
||||
auto ShortVectorH = vec_unpackh(IntegerVector);
|
||||
auto IntegerVector0 = vec_unpackh(ShortVectorH);
|
||||
auto IntegerVector1 = vec_unpackl(ShortVectorH);
|
||||
auto IntegerVector2 = vec_unpackh(ShortVectorL);
|
||||
auto IntegerVector3 = vec_unpackl(ShortVectorL);
|
||||
|
||||
ValueBVector0 = vec_mul(ScaleBVector, vec_ctf(IntegerVector0, 0));
|
||||
ValueBVector1 = vec_mul(ScaleBVector, vec_ctf(IntegerVector1, 0));
|
||||
ValueBVector2 = vec_mul(ScaleBVector, vec_ctf(IntegerVector2, 0));
|
||||
ValueBVector3 = vec_mul(ScaleBVector, vec_ctf(IntegerVector3, 0));
|
||||
}
|
||||
|
||||
auto ValueCVector0 = vec_div(vec_mul(ValueAVector0, ValueBVector0), ScaleCVector);
|
||||
auto ValueCVector1 = vec_div(vec_mul(ValueAVector1, ValueBVector1), ScaleCVector);
|
||||
auto ValueCVector2 = vec_div(vec_mul(ValueAVector2, ValueBVector2), ScaleCVector);
|
||||
auto ValueCVector3 = vec_div(vec_mul(ValueAVector3, ValueBVector3), ScaleCVector);
|
||||
|
||||
ValueCVector0 = vec_min(vec_max(ValueCVector0, MinimumVector), MaximumVector);
|
||||
ValueCVector1 = vec_min(vec_max(ValueCVector1, MinimumVector), MaximumVector);
|
||||
ValueCVector2 = vec_min(vec_max(ValueCVector2, MinimumVector), MaximumVector);
|
||||
ValueCVector3 = vec_min(vec_max(ValueCVector3, MinimumVector), MaximumVector);
|
||||
|
||||
ValueCVector0 = vec_nearbyint(vec_add(ValueCVector0, ZeroPointCVector));
|
||||
ValueCVector1 = vec_nearbyint(vec_add(ValueCVector1, ZeroPointCVector));
|
||||
ValueCVector2 = vec_nearbyint(vec_add(ValueCVector2, ZeroPointCVector));
|
||||
ValueCVector3 = vec_nearbyint(vec_add(ValueCVector3, ZeroPointCVector));
|
||||
|
||||
auto IntegerValueCVector0 = vec_signed(ValueCVector0);
|
||||
auto IntegerValueCVector1 = vec_signed(ValueCVector1);
|
||||
auto IntegerValueCVector2 = vec_signed(ValueCVector2);
|
||||
auto IntegerValueCVector3 = vec_signed(ValueCVector3);
|
||||
|
||||
auto ShortVector0 = vec_pack(IntegerValueCVector0, IntegerValueCVector1);
|
||||
auto ShortVector1 = vec_pack(IntegerValueCVector2, IntegerValueCVector3);
|
||||
auto CharVector = vec_pack(ShortVector0, ShortVector1);
|
||||
|
||||
vec_xst(CharVector, 0, (int8_t *) OutputC);
|
||||
|
||||
OutputC += 16;
|
||||
InputA += 16;
|
||||
InputB += 16;
|
||||
|
||||
N -= 16;
|
||||
|
||||
// Suppress wrong GCC warnings
|
||||
MLAS_UNREFERENCED_PARAMETER(ValueAVector0);
|
||||
MLAS_UNREFERENCED_PARAMETER(ValueAVector1);
|
||||
MLAS_UNREFERENCED_PARAMETER(ValueAVector2);
|
||||
MLAS_UNREFERENCED_PARAMETER(ValueAVector3);
|
||||
}
|
||||
|
||||
while (N > 0) {
|
||||
float ValueA = ScaleA * (int32_t(*InputA) - ZeroPointA);
|
||||
if (!IsScalarB) {
|
||||
ValueB = ScaleB * (int32_t(*InputB) - ZeroPointB);
|
||||
}
|
||||
float ValueC = (ValueA * ValueB) / ScaleC;
|
||||
ValueC = std::min(std::max(ValueC, MinimumValue), MaximumValue);
|
||||
|
||||
*OutputC = (DataType)(int32_t)std::nearbyintf(ValueC + ZeroPointC);
|
||||
|
||||
InputA++;
|
||||
InputB++;
|
||||
OutputC++;
|
||||
N--;
|
||||
}
|
||||
|
||||
// Suppress wrong GCC warnings
|
||||
MLAS_UNREFERENCED_PARAMETER(ScaleAVector);
|
||||
MLAS_UNREFERENCED_PARAMETER(ScaleBVector);
|
||||
MLAS_UNREFERENCED_PARAMETER(ValueBVector0);
|
||||
MLAS_UNREFERENCED_PARAMETER(ValueBVector1);
|
||||
MLAS_UNREFERENCED_PARAMETER(ValueBVector2);
|
||||
MLAS_UNREFERENCED_PARAMETER(ValueBVector3);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
// Pure C++ implementation.
|
||||
|
|
|
|||
Loading…
Reference in a new issue