mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
Fix avx2 load 32 bytes buffer overrun. (#4455)
* Fix avx2 load 32 bytes buffer overrun. * Fix qladd buffer overrun for sse2 code. * Fix QLinearAdd buffer overrun for arm. * Add mlas test for qladd to cover overrun and more. * Change API to save binary space. Add more test in mlas to cover different zeropoints.
This commit is contained in:
parent
d4db83858b
commit
ccbf49e59f
7 changed files with 344 additions and 169 deletions
|
|
@ -118,21 +118,21 @@ Status QLinearAdd<T>::Compute(OpKernelContext* context) const {
|
|||
*context,
|
||||
[](gsl::span<T> output, const T& input0, gsl::span<const T> input1,
|
||||
float A_scale, float B_scale, float C_scale, T A_zero_point, T B_zero_point, T C_zero_point) {
|
||||
MlasQLinearAdd(&input0, A_scale, A_zero_point,
|
||||
input1.data(), B_scale, B_zero_point,
|
||||
C_scale, C_zero_point, output.data(), 1, output.size());
|
||||
MlasQLinearAdd(input1.data(), B_scale, B_zero_point,
|
||||
&input0, A_scale, A_zero_point,
|
||||
C_scale, C_zero_point, output.data(), output.size(), true);
|
||||
},
|
||||
[](gsl::span<T> output, gsl::span<const T> input0, const T& input1,
|
||||
float A_scale, float B_scale, float C_scale, T A_zero_point, T B_zero_point, T C_zero_point) {
|
||||
MlasQLinearAdd(input0.data(), A_scale, A_zero_point,
|
||||
&input1, B_scale, B_zero_point,
|
||||
C_scale, C_zero_point, output.data(), output.size(), 1);
|
||||
C_scale, C_zero_point, output.data(), output.size(), true);
|
||||
},
|
||||
[](gsl::span<T> output, gsl::span<const T> input0, gsl::span<const T> input1,
|
||||
float A_scale, float B_scale, float C_scale, T A_zero_point, T B_zero_point, T C_zero_point) {
|
||||
MlasQLinearAdd(input0.data(), A_scale, A_zero_point,
|
||||
input1.data(), B_scale, B_zero_point,
|
||||
C_scale, C_zero_point, output.data(), output.size(), output.size());
|
||||
C_scale, C_zero_point, output.data(), output.size(), false);
|
||||
},
|
||||
1.0);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -535,7 +535,8 @@ MlasFindMinMaxElement(
|
|||
);
|
||||
|
||||
//
|
||||
// LengthA == LengthB, or (LengthA == 1 or LengthB == 1), broadcasting semantic
|
||||
// InputA is of size N,
|
||||
// Input B is of size 1 if IsScalarB == true, otherwise it is of size N
|
||||
//
|
||||
template<typename DataType>
|
||||
void
|
||||
|
|
@ -550,6 +551,6 @@ MlasQLinearAdd(
|
|||
float ScaleC,
|
||||
int32_t ZeroPointC,
|
||||
DataType* OutputC,
|
||||
size_t LengthA,
|
||||
size_t LengthB
|
||||
size_t N,
|
||||
bool IsScalarB
|
||||
);
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ Abstract:
|
|||
--*/
|
||||
|
||||
#include "../../mlasi.h"
|
||||
#include "../../qladd.h"
|
||||
|
||||
template <typename DataType>
|
||||
MLAS_FORCEINLINE
|
||||
|
|
@ -75,7 +76,21 @@ MlasPackS16_256<int8_t>(
|
|||
return _mm256_packs_epi16(a, b);
|
||||
}
|
||||
|
||||
template <typename DataType, bool IsScalarA, bool IsScalarB>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
__m256i
|
||||
MlasLoad32Bytes(const uint8_t* buffer, int64_t N)
|
||||
{
|
||||
if (N >= 32) {
|
||||
return _mm256_lddqu_si256((const __m256i*)buffer);
|
||||
} else {
|
||||
uint8_t dup[32];
|
||||
MlasCopyTailBytes(dup, buffer, (size_t)N);
|
||||
return _mm256_lddqu_si256((const __m256i*)dup);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType, bool IsScalarB>
|
||||
static
|
||||
void
|
||||
MlasQLinearAddKernelAvx2Helper(
|
||||
|
|
@ -97,10 +112,6 @@ MlasQLinearAddKernelAvx2Helper(
|
|||
const __m256 VectorScaleRatio_BC = _mm256_set1_ps(ScaleRatio_BC);
|
||||
__m256 VectorFixedPart = _mm256_set1_ps((float)ZeroPointC - (ScaleRatio_AC * ZeroPointA + ScaleRatio_BC * ZeroPointB));
|
||||
|
||||
if (IsScalarA) {
|
||||
const auto va_f32x8 = _mm256_set1_ps((float)(int32_t)*InputA);
|
||||
VectorFixedPart = _mm256_add_ps(VectorFixedPart, _mm256_mul_ps(va_f32x8, VectorScaleRatio_AC));
|
||||
}
|
||||
if (IsScalarB) {
|
||||
const auto vb_f32x8 = _mm256_set1_ps((float)(int32_t)*InputB);
|
||||
VectorFixedPart = _mm256_add_ps(VectorFixedPart, _mm256_mul_ps(vb_f32x8, VectorScaleRatio_BC));
|
||||
|
|
@ -110,28 +121,16 @@ MlasQLinearAddKernelAvx2Helper(
|
|||
__m256i vc = _mm256_setzero_si256();
|
||||
while (n > 0) {
|
||||
__m256i va_i8x32, vb_i8x32;
|
||||
if (!IsScalarA) {
|
||||
va_i8x32 = _mm256_lddqu_si256((const __m256i*)InputA);
|
||||
InputA += 32;
|
||||
}
|
||||
va_i8x32 = MlasLoad32Bytes((const uint8_t*)InputA, n);
|
||||
InputA += 32;
|
||||
|
||||
if (!IsScalarB) {
|
||||
vb_i8x32 = _mm256_lddqu_si256((const __m256i*)InputB);
|
||||
vb_i8x32 = MlasLoad32Bytes((const uint8_t*)InputB, n);
|
||||
InputB += 32;
|
||||
}
|
||||
|
||||
__m256 lolo_f32x8, lohi_f32x8, hilo_f32x8, hihi_f32x8;
|
||||
if (IsScalarA) {
|
||||
const auto blo_i16x16 = _mm256_unpacklo_epi8(vb_i8x32, vb_i8x32);
|
||||
const auto bhi_i16x16 = _mm256_unpackhi_epi8(vb_i8x32, vb_i8x32);
|
||||
lolo_f32x8 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32<DataType>(_mm256_unpacklo_epi16(blo_i16x16, blo_i16x16)));
|
||||
lohi_f32x8 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32<DataType>(_mm256_unpackhi_epi16(blo_i16x16, blo_i16x16)));
|
||||
hilo_f32x8 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32<DataType>(_mm256_unpacklo_epi16(bhi_i16x16, bhi_i16x16)));
|
||||
hihi_f32x8 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32<DataType>(_mm256_unpackhi_epi16(bhi_i16x16, bhi_i16x16)));
|
||||
lolo_f32x8 = _mm256_fmadd_ps(lolo_f32x8, VectorScaleRatio_BC, VectorFixedPart);
|
||||
lohi_f32x8 = _mm256_fmadd_ps(lohi_f32x8, VectorScaleRatio_BC, VectorFixedPart);
|
||||
hilo_f32x8 = _mm256_fmadd_ps(hilo_f32x8, VectorScaleRatio_BC, VectorFixedPart);
|
||||
hihi_f32x8 = _mm256_fmadd_ps(hihi_f32x8, VectorScaleRatio_BC, VectorFixedPart);
|
||||
} else if (IsScalarB) {
|
||||
if (IsScalarB) {
|
||||
const auto alo_i16x16 = _mm256_unpacklo_epi8(va_i8x32, va_i8x32);
|
||||
const auto ahi_i16x16 = _mm256_unpackhi_epi8(va_i8x32, va_i8x32);
|
||||
lolo_f32x8 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32<DataType>(_mm256_unpacklo_epi16(alo_i16x16, alo_i16x16)));
|
||||
|
|
@ -210,19 +209,16 @@ MlasQLinearAddS8KernelAvx2(
|
|||
float ScaleC,
|
||||
int32_t ZeroPointC,
|
||||
int8_t* OutputC,
|
||||
size_t LengthA,
|
||||
size_t LengthB
|
||||
size_t N,
|
||||
bool IsScalarB
|
||||
)
|
||||
{
|
||||
if (LengthA == 1) {
|
||||
MlasQLinearAddKernelAvx2Helper<int8_t, true, false>(
|
||||
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthB);
|
||||
} else if (LengthB == 1) {
|
||||
MlasQLinearAddKernelAvx2Helper<int8_t, false, true>(
|
||||
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthA);
|
||||
if (IsScalarB) {
|
||||
MlasQLinearAddKernelAvx2Helper<int8_t, true>(
|
||||
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N);
|
||||
} else {
|
||||
MlasQLinearAddKernelAvx2Helper<int8_t, false, false>(
|
||||
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthA);
|
||||
MlasQLinearAddKernelAvx2Helper<int8_t, false>(
|
||||
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -238,18 +234,15 @@ MlasQLinearAddU8KernelAvx2(
|
|||
float ScaleC,
|
||||
int32_t ZeroPointC,
|
||||
uint8_t* OutputC,
|
||||
size_t LengthA,
|
||||
size_t LengthB
|
||||
size_t N,
|
||||
bool IsScalarB
|
||||
)
|
||||
{
|
||||
if (LengthA == 1) {
|
||||
MlasQLinearAddKernelAvx2Helper<uint8_t, true, false>(
|
||||
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthB);
|
||||
} else if (LengthB == 1) {
|
||||
MlasQLinearAddKernelAvx2Helper<uint8_t, false, true>(
|
||||
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthA);
|
||||
if (IsScalarB) {
|
||||
MlasQLinearAddKernelAvx2Helper<uint8_t, true>(
|
||||
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N);
|
||||
} else {
|
||||
MlasQLinearAddKernelAvx2Helper<uint8_t, false, false>(
|
||||
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthA);
|
||||
MlasQLinearAddKernelAvx2Helper<uint8_t, false>(
|
||||
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -485,8 +485,8 @@ void
|
|||
float ScaleC,
|
||||
int32_t ZeroPointC,
|
||||
int8_t* OutputC,
|
||||
size_t LengthA,
|
||||
size_t LengthB
|
||||
size_t N,
|
||||
bool IsScalarB
|
||||
);
|
||||
|
||||
typedef MLAS_QLINEAR_BINARY_OP_S8_KERNEL* PMLAS_QLINEAR_BINARY_OP_S8_KERNEL;
|
||||
|
|
@ -503,8 +503,8 @@ void
|
|||
float ScaleC,
|
||||
int32_t ZeroPointC,
|
||||
uint8_t* OutputC,
|
||||
size_t LengthA,
|
||||
size_t LengthB
|
||||
size_t N,
|
||||
bool IsScalarB
|
||||
);
|
||||
|
||||
typedef MLAS_QLINEAR_BINARY_OP_U8_KERNEL* PMLAS_QLINEAR_BINARY_OP_U8_KERNEL;
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ MlasCalcQLinearAddParameters(
|
|||
}
|
||||
|
||||
// Pure C++ helper, back off here in rare case.
|
||||
template<typename DataType, bool IsScalarA, bool IsScalarB>
|
||||
template<typename DataType, bool IsScalarB>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
|
|
@ -69,20 +69,14 @@ MlasQLinearAddKernelRawHelper(
|
|||
const float MinimumValue = (float)((int)std::numeric_limits<DataType>::min() - ZeroPointC);
|
||||
const float MaximumValue = (float)((int)std::numeric_limits<DataType>::max() - ZeroPointC);
|
||||
|
||||
float ValueA;
|
||||
float ValueB;
|
||||
|
||||
if (IsScalarA) {
|
||||
ValueA = ScaleA * (int32_t(InputA[0]) - ZeroPointA);
|
||||
}
|
||||
if (IsScalarB) {
|
||||
ValueB = ScaleB * (int32_t(InputB[0]) - ZeroPointB);
|
||||
}
|
||||
|
||||
for (size_t n = 0; n < N; n++) {
|
||||
if (!IsScalarA) {
|
||||
ValueA = ScaleA * (int32_t(InputA[n]) - ZeroPointA);
|
||||
}
|
||||
float ValueA = ScaleA * (int32_t(InputA[n]) - ZeroPointA);
|
||||
if (!IsScalarB) {
|
||||
ValueB = ScaleB * (int32_t(InputB[n]) - ZeroPointB);
|
||||
}
|
||||
|
|
@ -320,7 +314,7 @@ public:
|
|||
|
||||
#endif
|
||||
|
||||
template<typename DataType, bool IsScalarA, bool IsScalarB>
|
||||
template<typename DataType, bool IsScalarB>
|
||||
static
|
||||
void
|
||||
MlasQLinearAddKernelHelper(
|
||||
|
|
@ -342,7 +336,7 @@ MlasQLinearAddKernelHelper(
|
|||
const float ScaleRatio_AC = ScaleA / ScaleC;
|
||||
const float ScaleRatio_BC = ScaleB / ScaleC;
|
||||
if (!MlasCalcQLinearAddParameters(ScaleRatio_AC, ScaleRatio_BC, Shift, MultiplierA, MultiplierB)) {
|
||||
MlasQLinearAddKernelRawHelper<DataType,IsScalarA, IsScalarB>(
|
||||
MlasQLinearAddKernelRawHelper<DataType, IsScalarB>(
|
||||
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N);
|
||||
return;
|
||||
}
|
||||
|
|
@ -356,40 +350,17 @@ MlasQLinearAddKernelHelper(
|
|||
const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
|
||||
|
||||
int32x4_t vscalar;
|
||||
if (IsScalarA) {
|
||||
const typename SUI::i8x8_t VectorA0 = SUI::vmov_n_i8(*InputA);
|
||||
const int16x8_t va_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(VectorA0, VectorZeroPointA));
|
||||
vscalar = vmulq_s32(vmovl_s16(vget_low_s16(va_s16x8)), VectorMultiplierA);
|
||||
}
|
||||
if (IsScalarB) {
|
||||
const typename SUI::i8x8_t VectorB0 = SUI::vmov_n_i8(*InputB);
|
||||
const int16x8_t vb_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(VectorB0, VectorZeroPointB));
|
||||
vscalar = vmulq_s32(vmovl_s16(vget_low_s16(vb_s16x8)), VectorMultiplierB);
|
||||
}
|
||||
auto n = static_cast<int64_t>(N);
|
||||
|
||||
#if defined(MLAS_NEON64_INTRINSICS)
|
||||
|
||||
while (n >= 32) {
|
||||
while (N >= 32) {
|
||||
int32x4_t vacc0_lo, vacc0_hi, vacc1_lo, vacc1_hi, vacc2_lo, vacc2_hi, vacc3_lo, vacc3_hi;
|
||||
if (IsScalarA) {
|
||||
const typename SUI::i8x16_t VectorB0 = SUI::vld1q_i8(InputB);
|
||||
const typename SUI::i8x16_t VectorB1 = SUI::vld1q_i8(InputB + 16);
|
||||
InputB += 32;
|
||||
const int16x8_t vb0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorB0), VectorZeroPointB));
|
||||
const int16x8_t vb1_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorB0), VectorZeroPointB));
|
||||
const int16x8_t vb2_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorB1), VectorZeroPointB));
|
||||
const int16x8_t vb3_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorB1), VectorZeroPointB));
|
||||
|
||||
vacc0_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(vb0_s16x8)), VectorMultiplierB);
|
||||
vacc1_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(vb1_s16x8)), VectorMultiplierB);
|
||||
vacc2_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(vb2_s16x8)), VectorMultiplierB);
|
||||
vacc3_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(vb3_s16x8)), VectorMultiplierB);
|
||||
vacc0_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(vb0_s16x8), VectorMultiplierB);
|
||||
vacc1_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(vb1_s16x8), VectorMultiplierB);
|
||||
vacc2_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(vb2_s16x8), VectorMultiplierB);
|
||||
vacc3_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(vb3_s16x8), VectorMultiplierB);
|
||||
} else if (IsScalarB) {
|
||||
if (IsScalarB) {
|
||||
const typename SUI::i8x16_t VectorA0 = SUI::vld1q_i8(InputA);
|
||||
const typename SUI::i8x16_t VectorA1 = SUI::vld1q_i8(InputA + 16);
|
||||
InputA += 32;
|
||||
|
|
@ -470,26 +441,15 @@ MlasQLinearAddKernelHelper(
|
|||
|
||||
SUI::vst1q_i8(OutputC, vc0);
|
||||
SUI::vst1q_i8(OutputC + 16, vc1);
|
||||
n -= 32;
|
||||
N -= 32;
|
||||
OutputC += 32;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
typename SUI::i8x16_t vc;
|
||||
while (n > 0) {
|
||||
while (N >= 16) {
|
||||
int32x4_t vacc0_lo, vacc1_lo, vacc0_hi, vacc1_hi;
|
||||
if (IsScalarA) {
|
||||
const typename SUI::i8x16_t VectorB0 = SUI::vld1q_i8(InputB);
|
||||
InputB += 16;
|
||||
const int16x8_t vb0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorB0), VectorZeroPointB));
|
||||
const int16x8_t vb1_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorB0), VectorZeroPointB));
|
||||
|
||||
vacc0_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(vb0_s16x8)), VectorMultiplierB);
|
||||
vacc1_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(vb1_s16x8)), VectorMultiplierB);
|
||||
vacc0_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(vb0_s16x8), VectorMultiplierB);
|
||||
vacc1_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(vb1_s16x8), VectorMultiplierB);
|
||||
} else if (IsScalarB) {
|
||||
if (IsScalarB) {
|
||||
const typename SUI::i8x16_t VectorA0 = SUI::vld1q_i8(InputA);
|
||||
InputA += 16;
|
||||
const int16x8_t va0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorA0), VectorZeroPointA));
|
||||
|
|
@ -533,34 +493,86 @@ MlasQLinearAddKernelHelper(
|
|||
// Pack, saturate, and add output zero point.
|
||||
const int16x8_t vacc0 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0_lo), vqmovn_s32(vacc0_hi)), VectorZeroPointC);
|
||||
const int16x8_t vacc1 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1_lo), vqmovn_s32(vacc1_hi)), VectorZeroPointC);
|
||||
vc = SUI::combine_i8_s16(vacc0, vacc1);
|
||||
|
||||
n -= 16;
|
||||
if (n < 0) break;
|
||||
typename SUI::i8x16_t vc = SUI::combine_i8_s16(vacc0, vacc1);
|
||||
|
||||
N -= 16;
|
||||
SUI::vst1q_i8(OutputC, vc);
|
||||
OutputC += 16;
|
||||
}
|
||||
|
||||
if (n < 0) {
|
||||
n += 16;
|
||||
if (N > 0) {
|
||||
typename SUI::T TailDataA[16] = { 0 };
|
||||
typename SUI::T TailDataB[16] = { 0 };
|
||||
|
||||
MlasCopyTailBytes((uint8_t*)TailDataA, (const uint8_t*)InputA, N);
|
||||
if (!IsScalarB) {
|
||||
MlasCopyTailBytes((uint8_t*)TailDataB, (const uint8_t*)InputB, N);
|
||||
}
|
||||
|
||||
int32x4_t vacc0_lo, vacc1_lo, vacc0_hi, vacc1_hi;
|
||||
if (IsScalarB) {
|
||||
const typename SUI::i8x16_t VectorA0 = SUI::vld1q_i8(TailDataA);
|
||||
InputA += 16;
|
||||
const int16x8_t va0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorA0), VectorZeroPointA));
|
||||
const int16x8_t va1_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorA0), VectorZeroPointA));
|
||||
|
||||
vacc0_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(va0_s16x8)), VectorMultiplierA);
|
||||
vacc1_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(va1_s16x8)), VectorMultiplierA);
|
||||
vacc0_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(va0_s16x8), VectorMultiplierA);
|
||||
vacc1_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(va1_s16x8), VectorMultiplierA);
|
||||
} else {
|
||||
const typename SUI::i8x16_t VectorA0 = SUI::vld1q_i8(TailDataA);
|
||||
const typename SUI::i8x16_t VectorB0 = SUI::vld1q_i8(TailDataB);
|
||||
InputA += 16;
|
||||
InputB += 16;
|
||||
const int16x8_t va0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorA0), VectorZeroPointA));
|
||||
const int16x8_t vb0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorB0), VectorZeroPointB));
|
||||
const int16x8_t va1_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorA0), VectorZeroPointA));
|
||||
const int16x8_t vb1_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorB0), VectorZeroPointB));
|
||||
|
||||
vacc0_lo = vmulq_s32(vmovl_s16(vget_low_s16(va0_s16x8)), VectorMultiplierA);
|
||||
vacc1_lo = vmulq_s32(vmovl_s16(vget_low_s16(va1_s16x8)), VectorMultiplierA);
|
||||
vacc0_hi = vmulq_s32(MlasMoveHighS16S32(va0_s16x8), VectorMultiplierA);
|
||||
vacc1_hi = vmulq_s32(MlasMoveHighS16S32(va1_s16x8), VectorMultiplierA);
|
||||
|
||||
vacc0_lo = vmlaq_s32(vacc0_lo, vmovl_s16(vget_low_s16(vb0_s16x8)), VectorMultiplierB);
|
||||
vacc1_lo = vmlaq_s32(vacc1_lo, vmovl_s16(vget_low_s16(vb1_s16x8)), VectorMultiplierB);
|
||||
vacc0_hi = vmlaq_s32(vacc0_hi, MlasMoveHighS16S32(vb0_s16x8), VectorMultiplierB);
|
||||
vacc1_hi = vmlaq_s32(vacc1_hi, MlasMoveHighS16S32(vb1_s16x8), VectorMultiplierB);
|
||||
}
|
||||
|
||||
vacc0_lo = vsraq_n_s32(vacc0_lo, vbicq_s32(vacc0_lo, vzero_shift_mask), 31);
|
||||
vacc1_lo = vsraq_n_s32(vacc1_lo, vbicq_s32(vacc1_lo, vzero_shift_mask), 31);
|
||||
vacc0_hi = vsraq_n_s32(vacc0_hi, vbicq_s32(vacc0_hi, vzero_shift_mask), 31);
|
||||
vacc1_hi = vsraq_n_s32(vacc1_hi, vbicq_s32(vacc1_hi, vzero_shift_mask), 31);
|
||||
|
||||
vacc0_lo = vrshlq_s32(vacc0_lo, vright_shift);
|
||||
vacc1_lo = vrshlq_s32(vacc1_lo, vright_shift);
|
||||
vacc0_hi = vrshlq_s32(vacc0_hi, vright_shift);
|
||||
vacc1_hi = vrshlq_s32(vacc1_hi, vright_shift);
|
||||
|
||||
// Pack, saturate, and add output zero point.
|
||||
const int16x8_t vacc0 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0_lo), vqmovn_s32(vacc0_hi)), VectorZeroPointC);
|
||||
const int16x8_t vacc1 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1_lo), vqmovn_s32(vacc1_hi)), VectorZeroPointC);
|
||||
typename SUI::i8x16_t vc = SUI::combine_i8_s16(vacc0, vacc1);
|
||||
|
||||
typename SUI::i8x8_t i8x8 = SUI::vget_low_i8(vc);
|
||||
if (n & 8) {
|
||||
if (N & 8) {
|
||||
SUI::vst1_i8(OutputC, i8x8);
|
||||
OutputC += 8;
|
||||
i8x8 = SUI::vget_high_i8(vc);
|
||||
}
|
||||
if (n & 4) {
|
||||
if (N & 4) {
|
||||
vst1_lane_u32_ex((uint32_t*)OutputC, SUI::vreinterpret_u32_i8(i8x8), 0, 8);
|
||||
OutputC += 4;
|
||||
i8x8 = SUI::template vext_i8<4>(i8x8, i8x8);
|
||||
}
|
||||
if (n & 2) {
|
||||
if (N & 2) {
|
||||
vst1_lane_u16_ex((uint16_t*)OutputC, SUI::vreinterpret_u16_i8(i8x8), 0, 8);
|
||||
OutputC += 2;
|
||||
i8x8 = SUI::template vext_i8<2>(i8x8, i8x8);
|
||||
}
|
||||
if (n & 1) {
|
||||
if (N & 1) {
|
||||
SUI::template vst1_lane_i8<0>(OutputC, i8x8);
|
||||
}
|
||||
}
|
||||
|
|
@ -626,7 +638,7 @@ MlasPackS16_128<int8_t>(
|
|||
return _mm_packs_epi16(a, b);
|
||||
}
|
||||
|
||||
template<typename DataType, bool IsScalarA, bool IsScalarB>
|
||||
template<typename DataType, bool IsScalarB>
|
||||
static
|
||||
void
|
||||
MlasQLinearAddKernelHelper(
|
||||
|
|
@ -649,25 +661,18 @@ MlasQLinearAddKernelHelper(
|
|||
auto VectorFixedPart = MlasBroadcastFloat32x4((float)ZeroPointC - (ScaleRatio_AC * ZeroPointA + ScaleRatio_BC * ZeroPointB));
|
||||
|
||||
MLAS_FLOAT32X4 va_lo, va_hi, vb_lo, vb_hi;
|
||||
if (IsScalarA) {
|
||||
va_lo = _mm_set1_ps((float)*InputA);
|
||||
VectorFixedPart = _mm_add_ps(VectorFixedPart, _mm_mul_ps(va_lo, VectorScaleRatio_AC));
|
||||
}
|
||||
if (IsScalarB) {
|
||||
vb_lo = _mm_set1_ps((float)*InputB);
|
||||
VectorFixedPart = _mm_add_ps(VectorFixedPart, _mm_mul_ps(vb_lo, VectorScaleRatio_BC));
|
||||
}
|
||||
|
||||
auto n = static_cast<int64_t>(N);
|
||||
MLAS_INT32X4 vc = _mm_setzero_si128();
|
||||
while (n > 0) {
|
||||
if (!IsScalarA) {
|
||||
const auto va_low_half = _mm_loadl_epi64((const MLAS_INT32X4*)InputA);
|
||||
const auto va_i16x8 = _mm_unpacklo_epi8(va_low_half, va_low_half);
|
||||
InputA += 8;
|
||||
va_lo = _mm_cvtepi32_ps(MlasShiftRightInt32<DataType>(_mm_unpacklo_epi16(va_i16x8, va_i16x8), 24));
|
||||
va_hi = _mm_cvtepi32_ps(MlasShiftRightInt32<DataType>(_mm_unpackhi_epi16(va_i16x8, va_i16x8), 24));
|
||||
}
|
||||
while (N >= 8) {
|
||||
const auto va_low_half = _mm_loadl_epi64((const MLAS_INT32X4*)InputA);
|
||||
const auto va_i16x8 = _mm_unpacklo_epi8(va_low_half, va_low_half);
|
||||
InputA += 8;
|
||||
va_lo = _mm_cvtepi32_ps(MlasShiftRightInt32<DataType>(_mm_unpacklo_epi16(va_i16x8, va_i16x8), 24));
|
||||
va_hi = _mm_cvtepi32_ps(MlasShiftRightInt32<DataType>(_mm_unpackhi_epi16(va_i16x8, va_i16x8), 24));
|
||||
|
||||
if (!IsScalarB) {
|
||||
const auto vb_low_half = _mm_loadl_epi64((const MLAS_INT32X4*)InputB);
|
||||
const auto vb_i16x8 = _mm_unpacklo_epi8(vb_low_half, vb_low_half);
|
||||
|
|
@ -677,10 +682,7 @@ MlasQLinearAddKernelHelper(
|
|||
}
|
||||
|
||||
MLAS_INT32X4 r_lo, r_hi;
|
||||
if (IsScalarA) {
|
||||
r_lo = _mm_cvtps_epi32(_mm_add_ps(VectorFixedPart, _mm_mul_ps(vb_lo, VectorScaleRatio_BC)));
|
||||
r_hi = _mm_cvtps_epi32(_mm_add_ps(VectorFixedPart, _mm_mul_ps(vb_hi, VectorScaleRatio_BC)));
|
||||
} else if (IsScalarB) {
|
||||
if (IsScalarB) {
|
||||
r_lo = _mm_cvtps_epi32(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_lo, VectorScaleRatio_AC)));
|
||||
r_hi = _mm_cvtps_epi32(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_hi, VectorScaleRatio_AC)));
|
||||
} else {
|
||||
|
|
@ -688,26 +690,54 @@ MlasQLinearAddKernelHelper(
|
|||
r_hi = _mm_cvtps_epi32(_mm_add_ps(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_hi, VectorScaleRatio_AC)), _mm_mul_ps(vb_hi, VectorScaleRatio_BC)));
|
||||
}
|
||||
const auto vc_i16x8 = _mm_packs_epi32(r_lo, r_hi);
|
||||
vc = MlasPackS16_128<DataType>(vc_i16x8, vc_i16x8);
|
||||
|
||||
n -= 8;
|
||||
if (n < 0) break;
|
||||
MLAS_INT32X4 vc = MlasPackS16_128<DataType>(vc_i16x8, vc_i16x8);
|
||||
|
||||
N -= 8;
|
||||
_mm_storel_epi64((MLAS_INT32X4*)OutputC, vc);
|
||||
OutputC += 8;
|
||||
}
|
||||
|
||||
if (n < 0) {
|
||||
n += 8;
|
||||
if (n & 4) {
|
||||
if (N > 0) {
|
||||
uint8_t TailData[8] = { 0 };
|
||||
|
||||
{
|
||||
MlasCopyTailBytes(TailData, (const uint8_t*)InputA, N);
|
||||
const auto va_low_half = _mm_loadl_epi64((const MLAS_INT32X4*)TailData);
|
||||
const auto va_i16x8 = _mm_unpacklo_epi8(va_low_half, va_low_half);
|
||||
InputA += 8;
|
||||
va_lo = _mm_cvtepi32_ps(MlasShiftRightInt32<DataType>(_mm_unpacklo_epi16(va_i16x8, va_i16x8), 24));
|
||||
va_hi = _mm_cvtepi32_ps(MlasShiftRightInt32<DataType>(_mm_unpackhi_epi16(va_i16x8, va_i16x8), 24));
|
||||
}
|
||||
|
||||
if (!IsScalarB) {
|
||||
MlasCopyTailBytes(TailData, (const uint8_t*)InputB, N);
|
||||
const auto vb_low_half = _mm_loadl_epi64((const MLAS_INT32X4*)TailData);
|
||||
const auto vb_i16x8 = _mm_unpacklo_epi8(vb_low_half, vb_low_half);
|
||||
InputB += 8;
|
||||
vb_lo = _mm_cvtepi32_ps(MlasShiftRightInt32<DataType>(_mm_unpacklo_epi16(vb_i16x8, vb_i16x8), 24));
|
||||
vb_hi = _mm_cvtepi32_ps(MlasShiftRightInt32<DataType>(_mm_unpackhi_epi16(vb_i16x8, vb_i16x8), 24));
|
||||
}
|
||||
|
||||
MLAS_INT32X4 r_lo, r_hi;
|
||||
if (IsScalarB) {
|
||||
r_lo = _mm_cvtps_epi32(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_lo, VectorScaleRatio_AC)));
|
||||
r_hi = _mm_cvtps_epi32(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_hi, VectorScaleRatio_AC)));
|
||||
} else {
|
||||
r_lo = _mm_cvtps_epi32(_mm_add_ps(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_lo, VectorScaleRatio_AC)), _mm_mul_ps(vb_lo, VectorScaleRatio_BC)));
|
||||
r_hi = _mm_cvtps_epi32(_mm_add_ps(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_hi, VectorScaleRatio_AC)), _mm_mul_ps(vb_hi, VectorScaleRatio_BC)));
|
||||
}
|
||||
const auto vc_i16x8 = _mm_packs_epi32(r_lo, r_hi);
|
||||
MLAS_INT32X4 vc = MlasPackS16_128<DataType>(vc_i16x8, vc_i16x8);
|
||||
|
||||
if (N & 4) {
|
||||
*(int*)OutputC = _mm_cvtsi128_si32(vc);
|
||||
n -= 4;
|
||||
N -= 4;
|
||||
OutputC += 4;
|
||||
vc = _mm_shuffle_epi32(vc, _MM_SHUFFLE(0, 3, 2, 1));
|
||||
}
|
||||
|
||||
uint32_t PackedValueC = (uint32_t)_mm_cvtsi128_si32(vc);
|
||||
for (int64_t i = 0; i < n; ++i) {
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
*((uint8_t*)OutputC + i) = (uint8_t)PackedValueC;
|
||||
PackedValueC >>= 8;
|
||||
}
|
||||
|
|
@ -716,7 +746,7 @@ MlasQLinearAddKernelHelper(
|
|||
|
||||
#else
|
||||
|
||||
template<typename DataType, bool IsScalarA, bool IsScalarB>
|
||||
template<typename DataType, bool IsScalarB>
|
||||
static
|
||||
void
|
||||
MlasQLinearAddKernelHelper(
|
||||
|
|
@ -733,7 +763,7 @@ MlasQLinearAddKernelHelper(
|
|||
)
|
||||
{
|
||||
// Pure C++ implementation.
|
||||
MlasQLinearAddKernelRawHelper<DataType,IsScalarA, IsScalarB>(
|
||||
MlasQLinearAddKernelRawHelper<DataType, IsScalarB>(
|
||||
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N);
|
||||
}
|
||||
|
||||
|
|
@ -753,22 +783,16 @@ MlasQLinearAddKernel(
|
|||
float ScaleC,
|
||||
int32_t ZeroPointC,
|
||||
DataType* OutputC,
|
||||
size_t LengthA,
|
||||
size_t LengthB
|
||||
size_t N,
|
||||
bool IsScalarB
|
||||
)
|
||||
{
|
||||
size_t N = std::max(LengthA, LengthB);
|
||||
if (N > 0) {
|
||||
if (LengthA == 1) {
|
||||
MlasQLinearAddKernelHelper<DataType, true, false>(
|
||||
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N);
|
||||
} else if (LengthB == 1) {
|
||||
MlasQLinearAddKernelHelper<DataType, false, true>(
|
||||
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N);
|
||||
} else {
|
||||
MlasQLinearAddKernelHelper<DataType, false, false>(
|
||||
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N);
|
||||
}
|
||||
if (IsScalarB) {
|
||||
MlasQLinearAddKernelHelper<DataType, true>(
|
||||
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N);
|
||||
} else {
|
||||
MlasQLinearAddKernelHelper<DataType, false>(
|
||||
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -785,8 +809,8 @@ MlasQLinearAdd<int8_t>(
|
|||
float ScaleC,
|
||||
int32_t ZeroPointC,
|
||||
int8_t* OutputC,
|
||||
size_t LengthA,
|
||||
size_t LengthB
|
||||
size_t N,
|
||||
bool IsScalarB
|
||||
)
|
||||
{
|
||||
#if defined(MLAS_TARGET_AMD64)
|
||||
|
|
@ -794,7 +818,7 @@ MlasQLinearAdd<int8_t>(
|
|||
#else
|
||||
MlasQLinearAddKernel<int8_t>(
|
||||
#endif
|
||||
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthA, LengthB);
|
||||
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N, IsScalarB);
|
||||
}
|
||||
|
||||
template<>
|
||||
|
|
@ -810,8 +834,8 @@ MlasQLinearAdd<uint8_t>(
|
|||
float ScaleC,
|
||||
int32_t ZeroPointC,
|
||||
uint8_t* OutputC,
|
||||
size_t LengthA,
|
||||
size_t LengthB
|
||||
size_t N,
|
||||
bool IsScalarB
|
||||
)
|
||||
{
|
||||
#if defined(MLAS_TARGET_AMD64)
|
||||
|
|
@ -819,7 +843,7 @@ MlasQLinearAdd<uint8_t>(
|
|||
#else
|
||||
MlasQLinearAddKernel<uint8_t>(
|
||||
#endif
|
||||
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthA, LengthB);
|
||||
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N, IsScalarB);
|
||||
}
|
||||
|
||||
//
|
||||
|
|
@ -838,12 +862,12 @@ MlasQLinearAddS8Kernel(
|
|||
float ScaleC,
|
||||
int32_t ZeroPointC,
|
||||
int8_t* OutputC,
|
||||
size_t LengthA,
|
||||
size_t LengthB
|
||||
size_t N,
|
||||
bool IsScalarB
|
||||
)
|
||||
{
|
||||
MlasQLinearAddKernel<int8_t>(
|
||||
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthA, LengthB);
|
||||
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N, IsScalarB);
|
||||
}
|
||||
|
||||
void
|
||||
|
|
@ -858,10 +882,10 @@ MlasQLinearAddU8Kernel(
|
|||
float ScaleC,
|
||||
int32_t ZeroPointC,
|
||||
uint8_t* OutputC,
|
||||
size_t LengthA,
|
||||
size_t LengthB
|
||||
size_t N,
|
||||
bool IsScalarB
|
||||
)
|
||||
{
|
||||
MlasQLinearAddKernel<uint8_t>(
|
||||
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthA, LengthB);
|
||||
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N, IsScalarB);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -48,6 +48,26 @@ MlasFp32FromBits(
|
|||
return uf.fp32;
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
MlasCopyTailBytes(
|
||||
uint8_t* target,
|
||||
const uint8_t* src,
|
||||
size_t N)
|
||||
{
|
||||
while (N >= sizeof(uint32_t)) {
|
||||
*(uint32_t*)(target) = *(uint32_t*)(src);
|
||||
N -= sizeof(uint32_t);
|
||||
target += sizeof(uint32_t);
|
||||
src += sizeof(uint32_t);
|
||||
}
|
||||
while (N > 0) {
|
||||
*target++ = *src++;
|
||||
--N;
|
||||
}
|
||||
}
|
||||
|
||||
bool
|
||||
MlasCalcQLinearAddParameters(
|
||||
float ScaleRatio_AC,
|
||||
|
|
|
|||
|
|
@ -2495,6 +2495,140 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class MlasQLinearAddTest : public MlasTestBase
|
||||
{
|
||||
private:
|
||||
MatrixGuardBuffer<uint8_t> BufferInputA;
|
||||
MatrixGuardBuffer<uint8_t> BufferInputB;
|
||||
MatrixGuardBuffer<uint8_t> BufferOutput;
|
||||
MatrixGuardBuffer<uint8_t> BufferOutputReference;
|
||||
|
||||
template <typename T>
|
||||
T
|
||||
QLinearAddScalar(
|
||||
T a,
|
||||
float ScaleA,
|
||||
int32_t ZeroPointA,
|
||||
T b,
|
||||
float ScaleB,
|
||||
int32_t ZeroPointB,
|
||||
float ScaleC,
|
||||
int32_t ZeroPointC
|
||||
)
|
||||
{
|
||||
|
||||
constexpr int qmax = std::numeric_limits<T>::max();
|
||||
constexpr int qmin = std::numeric_limits<T>::min();
|
||||
|
||||
float ValueA = ScaleA * (static_cast<int>(a) - ZeroPointA);
|
||||
float ValueB = ScaleB * (static_cast<int>(b) - ZeroPointB);
|
||||
float ValueC = std::nearbyintf((ValueA + ValueB) / ScaleC) + ZeroPointC;
|
||||
int qc = static_cast<int>(ValueC);
|
||||
qc = std::min(qc, qmax);
|
||||
qc = std::max(qc, qmin);
|
||||
return static_cast<T>(qc);
|
||||
}
|
||||
|
||||
template <typename T, bool IsScalarB>
|
||||
void
|
||||
Test(
|
||||
size_t N,
|
||||
float ScaleA,
|
||||
int32_t ZeroPointA,
|
||||
float ScaleB,
|
||||
int32_t ZeroPointB,
|
||||
float ScaleC,
|
||||
int32_t ZeroPointC
|
||||
)
|
||||
{
|
||||
T* InputA = (T*)BufferInputA.GetBuffer(N);
|
||||
T* InputB = (T*)BufferInputB.GetBuffer(IsScalarB ? 1 : N);
|
||||
T* OutputC = (T*)BufferOutput.GetBuffer(N);
|
||||
T* OutputReference = (T*)BufferOutputReference.GetBuffer(N);
|
||||
|
||||
constexpr int MinimumValue = (int)std::numeric_limits<T>::min();
|
||||
constexpr int MaximumValue = (int)std::numeric_limits<T>::max();
|
||||
std::default_random_engine generator(static_cast<unsigned>(N));
|
||||
std::uniform_int_distribution<int> distribution(MinimumValue, MaximumValue);
|
||||
|
||||
if (IsScalarB) {
|
||||
InputB[0] = static_cast<T>(distribution(generator));
|
||||
}
|
||||
for (size_t n = 0; n < N; n++) {
|
||||
InputA[n] = static_cast<T>(distribution(generator));
|
||||
if (!IsScalarB) {
|
||||
InputB[n] = static_cast<T>(distribution(generator));
|
||||
}
|
||||
OutputReference[n] = QLinearAddScalar(InputA[n], ScaleA, ZeroPointA, InputB[IsScalarB ? 0 : n], ScaleB, ZeroPointB, ScaleC, ZeroPointC);
|
||||
}
|
||||
|
||||
MlasQLinearAdd(InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N, IsScalarB);
|
||||
|
||||
for (size_t n = 0; n < N; n++) {
|
||||
int diff = (int)OutputC[n] - (int)OutputReference[n];
|
||||
if (diff < -1 || diff > 1) {
|
||||
printf("Test IsScalarB=%d difference @%u of %u, %d(%f,%d) + %d(%f,%d) => %d(%f,%d) (expecting %d)\n",
|
||||
int(IsScalarB), static_cast<unsigned>(n), static_cast<unsigned>(N),
|
||||
static_cast<int>(InputA[n]), ScaleA, ZeroPointA,
|
||||
static_cast<int>(InputB[IsScalarB ? 0 : n]), ScaleB, ZeroPointB,
|
||||
static_cast<int>(OutputC[n]), ScaleC, ZeroPointC,
|
||||
static_cast<int>(OutputReference[n]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
void
|
||||
ExecuteShort(
|
||||
void
|
||||
) override
|
||||
{
|
||||
// uint8_t test
|
||||
static const uint8_t zero_points[] = { 0, 18, 75, 128, 157, 231, 255 };
|
||||
for (size_t a = 0; a < _countof(zero_points); a++) {
|
||||
uint8_t offa = zero_points[a];
|
||||
|
||||
for (size_t b = 0; b < _countof(zero_points); b++) {
|
||||
uint8_t offb = zero_points[b];
|
||||
|
||||
for (size_t c = 0; c < _countof(zero_points); c++) {
|
||||
uint8_t offc = zero_points[c];
|
||||
|
||||
for (size_t n = 1; n < 128; n++) {
|
||||
// vector + vector
|
||||
Test<uint8_t, false>(n, 10.f, offa, 10.f, offb, 20.f, offc);
|
||||
|
||||
// vector + scalar
|
||||
Test<uint8_t, true>(n, 10.f, offa, 10.f, offb, 20.f, offc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static const int8_t szero_points[] = { -128, -110, -53, 0, 29, 103, 127 };
|
||||
for (size_t a = 0; a < _countof(zero_points); a++) {
|
||||
int8_t offa = szero_points[a];
|
||||
|
||||
for (size_t b = 0; b < _countof(zero_points); b++) {
|
||||
int8_t offb = szero_points[b];
|
||||
|
||||
for (size_t c = 0; c < _countof(zero_points); c++) {
|
||||
int8_t offc = szero_points[c];
|
||||
|
||||
for (size_t n = 1; n < 128; n++) {
|
||||
// vector + vector
|
||||
Test<int8_t, false>(n, 10.f, offa, 10.f, offb, 20.f, offc);
|
||||
|
||||
// vector + scalar
|
||||
Test<int8_t, true>(n, 10.f, offa, 10.f, offb, 20.f, offc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
class MlasFindMinMaxElementsTest : public MlasTestBase
|
||||
{
|
||||
private:
|
||||
|
|
@ -2652,6 +2786,9 @@ main(
|
|||
onnxruntime::make_unique<MlasReorderOutputTest>()->ExecuteShort();
|
||||
}
|
||||
|
||||
printf("QLinearAdd tests.\n");
|
||||
onnxruntime::make_unique<MlasQLinearAddTest>()->ExecuteShort();
|
||||
|
||||
printf("Done.\n");
|
||||
|
||||
return 0;
|
||||
|
|
|
|||
Loading…
Reference in a new issue