diff --git a/onnxruntime/contrib_ops/cpu/qlinear_binary_op.cc b/onnxruntime/contrib_ops/cpu/qlinear_binary_op.cc index 7302ec3a5d..c6fe78dac0 100644 --- a/onnxruntime/contrib_ops/cpu/qlinear_binary_op.cc +++ b/onnxruntime/contrib_ops/cpu/qlinear_binary_op.cc @@ -118,21 +118,21 @@ Status QLinearAdd::Compute(OpKernelContext* context) const { *context, [](gsl::span output, const T& input0, gsl::span input1, float A_scale, float B_scale, float C_scale, T A_zero_point, T B_zero_point, T C_zero_point) { - MlasQLinearAdd(&input0, A_scale, A_zero_point, - input1.data(), B_scale, B_zero_point, - C_scale, C_zero_point, output.data(), 1, output.size()); + MlasQLinearAdd(input1.data(), B_scale, B_zero_point, + &input0, A_scale, A_zero_point, + C_scale, C_zero_point, output.data(), output.size(), true); }, [](gsl::span output, gsl::span input0, const T& input1, float A_scale, float B_scale, float C_scale, T A_zero_point, T B_zero_point, T C_zero_point) { MlasQLinearAdd(input0.data(), A_scale, A_zero_point, &input1, B_scale, B_zero_point, - C_scale, C_zero_point, output.data(), output.size(), 1); + C_scale, C_zero_point, output.data(), output.size(), true); }, [](gsl::span output, gsl::span input0, gsl::span input1, float A_scale, float B_scale, float C_scale, T A_zero_point, T B_zero_point, T C_zero_point) { MlasQLinearAdd(input0.data(), A_scale, A_zero_point, input1.data(), B_scale, B_zero_point, - C_scale, C_zero_point, output.data(), output.size(), output.size()); + C_scale, C_zero_point, output.data(), output.size(), false); }, 1.0); } diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 9345667bc4..77a6decc06 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -535,7 +535,8 @@ MlasFindMinMaxElement( ); // -// LengthA == LengthB, or (LengthA == 1 or LengthB == 1), broadcasting semantic +// InputA is of size N, +// Input B is of size 1 if IsScalarB == true, otherwise it is of size N // template void @@ -550,6 +551,6 @@ MlasQLinearAdd( float ScaleC, int32_t ZeroPointC, DataType* OutputC, - size_t LengthA, - size_t LengthB + size_t N, + bool IsScalarB ); diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx2/qladd_avx2.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx2/qladd_avx2.cpp index 7546ad8102..266b289046 100644 --- a/onnxruntime/core/mlas/lib/intrinsics/avx2/qladd_avx2.cpp +++ b/onnxruntime/core/mlas/lib/intrinsics/avx2/qladd_avx2.cpp @@ -19,6 +19,7 @@ Abstract: --*/ #include "../../mlasi.h" +#include "../../qladd.h" template MLAS_FORCEINLINE @@ -75,7 +76,21 @@ MlasPackS16_256( return _mm256_packs_epi16(a, b); } -template +MLAS_FORCEINLINE +static +__m256i +MlasLoad32Bytes(const uint8_t* buffer, int64_t N) +{ + if (N >= 32) { + return _mm256_lddqu_si256((const __m256i*)buffer); + } else { + uint8_t dup[32]; + MlasCopyTailBytes(dup, buffer, (size_t)N); + return _mm256_lddqu_si256((const __m256i*)dup); + } +} + +template static void MlasQLinearAddKernelAvx2Helper( @@ -97,10 +112,6 @@ MlasQLinearAddKernelAvx2Helper( const __m256 VectorScaleRatio_BC = _mm256_set1_ps(ScaleRatio_BC); __m256 VectorFixedPart = _mm256_set1_ps((float)ZeroPointC - (ScaleRatio_AC * ZeroPointA + ScaleRatio_BC * ZeroPointB)); - if (IsScalarA) { - const auto va_f32x8 = _mm256_set1_ps((float)(int32_t)*InputA); - VectorFixedPart = _mm256_add_ps(VectorFixedPart, _mm256_mul_ps(va_f32x8, VectorScaleRatio_AC)); - } if (IsScalarB) { const auto vb_f32x8 = _mm256_set1_ps((float)(int32_t)*InputB); VectorFixedPart = _mm256_add_ps(VectorFixedPart, _mm256_mul_ps(vb_f32x8, VectorScaleRatio_BC)); @@ -110,28 +121,16 @@ MlasQLinearAddKernelAvx2Helper( __m256i vc = _mm256_setzero_si256(); while (n > 0) { __m256i va_i8x32, vb_i8x32; - if (!IsScalarA) { - va_i8x32 = _mm256_lddqu_si256((const __m256i*)InputA); - InputA += 32; - } + va_i8x32 = MlasLoad32Bytes((const uint8_t*)InputA, n); + InputA += 32; + if (!IsScalarB) { - vb_i8x32 = _mm256_lddqu_si256((const __m256i*)InputB); + vb_i8x32 = MlasLoad32Bytes((const uint8_t*)InputB, n); InputB += 32; } __m256 lolo_f32x8, lohi_f32x8, hilo_f32x8, hihi_f32x8; - if (IsScalarA) { - const auto blo_i16x16 = _mm256_unpacklo_epi8(vb_i8x32, vb_i8x32); - const auto bhi_i16x16 = _mm256_unpackhi_epi8(vb_i8x32, vb_i8x32); - lolo_f32x8 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32(_mm256_unpacklo_epi16(blo_i16x16, blo_i16x16))); - lohi_f32x8 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32(_mm256_unpackhi_epi16(blo_i16x16, blo_i16x16))); - hilo_f32x8 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32(_mm256_unpacklo_epi16(bhi_i16x16, bhi_i16x16))); - hihi_f32x8 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32(_mm256_unpackhi_epi16(bhi_i16x16, bhi_i16x16))); - lolo_f32x8 = _mm256_fmadd_ps(lolo_f32x8, VectorScaleRatio_BC, VectorFixedPart); - lohi_f32x8 = _mm256_fmadd_ps(lohi_f32x8, VectorScaleRatio_BC, VectorFixedPart); - hilo_f32x8 = _mm256_fmadd_ps(hilo_f32x8, VectorScaleRatio_BC, VectorFixedPart); - hihi_f32x8 = _mm256_fmadd_ps(hihi_f32x8, VectorScaleRatio_BC, VectorFixedPart); - } else if (IsScalarB) { + if (IsScalarB) { const auto alo_i16x16 = _mm256_unpacklo_epi8(va_i8x32, va_i8x32); const auto ahi_i16x16 = _mm256_unpackhi_epi8(va_i8x32, va_i8x32); lolo_f32x8 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32(_mm256_unpacklo_epi16(alo_i16x16, alo_i16x16))); @@ -210,19 +209,16 @@ MlasQLinearAddS8KernelAvx2( float ScaleC, int32_t ZeroPointC, int8_t* OutputC, - size_t LengthA, - size_t LengthB + size_t N, + bool IsScalarB ) { - if (LengthA == 1) { - MlasQLinearAddKernelAvx2Helper( - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthB); - } else if (LengthB == 1) { - MlasQLinearAddKernelAvx2Helper( - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthA); + if (IsScalarB) { + MlasQLinearAddKernelAvx2Helper( + InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); } else { - MlasQLinearAddKernelAvx2Helper( - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthA); + MlasQLinearAddKernelAvx2Helper( + InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); } } @@ -238,18 +234,15 @@ MlasQLinearAddU8KernelAvx2( float ScaleC, int32_t ZeroPointC, uint8_t* OutputC, - size_t LengthA, - size_t LengthB + size_t N, + bool IsScalarB ) { - if (LengthA == 1) { - MlasQLinearAddKernelAvx2Helper( - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthB); - } else if (LengthB == 1) { - MlasQLinearAddKernelAvx2Helper( - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthA); + if (IsScalarB) { + MlasQLinearAddKernelAvx2Helper( + InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); } else { - MlasQLinearAddKernelAvx2Helper( - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthA); + MlasQLinearAddKernelAvx2Helper( + InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); } } diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index d44f696aea..6da6413d56 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -485,8 +485,8 @@ void float ScaleC, int32_t ZeroPointC, int8_t* OutputC, - size_t LengthA, - size_t LengthB + size_t N, + bool IsScalarB ); typedef MLAS_QLINEAR_BINARY_OP_S8_KERNEL* PMLAS_QLINEAR_BINARY_OP_S8_KERNEL; @@ -503,8 +503,8 @@ void float ScaleC, int32_t ZeroPointC, uint8_t* OutputC, - size_t LengthA, - size_t LengthB + size_t N, + bool IsScalarB ); typedef MLAS_QLINEAR_BINARY_OP_U8_KERNEL* PMLAS_QLINEAR_BINARY_OP_U8_KERNEL; diff --git a/onnxruntime/core/mlas/lib/qladd.cpp b/onnxruntime/core/mlas/lib/qladd.cpp index ff4286230d..8e8818b196 100644 --- a/onnxruntime/core/mlas/lib/qladd.cpp +++ b/onnxruntime/core/mlas/lib/qladd.cpp @@ -49,7 +49,7 @@ MlasCalcQLinearAddParameters( } // Pure C++ helper, back off here in rare case. -template +template MLAS_FORCEINLINE static void @@ -69,20 +69,14 @@ MlasQLinearAddKernelRawHelper( const float MinimumValue = (float)((int)std::numeric_limits::min() - ZeroPointC); const float MaximumValue = (float)((int)std::numeric_limits::max() - ZeroPointC); - float ValueA; float ValueB; - if (IsScalarA) { - ValueA = ScaleA * (int32_t(InputA[0]) - ZeroPointA); - } if (IsScalarB) { ValueB = ScaleB * (int32_t(InputB[0]) - ZeroPointB); } for (size_t n = 0; n < N; n++) { - if (!IsScalarA) { - ValueA = ScaleA * (int32_t(InputA[n]) - ZeroPointA); - } + float ValueA = ScaleA * (int32_t(InputA[n]) - ZeroPointA); if (!IsScalarB) { ValueB = ScaleB * (int32_t(InputB[n]) - ZeroPointB); } @@ -320,7 +314,7 @@ public: #endif -template +template static void MlasQLinearAddKernelHelper( @@ -342,7 +336,7 @@ MlasQLinearAddKernelHelper( const float ScaleRatio_AC = ScaleA / ScaleC; const float ScaleRatio_BC = ScaleB / ScaleC; if (!MlasCalcQLinearAddParameters(ScaleRatio_AC, ScaleRatio_BC, Shift, MultiplierA, MultiplierB)) { - MlasQLinearAddKernelRawHelper( + MlasQLinearAddKernelRawHelper( InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); return; } @@ -356,40 +350,17 @@ MlasQLinearAddKernelHelper( const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); int32x4_t vscalar; - if (IsScalarA) { - const typename SUI::i8x8_t VectorA0 = SUI::vmov_n_i8(*InputA); - const int16x8_t va_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(VectorA0, VectorZeroPointA)); - vscalar = vmulq_s32(vmovl_s16(vget_low_s16(va_s16x8)), VectorMultiplierA); - } if (IsScalarB) { const typename SUI::i8x8_t VectorB0 = SUI::vmov_n_i8(*InputB); const int16x8_t vb_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(VectorB0, VectorZeroPointB)); vscalar = vmulq_s32(vmovl_s16(vget_low_s16(vb_s16x8)), VectorMultiplierB); } - auto n = static_cast(N); #if defined(MLAS_NEON64_INTRINSICS) - while (n >= 32) { + while (N >= 32) { int32x4_t vacc0_lo, vacc0_hi, vacc1_lo, vacc1_hi, vacc2_lo, vacc2_hi, vacc3_lo, vacc3_hi; - if (IsScalarA) { - const typename SUI::i8x16_t VectorB0 = SUI::vld1q_i8(InputB); - const typename SUI::i8x16_t VectorB1 = SUI::vld1q_i8(InputB + 16); - InputB += 32; - const int16x8_t vb0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorB0), VectorZeroPointB)); - const int16x8_t vb1_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorB0), VectorZeroPointB)); - const int16x8_t vb2_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorB1), VectorZeroPointB)); - const int16x8_t vb3_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorB1), VectorZeroPointB)); - - vacc0_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(vb0_s16x8)), VectorMultiplierB); - vacc1_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(vb1_s16x8)), VectorMultiplierB); - vacc2_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(vb2_s16x8)), VectorMultiplierB); - vacc3_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(vb3_s16x8)), VectorMultiplierB); - vacc0_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(vb0_s16x8), VectorMultiplierB); - vacc1_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(vb1_s16x8), VectorMultiplierB); - vacc2_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(vb2_s16x8), VectorMultiplierB); - vacc3_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(vb3_s16x8), VectorMultiplierB); - } else if (IsScalarB) { + if (IsScalarB) { const typename SUI::i8x16_t VectorA0 = SUI::vld1q_i8(InputA); const typename SUI::i8x16_t VectorA1 = SUI::vld1q_i8(InputA + 16); InputA += 32; @@ -470,26 +441,15 @@ MlasQLinearAddKernelHelper( SUI::vst1q_i8(OutputC, vc0); SUI::vst1q_i8(OutputC + 16, vc1); - n -= 32; + N -= 32; OutputC += 32; } #endif - typename SUI::i8x16_t vc; - while (n > 0) { + while (N >= 16) { int32x4_t vacc0_lo, vacc1_lo, vacc0_hi, vacc1_hi; - if (IsScalarA) { - const typename SUI::i8x16_t VectorB0 = SUI::vld1q_i8(InputB); - InputB += 16; - const int16x8_t vb0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorB0), VectorZeroPointB)); - const int16x8_t vb1_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorB0), VectorZeroPointB)); - - vacc0_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(vb0_s16x8)), VectorMultiplierB); - vacc1_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(vb1_s16x8)), VectorMultiplierB); - vacc0_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(vb0_s16x8), VectorMultiplierB); - vacc1_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(vb1_s16x8), VectorMultiplierB); - } else if (IsScalarB) { + if (IsScalarB) { const typename SUI::i8x16_t VectorA0 = SUI::vld1q_i8(InputA); InputA += 16; const int16x8_t va0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorA0), VectorZeroPointA)); @@ -533,34 +493,86 @@ MlasQLinearAddKernelHelper( // Pack, saturate, and add output zero point. const int16x8_t vacc0 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0_lo), vqmovn_s32(vacc0_hi)), VectorZeroPointC); const int16x8_t vacc1 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1_lo), vqmovn_s32(vacc1_hi)), VectorZeroPointC); - vc = SUI::combine_i8_s16(vacc0, vacc1); - - n -= 16; - if (n < 0) break; + typename SUI::i8x16_t vc = SUI::combine_i8_s16(vacc0, vacc1); + N -= 16; SUI::vst1q_i8(OutputC, vc); OutputC += 16; } - if (n < 0) { - n += 16; + if (N > 0) { + typename SUI::T TailDataA[16] = { 0 }; + typename SUI::T TailDataB[16] = { 0 }; + + MlasCopyTailBytes((uint8_t*)TailDataA, (const uint8_t*)InputA, N); + if (!IsScalarB) { + MlasCopyTailBytes((uint8_t*)TailDataB, (const uint8_t*)InputB, N); + } + + int32x4_t vacc0_lo, vacc1_lo, vacc0_hi, vacc1_hi; + if (IsScalarB) { + const typename SUI::i8x16_t VectorA0 = SUI::vld1q_i8(TailDataA); + InputA += 16; + const int16x8_t va0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorA0), VectorZeroPointA)); + const int16x8_t va1_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorA0), VectorZeroPointA)); + + vacc0_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(va0_s16x8)), VectorMultiplierA); + vacc1_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(va1_s16x8)), VectorMultiplierA); + vacc0_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(va0_s16x8), VectorMultiplierA); + vacc1_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(va1_s16x8), VectorMultiplierA); + } else { + const typename SUI::i8x16_t VectorA0 = SUI::vld1q_i8(TailDataA); + const typename SUI::i8x16_t VectorB0 = SUI::vld1q_i8(TailDataB); + InputA += 16; + InputB += 16; + const int16x8_t va0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorA0), VectorZeroPointA)); + const int16x8_t vb0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorB0), VectorZeroPointB)); + const int16x8_t va1_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorA0), VectorZeroPointA)); + const int16x8_t vb1_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorB0), VectorZeroPointB)); + + vacc0_lo = vmulq_s32(vmovl_s16(vget_low_s16(va0_s16x8)), VectorMultiplierA); + vacc1_lo = vmulq_s32(vmovl_s16(vget_low_s16(va1_s16x8)), VectorMultiplierA); + vacc0_hi = vmulq_s32(MlasMoveHighS16S32(va0_s16x8), VectorMultiplierA); + vacc1_hi = vmulq_s32(MlasMoveHighS16S32(va1_s16x8), VectorMultiplierA); + + vacc0_lo = vmlaq_s32(vacc0_lo, vmovl_s16(vget_low_s16(vb0_s16x8)), VectorMultiplierB); + vacc1_lo = vmlaq_s32(vacc1_lo, vmovl_s16(vget_low_s16(vb1_s16x8)), VectorMultiplierB); + vacc0_hi = vmlaq_s32(vacc0_hi, MlasMoveHighS16S32(vb0_s16x8), VectorMultiplierB); + vacc1_hi = vmlaq_s32(vacc1_hi, MlasMoveHighS16S32(vb1_s16x8), VectorMultiplierB); + } + + vacc0_lo = vsraq_n_s32(vacc0_lo, vbicq_s32(vacc0_lo, vzero_shift_mask), 31); + vacc1_lo = vsraq_n_s32(vacc1_lo, vbicq_s32(vacc1_lo, vzero_shift_mask), 31); + vacc0_hi = vsraq_n_s32(vacc0_hi, vbicq_s32(vacc0_hi, vzero_shift_mask), 31); + vacc1_hi = vsraq_n_s32(vacc1_hi, vbicq_s32(vacc1_hi, vzero_shift_mask), 31); + + vacc0_lo = vrshlq_s32(vacc0_lo, vright_shift); + vacc1_lo = vrshlq_s32(vacc1_lo, vright_shift); + vacc0_hi = vrshlq_s32(vacc0_hi, vright_shift); + vacc1_hi = vrshlq_s32(vacc1_hi, vright_shift); + + // Pack, saturate, and add output zero point. + const int16x8_t vacc0 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0_lo), vqmovn_s32(vacc0_hi)), VectorZeroPointC); + const int16x8_t vacc1 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1_lo), vqmovn_s32(vacc1_hi)), VectorZeroPointC); + typename SUI::i8x16_t vc = SUI::combine_i8_s16(vacc0, vacc1); + typename SUI::i8x8_t i8x8 = SUI::vget_low_i8(vc); - if (n & 8) { + if (N & 8) { SUI::vst1_i8(OutputC, i8x8); OutputC += 8; i8x8 = SUI::vget_high_i8(vc); } - if (n & 4) { + if (N & 4) { vst1_lane_u32_ex((uint32_t*)OutputC, SUI::vreinterpret_u32_i8(i8x8), 0, 8); OutputC += 4; i8x8 = SUI::template vext_i8<4>(i8x8, i8x8); } - if (n & 2) { + if (N & 2) { vst1_lane_u16_ex((uint16_t*)OutputC, SUI::vreinterpret_u16_i8(i8x8), 0, 8); OutputC += 2; i8x8 = SUI::template vext_i8<2>(i8x8, i8x8); } - if (n & 1) { + if (N & 1) { SUI::template vst1_lane_i8<0>(OutputC, i8x8); } } @@ -626,7 +638,7 @@ MlasPackS16_128( return _mm_packs_epi16(a, b); } -template +template static void MlasQLinearAddKernelHelper( @@ -649,25 +661,18 @@ MlasQLinearAddKernelHelper( auto VectorFixedPart = MlasBroadcastFloat32x4((float)ZeroPointC - (ScaleRatio_AC * ZeroPointA + ScaleRatio_BC * ZeroPointB)); MLAS_FLOAT32X4 va_lo, va_hi, vb_lo, vb_hi; - if (IsScalarA) { - va_lo = _mm_set1_ps((float)*InputA); - VectorFixedPart = _mm_add_ps(VectorFixedPart, _mm_mul_ps(va_lo, VectorScaleRatio_AC)); - } if (IsScalarB) { vb_lo = _mm_set1_ps((float)*InputB); VectorFixedPart = _mm_add_ps(VectorFixedPart, _mm_mul_ps(vb_lo, VectorScaleRatio_BC)); } - auto n = static_cast(N); - MLAS_INT32X4 vc = _mm_setzero_si128(); - while (n > 0) { - if (!IsScalarA) { - const auto va_low_half = _mm_loadl_epi64((const MLAS_INT32X4*)InputA); - const auto va_i16x8 = _mm_unpacklo_epi8(va_low_half, va_low_half); - InputA += 8; - va_lo = _mm_cvtepi32_ps(MlasShiftRightInt32(_mm_unpacklo_epi16(va_i16x8, va_i16x8), 24)); - va_hi = _mm_cvtepi32_ps(MlasShiftRightInt32(_mm_unpackhi_epi16(va_i16x8, va_i16x8), 24)); - } + while (N >= 8) { + const auto va_low_half = _mm_loadl_epi64((const MLAS_INT32X4*)InputA); + const auto va_i16x8 = _mm_unpacklo_epi8(va_low_half, va_low_half); + InputA += 8; + va_lo = _mm_cvtepi32_ps(MlasShiftRightInt32(_mm_unpacklo_epi16(va_i16x8, va_i16x8), 24)); + va_hi = _mm_cvtepi32_ps(MlasShiftRightInt32(_mm_unpackhi_epi16(va_i16x8, va_i16x8), 24)); + if (!IsScalarB) { const auto vb_low_half = _mm_loadl_epi64((const MLAS_INT32X4*)InputB); const auto vb_i16x8 = _mm_unpacklo_epi8(vb_low_half, vb_low_half); @@ -677,10 +682,7 @@ MlasQLinearAddKernelHelper( } MLAS_INT32X4 r_lo, r_hi; - if (IsScalarA) { - r_lo = _mm_cvtps_epi32(_mm_add_ps(VectorFixedPart, _mm_mul_ps(vb_lo, VectorScaleRatio_BC))); - r_hi = _mm_cvtps_epi32(_mm_add_ps(VectorFixedPart, _mm_mul_ps(vb_hi, VectorScaleRatio_BC))); - } else if (IsScalarB) { + if (IsScalarB) { r_lo = _mm_cvtps_epi32(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_lo, VectorScaleRatio_AC))); r_hi = _mm_cvtps_epi32(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_hi, VectorScaleRatio_AC))); } else { @@ -688,26 +690,54 @@ MlasQLinearAddKernelHelper( r_hi = _mm_cvtps_epi32(_mm_add_ps(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_hi, VectorScaleRatio_AC)), _mm_mul_ps(vb_hi, VectorScaleRatio_BC))); } const auto vc_i16x8 = _mm_packs_epi32(r_lo, r_hi); - vc = MlasPackS16_128(vc_i16x8, vc_i16x8); - - n -= 8; - if (n < 0) break; + MLAS_INT32X4 vc = MlasPackS16_128(vc_i16x8, vc_i16x8); + N -= 8; _mm_storel_epi64((MLAS_INT32X4*)OutputC, vc); OutputC += 8; } - if (n < 0) { - n += 8; - if (n & 4) { + if (N > 0) { + uint8_t TailData[8] = { 0 }; + + { + MlasCopyTailBytes(TailData, (const uint8_t*)InputA, N); + const auto va_low_half = _mm_loadl_epi64((const MLAS_INT32X4*)TailData); + const auto va_i16x8 = _mm_unpacklo_epi8(va_low_half, va_low_half); + InputA += 8; + va_lo = _mm_cvtepi32_ps(MlasShiftRightInt32(_mm_unpacklo_epi16(va_i16x8, va_i16x8), 24)); + va_hi = _mm_cvtepi32_ps(MlasShiftRightInt32(_mm_unpackhi_epi16(va_i16x8, va_i16x8), 24)); + } + + if (!IsScalarB) { + MlasCopyTailBytes(TailData, (const uint8_t*)InputB, N); + const auto vb_low_half = _mm_loadl_epi64((const MLAS_INT32X4*)TailData); + const auto vb_i16x8 = _mm_unpacklo_epi8(vb_low_half, vb_low_half); + InputB += 8; + vb_lo = _mm_cvtepi32_ps(MlasShiftRightInt32(_mm_unpacklo_epi16(vb_i16x8, vb_i16x8), 24)); + vb_hi = _mm_cvtepi32_ps(MlasShiftRightInt32(_mm_unpackhi_epi16(vb_i16x8, vb_i16x8), 24)); + } + + MLAS_INT32X4 r_lo, r_hi; + if (IsScalarB) { + r_lo = _mm_cvtps_epi32(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_lo, VectorScaleRatio_AC))); + r_hi = _mm_cvtps_epi32(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_hi, VectorScaleRatio_AC))); + } else { + r_lo = _mm_cvtps_epi32(_mm_add_ps(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_lo, VectorScaleRatio_AC)), _mm_mul_ps(vb_lo, VectorScaleRatio_BC))); + r_hi = _mm_cvtps_epi32(_mm_add_ps(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_hi, VectorScaleRatio_AC)), _mm_mul_ps(vb_hi, VectorScaleRatio_BC))); + } + const auto vc_i16x8 = _mm_packs_epi32(r_lo, r_hi); + MLAS_INT32X4 vc = MlasPackS16_128(vc_i16x8, vc_i16x8); + + if (N & 4) { *(int*)OutputC = _mm_cvtsi128_si32(vc); - n -= 4; + N -= 4; OutputC += 4; vc = _mm_shuffle_epi32(vc, _MM_SHUFFLE(0, 3, 2, 1)); } uint32_t PackedValueC = (uint32_t)_mm_cvtsi128_si32(vc); - for (int64_t i = 0; i < n; ++i) { + for (size_t i = 0; i < N; ++i) { *((uint8_t*)OutputC + i) = (uint8_t)PackedValueC; PackedValueC >>= 8; } @@ -716,7 +746,7 @@ MlasQLinearAddKernelHelper( #else -template +template static void MlasQLinearAddKernelHelper( @@ -733,7 +763,7 @@ MlasQLinearAddKernelHelper( ) { // Pure C++ implementation. - MlasQLinearAddKernelRawHelper( + MlasQLinearAddKernelRawHelper( InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); } @@ -753,22 +783,16 @@ MlasQLinearAddKernel( float ScaleC, int32_t ZeroPointC, DataType* OutputC, - size_t LengthA, - size_t LengthB + size_t N, + bool IsScalarB ) { - size_t N = std::max(LengthA, LengthB); - if (N > 0) { - if (LengthA == 1) { - MlasQLinearAddKernelHelper( - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); - } else if (LengthB == 1) { - MlasQLinearAddKernelHelper( - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); - } else { - MlasQLinearAddKernelHelper( - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); - } + if (IsScalarB) { + MlasQLinearAddKernelHelper( + InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); + } else { + MlasQLinearAddKernelHelper( + InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); } } @@ -785,8 +809,8 @@ MlasQLinearAdd( float ScaleC, int32_t ZeroPointC, int8_t* OutputC, - size_t LengthA, - size_t LengthB + size_t N, + bool IsScalarB ) { #if defined(MLAS_TARGET_AMD64) @@ -794,7 +818,7 @@ MlasQLinearAdd( #else MlasQLinearAddKernel( #endif - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthA, LengthB); + InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N, IsScalarB); } template<> @@ -810,8 +834,8 @@ MlasQLinearAdd( float ScaleC, int32_t ZeroPointC, uint8_t* OutputC, - size_t LengthA, - size_t LengthB + size_t N, + bool IsScalarB ) { #if defined(MLAS_TARGET_AMD64) @@ -819,7 +843,7 @@ MlasQLinearAdd( #else MlasQLinearAddKernel( #endif - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthA, LengthB); + InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N, IsScalarB); } // @@ -838,12 +862,12 @@ MlasQLinearAddS8Kernel( float ScaleC, int32_t ZeroPointC, int8_t* OutputC, - size_t LengthA, - size_t LengthB + size_t N, + bool IsScalarB ) { MlasQLinearAddKernel( - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthA, LengthB); + InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N, IsScalarB); } void @@ -858,10 +882,10 @@ MlasQLinearAddU8Kernel( float ScaleC, int32_t ZeroPointC, uint8_t* OutputC, - size_t LengthA, - size_t LengthB + size_t N, + bool IsScalarB ) { MlasQLinearAddKernel( - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, LengthA, LengthB); + InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N, IsScalarB); } diff --git a/onnxruntime/core/mlas/lib/qladd.h b/onnxruntime/core/mlas/lib/qladd.h index a1049ce931..5ce00907c0 100644 --- a/onnxruntime/core/mlas/lib/qladd.h +++ b/onnxruntime/core/mlas/lib/qladd.h @@ -48,6 +48,26 @@ MlasFp32FromBits( return uf.fp32; } +MLAS_FORCEINLINE +static +void +MlasCopyTailBytes( + uint8_t* target, + const uint8_t* src, + size_t N) +{ + while (N >= sizeof(uint32_t)) { + *(uint32_t*)(target) = *(uint32_t*)(src); + N -= sizeof(uint32_t); + target += sizeof(uint32_t); + src += sizeof(uint32_t); + } + while (N > 0) { + *target++ = *src++; + --N; + } +} + bool MlasCalcQLinearAddParameters( float ScaleRatio_AC, diff --git a/onnxruntime/test/mlas/unittest.cpp b/onnxruntime/test/mlas/unittest.cpp index bb3b0d6189..ce946f7b5b 100644 --- a/onnxruntime/test/mlas/unittest.cpp +++ b/onnxruntime/test/mlas/unittest.cpp @@ -2495,6 +2495,140 @@ public: } }; +class MlasQLinearAddTest : public MlasTestBase +{ +private: + MatrixGuardBuffer BufferInputA; + MatrixGuardBuffer BufferInputB; + MatrixGuardBuffer BufferOutput; + MatrixGuardBuffer BufferOutputReference; + + template + T + QLinearAddScalar( + T a, + float ScaleA, + int32_t ZeroPointA, + T b, + float ScaleB, + int32_t ZeroPointB, + float ScaleC, + int32_t ZeroPointC + ) + { + + constexpr int qmax = std::numeric_limits::max(); + constexpr int qmin = std::numeric_limits::min(); + + float ValueA = ScaleA * (static_cast(a) - ZeroPointA); + float ValueB = ScaleB * (static_cast(b) - ZeroPointB); + float ValueC = std::nearbyintf((ValueA + ValueB) / ScaleC) + ZeroPointC; + int qc = static_cast(ValueC); + qc = std::min(qc, qmax); + qc = std::max(qc, qmin); + return static_cast(qc); + } + + template + void + Test( + size_t N, + float ScaleA, + int32_t ZeroPointA, + float ScaleB, + int32_t ZeroPointB, + float ScaleC, + int32_t ZeroPointC + ) + { + T* InputA = (T*)BufferInputA.GetBuffer(N); + T* InputB = (T*)BufferInputB.GetBuffer(IsScalarB ? 1 : N); + T* OutputC = (T*)BufferOutput.GetBuffer(N); + T* OutputReference = (T*)BufferOutputReference.GetBuffer(N); + + constexpr int MinimumValue = (int)std::numeric_limits::min(); + constexpr int MaximumValue = (int)std::numeric_limits::max(); + std::default_random_engine generator(static_cast(N)); + std::uniform_int_distribution distribution(MinimumValue, MaximumValue); + + if (IsScalarB) { + InputB[0] = static_cast(distribution(generator)); + } + for (size_t n = 0; n < N; n++) { + InputA[n] = static_cast(distribution(generator)); + if (!IsScalarB) { + InputB[n] = static_cast(distribution(generator)); + } + OutputReference[n] = QLinearAddScalar(InputA[n], ScaleA, ZeroPointA, InputB[IsScalarB ? 0 : n], ScaleB, ZeroPointB, ScaleC, ZeroPointC); + } + + MlasQLinearAdd(InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N, IsScalarB); + + for (size_t n = 0; n < N; n++) { + int diff = (int)OutputC[n] - (int)OutputReference[n]; + if (diff < -1 || diff > 1) { + printf("Test IsScalarB=%d difference @%u of %u, %d(%f,%d) + %d(%f,%d) => %d(%f,%d) (expecting %d)\n", + int(IsScalarB), static_cast(n), static_cast(N), + static_cast(InputA[n]), ScaleA, ZeroPointA, + static_cast(InputB[IsScalarB ? 0 : n]), ScaleB, ZeroPointB, + static_cast(OutputC[n]), ScaleC, ZeroPointC, + static_cast(OutputReference[n])); + } + } + } + +public: + void + ExecuteShort( + void + ) override + { + // uint8_t test + static const uint8_t zero_points[] = { 0, 18, 75, 128, 157, 231, 255 }; + for (size_t a = 0; a < _countof(zero_points); a++) { + uint8_t offa = zero_points[a]; + + for (size_t b = 0; b < _countof(zero_points); b++) { + uint8_t offb = zero_points[b]; + + for (size_t c = 0; c < _countof(zero_points); c++) { + uint8_t offc = zero_points[c]; + + for (size_t n = 1; n < 128; n++) { + // vector + vector + Test(n, 10.f, offa, 10.f, offb, 20.f, offc); + + // vector + scalar + Test(n, 10.f, offa, 10.f, offb, 20.f, offc); + } + } + } + } + + static const int8_t szero_points[] = { -128, -110, -53, 0, 29, 103, 127 }; + for (size_t a = 0; a < _countof(zero_points); a++) { + int8_t offa = szero_points[a]; + + for (size_t b = 0; b < _countof(zero_points); b++) { + int8_t offb = szero_points[b]; + + for (size_t c = 0; c < _countof(zero_points); c++) { + int8_t offc = szero_points[c]; + + for (size_t n = 1; n < 128; n++) { + // vector + vector + Test(n, 10.f, offa, 10.f, offb, 20.f, offc); + + // vector + scalar + Test(n, 10.f, offa, 10.f, offb, 20.f, offc); + } + } + } + } + + } +}; + class MlasFindMinMaxElementsTest : public MlasTestBase { private: @@ -2652,6 +2786,9 @@ main( onnxruntime::make_unique()->ExecuteShort(); } + printf("QLinearAdd tests.\n"); + onnxruntime::make_unique()->ExecuteShort(); + printf("Done.\n"); return 0;