From 2ba637c55836c7b2dca554d96bf993970d086606 Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Tue, 10 Nov 2020 23:34:38 -0800 Subject: [PATCH] Implement Scale function for quant gemm (#5632) * Implement a Scale function for quantization Quantized GEMM is always followed by Scaling (PerTensor Or PerColumn), and often need to be accumulated to an existing matrix. This PR implements a post-processor for quantized GEMM result and accumulate it to another matrix. --- cmake/onnxruntime_mlas.cmake | 1 + .../cpu/quantization/attention_quant.cc | 33 +- .../quantization/dynamic_quantize_matmul.cc | 11 +- onnxruntime/core/mlas/inc/mlas.h | 119 ++++-- onnxruntime/core/mlas/lib/qgemm.cpp | 351 ++---------------- onnxruntime/core/mlas/lib/qpostprocessor.cpp | 242 ++++++++++++ onnxruntime/core/util/qmath.cc | 8 +- onnxruntime/test/mlas/unittest.cpp | 105 +++++- 8 files changed, 483 insertions(+), 387 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/qpostprocessor.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index e3a5ce28cc..eb6455542f 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -20,6 +20,7 @@ set(mlas_common_srcs ${ONNXRUNTIME_ROOT}/core/mlas/lib/quantize.cpp ${ONNXRUNTIME_ROOT}/core/mlas/lib/qladd.cpp ${ONNXRUNTIME_ROOT}/core/mlas/lib/qlmul.cpp + ${ONNXRUNTIME_ROOT}/core/mlas/lib/qpostprocessor.cpp ) if(MSVC) diff --git a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc index 4dd0c95ad8..97cd1a1a64 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc @@ -206,21 +206,26 @@ Status QAttention::Compute(OpKernelContext* context) const { if (packed_weights_) { const auto* packed_weight = static_cast(packed_weights_.get()) + packed_weights_size_ * (weights_offset / head_size); + + MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(qkv_dest + qkv_offset, + head_size, + &dequant_scale, + bias_data + weights_offset); MlasGemm( - sequence_length, // M = S - head_size, // N = H - hidden_size, // K = NH - input_data + input_offset, // A - hidden_size, // lda = NH - input_zero_point, // input zero point - packed_weight, // B - weight_zero_point, // weight zero point - weights_is_signed, // weight data type - qkv_dest + qkv_offset, // C - head_size, // ldc - &dequant_scale, // output scale - bias_data + weights_offset, // bias - nullptr); // use single-thread + sequence_length, // M = S + head_size, // N = H + hidden_size, // K = NH + input_data + input_offset, // A + hidden_size, // lda = NH + input_zero_point, // input zero point + packed_weight, // B + weight_zero_point, // weight zero point + weights_is_signed, // weight data type + reinterpret_cast(qkv_dest + qkv_offset), // C + head_size, // ldc + nullptr, // use single-thread + &scale_bias_processor); // output processor + continue; } #endif diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc index fabddfb17b..f8675ee4c2 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc @@ -52,6 +52,10 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx, for (size_t i = 0; i < helper.OutputOffsets().size(); i++) { #ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 if (packed_b_) { + MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(y_data + helper.OutputOffsets()[i], + static_cast(helper.N()), + &multiplier, + bias_data); MlasGemm(static_cast(helper.M()), static_cast(helper.N()), static_cast(helper.K()), @@ -61,11 +65,10 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx, packed_b_.get(), b_zero_point, b_is_signed_, - y_data + helper.OutputOffsets()[i], + reinterpret_cast(y_data + helper.OutputOffsets()[i]), static_cast(helper.N()), - &multiplier, - bias_data, - thread_pool); + thread_pool, + &scale_bias_processor); continue; } #endif diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index eb5724ae39..920f1eeff7 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -207,6 +207,80 @@ MlasGemm( MLAS_THREADPOOL* ThreadPool ); +enum class MLAS_QUANTIZATION_GRANULARITY { + PerMatrix, + PerColumn, +}; + +enum class MLAS_QGEMM_OUTPUT_MODE { + ZeroMode, // overwrite the output buffer + AccumulateMode, // accumulate to the output buffer +}; + +class MLAS_QGEMM_OUTPUT_PROCESSOR { +public: + virtual + void + Process( + const int32_t*, // Supplies the address of matrix to process + size_t, // Supplies the start row index of matrix + size_t, // Supplies the start col index of matrix + size_t, // Supplies the element count per row to process + size_t, // Supplies the element count per col to process + size_t // Supplies the leading dimension of matrix + ) const = 0; +}; + +class MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR : public MLAS_QGEMM_OUTPUT_PROCESSOR { +public: + MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR( + float* Output, + size_t LeadingDimensionOutput, + const float* Scale, + const float* Bias, + MLAS_QGEMM_OUTPUT_MODE Mode = MLAS_QGEMM_OUTPUT_MODE::ZeroMode, + MLAS_QUANTIZATION_GRANULARITY QuantGran = MLAS_QUANTIZATION_GRANULARITY::PerMatrix) : + Output_(Output), + LeadingDimensionOutput_(LeadingDimensionOutput), + Scale_(Scale), + Bias_(Bias), + OutputMode_(Mode), + QuantGran_(QuantGran) + { + } + + void + Process( + const int32_t* C, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN, + size_t ldc + ) const override; + +private: + template + inline + void + ProcessImpl( + const int32_t* C, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN, + size_t ldc + ) const; + +private: + float* Output_; + size_t LeadingDimensionOutput_; + const float* Scale_; + const float* Bias_; + MLAS_QGEMM_OUTPUT_MODE OutputMode_; + MLAS_QUANTIZATION_GRANULARITY QuantGran_; +}; + void MLASCALL MlasGemm( @@ -222,27 +296,8 @@ MlasGemm( bool BIsSigned, int32_t* C, size_t ldc, - MLAS_THREADPOOL* ThreadPool - ); - -void -MLASCALL -MlasGemm( - size_t M, - size_t N, - size_t K, - const uint8_t* A, - size_t lda, - uint8_t offa, - const uint8_t* B, - size_t ldb, - uint8_t offb, - bool BIsSigned, - float* C, - size_t ldc, - const float* Scale, - const float* Bias, - MLAS_THREADPOOL* ThreadPool + MLAS_THREADPOOL* ThreadPool, + const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor = nullptr ); void @@ -259,26 +314,8 @@ MlasGemm( bool BIsSigned, int32_t* C, size_t ldc, - MLAS_THREADPOOL* ThreadPool - ); - -void -MLASCALL -MlasGemm( - size_t M, - size_t N, - size_t K, - const uint8_t* A, - size_t lda, - uint8_t offa, - const void* PackedB, - uint8_t offb, - bool BIsSigned, - float* C, - size_t ldc, - const float* Scale, - const float* Bias, - MLAS_THREADPOOL* ThreadPool + MLAS_THREADPOOL* ThreadPool, + const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor = nullptr ); // diff --git a/onnxruntime/core/mlas/lib/qgemm.cpp b/onnxruntime/core/mlas/lib/qgemm.cpp index ad54fbc5c4..ba21a69dd5 100644 --- a/onnxruntime/core/mlas/lib/qgemm.cpp +++ b/onnxruntime/core/mlas/lib/qgemm.cpp @@ -40,13 +40,11 @@ struct MLAS_GEMM_U8X8_WORK_BLOCK { size_t ldb; int32_t* C; size_t ldc; - const float* Scale; - const float* BiasFloat; uint8_t offa; uint8_t offb; bool BIsPacked; bool BIsSigned; - bool CIsFloat; + const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor; }; // @@ -84,123 +82,6 @@ MlasGemmU8X8ScaleSumBuffer( return MlasGemmU8X8ScaleSumBuffer(SumBuffer, SumBuffer, N, Scale); } -void -MlasGemmU8X8OutputFloat( - const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock, - int32_t* C, - size_t StartN, - size_t CountM, - size_t CountN - ) -/*++ - -Routine Description: - - This routine converts the output matrix to a floating point format using - the supplied scale and bias parameters. - -Arguments: - - WorkBlock - Supplies the structure containing the GEMM parameters. - - C - Supplies the address of matrix C. - - StartN - Supplies the starting column offset relative to the base of the - work block. This is used to offset into column vectors accessed via the - work block. - - CountM - Supplies the number of rows of the output matrix to process. - - CountN - Supplies the number of columns of the output matrix to process. - -Return Value: - - None. - ---*/ -{ - const size_t ldc = WorkBlock->ldc; - const MLAS_FLOAT32X4 ScaleVector = MlasBroadcastFloat32x4(WorkBlock->Scale); -#if !defined(MLAS_SSE2_INTRINSICS) - const float ScaleValue = MlasExtractLaneFloat32x4<0>(ScaleVector); -#endif - - // - // Check if the optional bias vector was supplied. - // - - const float* BiasFloat = WorkBlock->BiasFloat; - - if (BiasFloat != nullptr) { - - BiasFloat += WorkBlock->RangeStartN + StartN; - - while (CountM-- > 0) { - - const float* bias = BiasFloat; - int32_t* c = C; - size_t n = CountN; - - while (n >= 4) { - - MLAS_FLOAT32X4 FloatVector = MlasCastToFloat32x4(MlasLoadInt32x4(c)); - FloatVector = MlasMultiplyFloat32x4(FloatVector, ScaleVector); - FloatVector = MlasAddFloat32x4(FloatVector, MlasLoadFloat32x4(bias)); - MlasStoreFloat32x4(reinterpret_cast(c), FloatVector); - - bias += 4; - c += 4; - n -= 4; - } - - for (size_t offset = 0; offset < n; offset++) { - -#if defined(MLAS_SSE2_INTRINSICS) - __m128 FloatVector = _mm_set_ss(float(c[offset])); - FloatVector = _mm_mul_ss(FloatVector, ScaleVector); - FloatVector = _mm_add_ss(FloatVector, _mm_load_ss(&bias[offset])); - _mm_store_ss(reinterpret_cast(&c[offset]), FloatVector); -#else - *reinterpret_cast(&c[offset]) = float(c[offset]) * ScaleValue + bias[offset]; -#endif - } - - C += ldc; - } - - } else { - - while (CountM-- > 0) { - - int32_t* c = C; - size_t n = CountN; - - while (n >= 4) { - - MLAS_FLOAT32X4 FloatVector = MlasCastToFloat32x4(MlasLoadInt32x4(c)); - FloatVector = MlasMultiplyFloat32x4(FloatVector, ScaleVector); - MlasStoreFloat32x4(reinterpret_cast(c), FloatVector); - - c += 4; - n -= 4; - } - - for (size_t offset = 0; offset < n; offset++) { - -#if defined(MLAS_SSE2_INTRINSICS) - __m128 FloatVector = _mm_set_ss((float)c[offset]); - FloatVector = _mm_mul_ss(FloatVector, ScaleVector); - _mm_store_ss(reinterpret_cast(&c[offset]), FloatVector); -#else - *reinterpret_cast(&c[offset]) = float(c[offset]) * ScaleValue; -#endif - } - - C += ldc; - } - } -} - template void MLASCALL @@ -251,7 +132,7 @@ Return Value: // Try to use a GEMV kernel if supported by this kernel type. // - if ((M == 1) && (offa == 0) && (offb == 0) && !WorkBlock->CIsFloat) { + if ((M == 1) && (offa == 0) && (offb == 0) && WorkBlock->OutputProcessor == nullptr) { if (KernelType::TryGemvKernel(A, B, ldb, C, K, N, WorkBlock->BIsSigned)) { return; } @@ -346,8 +227,13 @@ Return Value: RowsRemaining, CountN, ldc, RowSums, ColumnSumBuffer, DepthValue, ZeroMode); - if (PostProcess && WorkBlock->CIsFloat) { - MlasGemmU8X8OutputFloat(WorkBlock, c, n, RowsHandled, CountN); + if (PostProcess && WorkBlock->OutputProcessor != nullptr) { + WorkBlock->OutputProcessor->Process(WorkBlock->C, + WorkBlock->RangeStartM + m + CountM - RowsRemaining, + WorkBlock->RangeStartN + n, + RowsHandled, + CountN, + WorkBlock->ldc); } c += ldc * RowsHandled; @@ -508,8 +394,14 @@ Return Value: RowsRemaining, CountN, ldc, RowSums, ColumnSumBuffer, DepthValue, ZeroMode); - if (PostProcess && WorkBlock->CIsFloat) { - MlasGemmU8X8OutputFloat(WorkBlock, c, n, RowsHandled, CountN); + if (PostProcess && WorkBlock->OutputProcessor != nullptr) { + WorkBlock->OutputProcessor->Process( + WorkBlock->C, + WorkBlock->RangeStartM + m + CountM - RowsRemaining, + WorkBlock->RangeStartN + n, + RowsHandled, + CountN, + WorkBlock->ldc); } c += ldc * RowsHandled; @@ -2187,7 +2079,8 @@ MlasGemm( bool BIsSigned, int32_t* C, size_t ldc, - MLAS_THREADPOOL* ThreadPool + MLAS_THREADPOOL* ThreadPool, + const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor ) /*++ @@ -2227,6 +2120,8 @@ Arguments: ThreadPool - Supplies the thread pool object to use, else nullptr if the base library threading support should be used. + OutputProcessor - Post Processor on C. + Return Value: None. @@ -2250,6 +2145,7 @@ Return Value: WorkBlock.ldb = ldb; WorkBlock.C = C; WorkBlock.ldc = ldc; + WorkBlock.OutputProcessor = OutputProcessor; WorkBlock.offa = offa; WorkBlock.offb = offb; WorkBlock.BIsSigned = BIsSigned; @@ -2261,107 +2157,6 @@ Return Value: MlasGemmU8X8Schedule(&WorkBlock, ThreadPool); } -void -MLASCALL -MlasGemm( - size_t M, - size_t N, - size_t K, - const uint8_t* A, - size_t lda, - uint8_t offa, - const uint8_t* B, - size_t ldb, - uint8_t offb, - bool BIsSigned, - float* C, - size_t ldc, - const float* Scale, - const float* Bias, - MLAS_THREADPOOL* ThreadPool - ) -/*++ - -Routine Description: - - This routine implements the quantized integer matrix/matrix multiply - operation (QGEMM). - -Arguments: - - M - Supplies the number of rows of matrix A and matrix C. - - N - Supplies the number of columns of matrix B and matrix C. - - K - Supplies the number of columns of matrix A and the number of rows of - matrix B. - - A - Supplies the address of matrix A. - - lda - Supplies the first dimension of matrix A. - - offa - Supplies the zero point offset of matrix A. - - B - Supplies the address of matrix B. - - ldb - Supplies the first dimension of matrix B. - - offb - Supplies the zero point offset of matrix B. - - BIsSigned - Supplies true if matrix B is signed data, else false if matrix - B is unsigned data. - - C - Supplies the address of matrix C. - - ldc - Supplies the first dimension of matrix C. - - Scale - Supplies the scale multiplier to apply to each element of matrix C. - Used to scale the integer output of the QGEMM back to a floating point - number. - - Bias - Supplies the bias vector to apply to element of matrix C. The vector - is of length N. - - ThreadPool - Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - -Return Value: - - None. - ---*/ -{ - MLAS_GEMM_U8X8_WORK_BLOCK WorkBlock; - - // - // Capture the GEMM parameters to the work block. - // - - memset(&WorkBlock, 0, sizeof(MLAS_GEMM_U8X8_WORK_BLOCK)); - - WorkBlock.M = M; - WorkBlock.N = N; - WorkBlock.K = K; - WorkBlock.A = A; - WorkBlock.lda = lda; - WorkBlock.B = B; - WorkBlock.ldb = ldb; - WorkBlock.C = (int32_t*)C; - WorkBlock.ldc = ldc; - WorkBlock.Scale = Scale; - WorkBlock.BiasFloat = Bias; - WorkBlock.offa = offa; - WorkBlock.offb = offb; - WorkBlock.BIsSigned = BIsSigned; - WorkBlock.CIsFloat = true; - - // - // Schedule the operation across a set of worker threads. - // - - MlasGemmU8X8Schedule(&WorkBlock, ThreadPool); -} - #endif // MLAS_SUPPORTS_GEMM_U8X8 #ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 @@ -2380,7 +2175,8 @@ MlasGemm( bool BIsSigned, int32_t* C, size_t ldc, - MLAS_THREADPOOL* ThreadPool + MLAS_THREADPOOL* ThreadPool, + const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor ) /*++ @@ -2418,100 +2214,7 @@ Arguments: ThreadPool - Supplies the thread pool object to use, else nullptr if the base library threading support should be used. -Return Value: - - None. - ---*/ -{ - MLAS_GEMM_U8X8_WORK_BLOCK WorkBlock; - - // - // Capture the GEMM parameters to the work block. - // - - memset(&WorkBlock, 0, sizeof(MLAS_GEMM_U8X8_WORK_BLOCK)); - - WorkBlock.M = M; - WorkBlock.N = N; - WorkBlock.K = K; - WorkBlock.A = A; - WorkBlock.lda = lda; - WorkBlock.B = PackedB; - WorkBlock.C = C; - WorkBlock.ldc = ldc; - WorkBlock.offa = offa; - WorkBlock.offb = offb; - WorkBlock.BIsPacked = true; - WorkBlock.BIsSigned = BIsSigned; - - // - // Schedule the operation across a set of worker threads. - // - - MlasGemmU8X8Schedule(&WorkBlock, ThreadPool); -} - -void -MLASCALL -MlasGemm( - size_t M, - size_t N, - size_t K, - const uint8_t* A, - size_t lda, - uint8_t offa, - const void* PackedB, - uint8_t offb, - bool BIsSigned, - float* C, - size_t ldc, - const float* Scale, - const float* Bias, - MLAS_THREADPOOL* ThreadPool - ) -/*++ - -Routine Description: - - This routine implements the quantized integer matrix/matrix multiply - operation (QGEMM). - -Arguments: - - M - Supplies the number of rows of matrix A and matrix C. - - N - Supplies the number of columns of matrix B and matrix C. - - K - Supplies the number of columns of matrix A and the number of rows of - matrix B. - - A - Supplies the address of matrix A. - - lda - Supplies the first dimension of matrix A. - - offa - Supplies the zero point offset of matrix A. - - PackedB - Supplies the address of packed matrix B. - - offb - Supplies the zero point offset of matrix B. - - BIsSigned - Supplies true if matrix B is signed data, else false if matrix - B is unsigned data. - - C - Supplies the address of matrix C. - - ldc - Supplies the first dimension of matrix C. - - Scale - Supplies the scale multiplier to apply to each element of matrix C. - Used to scale the integer output of the QGEMM back to a floating point - number. - - Bias - Supplies the bias vector to apply to element of matrix C. The vector - is of length N. - - ThreadPool - Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. + OutputProcessor - Post Processor on C Return Value: @@ -2535,13 +2238,11 @@ Return Value: WorkBlock.B = PackedB; WorkBlock.C = (int32_t*)C; WorkBlock.ldc = ldc; - WorkBlock.Scale = Scale; - WorkBlock.BiasFloat = Bias; + WorkBlock.OutputProcessor = OutputProcessor, WorkBlock.offa = offa; WorkBlock.offb = offb; WorkBlock.BIsPacked = true; WorkBlock.BIsSigned = BIsSigned; - WorkBlock.CIsFloat = true; // // Schedule the operation across a set of worker threads. diff --git a/onnxruntime/core/mlas/lib/qpostprocessor.cpp b/onnxruntime/core/mlas/lib/qpostprocessor.cpp new file mode 100644 index 0000000000..259ce240d0 --- /dev/null +++ b/onnxruntime/core/mlas/lib/qpostprocessor.cpp @@ -0,0 +1,242 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + qpostprocessor.cpp + +Abstract: + + This module implements the post processor for QGEMM. + +--*/ + +#include "mlasi.h" + +#ifdef MLAS_SUPPORTS_GEMM_U8X8 + +void MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR::Process( + const int32_t* C, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN, + size_t ldc + ) const +{ + if (Bias_) { + if (QuantGran_ == MLAS_QUANTIZATION_GRANULARITY::PerColumn) { + if (OutputMode_ == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) { + ProcessImpl( + C, + StartM, + StartN, + CountM, + CountN, + ldc); + } else { + ProcessImpl( + C, + StartM, + StartN, + CountM, + CountN, + ldc); + } + } else if (OutputMode_ == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) { + ProcessImpl( + C, + StartM, + StartN, + CountM, + CountN, + ldc); + } else { + ProcessImpl( + C, + StartM, + StartN, + CountM, + CountN, + ldc); + } + } else { + if (QuantGran_ == MLAS_QUANTIZATION_GRANULARITY::PerColumn) { + if (OutputMode_ == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) { + ProcessImpl( + C, + StartM, + StartN, + CountM, + CountN, + ldc); + } else { + ProcessImpl( + C, + StartM, + StartN, + CountM, + CountN, + ldc); + } + } else if (OutputMode_ == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) { + ProcessImpl( + C, + StartM, + StartN, + CountM, + CountN, + ldc); + } else { + ProcessImpl( + C, + StartM, + StartN, + CountM, + CountN, + ldc); + } + } +} + +template +inline +void +MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR::ProcessImpl( + const int32_t* C, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN, + size_t ldc) const +/*++ + +Routine Description: + + This routine converts the output matrix C to a floating point format using + the stored scale and bias parameters. + +Arguments: + + C - Supplies the address of matrix C. + + StartM - Supplies the starting row offset relative to the matrix. + + StartN - Supplies the starting column offset relative to the matrix. + + CountM - Supplies the number of rows of the output matrix to process. + + CountN - Supplies the number of columns of the output matrix to process. + + ldc - Supplies the leading dimension of C. + +Return Value: + + None. + +--*/ +{ + float* Output = Output_; + const float* Bias = Bias_; + const float* Scale = Scale_; + + if (HasBias) { + Bias += StartN; + } + + if(QuantGran == MLAS_QUANTIZATION_GRANULARITY::PerColumn){ + Scale += StartN; + } + + MLAS_FLOAT32X4 ScaleVector = MlasBroadcastFloat32x4(Scale_); +#if !defined(MLAS_SSE2_INTRINSICS) + float ScaleValue = MlasExtractLaneFloat32x4<0>(ScaleVector); +#endif + + C += StartM * ldc + StartN; + Output += StartM * LeadingDimensionOutput_ + StartN; + + + while (CountM-- > 0) { + + float* c_out = Output; + const int32_t* c = C; + const float* bias = Bias; + const float* scale = Scale; + + size_t n = CountN; + + while (n >= 4) { + + MLAS_FLOAT32X4 FloatVector = MlasCastToFloat32x4(MlasLoadInt32x4(c)); + + if (QuantGran == MLAS_QUANTIZATION_GRANULARITY::PerColumn) { + ScaleVector = MlasLoadFloat32x4(scale); + scale += 4; + } + + if (Mode == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) { + FloatVector = MlasMultiplyAddFloat32x4(FloatVector, ScaleVector, MlasLoadFloat32x4(c_out)); + } else { + FloatVector = MlasMultiplyFloat32x4(FloatVector, ScaleVector); + } + + if (HasBias) { + FloatVector = MlasAddFloat32x4(FloatVector, MlasLoadFloat32x4(bias)); + bias += 4; + } + + MlasStoreFloat32x4(c_out, FloatVector); + + c_out += 4; + c += 4; + n -= 4; + } + + for (size_t offset = 0; offset < n; offset++) { + +#if defined(MLAS_SSE2_INTRINSICS) + __m128 FloatVector = _mm_set_ss(float(c[offset])); + + if (QuantGran == MLAS_QUANTIZATION_GRANULARITY::PerColumn) { + ScaleVector = _mm_load_ss(&scale[offset]); + } + + if (Mode == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) { + FloatVector = _mm_add_ps(_mm_mul_ss(FloatVector, ScaleVector), _mm_load_ss(&c_out[offset])); + } else { + FloatVector = _mm_mul_ss(FloatVector, ScaleVector); + } + + if (HasBias) { + FloatVector = _mm_add_ss(FloatVector, _mm_load_ss(&bias[offset])); + } + + _mm_store_ss(&c_out[offset], FloatVector); +#else + if (QuantGran == MLAS_QUANTIZATION_GRANULARITY::PerColumn) { + ScaleValue = scale[offset]; + } + + float result = float(c[offset]) * ScaleValue; + if (HasBias) { + result += bias[offset]; + } + + if (Mode == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) { + c_out[offset] += result; + } else { + c_out[offset] = result; + } +#endif + } + + C += ldc; + Output += LeadingDimensionOutput_; + } +} + +#endif // MLAS_SUPPORTS_GEMM_U8X8 diff --git a/onnxruntime/core/util/qmath.cc b/onnxruntime/core/util/qmath.cc index e3aa0b761f..2b62d885c0 100644 --- a/onnxruntime/core/util/qmath.cc +++ b/onnxruntime/core/util/qmath.cc @@ -85,7 +85,13 @@ void QGemm( const float* bias, concurrency::ThreadPool* thread_pool) { #ifdef MLAS_SUPPORTS_GEMM_U8X8 - MlasGemm(M, N, K, lhs_data, lda, lhs_offset, rhs_data, ldb, rhs_offset, rhs_signed, result_data, ldc, result_scale, bias, thread_pool); + MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(result_data, ldc, result_scale, bias); + MlasGemm(M, N, K, + lhs_data, lda, lhs_offset, + rhs_data, ldb, rhs_offset, rhs_signed, + reinterpret_cast(result_data), ldc, + thread_pool, + &scale_bias_processor); #else QGemm(M, N, K, lhs_data, lda, lhs_offset, rhs_data, ldb, rhs_offset, rhs_signed, reinterpret_cast(result_data), ldc, thread_pool); for (int m = 0; m < M; m++) { diff --git a/onnxruntime/test/mlas/unittest.cpp b/onnxruntime/test/mlas/unittest.cpp index 637da3816b..f9df282845 100644 --- a/onnxruntime/test/mlas/unittest.cpp +++ b/onnxruntime/test/mlas/unittest.cpp @@ -573,7 +573,13 @@ protected: const float* Bias ) { - MlasGemm(M, N, K, A, lda, offa, B, ldb, offb, BIsSigned, C, ldc, &CScale, Bias, threadpool); + MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(C, ldc, &CScale, Bias); + MlasGemm(M, N, K, + A, lda, offa, + B, ldb, offb, BIsSigned, + reinterpret_cast(C), ldc, + threadpool, + &scale_bias_processor); } }; @@ -638,7 +644,13 @@ protected: ) { const void* PackedB = PackB(N, K, B, ldb, BIsSigned); - MlasGemm(M, N, K, A, lda, offa, PackedB, offb, BIsSigned, C, ldc, &CScale, Bias, threadpool); + MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(C, ldc, &CScale, Bias); + MlasGemm(M, N, K, + A, lda, offa, + PackedB, offb, BIsSigned, + reinterpret_cast(C), ldc, + threadpool, + &scale_bias_processor); } private: @@ -2756,6 +2768,92 @@ public: } }; +class MlasScaleOutputTest : public MlasTestBase +{ +private: + MatrixGuardBuffer BufferInput; + MatrixGuardBuffer BufferOutput; + MatrixGuardBuffer BufferOutputRef; + MatrixGuardBuffer BufferScale; + + void + Test( + size_t M, + size_t N, + bool PerColumn, + bool AccumulateMode + ) + { + int32_t* Input = BufferInput.GetBuffer(M * N); + float* Output = BufferOutput.GetBuffer(M * N); + float* OutputRef = BufferOutputRef.GetBuffer(M * N); + float* Scale = BufferScale.GetBuffer(PerColumn ? N : 1); + + std::default_random_engine generator(static_cast(M * N)); + std::uniform_real_distribution real_distribution(-1.0f, 1.0f); + std::uniform_int_distribution int_distribution(std::numeric_limits::min(), + std::numeric_limits::max()); + + for (size_t s = 0; s < M * N; s++) { + + Input[s] = int_distribution(generator); + Output[s] = OutputRef[s] = real_distribution(generator); + } + + for (size_t s = 0; s < (PerColumn ? N : 1); s++) { + Scale[s] = real_distribution(generator); + } + + // Compute Reference Value + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + + float current_scale = PerColumn ? Scale[n] : Scale[0]; + if (AccumulateMode) { + OutputRef[m * N + n] += Input[m * N + n] * current_scale; + } else { + OutputRef[m * N + n] = Input[m * N + n] * current_scale; + } + } + } + + // Compute Output with MLAS + MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR OutputProcessor(Output, N, Scale, nullptr, + AccumulateMode ? MLAS_QGEMM_OUTPUT_MODE::AccumulateMode : MLAS_QGEMM_OUTPUT_MODE::ZeroMode, + PerColumn ? MLAS_QUANTIZATION_GRANULARITY::PerColumn : MLAS_QUANTIZATION_GRANULARITY::PerMatrix); + OutputProcessor.Process(Input, 0, 0, M, N, N); + + constexpr float epsilon = 1e-6f; + + for (size_t n = 0; n < M * N; n++) { + + float diff = std::fabs(Output[n] - OutputRef[n]); + if (diff > epsilon) { + printf("MlasScaleOutputTest: Output[%zu][%zu]:%.8f, OutputRef[%zu][%zu]:%.8f, for case M=%zu, N=%zu\n", + n / N, n % N, Output[n], n / N, n % N, OutputRef[n], M, N); + } + } + } + +public: + void + ExecuteShort( + void + ) override + { + for (size_t m = 1; m < 18; m++) { + + for (size_t n = 1; n < 18; n++) { + + Test(m, n, true, true); + Test(m, n, true, false); + Test(m, n, false, true); + Test(m, n, false, false); + } + } + } +}; + void RunThreadedTests( void @@ -2870,6 +2968,9 @@ main( onnxruntime::make_unique( [] (float a, float b) { return a * b; }, "*", MlasQLinearMul, MlasQLinearMul)->ExecuteShort(); + printf("MlasScaleOutput tests.\n"); + onnxruntime::make_unique()->ExecuteShort(); + printf("Done.\n"); return 0;