diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index ecceb64f18..e9f8e44446 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -963,19 +963,83 @@ MlasQuantizeLinear( OutputType ZeroPoint ); +/** + * @brief Requantize a block of the intermediate buffer to the output buffer, + * optionally adding the supplied bias + * + * @param Input Input matrix + * @param InputLeadingDimension Input matrix leading dimension + * @param Output Output matrix + * @param OutputLeadingDimension Output matrix leading dimension + * @param Bias Optional bias vector, to be added + to the input before quantization + * @param Scale Quantization scale + * @param PerColumnScale true if scale is per-column + * @param ZeroPoint quantization zero point value + * @param StartM + * @param StartN + * @param CountM + * @param CountN + * @return +*/ void MLASCALL MlasRequantizeOutput( const int32_t* Input, + size_t InputLeadingDimension, uint8_t* Output, + size_t OutputLeadingDimension, const int32_t* Bias, - size_t M, - size_t N, const float* Scale, bool PerColumnScale, - uint8_t ZeroPoint + uint8_t ZeroPoint, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN ); +class MLAS_QGEMM_REQUANT_OUTPUT_PROCESSOR : public MLAS_QGEMM_OUTPUT_PROCESSOR +{ + public: + MLAS_QGEMM_REQUANT_OUTPUT_PROCESSOR( + uint8_t* Output, + size_t OutputLeadingDimension, + const int32_t* Bias, + const float* Scale, + bool PerColumnScale, + uint8_t ZeroPoint) + : Output_(Output), + OutputLeadingDimension_(OutputLeadingDimension), + Bias_(Bias), + Scale_(Scale), + PerColumnScale_(PerColumnScale), + ZeroPoint_(ZeroPoint) + { + } + + void Process(const int32_t* C, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN, + size_t ldc) const override + { + MlasRequantizeOutput(C, ldc, Output_, OutputLeadingDimension_, Bias_, Scale_, + PerColumnScale_, ZeroPoint_, StartM, StartN, CountM, CountN); + } + + + private: + uint8_t* Output_; + size_t OutputLeadingDimension_; + const int32_t* Bias_; + const float* Scale_; + bool PerColumnScale_; + uint8_t ZeroPoint_; +}; + + void MLASCALL MlasFindMinMaxElement( diff --git a/onnxruntime/core/mlas/lib/qlgavgpool.cpp b/onnxruntime/core/mlas/lib/qlgavgpool.cpp index 81345dfd97..d8972eecbf 100644 --- a/onnxruntime/core/mlas/lib/qlgavgpool.cpp +++ b/onnxruntime/core/mlas/lib/qlgavgpool.cpp @@ -121,7 +121,9 @@ MlasQLinearGlobalAveragePoolNchw( int32x2_t vacc = vadd_s32(vget_high_s32(vacc_lo), vget_low_s32(vacc_lo)); *sum_buffer++ = vget_lane_s32(vpadd_s32(vacc, vacc), 0); } - MlasRequantizeOutput(AccumulateBuffer, Output, nullptr, 1, Channels, &scale, false, static_cast(ZeroPointOutput)); + + MlasRequantizeOutput(AccumulateBuffer, Channels, Output, Channels, nullptr, &scale, false, + static_cast(ZeroPointOutput), 0, 0, 1, Channels); } MLAS_FORCEINLINE @@ -256,7 +258,8 @@ MlasQLinearGlobalAveragePoolNhwcSingleBatch( vst1q_s32(acc + 4, vacc_hi); } } - MlasRequantizeOutput(AccumulateBuffer, Output, nullptr, 1, Channels, &Scale, false, Output_zero_point); + MlasRequantizeOutput(AccumulateBuffer, Channels, Output, Channels, nullptr, &Scale, false, + Output_zero_point, 0, 0, 1, Channels); } #elif defined(MLAS_SSE2_INTRINSICS) @@ -323,7 +326,8 @@ MlasQLinearGlobalAveragePoolNchw( vsums = _mm_add_epi32(vsums, vshuf); *sum_buffer++ = _mm_cvtsi128_si32(vsums); } - MlasRequantizeOutput(AccumulateBuffer, Output, nullptr, 1, Channels, &scale, false, static_cast(ZeroPointOutput)); + MlasRequantizeOutput(AccumulateBuffer, Channels, Output, Channels, nullptr, &scale, false, + static_cast(ZeroPointOutput), 0, 0, 1, Channels); } MLAS_FORCEINLINE @@ -515,7 +519,8 @@ MlasQLinearGlobalAveragePoolNhwcSingleBatch( _mm_storeu_si128(((__m128i*)acc) + 1, vacc_hi); } } - MlasRequantizeOutput(AccumulateBuffer, Output, nullptr, 1, Channels, &Scale, false, Output_zero_point); + MlasRequantizeOutput(AccumulateBuffer, Channels, Output, Channels, nullptr, &Scale, false, + Output_zero_point, 0, 0, 1, Channels); } #else diff --git a/onnxruntime/core/mlas/lib/quantize.cpp b/onnxruntime/core/mlas/lib/quantize.cpp index facb060218..01a5529fb6 100644 --- a/onnxruntime/core/mlas/lib/quantize.cpp +++ b/onnxruntime/core/mlas/lib/quantize.cpp @@ -356,65 +356,46 @@ void MLASCALL MlasRequantizeOutput( const int32_t* Input, + size_t InputLeadingDimension, uint8_t* Output, + size_t OutputLeadingDimension, const int32_t* Bias, - size_t M, - size_t N, const float* Scale, bool PerColumnScale, - uint8_t ZeroPoint + uint8_t ZeroPoint, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN ) -/*++ - -Routine Description: - - This routine requantizes the intermediate buffer to the output buffer - optionally adding the supplied bias. - -Arguments: - - Input - Supplies the input matrix. - - Output - Supplies the output matrix. - - Bias - Supplies the optional bias vector to be added to the input buffer - before requantization. - - Buffer - Supplies the output matrix. - - M - Supplies the number of elements of the bias vector and the number of - rows in the output matrix. - - N - Supplies the number of columns of the output matrix. - - Scale - Supplies the quantization scale. - - PerColumnScale - Supplies true if the quantization scale has per-column - values, else false if a single quantization scale applies to the - entire matrix. - - ZeroPoint - Supplies the quantization zero point value. - -Return Value: - - None. - ---*/ { const __m128 PerMatrixScaleVector = PerColumnScale ? _mm_setzero_ps() : _mm_load1_ps(Scale); const __m128 MinimumValueVector = _mm_set1_ps(float(0 - ZeroPoint)); const __m128 MaximumValueVector = _mm_set1_ps(float(255 - ZeroPoint)); const __m128i ZeroPointVector = _mm_set1_epi32(ZeroPoint); + if (nullptr != Bias) { + Bias += StartN; + } + if (PerColumnScale) { + Scale += StartN; + } + + Input += StartM * InputLeadingDimension + StartN; + Output += StartM * OutputLeadingDimension + StartN; + // // Step through each row of the output matrix. // - while (M-- > 0) { + while (CountM-- > 0) { const int32_t* bias = Bias; const float* scale = PerColumnScale ? Scale : nullptr; - size_t n = N; + size_t n = CountN; + + auto* RowInput = Input; + auto* RowOutput = Output; // // Process 16 columns of the matrices at a time. @@ -426,11 +407,11 @@ Return Value: // Load the input data and optionally add the per-column bias. // - __m128i IntegerVector0 = _mm_loadu_si128((const __m128i *)&Input[0]); - __m128i IntegerVector1 = _mm_loadu_si128((const __m128i *)&Input[4]); - __m128i IntegerVector2 = _mm_loadu_si128((const __m128i *)&Input[8]); - __m128i IntegerVector3 = _mm_loadu_si128((const __m128i *)&Input[12]); - Input += 16; + __m128i IntegerVector0 = _mm_loadu_si128((const __m128i*)&RowInput[0]); + __m128i IntegerVector1 = _mm_loadu_si128((const __m128i*)&RowInput[4]); + __m128i IntegerVector2 = _mm_loadu_si128((const __m128i*)&RowInput[8]); + __m128i IntegerVector3 = _mm_loadu_si128((const __m128i*)&RowInput[12]); + RowInput += 16; if (bias != nullptr) { IntegerVector0 = _mm_add_epi32(IntegerVector0, _mm_loadu_si128((const __m128i *)&bias[0])); @@ -491,8 +472,8 @@ Return Value: __m128i ByteVector = _mm_packus_epi16(WordVector0, WordVector1); - _mm_storeu_si128((__m128i*)Output, ByteVector); - Output += 16; + _mm_storeu_si128((__m128i*)RowOutput, ByteVector); + RowOutput += 16; n -= 16; } @@ -511,8 +492,8 @@ Return Value: if (n >= 4) { - IntegerVector = _mm_loadu_si128((const __m128i*)&Input[0]); - Input += 4; + IntegerVector = _mm_loadu_si128((const __m128i*)&RowInput[0]); + RowInput += 4; if (bias != nullptr) { IntegerVector = _mm_add_epi32(IntegerVector, _mm_loadu_si128((const __m128i*)&bias[0])); @@ -521,7 +502,7 @@ Return Value: } else { - int32_t IntegerValue = *Input++; + int32_t IntegerValue = *RowInput++; if (bias != nullptr) { IntegerValue += *bias++; @@ -567,19 +548,23 @@ Return Value: if (n >= 4) { - *reinterpret_cast(Output) = OutputValue; - Output += 4; + *reinterpret_cast(RowOutput) = OutputValue; + RowOutput += 4; n -= 4; } else { - *Output = uint8_t(OutputValue); - Output += 1; + *RowOutput = uint8_t(OutputValue); + RowOutput += 1; n -= 1; } } + + // Next Row + Input += InputLeadingDimension; + Output += OutputLeadingDimension; } } @@ -589,63 +574,44 @@ void MLASCALL MlasRequantizeOutput( const int32_t* Input, + size_t InputLeadingDimension, uint8_t* Output, + size_t OutputLeadingDimension, const int32_t* Bias, - size_t M, - size_t N, const float* Scale, bool PerColumnScale, - uint8_t ZeroPoint + uint8_t ZeroPoint, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN ) -/*++ - -Routine Description: - - This routine requantizes the intermediate buffer to the output buffer - optionally adding the supplied bias. - -Arguments: - - Input - Supplies the input matrix. - - Output - Supplies the output matrix. - - Bias - Supplies the optional bias vector to be added to the input buffer - before requantization. - - Buffer - Supplies the output matrix. - - M - Supplies the number of elements of the bias vector and the number of - rows in the output matrix. - - N - Supplies the number of columns of the output matrix. - - Scale - Supplies the quantization scale. - - PerColumnScale - Supplies true if the quantization scale has per-column - values, else false if a single quantization scale applies to the - entire matrix. - - ZeroPoint - Supplies the quantization zero point value. - -Return Value: - - None. - ---*/ { const float32x4_t PerMatrixScaleVector = PerColumnScale ? vdupq_n_f32(0) : vld1q_dup_f32(Scale); const int16x8_t ZeroPointVector = vdupq_n_s16(ZeroPoint); + if (nullptr != Bias) { + Bias += StartN; + } + if (PerColumnScale) { + Scale += StartN; + } + + Input += StartM * InputLeadingDimension + StartN; + Output += StartM * OutputLeadingDimension + StartN; + // // Step through each row of the output matrix. // - while (M-- > 0) { + while (CountM-- > 0) { const int32_t* bias = Bias; const float* scale = PerColumnScale ? Scale : nullptr; - size_t n = N; + size_t n = CountN; + + auto* RowInput = Input; + auto* RowOutput = Output; // // Process 16 columns of the matrices at a time. @@ -659,11 +625,11 @@ Return Value: int32x4x4_t IntegerVector; - IntegerVector.val[0] = vld1q_s32(&Input[0]); - IntegerVector.val[1] = vld1q_s32(&Input[4]); - IntegerVector.val[2] = vld1q_s32(&Input[8]); - IntegerVector.val[3] = vld1q_s32(&Input[12]); - Input += 16; + IntegerVector.val[0] = vld1q_s32(&RowInput[0]); + IntegerVector.val[1] = vld1q_s32(&RowInput[4]); + IntegerVector.val[2] = vld1q_s32(&RowInput[8]); + IntegerVector.val[3] = vld1q_s32(&RowInput[12]); + RowInput += 16; if (bias != nullptr) { IntegerVector.val[0] = vaddq_s32(IntegerVector.val[0], vld1q_s32(&bias[0])); @@ -731,8 +697,8 @@ Return Value: WordVector.val[0] = vqaddq_s16(WordVector.val[0], ZeroPointVector); WordVector.val[1] = vqaddq_s16(WordVector.val[1], ZeroPointVector); - vst1q_u8(Output, vqmovun_high_s16(vqmovun_s16(WordVector.val[0]), WordVector.val[1])); - Output += 16; + vst1q_u8(RowOutput, vqmovun_high_s16(vqmovun_s16(WordVector.val[0]), WordVector.val[1])); + RowOutput += 16; n -= 16; } @@ -751,8 +717,8 @@ Return Value: if (n >= 4) { - IntegerVector = vld1q_s32(&Input[0]); - Input += 4; + IntegerVector = vld1q_s32(&RowInput[0]); + RowInput += 4; if (bias != nullptr) { IntegerVector = vaddq_s32(IntegerVector, vld1q_s32(&bias[0])); @@ -761,8 +727,8 @@ Return Value: } else { - IntegerVector = vld1q_dup_s32(Input); - Input += 1; + IntegerVector = vld1q_dup_s32(RowInput); + RowInput += 1; if (bias != nullptr) { IntegerVector = vaddq_s32(IntegerVector, vld1q_dup_s32(bias)); @@ -813,19 +779,24 @@ Return Value: if (n >= 4) { - vst1q_lane_u32(reinterpret_cast(Output), vreinterpretq_u32_u8(ByteVector), 0); - Output += 4; + vst1q_lane_u32(reinterpret_cast(RowOutput), + vreinterpretq_u32_u8(ByteVector), 0); + RowOutput += 4; n -= 4; } else { - vst1q_lane_u8(Output, ByteVector, 0); - Output += 1; + vst1q_lane_u8(RowOutput, ByteVector, 0); + RowOutput += 1; n -= 1; } } + + // Next Row + Input += InputLeadingDimension; + Output += OutputLeadingDimension; } } @@ -835,68 +806,49 @@ void MLASCALL MlasRequantizeOutput( const int32_t* Input, + size_t InputLeadingDimension, uint8_t* Output, + size_t OutputLeadingDimension, const int32_t* Bias, - size_t M, - size_t N, const float* Scale, bool PerColumnScale, - uint8_t ZeroPoint + uint8_t ZeroPoint, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN ) -/*++ - -Routine Description: - - This routine requantizes the intermediate buffer to the output buffer - optionally adding the supplied bias. - -Arguments: - - Input - Supplies the input matrix. - - Output - Supplies the output matrix. - - Bias - Supplies the optional bias vector to be added to the input buffer - before requantization. - - Buffer - Supplies the output matrix. - - M - Supplies the number of elements of the bias vector and the number of - rows in the output matrix. - - N - Supplies the number of columns of the output matrix. - - Scale - Supplies the quantization scale. - - PerColumnScale - Supplies true if the quantization scale has per-column - values, else false if a single quantization scale applies to the - entire matrix. - - ZeroPoint - Supplies the quantization zero point value. - -Return Value: - - None. - ---*/ { const float PerMatrixScaleValue = PerColumnScale ? 0.0f : *Scale; const float MinimumValue = float(0 - ZeroPoint); const float MaximumValue = float(255 - ZeroPoint); + if (nullptr != Bias) { + Bias += StartN; + } + if (PerColumnScale) { + Scale += StartN; + } + + Input += StartM * InputLeadingDimension + StartN; + Output += StartM * OutputLeadingDimension + StartN; + // // Step through each row of the output matrix. // - while (M-- > 0) { + while (CountM-- > 0) { const int32_t* bias = Bias; const float* scale = Scale; - size_t n = N; + size_t n = CountN; + + auto* RowInput = Input; + auto* RowOutput = Output; while (n > 0) { - int32_t IntegerValue = *Input++; + int32_t IntegerValue = *RowInput++; if (bias != nullptr) { IntegerValue += *bias++; @@ -920,10 +872,14 @@ Return Value: IntegerValue = int32_t(MlasBitsOfFp32(FloatValue + MLAS_ROUNDING_BIAS_MAGIC)) - MLAS_ROUNDING_BIAS_MAGIC_BITS; - *Output++ = uint8_t(IntegerValue + ZeroPoint); + *RowOutput++ = uint8_t(IntegerValue + ZeroPoint); n -= 1; } + + // Next Row + Input += InputLeadingDimension; + Output += OutputLeadingDimension; } } diff --git a/onnxruntime/core/providers/cpu/math/quantize_linear_matmul.cc b/onnxruntime/core/providers/cpu/math/quantize_linear_matmul.cc index 60068885b6..6e5b780a4d 100644 --- a/onnxruntime/core/providers/cpu/math/quantize_linear_matmul.cc +++ b/onnxruntime/core/providers/cpu/math/quantize_linear_matmul.cc @@ -78,47 +78,52 @@ Status QLinearMatMul::Compute(OpKernelContext* ctx) const { output_scales[i] = (a_scale_data * b_scale_data[i] / y_scale_data); } - AllocatorPtr alloc; - ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc)); - auto gemm_output_data = alloc->Alloc(SafeInt(sizeof(int32_t)) * - static_cast(helper.M()) * static_cast(helper.N())); - BufferUniquePtr gemm_output_buffer(gemm_output_data, BufferDeleter(alloc)); - auto* gemm_output = static_cast(gemm_output_buffer.get()); - + const size_t num_gemms = helper.OutputOffsets().size(); MLAS_GEMM_U8X8_SHAPE_PARAMS gemm_shape; gemm_shape.M = static_cast(helper.M()); gemm_shape.N = static_cast(helper.N()); gemm_shape.K = static_cast(helper.K()); gemm_shape.BIsSigned = b_is_signed; - MLAS_GEMM_U8X8_DATA_PARAMS gemm_params; - gemm_params.lda = gemm_shape.K; - gemm_params.ZeroPointA = *a_offset->template Data(); - gemm_params.ldb = gemm_shape.N; - gemm_params.C = gemm_output; - gemm_params.ldc = gemm_shape.N; - gemm_params.BIsPacked = bool(packed_b_); - gemm_params.PerColumnZeroPoints = !IsScalarOr1ElementVector(b_offset); + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc)); + auto gemm_output_data = alloc->Alloc(SafeInt(gemm_shape.M) * + gemm_shape.N * sizeof(int32_t) * num_gemms); + BufferUniquePtr gemm_output_buffer(gemm_output_data, BufferDeleter(alloc)); + auto* gemm_output = static_cast(gemm_output_buffer.get()); + + + std::vector gemm_params(num_gemms); + std::vector requant_procs; + requant_procs.reserve(num_gemms); auto b_zp_data = static_cast(b_offset->DataRaw()); - for (size_t i = 0; i < helper.OutputOffsets().size(); i++) { - gemm_params.A = a->template Data() + helper.LeftOffsets()[i]; - gemm_params.B = b_data + helper.RightOffsets()[i]; - gemm_params.ZeroPointB = b_zp_data + helper.RightZeroPointOffsets()[i]; + for (size_t i = 0; i < num_gemms; i++) { + gemm_params[i].A = a->template Data() + helper.LeftOffsets()[i]; + gemm_params[i].lda = gemm_shape.K; + gemm_params[i].ZeroPointA = *a_offset->template Data(); - MlasGemm(gemm_shape, gemm_params, ctx->GetOperatorThreadPool()); + gemm_params[i].B = b_data + helper.RightOffsets()[i]; + gemm_params[i].ldb = gemm_shape.N; + gemm_params[i].BIsPacked = bool(packed_b_); + gemm_params[i].ZeroPointB = b_zp_data + helper.RightZeroPointOffsets()[i]; - //TODO!! consider making this a post processor, so that we can parallize this loop - MlasRequantizeOutput(gemm_output, - y->template MutableData() + helper.OutputOffsets()[i], - nullptr, - static_cast(helper.M()), - static_cast(helper.N()), - output_scales.data() + helper.RightScaleOffsets()[i], - output_scales.size() > 1, - *y_offset->template Data()); + gemm_params[i].C = gemm_output + (gemm_shape.M * gemm_shape.N * i); + gemm_params[i].ldc = gemm_shape.N; + + gemm_params[i].PerColumnZeroPoints = !IsScalarOr1ElementVector(b_offset); + + requant_procs.emplace_back(y->template MutableData() + helper.OutputOffsets()[i], + static_cast(helper.N()), + nullptr, + output_scales.data() + helper.RightScaleOffsets()[i], + output_scales.size() > 1, + *y_offset->template Data()); + gemm_params[i].OutputProcessor = &(requant_procs[i]); } + MlasGemmBatch(gemm_shape, gemm_params.data(), num_gemms, ctx->GetOperatorThreadPool()); + return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/nn/qlinearconv.cc b/onnxruntime/core/providers/cpu/nn/qlinearconv.cc index f5fe4bf4fe..882138d50b 100644 --- a/onnxruntime/core/providers/cpu/nn/qlinearconv.cc +++ b/onnxruntime/core/providers/cpu/nn/qlinearconv.cc @@ -590,13 +590,16 @@ Status QLinearConv::Compute(OpKernelContext* context) const { MlasRequantizeOutput( worker_gemm_output, - worker_requantize_output, - Bdata, - static_cast(output_count), static_cast(M), + worker_requantize_output, + static_cast(M), + Bdata, output_scales.data(), output_scales.size() > 1, - Y_zero_point_value); + Y_zero_point_value, + 0,0, + static_cast(output_count), + static_cast(M)); }; concurrency::ThreadPool::TrySimpleParallelFor(thread_pool, thread_count, conv_worker);