Fix avx2 load 32 bytes buffer overrun. (#4455)

* Fix avx2 load 32 bytes buffer overrun.

* Fix qladd buffer overrun for sse2 code.

* Fix QLinearAdd buffer overrun for arm.

* Add mlas test for qladd to cover overrun and more.

* Change API to save binary space. Add more test in mlas to cover different zeropoints.
This commit is contained in:
Zhang Lei 2020-07-09 15:54:31 -07:00 committed by GitHub
parent d4db83858b
commit ccbf49e59f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 344 additions and 169 deletions

View file

@ -118,21 +118,21 @@ Status QLinearAdd<T>::Compute(OpKernelContext* context) const {
*context,
[](gsl::span<T> output, const T& input0, gsl::span<const T> input1,
float A_scale, float B_scale, float C_scale, T A_zero_point, T B_zero_point, T C_zero_point) {
MlasQLinearAdd(&input0, A_scale, A_zero_point,
input1.data(), B_scale, B_zero_point,
C_scale, C_zero_point, output.data(), 1, output.size());
MlasQLinearAdd(input1.data(), B_scale, B_zero_point,
&input0, A_scale, A_zero_point,
C_scale, C_zero_point, output.data(), output.size(), true);
},
[](gsl::span<T> output, gsl::span<const T> input0, const T& input1,
float A_scale, float B_scale, float C_scale, T A_zero_point, T B_zero_point, T C_zero_point) {
MlasQLinearAdd(input0.data(), A_scale, A_zero_point,
&input1, B_scale, B_zero_point,
C_scale, C_zero_point, output.data(), output.size(), 1);
C_scale, C_zero_point, output.data(), output.size(), true);
},
[](gsl::span<T> output, gsl::span<const T> input0, gsl::span<const T> input1,
float A_scale, float B_scale, float C_scale, T A_zero_point, T B_zero_point, T C_zero_point) {
MlasQLinearAdd(input0.data(), A_scale, A_zero_point,
input1.data(), B_scale, B_zero_point,
C_scale, C_zero_point, output.data(), output.size(), output.size());
C_scale, C_zero_point, output.data(), output.size(), false);
},
1.0);
}

View file

@ -535,7 +535,8 @@ MlasFindMinMaxElement(
);
//
// LengthA == LengthB, or (LengthA == 1 or LengthB == 1), broadcasting semantic
// InputA is of size N,
// Input B is of size 1 if IsScalarB == true, otherwise it is of size N
//
template<typename DataType>
void
@ -550,6 +551,6 @@ MlasQLinearAdd(
float ScaleC,
int32_t ZeroPointC,
DataType* OutputC,
size_t LengthA,
size_t LengthB
size_t N,
bool IsScalarB
);

View file

@ -19,6 +19,7 @@ Abstract:
--*/
#include "../../mlasi.h"
#include "../../qladd.h"
template <typename DataType>
MLAS_FORCEINLINE
@ -75,7 +76,21 @@ MlasPackS16_256<int8_t>(
return _mm256_packs_epi16(a, b);
}
template <typename DataType, bool IsScalarA, bool IsScalarB>
MLAS_FORCEINLINE
static
__m256i
MlasLoad32Bytes(const uint8_t* buffer, int64_t N)
{
if (N >= 32) {
return _mm256_lddqu_si256((const __m256i*)buffer);
} else {
uint8_t dup[32];
MlasCopyTailBytes(dup, buffer, (size_t)N);
return _mm256_lddqu_si256((const __m256i*)dup);
}
}
template <typename DataType, bool IsScalarB>
static
void
MlasQLinearAddKernelAvx2Helper(
@ -97,10 +112,6 @@ MlasQLinearAddKernelAvx2Helper(
const __m256 VectorScaleRatio_BC = _mm256_set1_ps(ScaleRatio_BC);
__m256 VectorFixedPart = _mm256_set1_ps((float)ZeroPointC - (ScaleRatio_AC * ZeroPointA + ScaleRatio_BC * ZeroPointB));
if (IsScalarA) {
const auto va_f32x8 = _mm256_set1_ps((float)(int32_t)*InputA);
VectorFixedPart = _mm256_add_ps(VectorFixedPart, _mm256_mul_ps(va_f32x8, VectorScaleRatio_AC));
}
if (IsScalarB) {
const auto vb_f32x8 = _mm256_set1_ps((float)(int32_t)*InputB);
VectorFixedPart = _mm256_add_ps(VectorFixedPart, _mm256_mul_ps(vb_f32x8, VectorScaleRatio_BC));
@ -110,28 +121,16 @@ MlasQLinearAddKernelAvx2Helper(
__m256i vc = _mm256_setzero_si256();
while (n > 0) {
__m256i va_i8x32, vb_i8x32;
if (!IsScalarA) {
va_i8x32 = _mm256_lddqu_si256((const __m256i*)InputA);
InputA += 32;
}
va_i8x32 = MlasLoad32Bytes((const uint8_t*)InputA, n);
InputA += 32;
if (!IsScalarB) {
vb_i8x32 = _mm256_lddqu_si256((const __m256i*)InputB);
vb_i8x32 = MlasLoad32Bytes((const uint8_t*)InputB, n);
InputB += 32;
}
__m256 lolo_f32x8, lohi_f32x8, hilo_f32x8, hihi_f32x8;
if (IsScalarA) {
const auto blo_i16x16 = _mm256_unpacklo_epi8(vb_i8x32, vb_i8x32);
const auto bhi_i16x16 = _mm256_unpackhi_epi8(vb_i8x32, vb_i8x32);
lolo_f32x8 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32<DataType>(_mm256_unpacklo_epi16(blo_i16x16, blo_i16x16)));
lohi_f32x8 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32<DataType>(_mm256_unpackhi_epi16(blo_i16x16, blo_i16x16)));
hilo_f32x8 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32<DataType>(_mm256_unpacklo_epi16(bhi_i16x16, bhi_i16x16)));
hihi_f32x8 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32<DataType>(_mm256_unpackhi_epi16(bhi_i16x16, bhi_i16x16)));
lolo_f32x8 = _mm256_fmadd_ps(lolo_f32x8, VectorScaleRatio_BC, VectorFixedPart);
lohi_f32x8 = _mm256_fmadd_ps(lohi_f32x8, VectorScaleRatio_BC, VectorFixedPart);
hilo_f32x8 = _mm256_fmadd_ps(hilo_f32x8, VectorScaleRatio_BC, VectorFixedPart);
hihi_f32x8 = _mm256_fmadd_ps(hihi_f32x8, VectorScaleRatio_BC, VectorFixedPart);
} else if (IsScalarB) {
if (IsScalarB) {
const auto alo_i16x16 = _mm256_unpacklo_epi8(va_i8x32, va_i8x32);
const auto ahi_i16x16 = _mm256_unpackhi_epi8(va_i8x32, va_i8x32);
lolo_f32x8 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32<DataType>(_mm256_unpacklo_epi16(alo_i16x16, alo_i16x16)));
@ -210,19 +209,16 @@ MlasQLinearAddS8KernelAvx2(
float ScaleC,
int32_t ZeroPointC,
int8_t* OutputC,
size_t LengthA,
size_t LengthB
size_t N,
bool IsScalarB
)
{
if (LengthA == 1) {
MlasQLinearAddKernelAvx2Helper<int8_t, true, false>(
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthB);
} else if (LengthB == 1) {
MlasQLinearAddKernelAvx2Helper<int8_t, false, true>(
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthA);
if (IsScalarB) {
MlasQLinearAddKernelAvx2Helper<int8_t, true>(
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N);
} else {
MlasQLinearAddKernelAvx2Helper<int8_t, false, false>(
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthA);
MlasQLinearAddKernelAvx2Helper<int8_t, false>(
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N);
}
}
@ -238,18 +234,15 @@ MlasQLinearAddU8KernelAvx2(
float ScaleC,
int32_t ZeroPointC,
uint8_t* OutputC,
size_t LengthA,
size_t LengthB
size_t N,
bool IsScalarB
)
{
if (LengthA == 1) {
MlasQLinearAddKernelAvx2Helper<uint8_t, true, false>(
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthB);
} else if (LengthB == 1) {
MlasQLinearAddKernelAvx2Helper<uint8_t, false, true>(
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthA);
if (IsScalarB) {
MlasQLinearAddKernelAvx2Helper<uint8_t, true>(
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N);
} else {
MlasQLinearAddKernelAvx2Helper<uint8_t, false, false>(
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthA);
MlasQLinearAddKernelAvx2Helper<uint8_t, false>(
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N);
}
}

View file

@ -485,8 +485,8 @@ void
float ScaleC,
int32_t ZeroPointC,
int8_t* OutputC,
size_t LengthA,
size_t LengthB
size_t N,
bool IsScalarB
);
typedef MLAS_QLINEAR_BINARY_OP_S8_KERNEL* PMLAS_QLINEAR_BINARY_OP_S8_KERNEL;
@ -503,8 +503,8 @@ void
float ScaleC,
int32_t ZeroPointC,
uint8_t* OutputC,
size_t LengthA,
size_t LengthB
size_t N,
bool IsScalarB
);
typedef MLAS_QLINEAR_BINARY_OP_U8_KERNEL* PMLAS_QLINEAR_BINARY_OP_U8_KERNEL;

View file

@ -49,7 +49,7 @@ MlasCalcQLinearAddParameters(
}
// Pure C++ helper, back off here in rare case.
template<typename DataType, bool IsScalarA, bool IsScalarB>
template<typename DataType, bool IsScalarB>
MLAS_FORCEINLINE
static
void
@ -69,20 +69,14 @@ MlasQLinearAddKernelRawHelper(
const float MinimumValue = (float)((int)std::numeric_limits<DataType>::min() - ZeroPointC);
const float MaximumValue = (float)((int)std::numeric_limits<DataType>::max() - ZeroPointC);
float ValueA;
float ValueB;
if (IsScalarA) {
ValueA = ScaleA * (int32_t(InputA[0]) - ZeroPointA);
}
if (IsScalarB) {
ValueB = ScaleB * (int32_t(InputB[0]) - ZeroPointB);
}
for (size_t n = 0; n < N; n++) {
if (!IsScalarA) {
ValueA = ScaleA * (int32_t(InputA[n]) - ZeroPointA);
}
float ValueA = ScaleA * (int32_t(InputA[n]) - ZeroPointA);
if (!IsScalarB) {
ValueB = ScaleB * (int32_t(InputB[n]) - ZeroPointB);
}
@ -320,7 +314,7 @@ public:
#endif
template<typename DataType, bool IsScalarA, bool IsScalarB>
template<typename DataType, bool IsScalarB>
static
void
MlasQLinearAddKernelHelper(
@ -342,7 +336,7 @@ MlasQLinearAddKernelHelper(
const float ScaleRatio_AC = ScaleA / ScaleC;
const float ScaleRatio_BC = ScaleB / ScaleC;
if (!MlasCalcQLinearAddParameters(ScaleRatio_AC, ScaleRatio_BC, Shift, MultiplierA, MultiplierB)) {
MlasQLinearAddKernelRawHelper<DataType,IsScalarA, IsScalarB>(
MlasQLinearAddKernelRawHelper<DataType, IsScalarB>(
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N);
return;
}
@ -356,40 +350,17 @@ MlasQLinearAddKernelHelper(
const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
int32x4_t vscalar;
if (IsScalarA) {
const typename SUI::i8x8_t VectorA0 = SUI::vmov_n_i8(*InputA);
const int16x8_t va_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(VectorA0, VectorZeroPointA));
vscalar = vmulq_s32(vmovl_s16(vget_low_s16(va_s16x8)), VectorMultiplierA);
}
if (IsScalarB) {
const typename SUI::i8x8_t VectorB0 = SUI::vmov_n_i8(*InputB);
const int16x8_t vb_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(VectorB0, VectorZeroPointB));
vscalar = vmulq_s32(vmovl_s16(vget_low_s16(vb_s16x8)), VectorMultiplierB);
}
auto n = static_cast<int64_t>(N);
#if defined(MLAS_NEON64_INTRINSICS)
while (n >= 32) {
while (N >= 32) {
int32x4_t vacc0_lo, vacc0_hi, vacc1_lo, vacc1_hi, vacc2_lo, vacc2_hi, vacc3_lo, vacc3_hi;
if (IsScalarA) {
const typename SUI::i8x16_t VectorB0 = SUI::vld1q_i8(InputB);
const typename SUI::i8x16_t VectorB1 = SUI::vld1q_i8(InputB + 16);
InputB += 32;
const int16x8_t vb0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorB0), VectorZeroPointB));
const int16x8_t vb1_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorB0), VectorZeroPointB));
const int16x8_t vb2_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorB1), VectorZeroPointB));
const int16x8_t vb3_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorB1), VectorZeroPointB));
vacc0_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(vb0_s16x8)), VectorMultiplierB);
vacc1_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(vb1_s16x8)), VectorMultiplierB);
vacc2_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(vb2_s16x8)), VectorMultiplierB);
vacc3_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(vb3_s16x8)), VectorMultiplierB);
vacc0_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(vb0_s16x8), VectorMultiplierB);
vacc1_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(vb1_s16x8), VectorMultiplierB);
vacc2_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(vb2_s16x8), VectorMultiplierB);
vacc3_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(vb3_s16x8), VectorMultiplierB);
} else if (IsScalarB) {
if (IsScalarB) {
const typename SUI::i8x16_t VectorA0 = SUI::vld1q_i8(InputA);
const typename SUI::i8x16_t VectorA1 = SUI::vld1q_i8(InputA + 16);
InputA += 32;
@ -470,26 +441,15 @@ MlasQLinearAddKernelHelper(
SUI::vst1q_i8(OutputC, vc0);
SUI::vst1q_i8(OutputC + 16, vc1);
n -= 32;
N -= 32;
OutputC += 32;
}
#endif
typename SUI::i8x16_t vc;
while (n > 0) {
while (N >= 16) {
int32x4_t vacc0_lo, vacc1_lo, vacc0_hi, vacc1_hi;
if (IsScalarA) {
const typename SUI::i8x16_t VectorB0 = SUI::vld1q_i8(InputB);
InputB += 16;
const int16x8_t vb0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorB0), VectorZeroPointB));
const int16x8_t vb1_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorB0), VectorZeroPointB));
vacc0_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(vb0_s16x8)), VectorMultiplierB);
vacc1_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(vb1_s16x8)), VectorMultiplierB);
vacc0_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(vb0_s16x8), VectorMultiplierB);
vacc1_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(vb1_s16x8), VectorMultiplierB);
} else if (IsScalarB) {
if (IsScalarB) {
const typename SUI::i8x16_t VectorA0 = SUI::vld1q_i8(InputA);
InputA += 16;
const int16x8_t va0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorA0), VectorZeroPointA));
@ -533,34 +493,86 @@ MlasQLinearAddKernelHelper(
// Pack, saturate, and add output zero point.
const int16x8_t vacc0 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0_lo), vqmovn_s32(vacc0_hi)), VectorZeroPointC);
const int16x8_t vacc1 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1_lo), vqmovn_s32(vacc1_hi)), VectorZeroPointC);
vc = SUI::combine_i8_s16(vacc0, vacc1);
n -= 16;
if (n < 0) break;
typename SUI::i8x16_t vc = SUI::combine_i8_s16(vacc0, vacc1);
N -= 16;
SUI::vst1q_i8(OutputC, vc);
OutputC += 16;
}
if (n < 0) {
n += 16;
if (N > 0) {
typename SUI::T TailDataA[16] = { 0 };
typename SUI::T TailDataB[16] = { 0 };
MlasCopyTailBytes((uint8_t*)TailDataA, (const uint8_t*)InputA, N);
if (!IsScalarB) {
MlasCopyTailBytes((uint8_t*)TailDataB, (const uint8_t*)InputB, N);
}
int32x4_t vacc0_lo, vacc1_lo, vacc0_hi, vacc1_hi;
if (IsScalarB) {
const typename SUI::i8x16_t VectorA0 = SUI::vld1q_i8(TailDataA);
InputA += 16;
const int16x8_t va0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorA0), VectorZeroPointA));
const int16x8_t va1_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorA0), VectorZeroPointA));
vacc0_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(va0_s16x8)), VectorMultiplierA);
vacc1_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(va1_s16x8)), VectorMultiplierA);
vacc0_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(va0_s16x8), VectorMultiplierA);
vacc1_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(va1_s16x8), VectorMultiplierA);
} else {
const typename SUI::i8x16_t VectorA0 = SUI::vld1q_i8(TailDataA);
const typename SUI::i8x16_t VectorB0 = SUI::vld1q_i8(TailDataB);
InputA += 16;
InputB += 16;
const int16x8_t va0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorA0), VectorZeroPointA));
const int16x8_t vb0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorB0), VectorZeroPointB));
const int16x8_t va1_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorA0), VectorZeroPointA));
const int16x8_t vb1_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorB0), VectorZeroPointB));
vacc0_lo = vmulq_s32(vmovl_s16(vget_low_s16(va0_s16x8)), VectorMultiplierA);
vacc1_lo = vmulq_s32(vmovl_s16(vget_low_s16(va1_s16x8)), VectorMultiplierA);
vacc0_hi = vmulq_s32(MlasMoveHighS16S32(va0_s16x8), VectorMultiplierA);
vacc1_hi = vmulq_s32(MlasMoveHighS16S32(va1_s16x8), VectorMultiplierA);
vacc0_lo = vmlaq_s32(vacc0_lo, vmovl_s16(vget_low_s16(vb0_s16x8)), VectorMultiplierB);
vacc1_lo = vmlaq_s32(vacc1_lo, vmovl_s16(vget_low_s16(vb1_s16x8)), VectorMultiplierB);
vacc0_hi = vmlaq_s32(vacc0_hi, MlasMoveHighS16S32(vb0_s16x8), VectorMultiplierB);
vacc1_hi = vmlaq_s32(vacc1_hi, MlasMoveHighS16S32(vb1_s16x8), VectorMultiplierB);
}
vacc0_lo = vsraq_n_s32(vacc0_lo, vbicq_s32(vacc0_lo, vzero_shift_mask), 31);
vacc1_lo = vsraq_n_s32(vacc1_lo, vbicq_s32(vacc1_lo, vzero_shift_mask), 31);
vacc0_hi = vsraq_n_s32(vacc0_hi, vbicq_s32(vacc0_hi, vzero_shift_mask), 31);
vacc1_hi = vsraq_n_s32(vacc1_hi, vbicq_s32(vacc1_hi, vzero_shift_mask), 31);
vacc0_lo = vrshlq_s32(vacc0_lo, vright_shift);
vacc1_lo = vrshlq_s32(vacc1_lo, vright_shift);
vacc0_hi = vrshlq_s32(vacc0_hi, vright_shift);
vacc1_hi = vrshlq_s32(vacc1_hi, vright_shift);
// Pack, saturate, and add output zero point.
const int16x8_t vacc0 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0_lo), vqmovn_s32(vacc0_hi)), VectorZeroPointC);
const int16x8_t vacc1 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1_lo), vqmovn_s32(vacc1_hi)), VectorZeroPointC);
typename SUI::i8x16_t vc = SUI::combine_i8_s16(vacc0, vacc1);
typename SUI::i8x8_t i8x8 = SUI::vget_low_i8(vc);
if (n & 8) {
if (N & 8) {
SUI::vst1_i8(OutputC, i8x8);
OutputC += 8;
i8x8 = SUI::vget_high_i8(vc);
}
if (n & 4) {
if (N & 4) {
vst1_lane_u32_ex((uint32_t*)OutputC, SUI::vreinterpret_u32_i8(i8x8), 0, 8);
OutputC += 4;
i8x8 = SUI::template vext_i8<4>(i8x8, i8x8);
}
if (n & 2) {
if (N & 2) {
vst1_lane_u16_ex((uint16_t*)OutputC, SUI::vreinterpret_u16_i8(i8x8), 0, 8);
OutputC += 2;
i8x8 = SUI::template vext_i8<2>(i8x8, i8x8);
}
if (n & 1) {
if (N & 1) {
SUI::template vst1_lane_i8<0>(OutputC, i8x8);
}
}
@ -626,7 +638,7 @@ MlasPackS16_128<int8_t>(
return _mm_packs_epi16(a, b);
}
template<typename DataType, bool IsScalarA, bool IsScalarB>
template<typename DataType, bool IsScalarB>
static
void
MlasQLinearAddKernelHelper(
@ -649,25 +661,18 @@ MlasQLinearAddKernelHelper(
auto VectorFixedPart = MlasBroadcastFloat32x4((float)ZeroPointC - (ScaleRatio_AC * ZeroPointA + ScaleRatio_BC * ZeroPointB));
MLAS_FLOAT32X4 va_lo, va_hi, vb_lo, vb_hi;
if (IsScalarA) {
va_lo = _mm_set1_ps((float)*InputA);
VectorFixedPart = _mm_add_ps(VectorFixedPart, _mm_mul_ps(va_lo, VectorScaleRatio_AC));
}
if (IsScalarB) {
vb_lo = _mm_set1_ps((float)*InputB);
VectorFixedPart = _mm_add_ps(VectorFixedPart, _mm_mul_ps(vb_lo, VectorScaleRatio_BC));
}
auto n = static_cast<int64_t>(N);
MLAS_INT32X4 vc = _mm_setzero_si128();
while (n > 0) {
if (!IsScalarA) {
const auto va_low_half = _mm_loadl_epi64((const MLAS_INT32X4*)InputA);
const auto va_i16x8 = _mm_unpacklo_epi8(va_low_half, va_low_half);
InputA += 8;
va_lo = _mm_cvtepi32_ps(MlasShiftRightInt32<DataType>(_mm_unpacklo_epi16(va_i16x8, va_i16x8), 24));
va_hi = _mm_cvtepi32_ps(MlasShiftRightInt32<DataType>(_mm_unpackhi_epi16(va_i16x8, va_i16x8), 24));
}
while (N >= 8) {
const auto va_low_half = _mm_loadl_epi64((const MLAS_INT32X4*)InputA);
const auto va_i16x8 = _mm_unpacklo_epi8(va_low_half, va_low_half);
InputA += 8;
va_lo = _mm_cvtepi32_ps(MlasShiftRightInt32<DataType>(_mm_unpacklo_epi16(va_i16x8, va_i16x8), 24));
va_hi = _mm_cvtepi32_ps(MlasShiftRightInt32<DataType>(_mm_unpackhi_epi16(va_i16x8, va_i16x8), 24));
if (!IsScalarB) {
const auto vb_low_half = _mm_loadl_epi64((const MLAS_INT32X4*)InputB);
const auto vb_i16x8 = _mm_unpacklo_epi8(vb_low_half, vb_low_half);
@ -677,10 +682,7 @@ MlasQLinearAddKernelHelper(
}
MLAS_INT32X4 r_lo, r_hi;
if (IsScalarA) {
r_lo = _mm_cvtps_epi32(_mm_add_ps(VectorFixedPart, _mm_mul_ps(vb_lo, VectorScaleRatio_BC)));
r_hi = _mm_cvtps_epi32(_mm_add_ps(VectorFixedPart, _mm_mul_ps(vb_hi, VectorScaleRatio_BC)));
} else if (IsScalarB) {
if (IsScalarB) {
r_lo = _mm_cvtps_epi32(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_lo, VectorScaleRatio_AC)));
r_hi = _mm_cvtps_epi32(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_hi, VectorScaleRatio_AC)));
} else {
@ -688,26 +690,54 @@ MlasQLinearAddKernelHelper(
r_hi = _mm_cvtps_epi32(_mm_add_ps(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_hi, VectorScaleRatio_AC)), _mm_mul_ps(vb_hi, VectorScaleRatio_BC)));
}
const auto vc_i16x8 = _mm_packs_epi32(r_lo, r_hi);
vc = MlasPackS16_128<DataType>(vc_i16x8, vc_i16x8);
n -= 8;
if (n < 0) break;
MLAS_INT32X4 vc = MlasPackS16_128<DataType>(vc_i16x8, vc_i16x8);
N -= 8;
_mm_storel_epi64((MLAS_INT32X4*)OutputC, vc);
OutputC += 8;
}
if (n < 0) {
n += 8;
if (n & 4) {
if (N > 0) {
uint8_t TailData[8] = { 0 };
{
MlasCopyTailBytes(TailData, (const uint8_t*)InputA, N);
const auto va_low_half = _mm_loadl_epi64((const MLAS_INT32X4*)TailData);
const auto va_i16x8 = _mm_unpacklo_epi8(va_low_half, va_low_half);
InputA += 8;
va_lo = _mm_cvtepi32_ps(MlasShiftRightInt32<DataType>(_mm_unpacklo_epi16(va_i16x8, va_i16x8), 24));
va_hi = _mm_cvtepi32_ps(MlasShiftRightInt32<DataType>(_mm_unpackhi_epi16(va_i16x8, va_i16x8), 24));
}
if (!IsScalarB) {
MlasCopyTailBytes(TailData, (const uint8_t*)InputB, N);
const auto vb_low_half = _mm_loadl_epi64((const MLAS_INT32X4*)TailData);
const auto vb_i16x8 = _mm_unpacklo_epi8(vb_low_half, vb_low_half);
InputB += 8;
vb_lo = _mm_cvtepi32_ps(MlasShiftRightInt32<DataType>(_mm_unpacklo_epi16(vb_i16x8, vb_i16x8), 24));
vb_hi = _mm_cvtepi32_ps(MlasShiftRightInt32<DataType>(_mm_unpackhi_epi16(vb_i16x8, vb_i16x8), 24));
}
MLAS_INT32X4 r_lo, r_hi;
if (IsScalarB) {
r_lo = _mm_cvtps_epi32(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_lo, VectorScaleRatio_AC)));
r_hi = _mm_cvtps_epi32(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_hi, VectorScaleRatio_AC)));
} else {
r_lo = _mm_cvtps_epi32(_mm_add_ps(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_lo, VectorScaleRatio_AC)), _mm_mul_ps(vb_lo, VectorScaleRatio_BC)));
r_hi = _mm_cvtps_epi32(_mm_add_ps(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_hi, VectorScaleRatio_AC)), _mm_mul_ps(vb_hi, VectorScaleRatio_BC)));
}
const auto vc_i16x8 = _mm_packs_epi32(r_lo, r_hi);
MLAS_INT32X4 vc = MlasPackS16_128<DataType>(vc_i16x8, vc_i16x8);
if (N & 4) {
*(int*)OutputC = _mm_cvtsi128_si32(vc);
n -= 4;
N -= 4;
OutputC += 4;
vc = _mm_shuffle_epi32(vc, _MM_SHUFFLE(0, 3, 2, 1));
}
uint32_t PackedValueC = (uint32_t)_mm_cvtsi128_si32(vc);
for (int64_t i = 0; i < n; ++i) {
for (size_t i = 0; i < N; ++i) {
*((uint8_t*)OutputC + i) = (uint8_t)PackedValueC;
PackedValueC >>= 8;
}
@ -716,7 +746,7 @@ MlasQLinearAddKernelHelper(
#else
template<typename DataType, bool IsScalarA, bool IsScalarB>
template<typename DataType, bool IsScalarB>
static
void
MlasQLinearAddKernelHelper(
@ -733,7 +763,7 @@ MlasQLinearAddKernelHelper(
)
{
// Pure C++ implementation.
MlasQLinearAddKernelRawHelper<DataType,IsScalarA, IsScalarB>(
MlasQLinearAddKernelRawHelper<DataType, IsScalarB>(
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N);
}
@ -753,22 +783,16 @@ MlasQLinearAddKernel(
float ScaleC,
int32_t ZeroPointC,
DataType* OutputC,
size_t LengthA,
size_t LengthB
size_t N,
bool IsScalarB
)
{
size_t N = std::max(LengthA, LengthB);
if (N > 0) {
if (LengthA == 1) {
MlasQLinearAddKernelHelper<DataType, true, false>(
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N);
} else if (LengthB == 1) {
MlasQLinearAddKernelHelper<DataType, false, true>(
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N);
} else {
MlasQLinearAddKernelHelper<DataType, false, false>(
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N);
}
if (IsScalarB) {
MlasQLinearAddKernelHelper<DataType, true>(
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N);
} else {
MlasQLinearAddKernelHelper<DataType, false>(
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N);
}
}
@ -785,8 +809,8 @@ MlasQLinearAdd<int8_t>(
float ScaleC,
int32_t ZeroPointC,
int8_t* OutputC,
size_t LengthA,
size_t LengthB
size_t N,
bool IsScalarB
)
{
#if defined(MLAS_TARGET_AMD64)
@ -794,7 +818,7 @@ MlasQLinearAdd<int8_t>(
#else
MlasQLinearAddKernel<int8_t>(
#endif
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthA, LengthB);
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N, IsScalarB);
}
template<>
@ -810,8 +834,8 @@ MlasQLinearAdd<uint8_t>(
float ScaleC,
int32_t ZeroPointC,
uint8_t* OutputC,
size_t LengthA,
size_t LengthB
size_t N,
bool IsScalarB
)
{
#if defined(MLAS_TARGET_AMD64)
@ -819,7 +843,7 @@ MlasQLinearAdd<uint8_t>(
#else
MlasQLinearAddKernel<uint8_t>(
#endif
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthA, LengthB);
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N, IsScalarB);
}
//
@ -838,12 +862,12 @@ MlasQLinearAddS8Kernel(
float ScaleC,
int32_t ZeroPointC,
int8_t* OutputC,
size_t LengthA,
size_t LengthB
size_t N,
bool IsScalarB
)
{
MlasQLinearAddKernel<int8_t>(
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthA, LengthB);
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N, IsScalarB);
}
void
@ -858,10 +882,10 @@ MlasQLinearAddU8Kernel(
float ScaleC,
int32_t ZeroPointC,
uint8_t* OutputC,
size_t LengthA,
size_t LengthB
size_t N,
bool IsScalarB
)
{
MlasQLinearAddKernel<uint8_t>(
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthA, LengthB);
InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N, IsScalarB);
}

View file

@ -48,6 +48,26 @@ MlasFp32FromBits(
return uf.fp32;
}
MLAS_FORCEINLINE
static
void
MlasCopyTailBytes(
uint8_t* target,
const uint8_t* src,
size_t N)
{
while (N >= sizeof(uint32_t)) {
*(uint32_t*)(target) = *(uint32_t*)(src);
N -= sizeof(uint32_t);
target += sizeof(uint32_t);
src += sizeof(uint32_t);
}
while (N > 0) {
*target++ = *src++;
--N;
}
}
bool
MlasCalcQLinearAddParameters(
float ScaleRatio_AC,

View file

@ -2495,6 +2495,140 @@ public:
}
};
class MlasQLinearAddTest : public MlasTestBase
{
private:
MatrixGuardBuffer<uint8_t> BufferInputA;
MatrixGuardBuffer<uint8_t> BufferInputB;
MatrixGuardBuffer<uint8_t> BufferOutput;
MatrixGuardBuffer<uint8_t> BufferOutputReference;
template <typename T>
T
QLinearAddScalar(
T a,
float ScaleA,
int32_t ZeroPointA,
T b,
float ScaleB,
int32_t ZeroPointB,
float ScaleC,
int32_t ZeroPointC
)
{
constexpr int qmax = std::numeric_limits<T>::max();
constexpr int qmin = std::numeric_limits<T>::min();
float ValueA = ScaleA * (static_cast<int>(a) - ZeroPointA);
float ValueB = ScaleB * (static_cast<int>(b) - ZeroPointB);
float ValueC = std::nearbyintf((ValueA + ValueB) / ScaleC) + ZeroPointC;
int qc = static_cast<int>(ValueC);
qc = std::min(qc, qmax);
qc = std::max(qc, qmin);
return static_cast<T>(qc);
}
template <typename T, bool IsScalarB>
void
Test(
size_t N,
float ScaleA,
int32_t ZeroPointA,
float ScaleB,
int32_t ZeroPointB,
float ScaleC,
int32_t ZeroPointC
)
{
T* InputA = (T*)BufferInputA.GetBuffer(N);
T* InputB = (T*)BufferInputB.GetBuffer(IsScalarB ? 1 : N);
T* OutputC = (T*)BufferOutput.GetBuffer(N);
T* OutputReference = (T*)BufferOutputReference.GetBuffer(N);
constexpr int MinimumValue = (int)std::numeric_limits<T>::min();
constexpr int MaximumValue = (int)std::numeric_limits<T>::max();
std::default_random_engine generator(static_cast<unsigned>(N));
std::uniform_int_distribution<int> distribution(MinimumValue, MaximumValue);
if (IsScalarB) {
InputB[0] = static_cast<T>(distribution(generator));
}
for (size_t n = 0; n < N; n++) {
InputA[n] = static_cast<T>(distribution(generator));
if (!IsScalarB) {
InputB[n] = static_cast<T>(distribution(generator));
}
OutputReference[n] = QLinearAddScalar(InputA[n], ScaleA, ZeroPointA, InputB[IsScalarB ? 0 : n], ScaleB, ZeroPointB, ScaleC, ZeroPointC);
}
MlasQLinearAdd(InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N, IsScalarB);
for (size_t n = 0; n < N; n++) {
int diff = (int)OutputC[n] - (int)OutputReference[n];
if (diff < -1 || diff > 1) {
printf("Test IsScalarB=%d difference @%u of %u, %d(%f,%d) + %d(%f,%d) => %d(%f,%d) (expecting %d)\n",
int(IsScalarB), static_cast<unsigned>(n), static_cast<unsigned>(N),
static_cast<int>(InputA[n]), ScaleA, ZeroPointA,
static_cast<int>(InputB[IsScalarB ? 0 : n]), ScaleB, ZeroPointB,
static_cast<int>(OutputC[n]), ScaleC, ZeroPointC,
static_cast<int>(OutputReference[n]));
}
}
}
public:
void
ExecuteShort(
void
) override
{
// uint8_t test
static const uint8_t zero_points[] = { 0, 18, 75, 128, 157, 231, 255 };
for (size_t a = 0; a < _countof(zero_points); a++) {
uint8_t offa = zero_points[a];
for (size_t b = 0; b < _countof(zero_points); b++) {
uint8_t offb = zero_points[b];
for (size_t c = 0; c < _countof(zero_points); c++) {
uint8_t offc = zero_points[c];
for (size_t n = 1; n < 128; n++) {
// vector + vector
Test<uint8_t, false>(n, 10.f, offa, 10.f, offb, 20.f, offc);
// vector + scalar
Test<uint8_t, true>(n, 10.f, offa, 10.f, offb, 20.f, offc);
}
}
}
}
static const int8_t szero_points[] = { -128, -110, -53, 0, 29, 103, 127 };
for (size_t a = 0; a < _countof(zero_points); a++) {
int8_t offa = szero_points[a];
for (size_t b = 0; b < _countof(zero_points); b++) {
int8_t offb = szero_points[b];
for (size_t c = 0; c < _countof(zero_points); c++) {
int8_t offc = szero_points[c];
for (size_t n = 1; n < 128; n++) {
// vector + vector
Test<int8_t, false>(n, 10.f, offa, 10.f, offb, 20.f, offc);
// vector + scalar
Test<int8_t, true>(n, 10.f, offa, 10.f, offb, 20.f, offc);
}
}
}
}
}
};
class MlasFindMinMaxElementsTest : public MlasTestBase
{
private:
@ -2652,6 +2786,9 @@ main(
onnxruntime::make_unique<MlasReorderOutputTest>()->ExecuteShort();
}
printf("QLinearAdd tests.\n");
onnxruntime::make_unique<MlasQLinearAddTest>()->ExecuteShort();
printf("Done.\n");
return 0;