diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 3fee9c0cb5..15323d6b59 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -28,6 +28,7 @@ if(MSVC) if(onnxruntime_target_platform STREQUAL "ARM64") set(mlas_platform_preprocess_srcs ${ONNXRUNTIME_ROOT}/core/mlas/lib/arm64/QgemmU8X8KernelNeon.asm + ${ONNXRUNTIME_ROOT}/core/mlas/lib/arm64/QgemmU8X8KernelUdot.asm ${ONNXRUNTIME_ROOT}/core/mlas/lib/arm64/SgemmKernelNeon.asm ) @@ -81,15 +82,13 @@ if(MSVC) ${mlas_platform_srcs_avx2} ${ONNXRUNTIME_ROOT}/core/mlas/lib/intrinsics/avx512/quantize_avx512f.cpp ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm - ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemvU8S8KernelAvx2.asm - ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Core.asm - ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemvU8S8KernelAvx512Core.asm - ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Vnni.asm - ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemvU8S8KernelAvx512Vnni.asm - ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8S8KernelAvxVnni.asm - ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemvU8S8KernelAvxVnni.asm ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8U8KernelAvx2.asm - ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Core.asm + ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8X8KernelAvx2.asm + ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8X8KernelAvx512Core.asm + ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemvU8S8KernelAvx2.asm + ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemvU8S8KernelAvx512Core.asm + ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemvU8S8KernelAvx512Vnni.asm + ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemvU8S8KernelAvxVnni.asm ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/DgemmKernelSse2.asm ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/DgemmKernelAvx.asm ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/DgemmKernelFma3.asm @@ -183,6 +182,7 @@ else() set(mlas_platform_srcs ${ONNXRUNTIME_ROOT}/core/mlas/lib/aarch64/QgemmU8X8KernelNeon.S + ${ONNXRUNTIME_ROOT}/core/mlas/lib/aarch64/QgemmU8X8KernelUdot.S ${ONNXRUNTIME_ROOT}/core/mlas/lib/aarch64/SgemmKernelNeon.S ${ONNXRUNTIME_ROOT}/core/mlas/lib/aarch64/SgemvKernelNeon.S ) @@ -246,8 +246,8 @@ else() ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemvU8S8KernelAvx2.S ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8U8KernelAvx2.S - ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8S8KernelAvxVnni.S ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemvU8S8KernelAvxVnni.S + ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/DgemmKernelFma3.S ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/SgemmKernelFma3.S ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/SconvKernelFma3.S @@ -322,11 +322,9 @@ else() if(COMPILES_AVX512CORE) set(mlas_platform_srcs_avx512core - ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Core.S ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemvU8S8KernelAvx512Core.S - ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Vnni.S ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemvU8S8KernelAvx512Vnni.S - ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Core.S + ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Core.S ) if(HAS_AVX512CORE) set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mavx512bw -mavx512dq -mavx512vl") diff --git a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc index e871a69c86..f67a16baf8 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc @@ -22,10 +22,7 @@ class QAttention : public OpKernel, public AttentionCPUBase { QAttention(const OpKernelInfo& info); Status Compute(OpKernelContext* context) const override; - -#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 Status PrePack(const Tensor& tensor, int input_idx, bool& is_packed) override; -#endif private: BufferUniquePtr packed_weights_; @@ -51,7 +48,6 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( template QAttention::QAttention(const OpKernelInfo& info) : OpKernel(info), AttentionCPUBase(info) {} -#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 template Status QAttention::PrePack(const Tensor& weights, int input_idx, bool& is_packed) { is_packed = false; @@ -98,7 +94,6 @@ Status QAttention::PrePack(const Tensor& weights, int input_idx, bool& is_pac is_packed = true; return Status::OK(); } -#endif template Status QAttention::Compute(OpKernelContext* context) const { @@ -217,44 +212,29 @@ Status QAttention::Compute(OpKernelContext* context) const { head_size, &dequant_scale, bias_data + weights_offset); -#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 + + MLAS_GEMM_U8X8_PARAMETERS gemm_params; + gemm_params.M = sequence_length; + gemm_params.N = head_size; + gemm_params.K = input_hidden_size; + gemm_params.A = input_data + input_offset; + gemm_params.lda = input_hidden_size; + gemm_params.ZeroPointA = input_zero_point; if (packed_weights_) { const auto* packed_weight = static_cast(packed_weights_.get()) + packed_weights_size_ * (weights_offset / head_size); - - MlasGemm( - sequence_length, // M = S - head_size, // N = H - input_hidden_size, // K = D - input_data + input_offset, // A - input_hidden_size, // lda = D - input_zero_point, // input zero point - packed_weight, // B - weight_zero_point, // weight zero point - weights_is_signed, // weight data type - reinterpret_cast(qkv_dest + qkv_offset), // C - head_size, // ldc - nullptr, // use single-thread - &scale_bias_processor); // output processor - - continue; + gemm_params.B = packed_weight; + gemm_params.BIsPacked = true; + } else { + gemm_params.B = weights_data + weights_offset; + gemm_params.ldb = 3 * hidden_size; } -#endif - MlasGemm( - sequence_length, // M = S - head_size, // N = H - input_hidden_size, // K = D - input_data + input_offset, // A - input_hidden_size, // lda = D - input_zero_point, // input zero point - weights_data + weights_offset, // B - 3 * hidden_size, // ldb = 3NH - weight_zero_point, // weight zero point - weights_is_signed, // weight data type - reinterpret_cast(qkv_dest + qkv_offset), // C - head_size, // ldc - nullptr, // use single-thread - &scale_bias_processor); // post processor + gemm_params.ZeroPointB = &weight_zero_point; + gemm_params.BIsSigned = weights_is_signed; + gemm_params.C = reinterpret_cast(qkv_dest + qkv_offset); + gemm_params.ldc = head_size; + gemm_params.OutputProcessor = &scale_bias_processor; + MlasGemm(&gemm_params, nullptr); } }); } diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc index 6a8636dfaa..4511a66f53 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc @@ -10,19 +10,14 @@ using namespace rnn::detail; class DynamicQuantizeLSTM : public OpKernel, public LSTMBase { public: DynamicQuantizeLSTM(const OpKernelInfo& info) : OpKernel(info), LSTMBase(info) {} - -#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 Status PrePack(const Tensor& tensor, int input_idx, bool& is_packed) override; -#endif Status Compute(OpKernelContext* context) const override; ~DynamicQuantizeLSTM() override = default; private: -#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 Status TryPackWeights(const Tensor& weights, PackedWeights& packed_weights, bool& is_packed, bool& is_weight_signed); -#endif template Status ComputeImpl(OpKernelContext& context) const; @@ -33,7 +28,6 @@ class DynamicQuantizeLSTM : public OpKernel, public LSTMBase { bool is_R_signed_; }; -#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 Status DynamicQuantizeLSTM::TryPackWeights(const Tensor& weights, PackedWeights& packed_weights, bool& is_packed, bool& is_weight_signed) { const auto& shape = weights.Shape(); if (shape.NumDimensions() != 3) { @@ -83,7 +77,6 @@ Status DynamicQuantizeLSTM::PrePack(const Tensor& tensor, int input_idx, bool& i return Status::OK(); } -#endif #define WeightCheck(weight_shape, weight_name) \ if (weight_shape.NumDimensions() != 1 && weight_shape.NumDimensions() != 2 || \ diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc index 1d06d5f0d0..0d8d05fd69 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc @@ -44,11 +44,19 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx, if (y->Shape().Size() == 0) return Status::OK(); + MLAS_GEMM_U8X8_PARAMETERS gemm_params; + gemm_params.M = static_cast(helper.M()); + gemm_params.N = static_cast(helper.N()); + gemm_params.K = static_cast(helper.K()); + gemm_params.lda = gemm_params.K; + gemm_params.ZeroPointA = a_zero_point; + gemm_params.ldb = gemm_params.N; + gemm_params.ZeroPointB = &b_zero_point; + gemm_params.ldc = gemm_params.N; + auto* y_data = y->template MutableData(); const auto* bias_data = bias_tensor != nullptr ? bias_tensor->Data() : nullptr; - concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); - for (size_t i = 0; i < helper.OutputOffsets().size(); i++) { MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor( y_data + helper.OutputOffsets()[i], @@ -56,40 +64,18 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx, &multiplier, bias_data); -#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 + gemm_params.A = a_data + helper.LeftOffsets()[i]; if (packed_b_) { - MlasGemm(static_cast(helper.M()), - static_cast(helper.N()), - static_cast(helper.K()), - a_data + helper.LeftOffsets()[i], - static_cast(helper.K()), - a_zero_point, - packed_b_.get(), - b_zero_point, - b_is_signed_, - reinterpret_cast(y_data + helper.OutputOffsets()[i]), - static_cast(helper.N()), - thread_pool, - &scale_bias_processor); - continue; + gemm_params.B = packed_b_.get(); + gemm_params.BIsPacked = true; + gemm_params.BIsSigned = b_is_signed_; + } else { + gemm_params.B = static_cast(b->DataRaw()) + + helper.RightOffsets()[i]; + gemm_params.BIsSigned = b->IsDataType(); } -#endif - const auto* b_data = static_cast(b->DataRaw()); - const bool b_is_signed = b->IsDataType(); - MlasGemm(static_cast(helper.M()), - static_cast(helper.N()), - static_cast(helper.K()), - a_data + helper.LeftOffsets()[i], - static_cast(helper.K()), - a_zero_point, - b_data + helper.RightOffsets()[i], - static_cast(helper.N()), - b_zero_point, - b_is_signed, - reinterpret_cast(y_data + helper.OutputOffsets()[i]), - static_cast(helper.N()), - thread_pool, - &scale_bias_processor); + gemm_params.C = reinterpret_cast(y_data) + helper.OutputOffsets()[i]; + gemm_params.OutputProcessor = &scale_bias_processor; + MlasGemm(&gemm_params, ctx->GetOperatorThreadPool()); } return Status::OK(); diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index a6943074d6..aae34514bb 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -61,10 +61,6 @@ Abstract: #define MLAS_SUPPORTS_GEMM_DOUBLE #endif -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_ARM64) || (defined(MLAS_TARGET_ARM) && !defined(_MSC_VER)) -#define MLAS_SUPPORTS_PACKED_GEMM_U8X8 -#endif - // // Basic Linear Algebra Subprograms (BLAS) types. // @@ -273,41 +269,29 @@ private: MLAS_QUANTIZATION_GRANULARITY QuantGran_; }; -void -MLASCALL -MlasGemm( - size_t M, - size_t N, - size_t K, - const uint8_t* A, - size_t lda, - uint8_t offa, - const uint8_t* B, - size_t ldb, - uint8_t offb, - bool BIsSigned, - int32_t* C, - size_t ldc, - MLAS_THREADPOOL* ThreadPool, - const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor = nullptr - ); +struct MLAS_GEMM_U8X8_PARAMETERS { + size_t M = 0; + size_t N = 0; + size_t K = 0; + const uint8_t* A = nullptr; + size_t lda = 0; + uint8_t ZeroPointA = 0; + const void* B = 0; + size_t ldb = 0; + const uint8_t* ZeroPointB = nullptr; + bool BIsPacked = false; + bool BIsSigned = false; + bool PerColumnZeroPoints = false; + int32_t* C = nullptr; + size_t ldc = 0; + const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor = nullptr; +}; void MLASCALL MlasGemm( - size_t M, - size_t N, - size_t K, - const uint8_t* A, - size_t lda, - uint8_t offa, - const void* PackedB, - uint8_t offb, - bool BIsSigned, - int32_t* C, - size_t ldc, - MLAS_THREADPOOL* ThreadPool, - const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor = nullptr + const MLAS_GEMM_U8X8_PARAMETERS* Parameters, + MLAS_THREADPOOL* ThreadPool ); // diff --git a/onnxruntime/core/mlas/lib/aarch32/QgemmU8X8KernelNeon.S b/onnxruntime/core/mlas/lib/aarch32/QgemmU8X8KernelNeon.S index 01799b196c..dbf1047f58 100644 --- a/onnxruntime/core/mlas/lib/aarch32/QgemmU8X8KernelNeon.S +++ b/onnxruntime/core/mlas/lib/aarch32/QgemmU8X8KernelNeon.S @@ -25,7 +25,7 @@ Abstract: // Stack frame layout for the U8X8 kernel. // - .equ .LGemmU8X8KernelFrame_SavedGeneralRegisters, (6 * 4) + .equ .LGemmU8X8KernelFrame_SavedGeneralRegisters, (7 * 4) .equ .LGemmU8X8KernelFrame_SavedNeonRegisters, (8 * 8) .equ .LGemmU8X8KernelFrame_SavedRegisters, .LGemmU8X8KernelFrame_SavedGeneralRegisters + .LGemmU8X8KernelFrame_SavedNeonRegisters .equ .LGemmU8X8KernelFrame_CountM, 0 + .LGemmU8X8KernelFrame_SavedRegisters @@ -33,7 +33,7 @@ Abstract: .equ .LGemmU8X8KernelFrame_ldc, 8 + .LGemmU8X8KernelFrame_SavedRegisters .equ .LGemmU8X8KernelFrame_RowSumBuffer, 12 + .LGemmU8X8KernelFrame_SavedRegisters .equ .LGemmU8X8KernelFrame_ColumnSumBuffer, 16 + .LGemmU8X8KernelFrame_SavedRegisters - .equ .LGemmU8X8KernelFrame_DepthValue, 20 + .LGemmU8X8KernelFrame_SavedRegisters + .equ .LGemmU8X8KernelFrame_ZeroPointB, 20 + .LGemmU8X8KernelFrame_SavedRegisters .equ .LGemmU8X8KernelFrame_ZeroMode, 24 + .LGemmU8X8KernelFrame_SavedRegisters .text @@ -48,10 +48,10 @@ Routine Description: Arguments: A (r0) - Supplies the address of matrix A. The matrix data has been packed - using MlasGemmU8X8CopyPackANeon. + using MlasGemmU8X8CopyPackA. B (r1) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmU8X8CopyPackBNeon. + using MlasGemmU8X8CopyPackB. C (r2) - Supplies the address of matrix C. @@ -67,17 +67,18 @@ Arguments: ldc - Supplies the first dimension of matrix C. - RowSumBuffer - Supplies the sum of each row from matrix A multiplied by - the zero point offset of matrix B. These values are accumulated into every - row of matrix C. + RowSumBuffer - Supplies the sum of each row from matrix A. These values have + been pre-scaled by the zero point offset of matrix B if the offset is + per-tensor (ZeroPointB is nullptr). Otherwise, these values must be + scaled by the per-column zero point offsets of matrix B. These values are + accumulated into every row of matrix C. ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied by the zero point offset of matrix A. These values are accumulated into every column of matrix C. - DepthValue - Supplies the value CountK multiplied by the zero point offset - of matrix A multplied by the zero point offset of matrix B. This value is - accumulated into every element of matrix C. + ZeroPointB - Optionally supplies the per-column zero point offsets of matrix + B, else nullptr if the matrix B is using per-tensor quantization. ZeroMode - Supplies true if the output matrix must be zero initialized, else false if the output matrix is accumulated into. @@ -95,26 +96,24 @@ Return Value: // // q0-q1 (d0-d3) matrix B data // q2-q3 (d4-d7) matrix A data -// q4 (d8-d9) packed matrix B data -// q5 (d10-d11) RowSumBufferData + DepthValue +// q4 (d8-d9) packed matrix A data +// q5 (d10-d11) RowSumBuffer data // q6-q7 (d12-d15) ColumnSumBuffer data // q8-q15 accumulators[4][2] // - push {r4,r5,r6,r7,r8,r10} + push {r4,r5,r6,r7,r8,r9,r10} vpush {d8-d15} ldr r4,[sp,#.LGemmU8X8KernelFrame_CountM] ldr r5,[sp,#.LGemmU8X8KernelFrame_ZeroMode] - ldr r7,[sp,#.LGemmU8X8KernelFrame_RowSumBuffer] + ldr r7,[sp,#.LGemmU8X8KernelFrame_ZeroPointB] ldr r8,[sp,#.LGemmU8X8KernelFrame_ColumnSumBuffer] - vldr s0,[sp,#.LGemmU8X8KernelFrame_DepthValue] + ldr r9,[sp,#.LGemmU8X8KernelFrame_RowSumBuffer] ldr r10,[sp,#.LGemmU8X8KernelFrame_ldc] ldr r12,[sp,#.LGemmU8X8KernelFrame_CountN] - vdup.32 q5,d0[0] // broadcast DepthValue - vld1.32 {d12-d13},[r7] + vld1.32 {d10-d11},[r9] // load RowSumBuffer mov r6,r0 - mov r7,r3 - vadd.u32 q5,q5,q6 // add row fixups and DepthValue + mov r9,r3 cmp r4,#1 // CountM == 1? beq .LGemmU8X8.M1.ProcessNextColumnLoop cmp r4,#4 // CountM < 4? @@ -128,12 +127,36 @@ Return Value: vldr d0,[r1] // load packed B0 mov r0,r6 // reload matrix A vld1.32 {d12-d15},[r8]! // load ColumnSumBuffer - mov r3,r7 // reload PackedCountK + mov r3,r9 // reload PackedCountK vmovl.u8 q0,d0 vdup.32 q9,d10[0] vdup.32 q11,d10[1] vdup.32 q13,d11[0] vdup.32 q15,d11[1] + cbz r7,.LGemmU8X8.M4.SkipScaleByZeroPointB + vld1.32 {d8-d9},[r7]! // load ZeroPointB0 + vmul.u32 q8,q9,q4 + vmul.u32 q10,q11,q4 + vmul.u32 q12,q13,q4 + vmul.u32 q14,q15,q4 + vld1.32 {d8-d9},[r7]! // load ZeroPointB1 + vmul.u32 q9,q9,q4 + vmul.u32 q11,q11,q4 + vmul.u32 q13,q13,q4 + vmul.u32 q15,q15,q4 + vldr d8,[r0] // load first packed A0 + vadd.u32 q8,q8,q6 + vadd.u32 q9,q9,q7 + vadd.u32 q10,q10,q6 + vadd.u32 q11,q11,q7 + vldr d9,[r0,#8] // load first packed A1 + vadd.u32 q12,q12,q6 + vadd.u32 q13,q13,q7 + vadd.u32 q14,q14,q6 + vadd.u32 q15,q15,q7 + b .LGemmU8X8.M4.ComputeBlockLoop + +.LGemmU8X8.M4.SkipScaleByZeroPointB: vldr d8,[r0] // load first packed A0 vadd.u32 q8,q9,q6 vadd.u32 q9,q9,q7 @@ -244,7 +267,7 @@ Return Value: .LGemmU8X8.M4.ExitKernel: mov r0,#4 // return number of rows handled vpop {d8-d15} - pop {r4,r5,r6,r7,r8,r10} + pop {r4,r5,r6,r7,r8,r9,r10} bx lr // @@ -353,10 +376,24 @@ Return Value: vldr d0,[r1] // load packed B0 mov r0,r6 // reload matrix A vld1.32 {d12-d15},[r8]! // load ColumnSumBuffer - mov r3,r7 // reload PackedCountK + mov r3,r9 // reload PackedCountK vmovl.u8 q0,d0 vdup.32 q9,d10[0] vdup.32 q11,d10[1] + cbz r7,.LGemmU8X8.M2.SkipScaleByZeroPointB + vld1.32 {d28-d31},[r7]! // load ZeroPointB + vmul.u32 q8,q9,q14 + vmul.u32 q9,q9,q15 + vmul.u32 q10,q11,q14 + vmul.u32 q11,q11,q15 + vld1.32 d8,[r0]! // load first packed A0 + vadd.u32 q8,q8,q6 + vadd.u32 q9,q9,q7 + vadd.u32 q10,q10,q6 + vadd.u32 q11,q11,q7 + b .LGemmU8X8.M2.ComputeBlockLoop + +.LGemmU8X8.M2.SkipScaleByZeroPointB: vld1.32 d8,[r0]! // load first packed A0 vadd.u32 q8,q9,q6 vadd.u32 q9,q9,q7 @@ -425,7 +462,7 @@ Return Value: .LGemmU8X8.M2.ExitKernel: mov r0,#2 // return number of rows handled vpop {d8-d15} - pop {r4,r5,r6,r7,r8,r10} + pop {r4,r5,r6,r7,r8,r9,r10} bx lr // @@ -502,9 +539,19 @@ Return Value: vldr d0,[r1] // load packed B0 mov r0,r6 // reload matrix A vld1.32 {d12-d15},[r8]! // load ColumnSumBuffer - mov r3,r7 // reload PackedCountK + mov r3,r9 // reload PackedCountK vmovl.u8 q0,d0 vdup.32 q9,d10[0] + cbz r7,.LGemmU8X8.M1.SkipScaleByZeroPointB + vld1.32 {d28-d31},[r7]! // load ZeroPointB + vmul.u32 q8,q9,q14 + vmul.u32 q9,q9,q15 + vld1.32 d8[0],[r0]! // load first packed A0 + vadd.u32 q8,q8,q6 + vadd.u32 q9,q9,q7 + b .LGemmU8X8.M1.ComputeBlockLoop + +.LGemmU8X8.M1.SkipScaleByZeroPointB: vld1.32 d8[0],[r0]! // load first packed A0 vadd.u32 q8,q9,q6 vadd.u32 q9,q9,q7 @@ -554,7 +601,7 @@ Return Value: .LGemmU8X8.M1.ExitKernel: mov r0,#1 // return number of rows handled vpop {d8-d15} - pop {r4,r5,r6,r7,r8,r10} + pop {r4,r5,r6,r7,r8,r9,r10} bx lr // diff --git a/onnxruntime/core/mlas/lib/aarch64/AssembleDotProduct.h b/onnxruntime/core/mlas/lib/aarch64/AssembleDotProduct.h new file mode 100644 index 0000000000..92c4d38980 --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/AssembleDotProduct.h @@ -0,0 +1,52 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + AssembleDotProduct.h + +Abstract: + + This module contains macros to build Advanced SIMD dot product instructions + for toolchains that do not natively support this newer instruction set + extension. + + This implementation uses ARM v8.4 dot product instructions. + +--*/ + +/*++ + +Macro Description: + + This macro builds a UDOT instruction of the form: + + UDOT DestReg.4s, Src1Reg.16b, Src2Reg.4b[Index] + +Arguments: + + DestReg - Specifies the destination register. + + Src1Reg - Specifies the first source register. + + Src2Reg - Specifies the second source register. + + Index - Specifies the element index of the second source register. + +--*/ + + .macro UdotByElement DestReg, Src1Reg, Src2Reg, Index + + .set Instruction, 0x6F80E000 + .set Instruction, Instruction + (\DestReg\() << 0) + .set Instruction, Instruction + (\Src1Reg\() << 5) + .set Instruction, Instruction + (\Src2Reg\() << 16) + .set Instruction, Instruction + ((\Index\() & 2) << 10) + .set Instruction, Instruction + ((\Index\() & 1) << 21) + + .inst Instruction + + .endm diff --git a/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelNeon.S index 6f1d7f9ec5..abd367c6ec 100644 --- a/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelNeon.S +++ b/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelNeon.S @@ -22,12 +22,8 @@ Abstract: // .equ .LGemmU8X8KernelFrame_ColumnSumBuffer, 0 - .equ .LGemmU8X8KernelFrame_DepthValue, 8 -#if defined(__APPLE__) - .equ .LGemmU8X8KernelFrame_ZeroMode, 12 -#else + .equ .LGemmU8X8KernelFrame_ZeroPointB, 8 .equ .LGemmU8X8KernelFrame_ZeroMode, 16 -#endif .text @@ -41,10 +37,10 @@ Routine Description: Arguments: A (x0) - Supplies the address of matrix A. The matrix data has been packed - using MlasGemmU8X8CopyPackANeon. + using MlasGemmU8X8CopyPackA. B (x1) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmU8X8CopyPackBNeon. + using MlasGemmU8X8CopyPackB. C (x2) - Supplies the address of matrix C. @@ -60,17 +56,18 @@ Arguments: ldc (x6) - Supplies the first dimension of matrix C. - RowSumBuffer (x7) - Supplies the sum of each row from matrix A multiplied by - the zero point offset of matrix B. These values are accumulated into every - row of matrix C. + RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values + have been pre-scaled by the zero point offset of matrix B if the offset + is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be + scaled by the per-column zero point offsets of matrix B. These values are + accumulated into every row of matrix C. ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied by the zero point offset of matrix A. These values are accumulated into every column of matrix C. - DepthValue - Supplies the value CountK multiplied by the zero point offset - of matrix A multplied by the zero point offset of matrix B. This value is - accumulated into every element of matrix C. + ZeroPointB - Optionally supplies the per-column zero point offsets of matrix + B, else nullptr if the matrix B is using per-tensor quantization. ZeroMode - Supplies true if the output matrix must be zero initialized, else false if the output matrix is accumulated into. @@ -84,13 +81,11 @@ Return Value: FUNCTION_ENTRY MlasGemmU8X8KernelNeon ldr x8,[sp,#.LGemmU8X8KernelFrame_ColumnSumBuffer] - ldr s27,[sp,#.LGemmU8X8KernelFrame_DepthValue] + ldr x9,[sp,#.LGemmU8X8KernelFrame_ZeroPointB] ldrb w13,[sp,#.LGemmU8X8KernelFrame_ZeroMode] - dup v27.4s,v27.s[0] mov x14,x0 - ld1 {v0.4s},[x7] + ld1 {v27.4s},[x7] mov x15,x3 - add v27.4s,v27.4s,v0.4s // broadcast add DepthValue dup v24.4s,v27.s[0] // broadcast row fixups cmp x4,#1 // CountM == 1? beq .LGemmU8X8.M1.ProcessNextColumnLoop @@ -111,6 +106,30 @@ Return Value: mov x3,x15 // reload PackedCountK ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1 uxtl v0.8h,v0.8b + cbz x9,.LGemmU8X8.M4.SkipScaleByZeroPointB + ld1 {v28.4s},[x9],#16 // load ZeroPointB0 + ld1 {v29.4s},[x9],#16 // load ZeroPointB1 + mul v16.4s,v24.4s,v28.4s + mul v17.4s,v24.4s,v29.4s + mul v18.4s,v25.4s,v28.4s + mul v19.4s,v25.4s,v29.4s + mul v20.4s,v26.4s,v28.4s + mul v21.4s,v26.4s,v29.4s + mul v22.4s,v27.4s,v28.4s + mul v23.4s,v27.4s,v29.4s + ld1 {v4.8b},[x0],#8 // load first packed A0 + add v16.4s,v2.4s,v16.4s + add v17.4s,v3.4s,v17.4s + add v18.4s,v2.4s,v18.4s + add v19.4s,v3.4s,v19.4s + ld1 {v5.8b},[x0],#8 // load first packed A1 + add v20.4s,v2.4s,v20.4s + add v21.4s,v3.4s,v21.4s + add v22.4s,v2.4s,v22.4s + add v23.4s,v3.4s,v23.4s + b .LGemmU8X8.M4.ComputeBlockLoop + +.LGemmU8X8.M4.SkipScaleByZeroPointB: ld1 {v4.8b},[x0],#8 // load first packed A0 add v16.4s,v2.4s,v24.4s add v17.4s,v3.4s,v24.4s @@ -322,6 +341,21 @@ Return Value: mov x3,x15 // reload PackedCountK ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1 uxtl v0.8h,v0.8b + cbz x9,.LGemmU8X8.M2.SkipScaleByZeroPointB + ld1 {v28.4s},[x9],#16 // load ZeroPointB0 + ld1 {v29.4s},[x9],#16 // load ZeroPointB1 + mul v16.4s,v24.4s,v28.4s + mul v17.4s,v24.4s,v29.4s + mul v18.4s,v25.4s,v28.4s + mul v19.4s,v25.4s,v29.4s + ld1 {v4.8b},[x0],#8 // load first packed A0 + add v16.4s,v2.4s,v16.4s + add v17.4s,v3.4s,v17.4s + add v18.4s,v2.4s,v18.4s + add v19.4s,v3.4s,v19.4s + b .LGemmU8X8.M2.ComputeBlockLoop + +.LGemmU8X8.M2.SkipScaleByZeroPointB: ld1 {v4.8b},[x0],#8 // load first packed A0 add v16.4s,v2.4s,v24.4s add v17.4s,v3.4s,v24.4s @@ -460,6 +494,17 @@ Return Value: mov x3,x15 // reload PackedCountK ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1 uxtl v0.8h,v0.8b + cbz x9,.LGemmU8X8.M1.SkipScaleByZeroPointB + ld1 {v28.4s},[x9],#16 // load ZeroPointB0 + ld1 {v29.4s},[x9],#16 // load ZeroPointB1 + mul v16.4s,v24.4s,v28.4s + mul v17.4s,v24.4s,v29.4s + ldr s4,[x0],#4 // load first packed A0 + add v16.4s,v2.4s,v16.4s + add v17.4s,v3.4s,v17.4s + b .LGemmU8X8.M1.ComputeBlockLoop + +.LGemmU8X8.M1.SkipScaleByZeroPointB: ldr s4,[x0],#4 // load first packed A0 add v16.4s,v2.4s,v24.4s add v17.4s,v3.4s,v24.4s diff --git a/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelUdot.S b/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelUdot.S new file mode 100644 index 0000000000..d0c086bb1a --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelUdot.S @@ -0,0 +1,585 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + QgemmU8X8KernelUdot.s + +Abstract: + + This module implements the kernels for the quantized integer matrix/matrix + multiply operation (QGEMM). + + This implementation uses ARM v8.4 dot product instructions. + +--*/ + +#include "asmmacro.h" +#include "AssembleDotProduct.h" + +// +// Stack frame layout for the U8X8 kernel. +// + + .equ .LGemmU8X8KernelFrame_SavedNeonRegisters, (4 * 8) + .equ .LGemmU8X8KernelFrame_SavedRegisters, .LGemmU8X8KernelFrame_SavedNeonRegisters + .equ .LGemmU8X8KernelFrame_ColumnSumBuffer, 0 + .LGemmU8X8KernelFrame_SavedRegisters + .equ .LGemmU8X8KernelFrame_ZeroPointB, 8 + .LGemmU8X8KernelFrame_SavedRegisters + .equ .LGemmU8X8KernelFrame_ZeroMode, 16 + .LGemmU8X8KernelFrame_SavedRegisters + + .text + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (x0) - Supplies the address of matrix A. The matrix data has been packed + using MlasGemmU8X8CopyPackA. + + B (x1) - Supplies the address of matrix B. The matrix data has been packed + using MlasGemmU8X8CopyPackB. + + C (x2) - Supplies the address of matrix C. + + PackedCountK (x3) - Supplies the number of packed columns from matrix A and + the number of packed rows from matrix B to iterate over. + + CountM (x4) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (x5) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + ldc (x6) - Supplies the first dimension of matrix C. + + RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values + have been pre-scaled by the zero point offset of matrix B if the offset + is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be + scaled by the per-column zero point offsets of matrix B. These values are + accumulated into every row of matrix C. + + ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied + by the zero point offset of matrix A. These values are accumulated into + every column of matrix C. + + ZeroPointB - Optionally supplies the per-column zero point offsets of matrix + B, else nullptr if the matrix B is using per-tensor quantization. + + ZeroMode - Supplies true if the output matrix must be zero initialized, else + false if the output matrix is accumulated into. + +Return Value: + + Returns the number of rows handled. + +--*/ + + FUNCTION_ENTRY MlasGemmU8X8KernelUdot + + stp d8,d9,[sp,#-32]! + stp d10,d11,[sp,#16] + ldr x8,[sp,#.LGemmU8X8KernelFrame_ColumnSumBuffer] + ldr x9,[sp,#.LGemmU8X8KernelFrame_ZeroPointB] + ldrb w13,[sp,#.LGemmU8X8KernelFrame_ZeroMode] + mov x14,x0 + ld1 {v11.4s},[x7] + mov x15,x3 + dup v8.4s,v11.s[0] // broadcast row fixups + cmp x4,#1 // CountM == 1? + beq .LGemmU8X8.M1.ProcessNextColumnLoop + dup v9.4s,v11.s[1] + cmp x4,#4 // CountM < 4? + blo .LGemmU8X8.M2.ProcessNextColumnLoop + dup v10.4s,v11.s[2] + dup v11.4s,v11.s[3] + +// +// Process 4 rows of the matrices. +// + +.LGemmU8X8.M4.ProcessNextColumnLoop: + ld1 {v0.16b},[x1],#16 // load packed B0 + mov x0,x14 // reload matrix A + ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer[0] + mov x3,x15 // reload PackedCountK + ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer[4] + cbz x9,.LGemmU8X8.M4.SkipScaleByZeroPointB + ld1 {v30.4s},[x9],#16 // load ZeroPointB[0] + mul v16.4s,v30.4s,v8.4s + mul v18.4s,v30.4s,v9.4s + ld1 {v31.4s},[x9],#16 // load ZeroPointB[4] + mul v20.4s,v30.4s,v10.4s + mul v22.4s,v30.4s,v11.4s + mul v17.4s,v31.4s,v8.4s + mul v19.4s,v31.4s,v9.4s + mul v21.4s,v31.4s,v10.4s + mul v23.4s,v31.4s,v11.4s + add v16.4s,v2.4s,v16.4s + add v18.4s,v2.4s,v18.4s + add v20.4s,v2.4s,v20.4s + add v22.4s,v2.4s,v22.4s + add v17.4s,v3.4s,v17.4s + add v19.4s,v3.4s,v19.4s + add v21.4s,v3.4s,v21.4s + add v23.4s,v3.4s,v23.4s + b .LGemmU8X8.M4.ComputeBlockLoopStart + +.LGemmU8X8.M4.SkipScaleByZeroPointB: + add v16.4s,v2.4s,v8.4s + add v18.4s,v2.4s,v9.4s + add v20.4s,v2.4s,v10.4s + add v22.4s,v2.4s,v11.4s + add v17.4s,v3.4s,v8.4s + add v19.4s,v3.4s,v9.4s + add v21.4s,v3.4s,v10.4s + add v23.4s,v3.4s,v11.4s + +// +// The packing layout is setup to have a pair of four quad vectors from +// packed matrix A and a pair of eight quad vectors from packed matrix B. +// With this scheme, alternating loads from the packed matrices can be +// interleaved with the dot product instructions. +// +// One negative consequence of using four rows here is that the accumulator +// register tile is too small for processors with high out of order execution +// windows (such as the Apple M1). The dot product instructions for a given +// cell are too close to each other to avoid dependencies. To workaround this, +// the below loop uses a pair of accumulator registers that are then added +// together when the loop finishes. +// +// A55-based cores are optimized for 64-bit loads, so use 64-bit loads for +// packed matrix A. At the time of this implementation, using a wider 128-bit +// load didn't affect performance for higher end cores. +// + +.LGemmU8X8.M4.ComputeBlockLoopStart: + ldr d4,[x0],#32 // load packed A0.l + movi v24.4s,#0 + movi v25.4s,#0 + ldur d5,[x0,#-24] // load packed A0.h + movi v26.4s,#0 + movi v27.4s,#0 + ldur d6,[x0,#-16] // load packed A1.l + movi v28.4s,#0 + movi v29.4s,#0 + movi v30.4s,#0 + movi v31.4s,#0 + +.LGemmU8X8.M4.ComputeBlockLoop: + ld1 {v1.16b},[x1],#16 // load packed B1 + UdotByElement 16, 0, 4, 0 + UdotByElement 18, 0, 4, 1 + ldur d7,[x0,#-8] // load packed A1.h + UdotByElement 20, 0, 5, 0 + UdotByElement 22, 0, 5, 1 + ld1 {v0.16b},[x1],#16 // load packed B0 + UdotByElement 17, 1, 4, 0 + UdotByElement 19, 1, 4, 1 + sub x3,x3,#1 + cbz x3,.LGemmU8X8.M4.ComputeBlockLoopFinish + ldr d4,[x0],#32 // load packed A0.l + UdotByElement 21, 1, 5, 0 + UdotByElement 23, 1, 5, 1 + ld1 {v1.16b},[x1],#16 // load packed B1 + UdotByElement 24, 0, 6, 0 + UdotByElement 26, 0, 6, 1 + ldur d5,[x0,#-24] // load packed A0.h + UdotByElement 28, 0, 7, 0 + UdotByElement 30, 0, 7, 1 + ld1 {v0.16b},[x1],#16 // load packed B0 + UdotByElement 25, 1, 6, 0 + UdotByElement 27, 1, 6, 1 + ldur d6,[x0,#-16] // load packed A1.l + UdotByElement 29, 1, 7, 0 + UdotByElement 31, 1, 7, 1 + b .LGemmU8X8.M4.ComputeBlockLoop + +.LGemmU8X8.M4.ComputeBlockLoopFinish: + UdotByElement 21, 1, 5, 0 + UdotByElement 23, 1, 5, 1 + ld1 {v1.16b},[x1],#16 // load packed B1 + UdotByElement 24, 0, 6, 0 + UdotByElement 26, 0, 6, 1 + UdotByElement 28, 0, 7, 0 + UdotByElement 30, 0, 7, 1 + UdotByElement 25, 1, 6, 0 + UdotByElement 27, 1, 6, 1 + UdotByElement 29, 1, 7, 0 + UdotByElement 31, 1, 7, 1 + add x10,x2,x6,lsl #2 // compute output row 2 + add v16.4s,v16.4s,v24.4s // fold high results into low results + add v18.4s,v18.4s,v26.4s + add v20.4s,v20.4s,v28.4s + add v22.4s,v22.4s,v30.4s + add x11,x10,x6,lsl #2 // compute output row 3 + add v17.4s,v17.4s,v25.4s + add v19.4s,v19.4s,v27.4s + add v21.4s,v21.4s,v29.4s + add v23.4s,v23.4s,v31.4s + add x12,x11,x6,lsl #2 // compute output row 4 + subs x5,x5,#8 // adjust CountN remaining + blo .LGemmU8X8.M4.StoreOutputPartial + cbnz x13,.LGemmU8X8.M4.SkipAccumulateOutput + ldp q0,q1,[x2] + ldp q2,q3,[x10] + add v16.4s,v16.4s,v0.4s + add v17.4s,v17.4s,v1.4s + ldp q4,q5,[x11] + add v18.4s,v18.4s,v2.4s + add v19.4s,v19.4s,v3.4s + ldp q6,q7,[x12] + add v20.4s,v20.4s,v4.4s + add v21.4s,v21.4s,v5.4s + add v22.4s,v22.4s,v6.4s + add v23.4s,v23.4s,v7.4s + +.LGemmU8X8.M4.SkipAccumulateOutput: + stp q16,q17,[x2],#32 + stp q18,q19,[x10] + stp q20,q21,[x11] + stp q22,q23,[x12] + cbnz x5,.LGemmU8X8.M4.ProcessNextColumnLoop + +.LGemmU8X8.M4.ExitKernel: + mov x0,#4 // return number of rows handled + ldp d10,d11,[sp,#16] + ldp d8,d9,[sp],#32 + ret + +// +// Store the partial 1 to 7 columns either overwriting the output matrix or +// accumulating into the existing contents of the output matrix. +// + +.LGemmU8X8.M4.StoreOutputPartial: + cbz x13,.LGemmU8X8.M4.StoreOutputPartial.AddMode + +.LGemmU8X8.M4.StoreOutputPartial.ZeroMode: + tbz x5,#2,.LGemmU8X8.M4.StoreOutputPartial2.ZeroMode + st1 {v16.4s},[x2],#16 + mov v16.16b,v17.16b // shift remaining elements down + st1 {v18.4s},[x10],#16 + mov v18.16b,v19.16b + st1 {v20.4s},[x11],#16 + mov v20.16b,v21.16b + st1 {v22.4s},[x12],#16 + mov v22.16b,v23.16b + +.LGemmU8X8.M4.StoreOutputPartial2.ZeroMode: + tbz x5,#1,.LGemmU8X8.M4.StoreOutputPartial1.ZeroMode + st1 {v16.2s},[x2],#8 + dup v16.4s,v16.s[2] // shift remaining elements down + st1 {v18.2s},[x10],#8 + dup v18.4s,v18.s[2] + st1 {v20.2s},[x11],#8 + dup v20.4s,v20.s[2] + st1 {v22.2s},[x12],#8 + dup v22.4s,v22.s[2] + +.LGemmU8X8.M4.StoreOutputPartial1.ZeroMode: + tbz x5,#0,.LGemmU8X8.M4.ExitKernel + st1 {v16.s}[0],[x2] + st1 {v18.s}[0],[x10] + st1 {v20.s}[0],[x11] + st1 {v22.s}[0],[x12] + b .LGemmU8X8.M4.ExitKernel + +.LGemmU8X8.M4.StoreOutputPartial.AddMode: + tbz x5,#2,.LGemmU8X8.M4.StoreOutputPartial2.AddMode + ld1 {v0.4s},[x2] + ld1 {v1.4s},[x10] + ld1 {v2.4s},[x11] + ld1 {v3.4s},[x12] + add v16.4s,v16.4s,v0.4s + add v18.4s,v18.4s,v1.4s + st1 {v16.4s},[x2],#16 + mov v16.16b,v17.16b // shift remaining elements down + st1 {v18.4s},[x10],#16 + mov v18.16b,v19.16b + add v20.4s,v20.4s,v2.4s + add v22.4s,v22.4s,v3.4s + st1 {v20.4s},[x11],#16 + mov v20.16b,v21.16b + st1 {v22.4s},[x12],#16 + mov v22.16b,v23.16b + +.LGemmU8X8.M4.StoreOutputPartial2.AddMode: + tbz x5,#1,.LGemmU8X8.M4.StoreOutputPartial1.AddMode + ld1 {v0.2s},[x2] + ld1 {v1.2s},[x10] + ld1 {v2.2s},[x11] + ld1 {v3.2s},[x12] + add v16.4s,v16.4s,v0.4s + add v18.4s,v18.4s,v1.4s + st1 {v16.2s},[x2],#8 + dup v16.4s,v16.s[2] // shift remaining elements down + st1 {v18.2s},[x10],#8 + dup v18.4s,v18.s[2] + add v20.4s,v20.4s,v2.4s + add v22.4s,v22.4s,v3.4s + st1 {v20.2s},[x11],#8 + dup v20.4s,v20.s[2] + st1 {v22.2s},[x12],#8 + dup v22.4s,v22.s[2] + +.LGemmU8X8.M4.StoreOutputPartial1.AddMode: + tbz x5,#0,.LGemmU8X8.M4.ExitKernel + ld1 {v0.s}[0],[x2] + ld1 {v1.s}[0],[x10] + add v16.4s,v16.4s,v0.4s + ld1 {v2.s}[0],[x11] + add v18.4s,v18.4s,v1.4s + ld1 {v3.s}[0],[x12] + add v20.4s,v20.4s,v2.4s + st1 {v16.s}[0],[x2] + st1 {v18.s}[0],[x10] + add v22.4s,v22.4s,v3.4s + st1 {v20.s}[0],[x11] + st1 {v22.s}[0],[x12] + b .LGemmU8X8.M4.ExitKernel + +// +// Process 2 rows of the matrices. +// + +.LGemmU8X8.M2.ProcessNextColumnLoop: + ld1 {v0.16b},[x1],#16 // load packed B0 + ld1 {v1.16b},[x1],#16 // load packed B1 + mov x0,x14 // reload matrix A + ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer[0] + mov x3,x15 // reload PackedCountK + ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer[4] + cbz x9,.LGemmU8X8.M2.SkipScaleByZeroPointB + ld1 {v30.4s},[x9],#16 // load ZeroPointB[0] + ld1 {v31.4s},[x9],#16 // load ZeroPointB[4] + mul v16.4s,v30.4s,v8.4s + mul v18.4s,v30.4s,v9.4s + mul v17.4s,v31.4s,v8.4s + mul v19.4s,v31.4s,v9.4s + ld1 {v4.16b},[x0],#16 // load packed A0 + add v16.4s,v2.4s,v16.4s + add v18.4s,v2.4s,v18.4s + add v17.4s,v3.4s,v17.4s + add v19.4s,v3.4s,v19.4s + b .LGemmU8X8.M2.ComputeBlockLoop + +.LGemmU8X8.M2.SkipScaleByZeroPointB: + ld1 {v4.16b},[x0],#16 // load packed A0 + add v16.4s,v2.4s,v8.4s + add v18.4s,v2.4s,v9.4s + add v17.4s,v3.4s,v8.4s + add v19.4s,v3.4s,v9.4s + +.LGemmU8X8.M2.ComputeBlockLoop: + UdotByElement 16, 0, 4, 0 + UdotByElement 17, 1, 4, 0 + UdotByElement 18, 0, 4, 1 + UdotByElement 19, 1, 4, 1 + ld1 {v0.16b},[x1],#16 // load packed B0 + ld1 {v1.16b},[x1],#16 // load packed B1 + UdotByElement 16, 0, 4, 2 + UdotByElement 17, 1, 4, 2 + UdotByElement 18, 0, 4, 3 + UdotByElement 19, 1, 4, 3 + sub x3,x3,#1 + cbz x3,.LGemmU8X8.M2.ComputeBlockLoopFinish + ld1 {v0.16b},[x1],#16 // load packed B0 + ld1 {v1.16b},[x1],#16 // load packed B1 + ld1 {v4.16b},[x0],#16 // load packed A0 + b .LGemmU8X8.M2.ComputeBlockLoop + +.LGemmU8X8.M2.ComputeBlockLoopFinish: + add x10,x2,x6,lsl #2 // compute output row 2 + subs x5,x5,#8 // adjust CountN remaining + blo .LGemmU8X8.M2.StoreOutputPartial + cbnz x13,.LGemmU8X8.M2.SkipAccumulateOutput + ldp q0,q1,[x2] + ldp q2,q3,[x10] + add v16.4s,v16.4s,v0.4s + add v17.4s,v17.4s,v1.4s + add v18.4s,v18.4s,v2.4s + add v19.4s,v19.4s,v3.4s + +.LGemmU8X8.M2.SkipAccumulateOutput: + stp q16,q17,[x2],#32 + stp q18,q19,[x10] + cbnz x5,.LGemmU8X8.M2.ProcessNextColumnLoop + +.LGemmU8X8.M2.ExitKernel: + mov x0,#2 // return number of rows handled + ldp d10,d11,[sp,#16] + ldp d8,d9,[sp],#32 + ret + +// +// Store the partial 1 to 7 columns either overwriting the output matrix or +// accumulating into the existing contents of the output matrix. +// + +.LGemmU8X8.M2.StoreOutputPartial: + cbz x13,.LGemmU8X8.M2.StoreOutputPartial.AddMode + +.LGemmU8X8.M2.StoreOutputPartial.ZeroMode: + tbz x5,#2,.LGemmU8X8.M2.StoreOutputPartial2.ZeroMode + st1 {v16.4s},[x2],#16 + mov v16.16b,v17.16b // shift remaining elements down + st1 {v18.4s},[x10],#16 + mov v18.16b,v19.16b + +.LGemmU8X8.M2.StoreOutputPartial2.ZeroMode: + tbz x5,#1,.LGemmU8X8.M2.StoreOutputPartial1.ZeroMode + st1 {v16.2s},[x2],#8 + dup v16.4s,v16.s[2] // shift remaining elements down + st1 {v18.2s},[x10],#8 + dup v18.4s,v18.s[2] + +.LGemmU8X8.M2.StoreOutputPartial1.ZeroMode: + tbz x5,#0,.LGemmU8X8.M2.ExitKernel + st1 {v16.s}[0],[x2] + st1 {v18.s}[0],[x10] + b .LGemmU8X8.M2.ExitKernel + +.LGemmU8X8.M2.StoreOutputPartial.AddMode: + tbz x5,#2,.LGemmU8X8.M2.StoreOutputPartial2.AddMode + ld1 {v0.4s},[x2] + ld1 {v1.4s},[x10] + add v16.4s,v16.4s,v0.4s + add v18.4s,v18.4s,v1.4s + st1 {v16.4s},[x2],#16 + mov v16.16b,v17.16b // shift remaining elements down + st1 {v18.4s},[x10],#16 + mov v18.16b,v19.16b + +.LGemmU8X8.M2.StoreOutputPartial2.AddMode: + tbz x5,#1,.LGemmU8X8.M2.StoreOutputPartial1.AddMode + ld1 {v0.2s},[x2] + ld1 {v1.2s},[x10] + add v16.4s,v16.4s,v0.4s + add v18.4s,v18.4s,v1.4s + st1 {v16.2s},[x2],#8 + dup v16.4s,v16.s[2] // shift remaining elements down + st1 {v18.2s},[x10],#8 + dup v18.4s,v18.s[2] + +.LGemmU8X8.M2.StoreOutputPartial1.AddMode: + tbz x5,#0,.LGemmU8X8.M2.ExitKernel + ld1 {v0.s}[0],[x2] + ld1 {v1.s}[0],[x10] + add v16.4s,v16.4s,v0.4s + add v18.4s,v18.4s,v1.4s + st1 {v16.s}[0],[x2] + st1 {v18.s}[0],[x10] + b .LGemmU8X8.M2.ExitKernel + +// +// Process 1 row of the matrices. +// + +.LGemmU8X8.M1.ProcessNextColumnLoop: + ld1 {v0.16b},[x1],#16 // load packed B0 + ld1 {v1.16b},[x1],#16 // load packed B1 + mov x0,x14 // reload matrix A + ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer0 + mov x3,x15 // reload PackedCountK + ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1 + cbz x9,.LGemmU8X8.M1.SkipScaleByZeroPointB + ld1 {v30.4s},[x9],#16 // load ZeroPointB0 + ld1 {v31.4s},[x9],#16 // load ZeroPointB1 + mul v16.4s,v30.4s,v8.4s + mul v17.4s,v31.4s,v8.4s + ldr d4,[x0],#8 // load packed A0 + add v16.4s,v2.4s,v16.4s + add v17.4s,v3.4s,v17.4s + b .LGemmU8X8.M1.ComputeBlockLoop + +.LGemmU8X8.M1.SkipScaleByZeroPointB: + ldr d4,[x0],#8 // load packed A0 + add v16.4s,v2.4s,v8.4s + add v17.4s,v3.4s,v8.4s + +.LGemmU8X8.M1.ComputeBlockLoop: + UdotByElement 16, 0, 4, 0 + UdotByElement 17, 1, 4, 0 + ld1 {v0.16b},[x1],#16 // load packed B0 + ld1 {v1.16b},[x1],#16 // load packed B1 + UdotByElement 16, 0, 4, 1 + UdotByElement 17, 1, 4, 1 + sub x3,x3,#1 + cbz x3,.LGemmU8X8.M1.ComputeBlockLoopFinish + ldr d4,[x0],#8 // load packed A0 + ld1 {v0.16b},[x1],#16 // load packed B0 + ld1 {v1.16b},[x1],#16 // load packed B1 + b .LGemmU8X8.M1.ComputeBlockLoop + +.LGemmU8X8.M1.ComputeBlockLoopFinish: + subs x5,x5,#8 // adjust CountN remaining + blo .LGemmU8X8.M1.StoreOutputPartial + cbnz x13,.LGemmU8X8.M1.SkipAccumulateOutput + ldp q0,q1,[x2] + add v16.4s,v16.4s,v0.4s + add v17.4s,v17.4s,v1.4s + +.LGemmU8X8.M1.SkipAccumulateOutput: + stp q16,q17,[x2],#32 + cbnz x5,.LGemmU8X8.M1.ProcessNextColumnLoop + +.LGemmU8X8.M1.ExitKernel: + mov x0,#1 // return number of rows handled + ldp d10,d11,[sp,#16] + ldp d8,d9,[sp],#32 + ret + +// +// Store the partial 1 to 7 columns either overwriting the output matrix or +// accumulating into the existing contents of the output matrix. +// + +.LGemmU8X8.M1.StoreOutputPartial: + cbz x13,.LGemmU8X8.M1.StoreOutputPartial.AddMode + +.LGemmU8X8.M1.StoreOutputPartial.ZeroMode: + tbz x5,#2,.LGemmU8X8.M1.StoreOutputPartial2.ZeroMode + st1 {v16.4s},[x2],#16 + mov v16.16b,v17.16b // shift remaining elements down + +.LGemmU8X8.M1.StoreOutputPartial2.ZeroMode: + tbz x5,#1,.LGemmU8X8.M1.StoreOutputPartial1.ZeroMode + st1 {v16.2s},[x2],#8 + dup v16.4s,v16.s[2] // shift remaining elements down + +.LGemmU8X8.M1.StoreOutputPartial1.ZeroMode: + tbz x5,#0,.LGemmU8X8.M1.ExitKernel + st1 {v16.s}[0],[x2] + b .LGemmU8X8.M1.ExitKernel + +.LGemmU8X8.M1.StoreOutputPartial.AddMode: + tbz x5,#2,.LGemmU8X8.M1.StoreOutputPartial2.AddMode + ld1 {v0.4s},[x2] + add v16.4s,v16.4s,v0.4s + st1 {v16.4s},[x2],#16 + mov v16.16b,v17.16b // shift remaining elements down + +.LGemmU8X8.M1.StoreOutputPartial2.AddMode: + tbz x5,#1,.LGemmU8X8.M1.StoreOutputPartial1.AddMode + ld1 {v0.2s},[x2] + add v16.4s,v16.4s,v0.4s + st1 {v16.2s},[x2],#8 + dup v16.4s,v16.s[2] // shift remaining elements down + +.LGemmU8X8.M1.StoreOutputPartial1.AddMode: + tbz x5,#0,.LGemmU8X8.M1.ExitKernel + ld1 {v0.s}[0],[x2] + add v16.4s,v16.4s,v0.4s + st1 {v16.s}[0],[x2] + b .LGemmU8X8.M1.ExitKernel + + .end diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm index f785833d98..cccea70c14 100644 --- a/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm +++ b/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm @@ -19,9 +19,10 @@ .xlist INCLUDE mlasi.inc -INCLUDE QgemmU8X8KernelAvx2Common.inc .list + EXTERN MlasMaskMoveTableAvx:NEAR + ; ; Stack frame layout for the U8S8 CopyPackA routine. ; @@ -142,9 +143,9 @@ GemmU8S8CopyPackBFrame ENDS and eax,15 ; isolate unaligned count add eax,3 shr eax,2 ; align unaligned count to quad count - mov DWORD PTR GemmU8S8CopyPackAFrame.CountK[rsp],eax - vpbroadcastd xmm10,DWORD PTR GemmU8S8CopyPackAFrame.CountK[rsp] - vpcmpgtd xmm10,xmm10,XMMWORD PTR [MlasMaskMoveAvx] + neg rax + lea rbx,MlasMaskMoveTableAvx+8*4 + vmovdqu xmm10,XMMWORD PTR [rbx+rax*4] ; ; Zero initialize the padded stack buffers. @@ -776,281 +777,4 @@ StoreColumnSumBufferNUnaligned: NESTED_END MlasGemmU8S8CopyPackBAvx2, _TEXT -; -; Macro Description: -; -; This macro generates code to multiply and accumulator a single row of the -; output block. -; -; Arguments: -; -; ColumnCount - Supplies the number of columns to produce. -; -; Vec1Reg - Supplies the high block accumulator register (when ColumnCount -; is 16). -; -; Vec2Reg - Supplies the low block accumulator register. -; -; Implicit Arguments: -; -; ymm0 - Supplies the first vector loaded from matrix B. -; -; ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount -; is 16). -; -; ymm2 - Supplies the broadcast value loaded from matrix A. -; -; ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001. -; - -MultiplyAccumulateRow MACRO ColumnCount, Vec1Reg, Vec2Reg - - vpmaddubsw ymm3,ymm2,ymm0 - vpmaddwd ymm3,ymm3,ymm12 -IF ColumnCount EQ 16 - vpaddd Vec1Reg,Vec1Reg,ymm3 - vpmaddubsw ymm2,ymm2,ymm1 - vpmaddwd ymm2,ymm2,ymm12 - vpaddd Vec2Reg,Vec2Reg,ymm2 -ELSE - vpaddd Vec2Reg,Vec2Reg,ymm3 -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to multiply and accumulate each row of the output -; block. -; -; Arguments: -; -; ColumnCount - Supplies the number of columns to produce. -; -; RowCount - Supplies the number of rows to produce. -; -; VectorOffset - Supplies the byte offset from matrix B to fetch elements. -; -; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. -; -; Implicit Arguments: -; -; rbx - Supplies the address into the matrix A data plus 3 rows. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; ymm4-ymm11 - Supplies the block accumulators. -; -; ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001. -; - -ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset - -IF RowCount EQ 1 - vpbroadcastd ymm2,DWORD PTR [rcx+BroadcastOffset] - vpmaddubsw ymm3,ymm2,YMMWORD PTR [rdx+VectorOffset] - vpmaddwd ymm3,ymm3,ymm12 -IF ColumnCount EQ 16 - vpaddd ymm4,ymm4,ymm3 - vpmaddubsw ymm2,ymm2,YMMWORD PTR [rdx+VectorOffset+32] - vpmaddwd ymm2,ymm2,ymm12 - vpaddd ymm5,ymm5,ymm2 -ELSE - vpaddd ymm5,ymm5,ymm3 -ENDIF -ELSE - vmovdqu ymm0,YMMWORD PTR [rdx+VectorOffset] - EmitIfCountGE ColumnCount, 16, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to execute the block compute macro multiple -; times and advancing the matrix A and matrix B data pointers. -; -; Arguments: -; -; ColumnCount - Supplies the number of columns to produce. -; -; RowCount - Supplies the number of rows to produce. -; -; Implicit Arguments: -; -; rbx - Supplies the address into the matrix A data plus 3 rows. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; ymm4-ymm11 - Supplies the block accumulators. -; - -ComputeBlockLoop MACRO ColumnCount, RowCount - - LOCAL ComputeBlockBy1Loop - - mov rsi,r9 ; reload row length remaining - -ComputeBlockBy1Loop: - ComputeBlock ColumnCount, RowCount, 0, 0 - add rcx,4 ; advance matrix A by 1 quad -IF RowCount GT 3 - add rbx,4 ; advance matrix A plus 3 rows by 1 quad -ENDIF - add rdx,64 ; advance matrix B - sub rsi,4 - jnz ComputeBlockBy1Loop - - ENDM - -;++ -; -; Routine Description: -; -; This routine is an inner kernel to compute matrix multiplication for a -; set of rows. -; -; Arguments: -; -; A (rcx) - Supplies the address of matrix A. The matrix data has been packed -; using MlasGemmU8S8CopyPackAAvx2. -; -; B (rdx) - Supplies the address of matrix B. The matrix data has been packed -; using MlasGemmU8S8CopyPackBAvx2. -; -; C (r8) - Supplies the address of matrix C. -; -; PackedCountK (r9) - Supplies the number of packed columns from matrix A and -; the number of packed rows from matrix B to iterate over. -; -; CountM - Supplies the maximum number of rows that can be processed for -; matrix A and matrix C. The actual number of rows handled for this -; invocation depends on the kernel implementation. -; -; CountN - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; ldc - Supplies the first dimension of matrix C. -; -; RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the -; zero point offset of matrix B. These values are accumulated into every -; row of matrix C. -; -; ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied -; by the zero point offset of matrix A. These values are accumulated into -; every column of matrix C. -; -; DepthValue - Supplies the value CountK multiplied by the zero point offset -; of matrix A multplied by the zero point offset of matrix B. This value is -; accumulated into every element of matrix C. -; -; ZeroMode - Supplies true if the output matrix must be zero initialized, -; else false if the output matrix is accumulated into. -; -; Return Value: -; -; Returns the number of rows handled. -; -;-- - - NESTED_ENTRY MlasGemmU8S8KernelAvx2, _TEXT - - rex_push_reg rbp - push_reg rbx - push_reg rsi - push_reg rdi - push_reg r12 - push_reg r13 - alloc_stack (GemmU8X8KernelFrame.SavedR13) - save_xmm128 xmm6,GemmU8X8KernelFrame.SavedXmm6 - save_xmm128 xmm7,GemmU8X8KernelFrame.SavedXmm7 - save_xmm128 xmm8,GemmU8X8KernelFrame.SavedXmm8 - save_xmm128 xmm9,GemmU8X8KernelFrame.SavedXmm9 - save_xmm128 xmm10,GemmU8X8KernelFrame.SavedXmm10 - save_xmm128 xmm11,GemmU8X8KernelFrame.SavedXmm11 - save_xmm128 xmm12,GemmU8X8KernelFrame.SavedXmm12 - - END_PROLOGUE - - mov rdi,rcx - mov rbp,GemmU8X8KernelFrame.CountN[rsp] - mov rax,GemmU8X8KernelFrame.ldc[rsp] - shl rax,2 ; convert ldc to bytes - shl r9,2 ; convert to row length - movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp] - mov r11,GemmU8X8KernelFrame.CountM[rsp] - mov r12,GemmU8X8KernelFrame.RowSumBuffer[rsp] - mov r13,GemmU8X8KernelFrame.ColumnSumBuffer[rsp] - vpcmpeqw ymm12,ymm12,ymm12 ; generate 256-bit word vector [0xFFFF] - vpsrlw ymm12,ymm12,15 ; generate 256-bit word vector [0x0001] - -; -; Process CountM rows of the matrices. -; - - cmp r11,3 - ja ProcessCountM4 - je ProcessCountM3 - cmp r11,1 - je ProcessCountM1 - -ProcessCountM2: - ProcessCountM 2 - -ProcessCountM4: - mov r11d,4 ; return 4 rows handled - ProcessCountM 4, Fallthrough - -; -; Restore non-volatile registers and return. -; - -ExitKernel: - mov eax,r11d - vzeroupper - movaps xmm6,GemmU8X8KernelFrame.SavedXmm6[rsp] - movaps xmm7,GemmU8X8KernelFrame.SavedXmm7[rsp] - movaps xmm8,GemmU8X8KernelFrame.SavedXmm8[rsp] - movaps xmm9,GemmU8X8KernelFrame.SavedXmm9[rsp] - movaps xmm10,GemmU8X8KernelFrame.SavedXmm10[rsp] - movaps xmm11,GemmU8X8KernelFrame.SavedXmm11[rsp] - movaps xmm12,GemmU8X8KernelFrame.SavedXmm12[rsp] - add rsp,(GemmU8X8KernelFrame.SavedR13) - - BEGIN_EPILOGUE - - pop r13 - pop r12 - pop rdi - pop rsi - pop rbx - pop rbp - ret - -ProcessCountM1: - ProcessCountM 1 - -ProcessCountM3: - ProcessCountM 3 - - NESTED_END MlasGemmU8S8KernelAvx2, _TEXT - END diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Common.inc b/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Common.inc deleted file mode 100644 index aa132a7df9..0000000000 --- a/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Common.inc +++ /dev/null @@ -1,91 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; QgemmU8S8KernelAvx512Common.inc -; -; Abstract: -; -; This module contains common kernel macros and structures for the quantized -; integer matrix/matrix multiply operation (QGEMM) for the AVX512 core and -; AVX512VNNI kernels. -; -;-- - -INCLUDE QgemmU8X8KernelAvx512Common.inc - -; -; Macro Description: -; -; This macro generates code to execute the block compute macro multiple -; times and advancing the matrix A and matrix B data pointers. -; -; Arguments: -; -; ColumnCount - Supplies the number of columns to produce. -; -; RowCount - Supplies the number of rows to produce. -; -; Implicit Arguments: -; -; rbx - Supplies the address into the matrix A data plus 3 rows. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; r14 - Supplies the stride in bytes of between packed blocks of matrix B. -; -; zmm14-zmm31 - Supplies the block accumulators. -; - -ComputeBlockLoop MACRO ColumnCount, RowCount - - LOCAL ComputeBlockBy4Loop - LOCAL ProcessRemainingBlocks - LOCAL ComputeBlockBy1Loop - LOCAL ComputeBlockLoopExit - - mov rsi,r9 ; reload row length remaining - -IF ((RowCount AND 1) EQ 0) - sub rsi,4*4 - jb ProcessRemainingBlocks - -ComputeBlockBy4Loop: - ComputeBlock ColumnCount, RowCount, 0*64, 0 - ComputeBlock ColumnCount, RowCount, 1*64, 4 - ComputeBlock ColumnCount, RowCount, 2*64, 8 - ComputeBlock ColumnCount, RowCount, 3*64, 12 - add rcx,4*4 ; advance matrix A by 1 quad -IF RowCount GT 3 - add rbx,4*4 ; advance matrix A plus 3 rows by 1 quad -ENDIF - add rdx,4*64 ; advance matrix B - sub rsi,4*4 ; decrement quads remaining - jae ComputeBlockBy4Loop - -ProcessRemainingBlocks: - add rsi,4*4 ; correct for over-subtract above - jz ComputeBlockLoopExit -ENDIF - -ComputeBlockBy1Loop: - ComputeBlock ColumnCount, RowCount, 0, 0 - add rcx,4 ; advance matrix A by 1 quad -IF RowCount GT 3 - add rbx,4 ; advance matrix A plus 3 rows by 1 quad -ENDIF - add rdx,64 ; advance matrix B - sub rsi,4 ; decrement quads remaining - jnz ComputeBlockBy1Loop - -ComputeBlockLoopExit: - - ENDM diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Core.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Core.asm deleted file mode 100644 index f6c8d2d327..0000000000 --- a/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Core.asm +++ /dev/null @@ -1,130 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; QgemmU8S8KernelAvx512Core.asm -; -; Abstract: -; -; This module implements the kernels for the quantized integer matrix/matrix -; multiply operation (QGEMM). -; -; This implementation uses AVX512 core instructions (BW/DQ/VL). -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE QgemmU8S8KernelAvx512Common.inc - .list - -; -; Macro Description: -; -; This macro generates code to multiply and accumulator a single cell of the -; output block. -; -; Arguments: -; -; AccumReg - Supplies the register to accumulate into. -; -; Mult1Reg - Supplies the first multiplication operand register. -; -; Mult2Reg - Supplies the second multiplication operand register. -; -; Implicit Arguments: -; -; zmm4 - Supplies a scratch register for intermediate results. -; -; zmm5 - Supplies a 512-bit with the broadcasted word value 0x0001. -; - -MultiplyAccumulateCell MACRO AccumReg, Mult1Reg, Mult2Reg - - vpmaddubsw zmm4,Mult1Reg,Mult2Reg - vpmaddwd zmm4,zmm4,zmm5 - vpaddd AccumReg,AccumReg,zmm4 - - ENDM - -; -; Macro Description: -; -; This macro generates code to multiply and accumulate each row of the output -; block. -; -; Arguments: -; -; ColumnCount - Supplies the number of columns to produce. -; -; RowCount - Supplies the number of rows to produce. -; -; VectorOffset - Supplies the byte offset from matrix B to fetch elements. -; -; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. -; -; Implicit Arguments: -; -; rbx - Supplies the address into the matrix A data plus 3 rows. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; r14 - Supplies the stride in bytes of between packed blocks of matrix B. -; -; zmm14-zmm31 - Supplies the block accumulators. -; - -ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset - -IF ColumnCount GE 48 - vmovdqu32 zmm0,ZMMWORD PTR [rdx+VectorOffset] - vmovdqu32 zmm1,ZMMWORD PTR [rdx+r14+VectorOffset] - vmovdqu32 zmm2,ZMMWORD PTR [rdx+r14*2+VectorOffset] -ELSEIF ColumnCount GE 32 - vmovdqu32 zmm1,ZMMWORD PTR [rdx+VectorOffset] - vmovdqu32 zmm2,ZMMWORD PTR [rdx+r14+VectorOffset] -ELSE - vmovdqu32 zmm2,ZMMWORD PTR [rdx+VectorOffset] -ENDIF - EmitIfCountGE RowCount, 1, - EmitIfCount2GE RowCount, 1, ColumnCount, 48, - EmitIfCount2GE RowCount, 1, ColumnCount, 32, - EmitIfCount2GE RowCount, 1, ColumnCount, 16, - EmitIfCountGE RowCount, 2, - EmitIfCount2GE RowCount, 2, ColumnCount, 48, - EmitIfCount2GE RowCount, 2, ColumnCount, 32, - EmitIfCount2GE RowCount, 2, ColumnCount, 16, - EmitIfCountGE RowCount, 3, - EmitIfCount2GE RowCount, 3, ColumnCount, 48, - EmitIfCount2GE RowCount, 3, ColumnCount, 32, - EmitIfCount2GE RowCount, 3, ColumnCount, 16, - EmitIfCountGE RowCount, 4, - EmitIfCount2GE RowCount, 4, ColumnCount, 48, - EmitIfCount2GE RowCount, 4, ColumnCount, 32, - EmitIfCount2GE RowCount, 4, ColumnCount, 16, - EmitIfCountGE RowCount, 5, - EmitIfCount2GE RowCount, 5, ColumnCount, 48, - EmitIfCount2GE RowCount, 5, ColumnCount, 32, - EmitIfCount2GE RowCount, 5, ColumnCount, 16, - EmitIfCountGE RowCount, 6, - EmitIfCount2GE RowCount, 6, ColumnCount, 48, - EmitIfCount2GE RowCount, 6, ColumnCount, 32, - EmitIfCount2GE RowCount, 6, ColumnCount, 16, - - ENDM - -; -; Generate the GEMM kernel. -; - -GemmU8X8KernelAvx512Function U8S8, Avx512Core - - END diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Vnni.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Vnni.asm deleted file mode 100644 index ddab81a6a4..0000000000 --- a/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Vnni.asm +++ /dev/null @@ -1,102 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; QgemmU8S8KernelAvx512Vnni.asm -; -; Abstract: -; -; This module implements the kernels for the quantized integer matrix/matrix -; multiply operation (QGEMM). -; -; This implementation uses AVX512VNNI instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE QgemmU8S8KernelAvx512Common.inc -INCLUDE AssembleAvx512Vnni.inc - .list - -; -; Macro Description: -; -; This macro generates code to multiply and accumulate each row of the output -; block. -; -; Arguments: -; -; ColumnCount - Supplies the number of columns to produce. -; -; RowCount - Supplies the number of rows to produce. -; -; VectorOffset - Supplies the byte offset from matrix B to fetch elements. -; -; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. -; -; Implicit Arguments: -; -; rbx - Supplies the address into the matrix A data plus 3 rows. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; r14 - Supplies the stride in bytes of between packed blocks of matrix B. -; -; zmm14-zmm31 - Supplies the block accumulators. -; - -ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset - -IF ColumnCount GE 48 - vmovdqu32 zmm0,ZMMWORD PTR [rdx+VectorOffset] - vmovdqu32 zmm1,ZMMWORD PTR [rdx+r14+VectorOffset] - vmovdqu32 zmm2,ZMMWORD PTR [rdx+r14*2+VectorOffset] -ELSEIF ColumnCount GE 32 - vmovdqu32 zmm1,ZMMWORD PTR [rdx+VectorOffset] - vmovdqu32 zmm2,ZMMWORD PTR [rdx+r14+VectorOffset] -ELSE - vmovdqu32 zmm2,ZMMWORD PTR [rdx+VectorOffset] -ENDIF - EmitIfCountGE RowCount, 1, - EmitIfCount2GE RowCount, 1, ColumnCount, 48, - EmitIfCount2GE RowCount, 1, ColumnCount, 32, - EmitIfCount2GE RowCount, 1, ColumnCount, 16, - EmitIfCountGE RowCount, 2, - EmitIfCount2GE RowCount, 2, ColumnCount, 48, - EmitIfCount2GE RowCount, 2, ColumnCount, 32, - EmitIfCount2GE RowCount, 2, ColumnCount, 16, - EmitIfCountGE RowCount, 3, - EmitIfCount2GE RowCount, 3, ColumnCount, 48, - EmitIfCount2GE RowCount, 3, ColumnCount, 32, - EmitIfCount2GE RowCount, 3, ColumnCount, 16, - EmitIfCountGE RowCount, 4, - EmitIfCount2GE RowCount, 4, ColumnCount, 48, - EmitIfCount2GE RowCount, 4, ColumnCount, 32, - EmitIfCount2GE RowCount, 4, ColumnCount, 16, - EmitIfCountGE RowCount, 5, - EmitIfCount2GE RowCount, 5, ColumnCount, 48, - EmitIfCount2GE RowCount, 5, ColumnCount, 32, - EmitIfCount2GE RowCount, 5, ColumnCount, 16, - EmitIfCountGE RowCount, 6, - EmitIfCount2GE RowCount, 6, ColumnCount, 48, - EmitIfCount2GE RowCount, 6, ColumnCount, 32, - EmitIfCount2GE RowCount, 6, ColumnCount, 16, - - ENDM - -; -; Generate the GEMM kernel. -; - -GemmU8X8KernelAvx512Function U8S8, Avx512Vnni - - END diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvxVnni.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvxVnni.asm deleted file mode 100644 index bd847336eb..0000000000 --- a/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvxVnni.asm +++ /dev/null @@ -1,298 +0,0 @@ -;++ -; -; Copyright (c) Intel Corporation 2020. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; QgemmU8S8KernelAvxVnni.asm -; -; Abstract: -; -; This module implements the kernels for the quantized integer matrix/matrix -; multiply operation (QGEMM). -; -; This implementation uses AVXVNNI instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE QgemmU8X8KernelAvx2Common.inc -INCLUDE AssembleAvxVnni.inc - .list - -; -; Macro Description: -; -; This macro generates code to multiply and accumulator a single row of the -; output block. -; -; Arguments: -; -; ColumnCount - Supplies the number of columns to produce. -; -; Vec1Reg - Supplies the high block accumulator register (when ColumnCount -; is 16). -; -; Vec2Reg - Supplies the low block accumulator register. -; -; Implicit Arguments: -; -; ymm0 - Supplies the first vector loaded from matrix B. -; -; ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount -; is 16). -; -; ymm2 - Supplies the broadcast value loaded from matrix A. -; - -MultiplyAccumulateRow MACRO ColumnCount, Vec1Reg, Vec2Reg - -IF ColumnCount EQ 16 - VpdpbusdsYmmYmmYmm Vec1Reg,ymm2,ymm0 - VpdpbusdsYmmYmmYmm Vec2Reg,ymm2,ymm1 -ELSE - VpdpbusdsYmmYmmYmm Vec2Reg,ymm2,ymm0 -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to multiply and accumulate each row of the output -; block. -; -; Arguments: -; -; ColumnCount - Supplies the number of columns to produce. -; -; RowCount - Supplies the number of rows to produce. -; -; VectorOffset - Supplies the byte offset from matrix B to fetch elements. -; -; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. -; -; Implicit Arguments: -; -; rbx - Supplies the address into the matrix A data plus 3 rows. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; ymm4-ymm15 - Supplies the block accumulators. -; - -ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset - - vmovdqu ymm0,YMMWORD PTR [rdx+VectorOffset] - EmitIfCountGE ColumnCount, 16, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 6, - - ENDM - -; -; Macro Description: -; -; This macro generates code to execute the block compute macro multiple -; times and advancing the matrix A and matrix B data pointers. -; -; Arguments: -; -; ColumnCount - Supplies the number of columns to produce. -; -; RowCount - Supplies the number of rows to produce. -; -; Implicit Arguments: -; -; rbx - Supplies the address into the matrix A data plus 3 rows. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; ymm4-ymm15 - Supplies the block accumulators. -; - -ComputeBlockLoop MACRO ColumnCount, RowCount - - LOCAL ComputeBlockBy1Loop - - mov rsi,r9 ; reload row length remaining - -ComputeBlockBy1Loop: - ComputeBlock ColumnCount, RowCount, 0, 0 - add rcx,4 ; advance matrix A by 1 quad -IF RowCount GT 3 - add rbx,4 ; advance matrix A plus 3 rows by 1 quad -ENDIF - add rdx,64 ; advance matrix B - sub rsi,4 - jnz ComputeBlockBy1Loop - - ENDM - -;++ -; -; Routine Description: -; -; This routine is an inner kernel to compute matrix multiplication for a -; set of rows. -; -; Arguments: -; -; A (rcx) - Supplies the address of matrix A. The matrix data has been packed -; using MlasGemmU8S8CopyPackAAvx2. -; -; B (rdx) - Supplies the address of matrix B. The matrix data has been packed -; using MlasGemmU8S8CopyPackBAvx2. -; -; C (r8) - Supplies the address of matrix C. -; -; PackedCountK (r9) - Supplies the number of packed columns from matrix A and -; the number of packed rows from matrix B to iterate over. -; -; CountM - Supplies the maximum number of rows that can be processed for -; matrix A and matrix C. The actual number of rows handled for this -; invocation depends on the kernel implementation. -; -; CountN - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; ldc - Supplies the first dimension of matrix C. -; -; RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the -; zero point offset of matrix B. These values are accumulated into every -; row of matrix C. -; -; ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied -; by the zero point offset of matrix A. These values are accumulated into -; every column of matrix C. -; -; DepthValue - Supplies the value CountK multiplied by the zero point offset -; of matrix A multplied by the zero point offset of matrix B. This value is -; accumulated into every element of matrix C. -; -; ZeroMode - Supplies true if the output matrix must be zero initialized, -; else false if the output matrix is accumulated into. -; -; Return Value: -; -; Returns the number of rows handled. -; -;-- - - NESTED_ENTRY MlasGemmU8S8KernelAvxVnni, _TEXT - - rex_push_reg rbp - push_reg rbx - push_reg rsi - push_reg rdi - push_reg r12 - push_reg r13 - alloc_stack (GemmU8X8KernelFrame.SavedR13) - save_xmm128 xmm6,GemmU8X8KernelFrame.SavedXmm6 - save_xmm128 xmm7,GemmU8X8KernelFrame.SavedXmm7 - save_xmm128 xmm8,GemmU8X8KernelFrame.SavedXmm8 - save_xmm128 xmm9,GemmU8X8KernelFrame.SavedXmm9 - save_xmm128 xmm10,GemmU8X8KernelFrame.SavedXmm10 - save_xmm128 xmm11,GemmU8X8KernelFrame.SavedXmm11 - save_xmm128 xmm12,GemmU8X8KernelFrame.SavedXmm12 - save_xmm128 xmm13,GemmU8X8KernelFrame.SavedXmm13 - save_xmm128 xmm14,GemmU8X8KernelFrame.SavedXmm14 - save_xmm128 xmm15,GemmU8X8KernelFrame.SavedXmm15 - - END_PROLOGUE - - mov rdi,rcx - mov rbp,GemmU8X8KernelFrame.CountN[rsp] - mov rax,GemmU8X8KernelFrame.ldc[rsp] - shl rax,2 ; convert ldc to bytes - shl r9,2 ; convert to row length - movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp] - mov r11,GemmU8X8KernelFrame.CountM[rsp] - mov r12,GemmU8X8KernelFrame.RowSumBuffer[rsp] - mov r13,GemmU8X8KernelFrame.ColumnSumBuffer[rsp] - -; -; Process CountM rows of the matrices. -; - - cmp r11,5 - ja ProcessCountM6 - je ProcessCountM5 - cmp r11,3 - ja ProcessCountM4 - je ProcessCountM3 - cmp r11,1 - je ProcessCountM1 - -ProcessCountM2: - ProcessCountM 2 - -ProcessCountM4: - ProcessCountM 4 - -ProcessCountM6: - mov r11d,6 ; return 6 rows handled - ProcessCountM 6, Fallthrough - -; -; Restore non-volatile registers and return. -; - -ExitKernel: - mov eax,r11d - vzeroupper - movaps xmm6,GemmU8X8KernelFrame.SavedXmm6[rsp] - movaps xmm7,GemmU8X8KernelFrame.SavedXmm7[rsp] - movaps xmm8,GemmU8X8KernelFrame.SavedXmm8[rsp] - movaps xmm9,GemmU8X8KernelFrame.SavedXmm9[rsp] - movaps xmm10,GemmU8X8KernelFrame.SavedXmm10[rsp] - movaps xmm11,GemmU8X8KernelFrame.SavedXmm11[rsp] - movaps xmm12,GemmU8X8KernelFrame.SavedXmm12[rsp] - movaps xmm13,GemmU8X8KernelFrame.SavedXmm13[rsp] - movaps xmm14,GemmU8X8KernelFrame.SavedXmm14[rsp] - movaps xmm15,GemmU8X8KernelFrame.SavedXmm15[rsp] - add rsp,(GemmU8X8KernelFrame.SavedR13) - - BEGIN_EPILOGUE - - pop r13 - pop r12 - pop rdi - pop rsi - pop rbx - pop rbp - ret - -ProcessCountM1: - ProcessCountM 1 - -ProcessCountM3: - ProcessCountM 3 - -ProcessCountM5: - ProcessCountM 5 - - NESTED_END MlasGemmU8S8KernelAvxVnni, _TEXT - - END diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx2.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx2.asm index d30823756f..30c97fd36f 100644 --- a/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx2.asm +++ b/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx2.asm @@ -19,9 +19,10 @@ .xlist INCLUDE mlasi.inc -INCLUDE QgemmU8X8KernelAvx2Common.inc .list + EXTERN MlasMaskMoveTableAvx:NEAR + ; ; Stack frame layout for the U8U8 CopyPackA routine. ; @@ -136,9 +137,9 @@ GemmU8U8CopyPackBFrame ENDS and eax,15 ; isolate unaligned count inc eax shr eax,1 ; align unaligned count to pair count - mov DWORD PTR GemmU8U8CopyPackAFrame.CountK[rsp],eax - vpbroadcastd ymm9,DWORD PTR GemmU8U8CopyPackAFrame.CountK[rsp] - vpcmpgtd ymm9,ymm9,YMMWORD PTR [MlasMaskMoveAvx] + neg rax + lea rbx,MlasMaskMoveTableAvx+8*4 + vmovdqu ymm9,YMMWORD PTR [rbx+rax*4] ; ; Zero initialize the padded stack buffers. @@ -663,305 +664,4 @@ StoreColumnSumBufferNUnaligned: NESTED_END MlasGemmU8U8CopyPackBAvx2, _TEXT -; -; Macro Description: -; -; This macro generates code to multiply and accumulator a single row of the -; output block. -; -; Arguments: -; -; ColumnCount - Supplies the number of columns to produce. -; -; Vec1Reg - Supplies the high block accumulator register (when ColumnCount -; is 16). -; -; Vec2Reg - Supplies the low block accumulator register. -; -; Implicit Arguments: -; -; ymm0 - Supplies the first vector loaded from matrix B. -; -; ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount -; is 16). -; -; ymm2 - Supplies the broadcast value loaded from matrix A. -; - -MultiplyAccumulateRow MACRO ColumnCount, Vec1Reg, Vec2Reg - - vpmaddwd ymm3,ymm2,ymm0 -IF ColumnCount EQ 16 - vpaddd Vec1Reg,Vec1Reg,ymm3 - vpmaddwd ymm2,ymm2,ymm1 - vpaddd Vec2Reg,Vec2Reg,ymm2 -ELSE - vpaddd Vec2Reg,Vec2Reg,ymm3 -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to multiply and accumulate each row of the output -; block. -; -; Arguments: -; -; ColumnCount - Supplies the number of columns to produce. -; -; RowCount - Supplies the number of rows to produce. -; -; VectorOffset - Supplies the byte offset from matrix B to fetch elements. -; -; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. -; -; Implicit Arguments: -; -; rbx - Supplies the address into the matrix A data plus 3 rows. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; ymm4-ymm15 - Supplies the block accumulators. -; - -ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset - - vpmovzxbw ymm0,XMMWORD PTR [rdx+VectorOffset] - EmitIfCountGE ColumnCount, 16, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 6, - - ENDM - -; -; Macro Description: -; -; This macro generates code to execute the block compute macro multiple -; times and advancing the matrix A and matrix B data pointers. -; -; Arguments: -; -; ColumnCount - Supplies the number of columns to produce. -; -; RowCount - Supplies the number of rows to produce. -; -; Implicit Arguments: -; -; rbx - Supplies the address into the matrix A data plus 3 rows. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; ymm4-ymm15 - Supplies the block accumulators. -; - -ComputeBlockLoop MACRO ColumnCount, RowCount - - LOCAL ComputeBlockBy2Loop - LOCAL ProcessRemainingBlocks - LOCAL ComputeBlockBy1Loop - LOCAL ComputeBlockLoopExit - - mov rsi,r9 ; reload row length remaining - -IF (ColumnCount EQ 16) AND ((RowCount AND 1) EQ 0) - sub rsi,2*4 - jb ProcessRemainingBlocks - -ComputeBlockBy2Loop: - ComputeBlock ColumnCount, RowCount, 0, 0 - ComputeBlock ColumnCount, RowCount, 32, 4 - add rcx,2*4 ; advance matrix A by 2 pairs -IF RowCount GT 3 - add rbx,2*4 ; advance matrix A plus 3 rows by 2 pairs -ENDIF - add rdx,2*32 ; advance matrix B - sub rsi,2*4 - jae ComputeBlockBy2Loop - -ProcessRemainingBlocks: - add rsi,2*4 ; correct for over-subtract above - jz ComputeBlockLoopExit - ComputeBlock ColumnCount, RowCount, 0, 0 - add rdx,32 ; advance matrix B -ELSE -ComputeBlockBy1Loop: - ComputeBlock ColumnCount, RowCount, 0, 0 - add rcx,4 ; advance matrix A by 1 pair -IF RowCount GT 3 - add rbx,4 ; advance matrix A plus 3 rows by 1 pair -ENDIF - add rdx,32 ; advance matrix B - sub rsi,4 - jnz ComputeBlockBy1Loop -ENDIF - -ComputeBlockLoopExit: - - ENDM - -;++ -; -; Routine Description: -; -; This routine is an inner kernel to compute matrix multiplication for a -; set of rows. -; -; Arguments: -; -; A (rcx) - Supplies the address of matrix A. The matrix data has been packed -; using MlasGemmU8U8CopyPackAAvx2. -; -; B (rdx) - Supplies the address of matrix B. The matrix data has been packed -; using MlasGemmU8U8CopyPackBAvx2. -; -; C (r8) - Supplies the address of matrix C. -; -; PackedCountK (r9) - Supplies the number of packed columns from matrix A and -; the number of packed rows from matrix B to iterate over. -; -; CountM - Supplies the maximum number of rows that can be processed for -; matrix A and matrix C. The actual number of rows handled for this -; invocation depends on the kernel implementation. -; -; CountN - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; ldc - Supplies the first dimension of matrix C. -; -; RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the -; zero point offset of matrix B. These values are accumulated into every -; row of matrix C. -; -; ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied -; by the zero point offset of matrix A. These values are accumulated into -; every column of matrix C. -; -; DepthValue - Supplies the value CountK multiplied by the zero point offset -; of matrix A multplied by the zero point offset of matrix B. This value is -; accumulated into every element of matrix C. -; -; ZeroMode - Supplies true if the output matrix must be zero initialized, -; else false if the output matrix is accumulated into. -; -; Return Value: -; -; Returns the number of rows handled. -; -;-- - - NESTED_ENTRY MlasGemmU8U8KernelAvx2, _TEXT - - rex_push_reg rbp - push_reg rbx - push_reg rsi - push_reg rdi - push_reg r12 - push_reg r13 - alloc_stack (GemmU8X8KernelFrame.SavedR13) - save_xmm128 xmm6,GemmU8X8KernelFrame.SavedXmm6 - save_xmm128 xmm7,GemmU8X8KernelFrame.SavedXmm7 - save_xmm128 xmm8,GemmU8X8KernelFrame.SavedXmm8 - save_xmm128 xmm9,GemmU8X8KernelFrame.SavedXmm9 - save_xmm128 xmm10,GemmU8X8KernelFrame.SavedXmm10 - save_xmm128 xmm11,GemmU8X8KernelFrame.SavedXmm11 - save_xmm128 xmm12,GemmU8X8KernelFrame.SavedXmm12 - save_xmm128 xmm13,GemmU8X8KernelFrame.SavedXmm13 - save_xmm128 xmm14,GemmU8X8KernelFrame.SavedXmm14 - save_xmm128 xmm15,GemmU8X8KernelFrame.SavedXmm15 - - END_PROLOGUE - - mov rdi,rcx - mov rbp,GemmU8X8KernelFrame.CountN[rsp] - mov rax,GemmU8X8KernelFrame.ldc[rsp] - shl rax,2 ; convert ldc to bytes - shl r9,2 ; convert to row length - movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp] - mov r11,GemmU8X8KernelFrame.CountM[rsp] - mov r12,GemmU8X8KernelFrame.RowSumBuffer[rsp] - mov r13,GemmU8X8KernelFrame.ColumnSumBuffer[rsp] - -; -; Process CountM rows of the matrices. -; - - cmp r11,5 - ja ProcessCountM6 - je ProcessCountM5 - cmp r11,3 - ja ProcessCountM4 - je ProcessCountM3 - cmp r11,1 - je ProcessCountM1 - -ProcessCountM2: - ProcessCountM 2 - -ProcessCountM4: - ProcessCountM 4 - -ProcessCountM6: - mov r11d,6 ; return 6 rows handled - ProcessCountM 6, Fallthrough - -; -; Restore non-volatile registers and return. -; - -ExitKernel: - mov eax,r11d - vzeroupper - movaps xmm6,GemmU8X8KernelFrame.SavedXmm6[rsp] - movaps xmm7,GemmU8X8KernelFrame.SavedXmm7[rsp] - movaps xmm8,GemmU8X8KernelFrame.SavedXmm8[rsp] - movaps xmm9,GemmU8X8KernelFrame.SavedXmm9[rsp] - movaps xmm10,GemmU8X8KernelFrame.SavedXmm10[rsp] - movaps xmm11,GemmU8X8KernelFrame.SavedXmm11[rsp] - movaps xmm12,GemmU8X8KernelFrame.SavedXmm12[rsp] - movaps xmm13,GemmU8X8KernelFrame.SavedXmm13[rsp] - movaps xmm14,GemmU8X8KernelFrame.SavedXmm14[rsp] - movaps xmm15,GemmU8X8KernelFrame.SavedXmm15[rsp] - add rsp,(GemmU8X8KernelFrame.SavedR13) - - BEGIN_EPILOGUE - - pop r13 - pop r12 - pop rdi - pop rsi - pop rbx - pop rbp - ret - -ProcessCountM1: - ProcessCountM 1 - -ProcessCountM3: - ProcessCountM 3 - -ProcessCountM5: - ProcessCountM 5 - - NESTED_END MlasGemmU8U8KernelAvx2, _TEXT - END diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Core.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Core.asm deleted file mode 100644 index 554aa1c981..0000000000 --- a/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Core.asm +++ /dev/null @@ -1,172 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; QgemmU8U8KernelAvx512Core.asm -; -; Abstract: -; -; This module implements the kernels for the quantized integer matrix/matrix -; multiply operation (QGEMM). -; -; This implementation uses AVX512 core instructions (BW/DQ/VL). -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE QgemmU8X8KernelAvx512Common.inc - .list - -; -; Macro Description: -; -; This macro generates code to multiply and accumulator a single cell of the -; output block. -; -; Arguments: -; -; AccumReg - Supplies the register to accumulate into. -; -; Mult1Reg - Supplies the first multiplication operand register. -; -; Mult2Reg - Supplies the second multiplication operand register. -; -; Implicit Arguments: -; -; zmm4 - Supplies a scratch register for intermediate results. -; - -MultiplyAccumulateCell MACRO AccumReg, Mult1Reg, Mult2Reg - - vpmaddwd zmm4,Mult1Reg,Mult2Reg - vpaddd AccumReg,AccumReg,zmm4 - - ENDM - -; -; Macro Description: -; -; This macro generates code to multiply and accumulate each row of the output -; block. -; -; Arguments: -; -; ColumnCount - Supplies the number of columns to produce. -; -; RowCount - Supplies the number of rows to produce. -; -; VectorOffset - Supplies the byte offset from matrix B to fetch elements. -; -; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. -; -; Implicit Arguments: -; -; rbx - Supplies the address into the matrix A data plus 3 rows. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; r14 - Supplies the stride in bytes of between packed blocks of matrix B. -; -; zmm14-zmm31 - Supplies the block accumulators. -; - -ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset - -IF ColumnCount GE 48 - vpmovzxbw zmm0,YMMWORD PTR [rdx+VectorOffset] - vpmovzxbw zmm1,YMMWORD PTR [rdx+r14+VectorOffset] - vpmovzxbw zmm2,YMMWORD PTR [rdx+r14*2+VectorOffset] -ELSEIF ColumnCount GE 32 - vpmovzxbw zmm1,YMMWORD PTR [rdx+VectorOffset] - vpmovzxbw zmm2,YMMWORD PTR [rdx+r14+VectorOffset] -ELSE - vpmovzxbw zmm2,YMMWORD PTR [rdx+VectorOffset] -ENDIF - EmitIfCountGE RowCount, 1, - EmitIfCount2GE RowCount, 1, ColumnCount, 48, - EmitIfCount2GE RowCount, 1, ColumnCount, 32, - EmitIfCount2GE RowCount, 1, ColumnCount, 16, - EmitIfCountGE RowCount, 2, - EmitIfCount2GE RowCount, 2, ColumnCount, 48, - EmitIfCount2GE RowCount, 2, ColumnCount, 32, - EmitIfCount2GE RowCount, 2, ColumnCount, 16, - EmitIfCountGE RowCount, 3, - EmitIfCount2GE RowCount, 3, ColumnCount, 48, - EmitIfCount2GE RowCount, 3, ColumnCount, 32, - EmitIfCount2GE RowCount, 3, ColumnCount, 16, - EmitIfCountGE RowCount, 4, - EmitIfCount2GE RowCount, 4, ColumnCount, 48, - EmitIfCount2GE RowCount, 4, ColumnCount, 32, - EmitIfCount2GE RowCount, 4, ColumnCount, 16, - EmitIfCountGE RowCount, 5, - EmitIfCount2GE RowCount, 5, ColumnCount, 48, - EmitIfCount2GE RowCount, 5, ColumnCount, 32, - EmitIfCount2GE RowCount, 5, ColumnCount, 16, - EmitIfCountGE RowCount, 6, - EmitIfCount2GE RowCount, 6, ColumnCount, 48, - EmitIfCount2GE RowCount, 6, ColumnCount, 32, - EmitIfCount2GE RowCount, 6, ColumnCount, 16, - - ENDM - -; -; Macro Description: -; -; This macro generates code to execute the block compute macro multiple -; times and advancing the matrix A and matrix B data pointers. -; -; Arguments: -; -; ColumnCount - Supplies the number of columns to produce. -; -; RowCount - Supplies the number of rows to produce. -; -; Implicit Arguments: -; -; rbx - Supplies the address into the matrix A data plus 3 rows. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; r14 - Supplies the stride in bytes of between packed blocks of matrix B. -; -; zmm14-zmm31 - Supplies the block accumulators. -; - -ComputeBlockLoop MACRO ColumnCount, RowCount - - LOCAL ComputeBlockBy1Loop - - mov rsi,r9 ; reload row length remaining - -ComputeBlockBy1Loop: - ComputeBlock ColumnCount, RowCount, 0, 0 - add rcx,4 ; advance matrix A by 1 pair -IF RowCount GT 3 - add rbx,4 ; advance matrix A plus 3 rows by 1 pair -ENDIF - add rdx,32 ; advance matrix B - sub rsi,4 - jnz ComputeBlockBy1Loop - - ENDM - -; -; Generate the GEMM kernel. -; - -GemmU8X8KernelAvx512Function U8U8, Avx512Core - - END diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2.asm new file mode 100644 index 0000000000..4b345d07fe --- /dev/null +++ b/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2.asm @@ -0,0 +1,940 @@ +;++ +; +; Copyright (c) Microsoft Corporation. All rights reserved. +; +; Licensed under the MIT License. +; +; Module Name: +; +; QgemmU8X8KernelAvx2.asm +; +; Abstract: +; +; This module implements the kernels for the quantized integer matrix/matrix +; multiply operation (QGEMM). +; +; This implementation uses AVX2 and AVX VNNI instructions. +; +;-- + + .xlist +INCLUDE mlasi.inc +INCLUDE AssembleAvxVnni.inc + .list + + EXTERN MlasMaskMoveTableAvx:NEAR + +; +; Stack frame layout for the U8X8 kernel. +; + +GemmU8X8KernelFrame STRUCT + + SavedXmm6 OWORD ? + SavedXmm7 OWORD ? + SavedXmm8 OWORD ? + SavedXmm9 OWORD ? + SavedXmm10 OWORD ? + SavedXmm11 OWORD ? + SavedXmm12 OWORD ? + SavedXmm13 OWORD ? + SavedXmm14 OWORD ? + SavedXmm15 OWORD ? + Padding QWORD ? + SavedR13 QWORD ? + SavedR12 QWORD ? + SavedRdi QWORD ? + SavedRsi QWORD ? + SavedRbx QWORD ? + SavedRbp QWORD ? + ReturnAddress QWORD ? + PreviousP1Home QWORD ? + PreviousP2Home QWORD ? + PreviousP3Home QWORD ? + PreviousP4Home QWORD ? + CountM QWORD ? + CountN QWORD ? + ldc QWORD ? + RowSumBuffer QWORD ? + ColumnSumBuffer QWORD ? + ZeroPointB QWORD ? + ZeroMode QWORD ? + +GemmU8X8KernelFrame ENDS + +; +; Macro Description: +; +; This macro generates code to multiply and accumulator a single row of the +; output block. +; +; Arguments: +; +; ColumnCount - Supplies the number of columns to produce. +; +; Vec1Reg - Supplies the high block accumulator register (when ColumnCount +; is 16). +; +; Vec2Reg - Supplies the low block accumulator register. +; +; Implicit Arguments: +; +; ymm0 - Supplies the first vector loaded from matrix B. +; +; ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount +; is 16). +; +; ymm2 - Supplies the broadcast value loaded from matrix A. +; +; ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001. +; + +MultiplyAccumulateRowU8S8Avx2 MACRO ColumnCount, Vec1Reg, Vec2Reg + + vpmaddubsw ymm3,ymm2,ymm0 + vpmaddwd ymm3,ymm3,ymm12 +IF ColumnCount EQ 16 + vpaddd Vec1Reg,Vec1Reg,ymm3 + vpmaddubsw ymm2,ymm2,ymm1 + vpmaddwd ymm2,ymm2,ymm12 + vpaddd Vec2Reg,Vec2Reg,ymm2 +ELSE + vpaddd Vec2Reg,Vec2Reg,ymm3 +ENDIF + + ENDM + +; +; Macro Description: +; +; This macro generates code to multiply and accumulate each row of the output +; block. +; +; Arguments: +; +; ColumnCount - Supplies the number of columns to produce. +; +; RowCount - Supplies the number of rows to produce. +; +; VectorOffset - Supplies the byte offset from matrix B to fetch elements. +; +; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. +; +; Implicit Arguments: +; +; rbx - Supplies the address into the matrix A data plus 3 rows. +; +; rcx - Supplies the address into the matrix A data. +; +; rdx - Supplies the address into the matrix B data. +; +; r9 - Supplies the length in bytes of a row from matrix A. +; +; ymm4-ymm11 - Supplies the block accumulators. +; +; ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001. +; + +ComputeBlockU8S8Avx2 MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset + +IF RowCount EQ 1 + vpbroadcastd ymm2,DWORD PTR [rcx+BroadcastOffset] + vpmaddubsw ymm3,ymm2,YMMWORD PTR [rdx+VectorOffset] + vpmaddwd ymm3,ymm3,ymm12 +IF ColumnCount EQ 16 + vpaddd ymm4,ymm4,ymm3 + vpmaddubsw ymm2,ymm2,YMMWORD PTR [rdx+VectorOffset+32] + vpmaddwd ymm2,ymm2,ymm12 + vpaddd ymm5,ymm5,ymm2 +ELSE + vpaddd ymm5,ymm5,ymm3 +ENDIF +ELSE + vmovdqu ymm0,YMMWORD PTR [rdx+VectorOffset] + EmitIfCountGE ColumnCount, 16, + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 4, +ENDIF + + ENDM + +; +; Macro Description: +; +; This macro generates code to multiply and accumulator a single row of the +; output block. +; +; Arguments: +; +; ColumnCount - Supplies the number of columns to produce. +; +; Vec1Reg - Supplies the high block accumulator register (when ColumnCount +; is 16). +; +; Vec2Reg - Supplies the low block accumulator register. +; +; Implicit Arguments: +; +; ymm0 - Supplies the first vector loaded from matrix B. +; +; ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount +; is 16). +; +; ymm2 - Supplies the broadcast value loaded from matrix A. +; + +MultiplyAccumulateRowU8S8AvxVnni MACRO ColumnCount, Vec1Reg, Vec2Reg + +IF ColumnCount EQ 16 + VpdpbusdsYmmYmmYmm Vec1Reg,ymm2,ymm0 + VpdpbusdsYmmYmmYmm Vec2Reg,ymm2,ymm1 +ELSE + VpdpbusdsYmmYmmYmm Vec2Reg,ymm2,ymm0 +ENDIF + + ENDM + +; +; Macro Description: +; +; This macro generates code to multiply and accumulate each row of the output +; block. +; +; Arguments: +; +; ColumnCount - Supplies the number of columns to produce. +; +; RowCount - Supplies the number of rows to produce. +; +; VectorOffset - Supplies the byte offset from matrix B to fetch elements. +; +; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. +; +; Implicit Arguments: +; +; rbx - Supplies the address into the matrix A data plus 3 rows. +; +; rcx - Supplies the address into the matrix A data. +; +; rdx - Supplies the address into the matrix B data. +; +; r9 - Supplies the length in bytes of a row from matrix A. +; +; ymm4-ymm15 - Supplies the block accumulators. +; + +ComputeBlockU8S8AvxVnni MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset + + vmovdqu ymm0,YMMWORD PTR [rdx+VectorOffset] + EmitIfCountGE ColumnCount, 16, + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + EmitIfCountGE RowCount, 6, + + ENDM + +; +; Macro Description: +; +; This macro generates code to execute the block compute macro multiple times +; and advancing the matrix A and matrix B data pointers. +; +; Arguments: +; +; Isa - Supplies the instruction set architecture string. +; +; ColumnCount - Supplies the number of columns to produce. +; +; RowCount - Supplies the number of rows to produce. +; +; Implicit Arguments: +; +; rbx - Supplies the address into the matrix A data plus 3 rows. +; +; rcx - Supplies the address into the matrix A data. +; +; rdx - Supplies the address into the matrix B data. +; +; r9 - Supplies the length in bytes of a row from matrix A. +; +; ymm4-ymm11 - Supplies the block accumulators. +; + +ComputeBlockLoopU8S8 MACRO Isa, ColumnCount, RowCount + + LOCAL ComputeBlockBy1Loop + + mov rsi,r9 ; reload row length remaining + +ComputeBlockBy1Loop: + ComputeBlockU8S8&Isa& ColumnCount, RowCount, 0, 0 + add rcx,4 ; advance matrix A by 1 quad +IF RowCount GT 3 + add rbx,4 ; advance matrix A plus 3 rows by 1 quad +ENDIF + add rdx,64 ; advance matrix B + sub rsi,4 + jnz ComputeBlockBy1Loop + + ENDM + +; +; Macro Description: +; +; This macro generates code to multiply and accumulator a single row of the +; output block. +; +; Arguments: +; +; ColumnCount - Supplies the number of columns to produce. +; +; Vec1Reg - Supplies the high block accumulator register (when ColumnCount +; is 16). +; +; Vec2Reg - Supplies the low block accumulator register. +; +; Implicit Arguments: +; +; ymm0 - Supplies the first vector loaded from matrix B. +; +; ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount +; is 16). +; +; ymm2 - Supplies the broadcast value loaded from matrix A. +; + +MultiplyAccumulateRowU8U8Avx2 MACRO ColumnCount, Vec1Reg, Vec2Reg + + vpmaddwd ymm3,ymm2,ymm0 +IF ColumnCount EQ 16 + vpaddd Vec1Reg,Vec1Reg,ymm3 + vpmaddwd ymm2,ymm2,ymm1 + vpaddd Vec2Reg,Vec2Reg,ymm2 +ELSE + vpaddd Vec2Reg,Vec2Reg,ymm3 +ENDIF + + ENDM + +; +; Macro Description: +; +; This macro generates code to multiply and accumulate each row of the output +; block. +; +; Arguments: +; +; ColumnCount - Supplies the number of columns to produce. +; +; RowCount - Supplies the number of rows to produce. +; +; VectorOffset - Supplies the byte offset from matrix B to fetch elements. +; +; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. +; +; Implicit Arguments: +; +; rbx - Supplies the address into the matrix A data plus 3 rows. +; +; rcx - Supplies the address into the matrix A data. +; +; rdx - Supplies the address into the matrix B data. +; +; r9 - Supplies the length in bytes of a row from matrix A. +; +; ymm4-ymm15 - Supplies the block accumulators. +; + +ComputeBlockU8U8Avx2 MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset + + vpmovzxbw ymm0,XMMWORD PTR [rdx+VectorOffset] + EmitIfCountGE ColumnCount, 16, + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + EmitIfCountGE RowCount, 6, + + ENDM + +; +; Macro Description: +; +; This macro generates code to execute the block compute macro multiple times +; and advancing the matrix A and matrix B data pointers. +; +; Arguments: +; +; Isa - Supplies the instruction set architecture string. +; +; ColumnCount - Supplies the number of columns to produce. +; +; RowCount - Supplies the number of rows to produce. +; +; Implicit Arguments: +; +; rbx - Supplies the address into the matrix A data plus 3 rows. +; +; rcx - Supplies the address into the matrix A data. +; +; rdx - Supplies the address into the matrix B data. +; +; r9 - Supplies the length in bytes of a row from matrix A. +; +; ymm4-ymm15 - Supplies the block accumulators. +; + +ComputeBlockLoopU8U8 MACRO Isa, ColumnCount, RowCount + + LOCAL ComputeBlockBy2Loop + LOCAL ProcessRemainingBlocks + LOCAL ComputeBlockBy1Loop + LOCAL ExitComputeBlockLoop + + mov rsi,r9 ; reload row length remaining + +IF (ColumnCount EQ 16) AND ((RowCount AND 1) EQ 0) + sub rsi,2*4 + jb ProcessRemainingBlocks + +ComputeBlockBy2Loop: + ComputeBlockU8U8&Isa& ColumnCount, RowCount, 0, 0 + ComputeBlockU8U8&Isa& ColumnCount, RowCount, 32, 4 + add rcx,2*4 ; advance matrix A by 2 pairs +IF RowCount GT 3 + add rbx,2*4 ; advance matrix A plus 3 rows by 2 pairs +ENDIF + add rdx,2*32 ; advance matrix B + sub rsi,2*4 + jae ComputeBlockBy2Loop + +ProcessRemainingBlocks: + add rsi,2*4 ; correct for over-subtract above + jz ExitComputeBlockLoop + ComputeBlockU8U8&Isa& ColumnCount, RowCount, 0, 0 + add rdx,32 ; advance matrix B +ELSE +ComputeBlockBy1Loop: + ComputeBlockU8U8&Isa& ColumnCount, RowCount, 0, 0 + add rcx,4 ; advance matrix A by 1 pair +IF RowCount GT 3 + add rbx,4 ; advance matrix A plus 3 rows by 1 pair +ENDIF + add rdx,32 ; advance matrix B + sub rsi,4 + jnz ComputeBlockBy1Loop +ENDIF + +ExitComputeBlockLoop: + + ENDM + +; +; Macro Description: +; +; This macro generates code to produce an output block for a set of columns +; and rows. +; +; Arguments: +; +; ColumnCount - Supplies the number of columns to produce. +; +; RowCount - Supplies the number of rows to produce. +; +; Implicit Arguments: +; +; rax - Supplies the length in bytes of a row from matrix C. +; +; rcx - Supplies the address into the matrix A data. +; +; rdx - Supplies the address into the matrix B data. +; +; r9 - Supplies the length in bytes of a row from matrix A. +; +; r11 - Supplies the address of the row sum buffer. +; +; r12 - Supplies the address of the column sum buffer. +; +; r13 - Optionally supplies the address of the matrix B zero point buffer. +; +; ymm4-ymm15 - Supplies the block accumulators. +; + +ProduceOutputBlock MACRO ColumnCount, RowCount + + LOCAL SkipScaleByZeroPointB + LOCAL AccumulatorsInitialized + LOCAL ProduceWithU8S8AvxVnni + LOCAL ProduceWithU8U8Avx2 + LOCAL ExitProduceOutputBlock + +; +; Initialize the accumulators with the row and column sums. +; + + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, +IF ColumnCount EQ 16 + vmovdqu ymm0,YMMWORD PTR [r12] + vmovdqu ymm1,YMMWORD PTR [r12+32] + add r12,16*4 ; advance ColumnSumBuffer by 16 columns +ELSE + vmovdqu ymm1,YMMWORD PTR [r12] +ENDIF + test r13,r13 ; per column zero points? + jz SkipScaleByZeroPointB +IF ColumnCount EQ 16 + vmovdqu ymm2,YMMWORD PTR [r13] + vmovdqu ymm3,YMMWORD PTR [r13+32] + add r13,16*4 ; advance ZeroPointB by 16 columns +ELSE + vmovdqu ymm3,YMMWORD PTR [r13] +ENDIF + EmitIfCount2GE RowCount, 1, ColumnCount, 16, + EmitIfCountGE RowCount, 1, + EmitIfCount2GE RowCount, 1, ColumnCount, 16, + EmitIfCountGE RowCount, 1, + EmitIfCount2GE RowCount, 2, ColumnCount, 16, + EmitIfCountGE RowCount, 2, + EmitIfCount2GE RowCount, 2, ColumnCount, 16, + EmitIfCountGE RowCount, 2, + EmitIfCount2GE RowCount, 3, ColumnCount, 16, + EmitIfCountGE RowCount, 3, + EmitIfCount2GE RowCount, 3, ColumnCount, 16, + EmitIfCountGE RowCount, 3, + EmitIfCount2GE RowCount, 4, ColumnCount, 16, + EmitIfCountGE RowCount, 4, + EmitIfCount2GE RowCount, 4, ColumnCount, 16, + EmitIfCountGE RowCount, 4, + EmitIfCount2GE RowCount, 5, ColumnCount, 16, + EmitIfCountGE RowCount, 5, + EmitIfCount2GE RowCount, 5, ColumnCount, 16, + EmitIfCountGE RowCount, 5, + EmitIfCount2GE RowCount, 6, ColumnCount, 16, + EmitIfCountGE RowCount, 6, + EmitIfCount2GE RowCount, 6, ColumnCount, 16, + EmitIfCountGE RowCount, 6, + jmp AccumulatorsInitialized + +SkipScaleByZeroPointB: + EmitIfCount2GE RowCount, 1, ColumnCount, 16, + EmitIfCountGE RowCount, 1, + EmitIfCount2GE RowCount, 2, ColumnCount, 16, + EmitIfCountGE RowCount, 2, + EmitIfCount2GE RowCount, 3, ColumnCount, 16, + EmitIfCountGE RowCount, 3, + EmitIfCount2GE RowCount, 4, ColumnCount, 16, + EmitIfCountGE RowCount, 4, + EmitIfCount2GE RowCount, 5, ColumnCount, 16, + EmitIfCountGE RowCount, 5, + EmitIfCount2GE RowCount, 6, ColumnCount, 16, + EmitIfCountGE RowCount, 6, + +AccumulatorsInitialized: + +; +; Iterate over the length of a matrix A row to produce the output accumulators. +; + +IF RowCount GT 3 + lea rbx,[r9*2+r9] + add rbx,rcx ; compute matrix A plus 3 rows +ENDIF + cmp DWORD PTR GemmU8X8KernelFrame.PreviousP1Home[rsp],0 + jg ProduceWithU8U8Avx2 +IF RowCount LE 4 + jl ProduceWithU8S8AvxVnni + ComputeBlockLoopU8S8 Avx2, ColumnCount, RowCount + jmp ExitProduceOutputBlock +ENDIF + +ProduceWithU8S8AvxVnni: + ComputeBlockLoopU8S8 AvxVnni, ColumnCount, RowCount + jmp ExitProduceOutputBlock + +ProduceWithU8U8Avx2: + ComputeBlockLoopU8U8 Avx2, ColumnCount, RowCount + +ExitProduceOutputBlock: +IF RowCount GT 3 + lea rbx,[rax*2+rax] + add rbx,r8 ; compute matrix C plus 3 rows +ENDIF + + ENDM + +; +; Macro Description: +; +; This macro generates code to compute matrix multiplication for a fixed set +; of rows. +; +; Arguments: +; +; RowCount - Supplies the number of rows to process. +; +; Implicit Arguments: +; +; rax - Supplies the length in bytes of a row from matrix C. +; +; rcx - Supplies the address of matrix A. +; +; rdx - Supplies the address of matrix B. +; +; r8 - Supplies the address of matrix C. +; +; rdi - Supplies the address of matrix A. +; +; rbp - Supplies the number of columns from matrix B and matrix C to iterate +; over. +; +; r9 - Supplies the length in bytes of a row from matrix A. +; +; r10b - Supplies the zero mode flag. +; +; r11 - Supplies the address of the row sum buffer. +; +; r12 - Supplies the address of the column sum buffer. +; +; r13 - Optionally supplies the address of the matrix B zero point buffer. +; + +ProcessCountM MACRO RowCount, Fallthrough + + LOCAL ProcessNextColumnLoop16xN + LOCAL SkipAccumulateOutput16xNBlock + LOCAL OutputMasked16xNBlock + LOCAL ExitProcessCountM + LOCAL ProcessRemainingCountN + LOCAL SkipAccumulateOutput8xNBlock + LOCAL SkipAccumulateOutputMasked16xNBlock + LOCAL OutputMasked8xNBlock + LOCAL SkipAccumulateOutputMasked8xNBlock + + cmp rbp,8 + jbe ProcessRemainingCountN + +ProcessNextColumnLoop16xN: + ProduceOutputBlock 16, RowCount + sub rbp,16 + jb OutputMasked16xNBlock + test r10b,r10b ; ZeroMode? + jnz SkipAccumulateOutput16xNBlock + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + EmitIfCountGE RowCount, 6, + +SkipAccumulateOutput16xNBlock: + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + EmitIfCountGE RowCount, 6, + add r8,16*4 ; advance matrix C by 16 columns + mov rcx,rdi ; reload matrix A + cmp rbp,8 + ja ProcessNextColumnLoop16xN + test rbp,rbp + jnz ProcessRemainingCountN + +ExitProcessCountM: + mov eax,RowCount + jmp ExitKernel + +ProcessRemainingCountN: + ProduceOutputBlock 8, RowCount + cmp rbp,8 + jb OutputMasked8xNBlock + test r10b,r10b ; ZeroMode? + jnz SkipAccumulateOutput8xNBlock + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + +SkipAccumulateOutput8xNBlock: + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + jmp ExitProcessCountM + +OutputMasked16xNBlock: + test r10b,r10b ; ZeroMode? + jnz SkipAccumulateOutputMasked16xNBlock + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + +SkipAccumulateOutputMasked16xNBlock: + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + add r8,8*4 ; advance matrix C by 8 columns +IF RowCount GT 3 + add rbx,8*4 ; advance matrix C plus 3 rows by 8 columns +ENDIF + add rbp,8 ; correct for over-subtract above + +OutputMasked8xNBlock: + neg rbp + lea rcx,MlasMaskMoveTableAvx+8*4 + vmovdqu ymm0,YMMWORD PTR [rcx+rbp*4] + test r10b,r10b ; ZeroMode? + jnz SkipAccumulateOutputMasked8xNBlock + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + +SkipAccumulateOutputMasked8xNBlock: + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + jmp ExitProcessCountM + + ENDM + +; +; Reduce code size for the various types of kernels by sharing the outer logic +; and switching on the selector codes (using sign bit to discriminate). +; + + LEAF_ENTRY MlasGemmU8S8KernelAvxVnni, _TEXT + + mov eax,-1 + jmp MlasGemmU8X8KernelAvx2 + + LEAF_END MlasGemmU8S8KernelAvxVnni, _TEXT + + LEAF_ENTRY MlasGemmU8U8KernelAvx2, _TEXT + + mov eax,1 + jmp MlasGemmU8X8KernelAvx2 + + LEAF_END MlasGemmU8U8KernelAvx2, _TEXT + + LEAF_ENTRY MlasGemmU8S8KernelAvx2, _TEXT + + xor eax,eax + jmp MlasGemmU8X8KernelAvx2 + + LEAF_END MlasGemmU8S8KernelAvx2, _TEXT + +;++ +; +; Routine Description: +; +; This routine is an inner kernel to compute matrix multiplication for a +; set of rows. +; +; Arguments: +; +; A (rcx) - Supplies the address of matrix A. The matrix data has been packed +; using MlasGemmU8X8CopyPackAAvx2. +; +; B (rdx) - Supplies the address of matrix B. The matrix data has been packed +; using MlasGemmU8X8CopyPackBAvx2. +; +; C (r8) - Supplies the address of matrix C. +; +; PackedCountK (r9) - Supplies the number of packed columns from matrix A and +; the number of packed rows from matrix B to iterate over. +; +; CountM - Supplies the maximum number of rows that can be processed for +; matrix A and matrix C. The actual number of rows handled for this +; invocation depends on the kernel implementation. +; +; CountN - Supplies the number of columns from matrix B and matrix C to iterate +; over. +; +; ldc - Supplies the first dimension of matrix C. +; +; RowSumBuffer - Supplies the sum of each row from matrix A. These values have +; been pre-scaled by the zero point offset of matrix B if the offset is +; per-tensor (ZeroPointB is nullptr). Otherwise, these values must be +; scaled by the per-column zero point offsets of matrix B. These values are +; accumulated into every row of matrix C. +; +; ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied +; by the zero point offset of matrix A. These values are accumulated into +; every column of matrix C. +; +; ZeroPointB - Optionally supplies the per-column zero point offsets of matrix +; B, else nullptr if the matrix B is using per-tensor quantization. +; +; ZeroMode - Supplies true if the output matrix must be zero initialized, +; else false if the output matrix is accumulated into. +; +; Return Value: +; +; Returns the number of rows handled. +; +;-- + + NESTED_ENTRY MlasGemmU8X8KernelAvx2, _TEXT + + rex_push_reg rbp + push_reg rbx + push_reg rsi + push_reg rdi + push_reg r12 + push_reg r13 + alloc_stack (GemmU8X8KernelFrame.SavedR13) + save_xmm128 xmm6,GemmU8X8KernelFrame.SavedXmm6 + save_xmm128 xmm7,GemmU8X8KernelFrame.SavedXmm7 + save_xmm128 xmm8,GemmU8X8KernelFrame.SavedXmm8 + save_xmm128 xmm9,GemmU8X8KernelFrame.SavedXmm9 + save_xmm128 xmm10,GemmU8X8KernelFrame.SavedXmm10 + save_xmm128 xmm11,GemmU8X8KernelFrame.SavedXmm11 + save_xmm128 xmm12,GemmU8X8KernelFrame.SavedXmm12 + save_xmm128 xmm13,GemmU8X8KernelFrame.SavedXmm13 + save_xmm128 xmm14,GemmU8X8KernelFrame.SavedXmm14 + save_xmm128 xmm15,GemmU8X8KernelFrame.SavedXmm15 + + END_PROLOGUE + + mov DWORD PTR GemmU8X8KernelFrame.PreviousP1Home[rsp],eax + mov rdi,rcx + mov rbx,GemmU8X8KernelFrame.CountM[rsp] + mov rbp,GemmU8X8KernelFrame.CountN[rsp] + mov rax,GemmU8X8KernelFrame.ldc[rsp] + shl rax,2 ; convert ldc to bytes + shl r9,2 ; convert to row length + movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp] + mov r11,GemmU8X8KernelFrame.RowSumBuffer[rsp] + mov r12,GemmU8X8KernelFrame.ColumnSumBuffer[rsp] + mov r13,GemmU8X8KernelFrame.ZeroPointB[rsp] + vpcmpeqw ymm12,ymm12,ymm12 ; generate 256-bit word vector [0xFFFF] + vpsrlw ymm12,ymm12,15 ; generate 256-bit word vector [0x0001] + cmp DWORD PTR GemmU8X8KernelFrame.PreviousP1Home[rsp],0 + je CheckCountM4OrMore ; U8S8 AVX2 kernel requires extra registers + +; +; Process CountM rows of the matrices. +; + +CheckCountM6OrMore: + cmp rbx,5 + ja ProcessCountM6 + je ProcessCountM5 + +CheckCountM4OrMore: + cmp rbx,3 + ja ProcessCountM4 + je ProcessCountM3 + cmp rbx,1 + je ProcessCountM1 + +ProcessCountM2: + ProcessCountM 2 + +ProcessCountM4: + ProcessCountM 4 + +ProcessCountM6: + ProcessCountM 6 + +; +; Restore non-volatile registers and return. +; + +ExitKernel: + vzeroupper + movaps xmm6,GemmU8X8KernelFrame.SavedXmm6[rsp] + movaps xmm7,GemmU8X8KernelFrame.SavedXmm7[rsp] + movaps xmm8,GemmU8X8KernelFrame.SavedXmm8[rsp] + movaps xmm9,GemmU8X8KernelFrame.SavedXmm9[rsp] + movaps xmm10,GemmU8X8KernelFrame.SavedXmm10[rsp] + movaps xmm11,GemmU8X8KernelFrame.SavedXmm11[rsp] + movaps xmm12,GemmU8X8KernelFrame.SavedXmm12[rsp] + movaps xmm13,GemmU8X8KernelFrame.SavedXmm13[rsp] + movaps xmm14,GemmU8X8KernelFrame.SavedXmm14[rsp] + movaps xmm15,GemmU8X8KernelFrame.SavedXmm15[rsp] + add rsp,(GemmU8X8KernelFrame.SavedR13) + + BEGIN_EPILOGUE + + pop r13 + pop r12 + pop rdi + pop rsi + pop rbx + pop rbp + ret + +ProcessCountM1: + ProcessCountM 1 + +ProcessCountM3: + ProcessCountM 3 + +ProcessCountM5: + ProcessCountM 5 + + NESTED_END MlasGemmU8X8KernelAvx2, _TEXT + + END diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2Common.inc b/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2Common.inc deleted file mode 100644 index 9d45b4fa25..0000000000 --- a/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2Common.inc +++ /dev/null @@ -1,302 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; QgemmU8X8KernelAvx2Common.inc -; -; Abstract: -; -; This module contains common kernel macros and structures for the quantized -; integer matrix/matrix multiply operation (QGEMM) for the AVX2 kernels. -; -;-- - - EXTERN MlasMaskMoveAvx:NEAR - -; -; Stack frame layout for the U8S8 and U8U8 kernels. -; - -GemmU8X8KernelFrame STRUCT - - SavedXmm6 OWORD ? - SavedXmm7 OWORD ? - SavedXmm8 OWORD ? - SavedXmm9 OWORD ? - SavedXmm10 OWORD ? - SavedXmm11 OWORD ? - SavedXmm12 OWORD ? - SavedXmm13 OWORD ? - SavedXmm14 OWORD ? - SavedXmm15 OWORD ? - Padding QWORD ? - SavedR13 QWORD ? - SavedR12 QWORD ? - SavedRdi QWORD ? - SavedRsi QWORD ? - SavedRbx QWORD ? - SavedRbp QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? - PreviousP2Home QWORD ? - PreviousP3Home QWORD ? - PreviousP4Home QWORD ? - CountM QWORD ? - CountN QWORD ? - ldc QWORD ? - RowSumBuffer QWORD ? - ColumnSumBuffer QWORD ? - DepthValue QWORD ? - ZeroMode QWORD ? - -GemmU8X8KernelFrame ENDS - -; -; Macro Description: -; -; This macro generates code to produce an output block for a set of columns -; and rows. -; -; Arguments: -; -; ColumnCount - Supplies the number of columns to produce. -; -; RowCount - Supplies the number of rows to produce. -; -; Implicit Arguments: -; -; rax - Supplies the length in bytes of a row from matrix C. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; r12 - Supplies the address of the row sum buffer. -; -; r13 - Supplies the address of the column sum buffer. -; -; ymm4-ymm15 - Supplies the block accumulators. -; - -ProduceOutputBlock MACRO ColumnCount, RowCount - -; -; Initialize the accumulators with the sum of the global depth value constant, -; the column sums, and the row sums. -; - - vpbroadcastd ymm1,DWORD PTR GemmU8X8KernelFrame.DepthValue[rsp] -IF ColumnCount EQ 16 - vpaddd ymm0,ymm1,YMMWORD PTR [r13] - vpaddd ymm1,ymm1,YMMWORD PTR [r13+32] - add r13,16*4 ; advance ColumnSumBuffer by 16 columns -ELSE - vpaddd ymm1,ymm1,YMMWORD PTR [r13] -ENDIF - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCount2GE RowCount, 1, ColumnCount, 16, - EmitIfCountGE RowCount, 1, - EmitIfCount2GE RowCount, 2, ColumnCount, 16, - EmitIfCountGE RowCount, 2, - EmitIfCount2GE RowCount, 3, ColumnCount, 16, - EmitIfCountGE RowCount, 3, - EmitIfCount2GE RowCount, 4, ColumnCount, 16, - EmitIfCountGE RowCount, 4, - EmitIfCount2GE RowCount, 5, ColumnCount, 16, - EmitIfCountGE RowCount, 5, - EmitIfCount2GE RowCount, 6, ColumnCount, 16, - EmitIfCountGE RowCount, 6, - -; -; Iterate over the length of a matrix A row to produce the output accumulators. -; - -IF RowCount GT 3 - lea rbx,[r9*2+r9] - add rbx,rcx ; compute matrix A plus 3 rows -ENDIF - ComputeBlockLoop ColumnCount, RowCount -IF RowCount GT 3 - lea rbx,[rax*2+rax] - add rbx,r8 ; compute matrix C plus 3 rows -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to compute matrix multiplication for a fixed set -; of rows. -; -; Arguments: -; -; RowCount - Supplies the number of rows to process. -; -; Fallthrough - Supplies a non-blank value if the macro may fall through to -; the ExitKernel label. -; -; Implicit Arguments: -; -; rax - Supplies the length in bytes of a row from matrix C. -; -; rcx - Supplies the address of matrix A. -; -; rdx - Supplies the address of matrix B. -; -; r8 - Supplies the address of matrix C. -; -; rdi - Supplies the address of matrix A. -; -; rbp - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; r10b - Supplies the zero mode flag. -; -; r12 - Supplies the address of the row sum buffer. -; -; r13 - Supplies the address of the column sum buffer. -; - -ProcessCountM MACRO RowCount, Fallthrough - - LOCAL ProcessNextColumnLoop16xN - LOCAL SkipAccumulateOutput16xNBlock - LOCAL OutputMasked16xNBlock - LOCAL ProcessRemainingCountN - LOCAL SkipAccumulateOutput8xNBlock - LOCAL SkipAccumulateOutputMasked16xNBlock - LOCAL OutputMasked8xNBlock - LOCAL SkipAccumulateOutputMasked8xNBlock - - cmp rbp,8 - jbe ProcessRemainingCountN - -ProcessNextColumnLoop16xN: - ProduceOutputBlock 16, RowCount - sub rbp,16 - jb OutputMasked16xNBlock - test r10b,r10b ; ZeroMode? - jnz SkipAccumulateOutput16xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 6, - -SkipAccumulateOutput16xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 6, - add r8,16*4 ; advance matrix C by 16 columns - mov rcx,rdi ; reload matrix A - cmp rbp,8 - ja ProcessNextColumnLoop16xN - test rbp,rbp - jz ExitKernel - -ProcessRemainingCountN: - ProduceOutputBlock 8, RowCount - cmp rbp,8 - jb OutputMasked8xNBlock - test r10b,r10b ; ZeroMode? - jnz SkipAccumulateOutput8xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - -SkipAccumulateOutput8xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - jmp ExitKernel - -OutputMasked16xNBlock: - test r10b,r10b ; ZeroMode? - jnz SkipAccumulateOutputMasked16xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - -SkipAccumulateOutputMasked16xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - add r8,8*4 ; advance matrix C by 8 columns -IF RowCount GT 3 - add rbx,8*4 ; advance matrix C plus 3 rows by 8 columns -ENDIF - add rbp,8 ; correct for over-subtract above - -OutputMasked8xNBlock: - mov DWORD PTR GemmU8X8KernelFrame.CountN[rsp],ebp - vpbroadcastd ymm0,DWORD PTR GemmU8X8KernelFrame.CountN[rsp] - vpcmpgtd ymm0,ymm0,YMMWORD PTR [MlasMaskMoveAvx] - test r10b,r10b ; ZeroMode? - jnz SkipAccumulateOutputMasked8xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - -SkipAccumulateOutputMasked8xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, -IFB - jmp ExitKernel -ENDIF - - ENDM diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx512Common.inc b/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx512Common.inc deleted file mode 100644 index ae02a70a05..0000000000 --- a/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx512Common.inc +++ /dev/null @@ -1,438 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; QgemmU8X8KernelAvx512Common.inc -; -; Abstract: -; -; This module contains common kernel macros and structures for the quantized -; integer matrix/matrix multiply operation (QGEMM) for the AVX512 core and -; AVX512VNNI kernels. -; -;-- - -; -; Stack frame layout for the U8S8 and U8U8 kernels. -; - -GemmU8X8KernelFrame STRUCT - - SavedXmm14 OWORD ? - SavedXmm15 OWORD ? - SavedR14 QWORD ? - SavedR13 QWORD ? - SavedR12 QWORD ? - SavedRdi QWORD ? - SavedRsi QWORD ? - SavedRbx QWORD ? - SavedRbp QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? - PreviousP2Home QWORD ? - PreviousP3Home QWORD ? - PreviousP4Home QWORD ? - CountM QWORD ? - CountN QWORD ? - ldc QWORD ? - RowSumBuffer QWORD ? - ColumnSumBuffer QWORD ? - DepthValue QWORD ? - ZeroMode QWORD ? - -GemmU8X8KernelFrame ENDS - -; -; Macro Description: -; -; This macro generates code to produce an output block for a set of columns -; and rows. -; -; Arguments: -; -; ColumnCount - Supplies the number of columns to produce. -; -; RowCount - Supplies the number of rows to produce. -; -; Implicit Arguments: -; -; rax - Supplies the length in bytes of a row from matrix C. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; r12 - Supplies the address of the row sum buffer. -; -; r13 - Supplies the address of the column sum buffer. -; - -ProduceOutputBlock MACRO ColumnCount, RowCount - -; -; Initialize the accumulators with the sum of the global depth value constant, -; the column sums, and the row sums. -; - - vpbroadcastd zmm3,DWORD PTR GemmU8X8KernelFrame.DepthValue[rsp] -IF ColumnCount GE 32 -IF ColumnCount GE 48 - vpaddd zmm2,zmm3,ZMMWORD PTR [r13] - vpaddd zmm1,zmm3,ZMMWORD PTR [r13+64] - vpaddd zmm0,zmm3,ZMMWORD PTR [r13+128] -ELSE - vpaddd zmm1,zmm3,ZMMWORD PTR [r13] - vpaddd zmm0,zmm3,ZMMWORD PTR [r13+64] -ENDIF - add_immed r13,ColumnCount*4 ; advance ColumnSumBuffer by N columns -ELSE - vpaddd zmm0,zmm3,ZMMWORD PTR [r13] -ENDIF - EmitIfCount2GE RowCount, 1, ColumnCount, 16, - EmitIfCount2GE RowCount, 1, ColumnCount, 32, - EmitIfCount2GE RowCount, 1, ColumnCount, 48, - EmitIfCount2GE RowCount, 2, ColumnCount, 16, - EmitIfCount2GE RowCount, 2, ColumnCount, 32, - EmitIfCount2GE RowCount, 2, ColumnCount, 48, - EmitIfCount2GE RowCount, 3, ColumnCount, 16, - EmitIfCount2GE RowCount, 3, ColumnCount, 32, - EmitIfCount2GE RowCount, 3, ColumnCount, 48, - EmitIfCount2GE RowCount, 4, ColumnCount, 16, - EmitIfCount2GE RowCount, 4, ColumnCount, 32, - EmitIfCount2GE RowCount, 4, ColumnCount, 48, - EmitIfCount2GE RowCount, 5, ColumnCount, 16, - EmitIfCount2GE RowCount, 5, ColumnCount, 32, - EmitIfCount2GE RowCount, 5, ColumnCount, 48, - EmitIfCount2GE RowCount, 6, ColumnCount, 16, - EmitIfCount2GE RowCount, 6, ColumnCount, 32, - EmitIfCount2GE RowCount, 6, ColumnCount, 48, - -; -; Iterate over the length of a matrix A row to produce the output accumulators. -; - -IF RowCount GT 3 - lea rbx,[r9*2+r9] - add rbx,rcx ; compute matrix A plus 3 rows -ENDIF - ComputeBlockLoop ColumnCount, RowCount -IF RowCount GT 3 - lea rbx,[r8+rax*2] ; compute matrix C plus 3 rows - add rbx,rax -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to compute matrix multiplication for a fixed set -; of rows. -; -; Arguments: -; -; RowCount - Supplies the number of rows to process. -; -; Implicit Arguments: -; -; rax - Supplies the length in bytes of a row from matrix C. -; -; rcx - Supplies the address of matrix A. -; -; rdx - Supplies the address of matrix B. -; -; r8 - Supplies the address of matrix C. -; -; rdi - Supplies the address of matrix A. -; -; rbp - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; r10b - Supplies the zero mode flag. -; -; r12 - Supplies the address of the row sum buffer. -; -; r13 - Supplies the address of the column sum buffer. -; -; r14 - Supplies the stride in bytes of between packed blocks of matrix B. -; - -ProcessCountM MACRO RowCount - - LOCAL ProcessNextColumnLoop32xN - LOCAL Output32xNBlock - LOCAL SkipAccumulateOutput32xNBlock - LOCAL Output16xNBlock - LOCAL Output16xNBlockWithMask - LOCAL SkipAccumulateOutput16xNBlockWithMask - LOCAL ProcessRemainingCountN - LOCAL ProcessNextColumnLoop48xN - LOCAL SkipAccumulateOutput48xNBlock - - cmp rbp,32 - ja ProcessNextColumnLoop48xN - cmp rbp,16 - jbe ProcessRemainingCountN - -ProcessNextColumnLoop32xN: - ProduceOutputBlock 32, RowCount - add rdx,r14 ; advance matrix B by packed block stride - -Output32xNBlock: - test r10b,r10b ; ZeroMode? - jnz SkipAccumulateOutput32xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - -SkipAccumulateOutput32xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - add r8,16*4 ; advance matrix C by 16 columns -IF RowCount GT 3 - add rbx,16*4 ; advance matrix C plus 3 rows by 16 columns -ENDIF - sub rbp,16 - -Output16xNBlock: - sub rbp,16 - jae Output16xNBlockWithMask - lea ecx,[ebp+16] ; correct for over-subtract above - mov esi,1 - shl esi,cl - dec esi - kmovw k1,esi ; update mask for remaining columns - xor ebp,ebp ; no more columns remaining - -Output16xNBlockWithMask: - test r10b,r10b ; ZeroMode? - jnz SkipAccumulateOutput16xNBlockWithMask - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - -SkipAccumulateOutput16xNBlockWithMask: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - add r8,16*4 ; advance matrix C by 16 columns - mov rcx,rdi ; reload matrix A - cmp rbp,32 - ja ProcessNextColumnLoop48xN - cmp rbp,16 - ja ProcessNextColumnLoop32xN - test rbp,rbp - jz ExitKernel - -ProcessRemainingCountN: - ProduceOutputBlock 16, RowCount - jmp Output16xNBlock - -ProcessNextColumnLoop48xN: - ProduceOutputBlock 48, RowCount - lea rdx,[rdx+r14*2] ; advance matrix B by packed block stride - test r10b,r10b ; ZeroMode? - jnz SkipAccumulateOutput48xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - -SkipAccumulateOutput48xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - add r8,16*4 ; advance matrix C by 16 columns -IF RowCount GT 3 - add rbx,16*4 ; advance matrix C plus 3 rows by 16 columns -ENDIF - sub rbp,16 - jmp Output32xNBlock - - ENDM - -; -; Macro Description: -; -; This macro generates the common AVX512 code for the inner kernel to compute -; matrix multiplication. -; -; Arguments: -; -; Type - Supplies the kernel type string for function tags. -; -; Isa - Supplies the instruction set architecture string for function tags. -; - -GemmU8X8KernelAvx512Function MACRO Type, Isa - -;++ -; -; Routine Description: -; -; This routine is an inner kernel to compute matrix multiplication for a -; set of rows. -; -; Arguments: -; -; A (rcx) - Supplies the address of matrix A. The matrix data has been packed -; using MlasGemmU8X8CopyPackAAvx2. -; -; B (rdx) - Supplies the address of matrix B. The matrix data has been packed -; using MlasGemmU8X8CopyPackBAvx2. -; -; C (r8) - Supplies the address of matrix C. -; -; PackedCountK (r9) - Supplies the number of packed columns from matrix A and -; the number of packed rows from matrix B to iterate over. -; -; CountM - Supplies the maximum number of rows that can be processed for -; matrix A and matrix C. The actual number of rows handled for this -; invocation depends on the kernel implementation. -; -; CountN - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; ldc - Supplies the first dimension of matrix C. -; -; RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the -; zero point offset of matrix B. These values are accumulated into every -; row of matrix C. -; -; ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied -; by the zero point offset of matrix A. These values are accumulated into -; every column of matrix C. -; -; DepthValue - Supplies the value CountK multiplied by the zero point offset -; of matrixA multplied by the zero point offset of matrix B. This value is -; accumulated into every element of matrix C. -; -; ZeroMode - Supplies true if the output matrix must be zero initialized, -; else false if the output matrix is accumulated into. -; -; Return Value: -; -; Returns the number of rows handled. -; -;-- - - NESTED_ENTRY MlasGemm&Type&Kernel&Isa&, _TEXT - - rex_push_reg rbp - push_reg rbx - push_reg rsi - push_reg rdi - push_reg r12 - push_reg r13 - push_reg r14 - alloc_stack (GemmU8X8KernelFrame.SavedR14) - save_xmm128 xmm14,GemmU8X8KernelFrame.SavedXmm14 - save_xmm128 xmm15,GemmU8X8KernelFrame.SavedXmm15 - - END_PROLOGUE - - mov rdi,rcx - mov rbp,GemmU8X8KernelFrame.CountN[rsp] - mov rax,GemmU8X8KernelFrame.ldc[rsp] - shl rax,2 ; convert ldc to bytes - shl r9,2 ; convert to row length - movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp] - mov r11,GemmU8X8KernelFrame.CountM[rsp] - mov r12,GemmU8X8KernelFrame.RowSumBuffer[rsp] - mov r13,GemmU8X8KernelFrame.ColumnSumBuffer[rsp] - mov esi,-1 - kmovw k1,esi ; update mask to write all columns -IFIDNI , -IFIDNI , - neg esi - vpbroadcastw zmm5,esi ; generate 512-bit word vector [0x0001] -ENDIF - mov r14,r9 - shl r14,4 ; compute matrix B packed stride -ELSE - lea r14,[r9*8] ; compute matrix B packed stride -ENDIF - -; -; Process CountM rows of the matrices. -; - - cmp r11,5 - ja ProcessCountM6 - je ProcessCountM5 - cmp r11,3 - ja ProcessCountM4 - je ProcessCountM3 - cmp r11,1 - je ProcessCountM1 - -ProcessCountM2: - ProcessCountM 2 - -ProcessCountM4: - ProcessCountM 4 - -ProcessCountM6: - mov r11d,6 ; return 6 rows handled - ProcessCountM 6 - -; -; Restore non-volatile registers and return. -; - -ExitKernel: - mov eax,r11d - vzeroupper - movaps xmm14,GemmU8X8KernelFrame.SavedXmm14[rsp] - movaps xmm15,GemmU8X8KernelFrame.SavedXmm15[rsp] - add rsp,(GemmU8X8KernelFrame.SavedR14) - - BEGIN_EPILOGUE - - pop r14 - pop r13 - pop r12 - pop rdi - pop rsi - pop rbx - pop rbp - ret - -ProcessCountM1: - ProcessCountM 1 - -ProcessCountM3: - ProcessCountM 3 - -ProcessCountM5: - ProcessCountM 5 - - NESTED_END MlasGemm&Type&Kernel&Isa&, _TEXT - - ENDM diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx512Core.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx512Core.asm new file mode 100644 index 0000000000..bcd3f52b7e --- /dev/null +++ b/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx512Core.asm @@ -0,0 +1,764 @@ +;++ +; +; Copyright (c) Microsoft Corporation. All rights reserved. +; +; Licensed under the MIT License. +; +; Module Name: +; +; QgemmU8X8KernelAvx512Core.asm +; +; Abstract: +; +; This module implements the kernels for the quantized integer matrix/matrix +; multiply operation (QGEMM). +; +; This implementation uses AVX512 core (BW/DQ/VL) and AVX512 VNNI instructions. +; +;-- + + .xlist +INCLUDE mlasi.inc +INCLUDE AssembleAvx512Vnni.inc + .list + +; +; Stack frame layout for the U8X8 kernel. +; + +GemmU8X8KernelFrame STRUCT + + SavedXmm13 OWORD ? + SavedXmm14 OWORD ? + SavedXmm15 OWORD ? + SavedR14 QWORD ? + SavedR13 QWORD ? + SavedR12 QWORD ? + SavedRdi QWORD ? + SavedRsi QWORD ? + SavedRbx QWORD ? + SavedRbp QWORD ? + ReturnAddress QWORD ? + PreviousP1Home QWORD ? + PreviousP2Home QWORD ? + PreviousP3Home QWORD ? + PreviousP4Home QWORD ? + CountM QWORD ? + CountN QWORD ? + ldc QWORD ? + RowSumBuffer QWORD ? + ColumnSumBuffer QWORD ? + ZeroPointB QWORD ? + ZeroMode QWORD ? + +GemmU8X8KernelFrame ENDS + +; +; Macro Description: +; +; This macro generates code to load packed data from matrix B. +; +; Arguments: +; +; VecReg - Supplies the register to load the data into. +; +; AddressOperand - Supplies the address operand. +; + +LoadPackedMatrixBU8S8 MACRO VecReg, AddressOperand + + vmovdqu32 VecReg,ZMMWORD PTR AddressOperand + + ENDM + +LoadPackedMatrixBU8U8 MACRO VecReg, AddressOperand + + vpmovzxbw VecReg,YMMWORD PTR AddressOperand + + ENDM + +; +; Macro Description: +; +; This macro generates code to multiply and accumulator a single cell of the +; output block. +; +; Arguments: +; +; AccumReg - Supplies the register to accumulate into. +; +; Mult1Reg - Supplies the first multiplication operand register. +; +; Mult2Reg - Supplies the second multiplication operand register. +; +; Implicit Arguments: +; +; zmm4 - Supplies a scratch register for intermediate results. +; +; zmm13 - Supplies a 512-bit with the broadcasted word value 0x0001. +; + +MultiplyAccumulateCellU8S8Avx512Core MACRO AccumReg, Mult1Reg, Mult2Reg + + vpmaddubsw zmm4,Mult1Reg,Mult2Reg + vpmaddwd zmm4,zmm4,zmm13 + vpaddd AccumReg,AccumReg,zmm4 + + ENDM + +MultiplyAccumulateCellU8S8Avx512Vnni MACRO AccumReg, Mult1Reg, Mult2Reg + + VpdpbusdsZmmZmmZmm AccumReg,Mult1Reg,Mult2Reg + + ENDM + +MultiplyAccumulateCellU8U8Avx512Core MACRO AccumReg, Mult1Reg, Mult2Reg + + vpmaddwd zmm4,Mult1Reg,Mult2Reg + vpaddd AccumReg,AccumReg,zmm4 + + ENDM + +; +; Macro Description: +; +; This macro generates code to multiply and accumulate each row of the output +; block. +; +; Arguments: +; +; Type - Supplies the type of kernel to generate (U8S8 or U8U8). +; +; Isa - Supplies the instruction set architecture string. +; +; ColumnCount - Supplies the number of columns to produce. +; +; RowCount - Supplies the number of rows to produce. +; +; VectorOffset - Supplies the byte offset from matrix B to fetch elements. +; +; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. +; +; Implicit Arguments: +; +; rbx - Supplies the address into the matrix A data plus 3 rows. +; +; rcx - Supplies the address into the matrix A data. +; +; rdx - Supplies the address into the matrix B data. +; +; r9 - Supplies the length in bytes of a row from matrix A. +; +; r14 - Supplies the stride in bytes of between packed blocks of matrix B. +; +; zmm13 - Supplies a 512-bit with the broadcasted word value 0x0001. +; +; zmm14-zmm31 - Supplies the block accumulators. +; + +ComputeBlock MACRO Type, Isa, ColumnCount, RowCount, VectorOffset, BroadcastOffset + +IF ColumnCount GE 48 + LoadPackedMatrixB&Type& zmm0,[rdx+VectorOffset] + LoadPackedMatrixB&Type& zmm1,[rdx+r14+VectorOffset] + LoadPackedMatrixB&Type& zmm2,[rdx+r14*2+VectorOffset] +ELSEIF ColumnCount GE 32 + LoadPackedMatrixB&Type& zmm1,[rdx+VectorOffset] + LoadPackedMatrixB&Type& zmm2,[rdx+r14+VectorOffset] +ELSE + LoadPackedMatrixB&Type& zmm2,[rdx+VectorOffset] +ENDIF + EmitIfCountGE RowCount, 1, + EmitIfCount2GE RowCount, 1, ColumnCount, 48, + EmitIfCount2GE RowCount, 1, ColumnCount, 32, + EmitIfCount2GE RowCount, 1, ColumnCount, 16, + EmitIfCountGE RowCount, 2, + EmitIfCount2GE RowCount, 2, ColumnCount, 48, + EmitIfCount2GE RowCount, 2, ColumnCount, 32, + EmitIfCount2GE RowCount, 2, ColumnCount, 16, + EmitIfCountGE RowCount, 3, + EmitIfCount2GE RowCount, 3, ColumnCount, 48, + EmitIfCount2GE RowCount, 3, ColumnCount, 32, + EmitIfCount2GE RowCount, 3, ColumnCount, 16, + EmitIfCountGE RowCount, 4, + EmitIfCount2GE RowCount, 4, ColumnCount, 48, + EmitIfCount2GE RowCount, 4, ColumnCount, 32, + EmitIfCount2GE RowCount, 4, ColumnCount, 16, + EmitIfCountGE RowCount, 5, + EmitIfCount2GE RowCount, 5, ColumnCount, 48, + EmitIfCount2GE RowCount, 5, ColumnCount, 32, + EmitIfCount2GE RowCount, 5, ColumnCount, 16, + EmitIfCountGE RowCount, 6, + EmitIfCount2GE RowCount, 6, ColumnCount, 48, + EmitIfCount2GE RowCount, 6, ColumnCount, 32, + EmitIfCount2GE RowCount, 6, ColumnCount, 16, + + ENDM + +; +; Macro Description: +; +; This macro generates code to execute the block compute macro multiple times +; and advancing the matrix A and matrix B data pointers. +; +; Arguments: +; +; Isa - Supplies the instruction set architecture string. +; +; ColumnCount - Supplies the number of columns to produce. +; +; RowCount - Supplies the number of rows to produce. +; +; Implicit Arguments: +; +; rbx - Supplies the address into the matrix A data plus 3 rows. +; +; rcx - Supplies the address into the matrix A data. +; +; rdx - Supplies the address into the matrix B data. +; +; r9 - Supplies the length in bytes of a row from matrix A. +; +; r14 - Supplies the stride in bytes of between packed blocks of matrix B. +; +; zmm14-zmm31 - Supplies the block accumulators. +; + +ComputeBlockLoopU8S8 MACRO Isa, ColumnCount, RowCount + + LOCAL ComputeBlockBy4Loop + LOCAL ProcessRemainingBlocks + LOCAL ComputeBlockBy1Loop + LOCAL ComputeBlockLoopExit + + mov rsi,r9 ; reload row length remaining + +IF ((RowCount AND 1) EQ 0) + sub rsi,4*4 + jb ProcessRemainingBlocks + +ComputeBlockBy4Loop: + ComputeBlock U8S8, Isa, ColumnCount, RowCount, 0*64, 0 + ComputeBlock U8S8, Isa, ColumnCount, RowCount, 1*64, 4 + ComputeBlock U8S8, Isa, ColumnCount, RowCount, 2*64, 8 + ComputeBlock U8S8, Isa, ColumnCount, RowCount, 3*64, 12 + add rcx,4*4 ; advance matrix A by 1 quad +IF RowCount GT 3 + add rbx,4*4 ; advance matrix A plus 3 rows by 1 quad +ENDIF + add rdx,4*64 ; advance matrix B + sub rsi,4*4 ; decrement quads remaining + jae ComputeBlockBy4Loop + +ProcessRemainingBlocks: + add rsi,4*4 ; correct for over-subtract above + jz ComputeBlockLoopExit +ENDIF + +ComputeBlockBy1Loop: + ComputeBlock U8S8, Isa, ColumnCount, RowCount, 0, 0 + add rcx,4 ; advance matrix A by 1 quad +IF RowCount GT 3 + add rbx,4 ; advance matrix A plus 3 rows by 1 quad +ENDIF + add rdx,64 ; advance matrix B + sub rsi,4 ; decrement quads remaining + jnz ComputeBlockBy1Loop + +ComputeBlockLoopExit: + + ENDM + +ComputeBlockLoopU8U8 MACRO Isa, ColumnCount, RowCount + + LOCAL ComputeBlockBy1Loop + + mov rsi,r9 ; reload row length remaining + +ComputeBlockBy1Loop: + ComputeBlock U8U8, Isa, ColumnCount, RowCount, 0, 0 + add rcx,4 ; advance matrix A by 1 pair +IF RowCount GT 3 + add rbx,4 ; advance matrix A plus 3 rows by 1 pair +ENDIF + add rdx,32 ; advance matrix B + sub rsi,4 + jnz ComputeBlockBy1Loop + + ENDM + +; +; Macro Description: +; +; This macro generates code to produce an output block for a set of columns +; and rows. +; +; Arguments: +; +; ColumnCount - Supplies the number of columns to produce. +; +; RowCount - Supplies the number of rows to produce. +; +; Implicit Arguments: +; +; rax - Supplies the length in bytes of a row from matrix C. +; +; rcx - Supplies the address into the matrix A data. +; +; rdx - Supplies the address into the matrix B data. +; +; r9 - Supplies the length in bytes of a row from matrix A. +; +; r11 - Supplies the address of the row sum buffer. +; +; r12 - Supplies the address of the column sum buffer. +; + +ProduceOutputBlock MACRO ColumnCount, RowCount + + LOCAL SkipScaleByZeroPointB + LOCAL AccumulatorsInitialized + LOCAL ProduceWithU8S8Avx512Core + LOCAL ProduceWithU8U8Avx512Core + LOCAL ExitProduceOutputBlock + +; +; Initialize the accumulators with the row and column sums. +; + +IF ColumnCount GE 32 +IF ColumnCount GE 48 + vmovdqu32 zmm2,ZMMWORD PTR [r12] + vmovdqu32 zmm1,ZMMWORD PTR [r12+64] + vmovdqu32 zmm0,ZMMWORD PTR [r12+128] +ELSE + vmovdqu32 zmm1,ZMMWORD PTR [r12] + vmovdqu32 zmm0,ZMMWORD PTR [r12+64] +ENDIF + add_immed r12,ColumnCount*4 ; advance ColumnSumBuffer by N columns +ELSE + vmovdqu32 zmm0,ZMMWORD PTR [r12] +ENDIF + test r13,r13 ; per column zero points? + jz SkipScaleByZeroPointB +IF ColumnCount GE 32 +IF ColumnCount GE 48 + vmovdqu32 zmm5,ZMMWORD PTR [r13] + vmovdqu32 zmm4,ZMMWORD PTR [r13+64] + vmovdqu32 zmm3,ZMMWORD PTR [r13+128] +ELSE + vmovdqu32 zmm4,ZMMWORD PTR [r13] + vmovdqu32 zmm3,ZMMWORD PTR [r13+64] +ENDIF + add_immed r13,ColumnCount*4 ; advance ZeroPointB by N columns +ELSE + vmovdqu32 zmm3,ZMMWORD PTR [r13] +ENDIF + EmitIfCount2GE RowCount, 1, ColumnCount, 16, + EmitIfCount2GE RowCount, 1, ColumnCount, 32, + EmitIfCount2GE RowCount, 1, ColumnCount, 48, + EmitIfCount2GE RowCount, 1, ColumnCount, 16, + EmitIfCount2GE RowCount, 1, ColumnCount, 32, + EmitIfCount2GE RowCount, 1, ColumnCount, 48, + EmitIfCount2GE RowCount, 2, ColumnCount, 16, + EmitIfCount2GE RowCount, 2, ColumnCount, 32, + EmitIfCount2GE RowCount, 2, ColumnCount, 48, + EmitIfCount2GE RowCount, 2, ColumnCount, 16, + EmitIfCount2GE RowCount, 2, ColumnCount, 32, + EmitIfCount2GE RowCount, 2, ColumnCount, 48, + EmitIfCount2GE RowCount, 3, ColumnCount, 16, + EmitIfCount2GE RowCount, 3, ColumnCount, 32, + EmitIfCount2GE RowCount, 3, ColumnCount, 48, + EmitIfCount2GE RowCount, 3, ColumnCount, 16, + EmitIfCount2GE RowCount, 3, ColumnCount, 32, + EmitIfCount2GE RowCount, 3, ColumnCount, 48, + EmitIfCount2GE RowCount, 4, ColumnCount, 16, + EmitIfCount2GE RowCount, 4, ColumnCount, 32, + EmitIfCount2GE RowCount, 4, ColumnCount, 48, + EmitIfCount2GE RowCount, 4, ColumnCount, 16, + EmitIfCount2GE RowCount, 4, ColumnCount, 32, + EmitIfCount2GE RowCount, 4, ColumnCount, 48, + EmitIfCount2GE RowCount, 5, ColumnCount, 16, + EmitIfCount2GE RowCount, 5, ColumnCount, 32, + EmitIfCount2GE RowCount, 5, ColumnCount, 48, + EmitIfCount2GE RowCount, 5, ColumnCount, 16, + EmitIfCount2GE RowCount, 5, ColumnCount, 32, + EmitIfCount2GE RowCount, 5, ColumnCount, 48, + EmitIfCount2GE RowCount, 6, ColumnCount, 16, + EmitIfCount2GE RowCount, 6, ColumnCount, 32, + EmitIfCount2GE RowCount, 6, ColumnCount, 48, + EmitIfCount2GE RowCount, 6, ColumnCount, 16, + EmitIfCount2GE RowCount, 6, ColumnCount, 32, + EmitIfCount2GE RowCount, 6, ColumnCount, 48, + jmp AccumulatorsInitialized + +SkipScaleByZeroPointB: + EmitIfCount2GE RowCount, 1, ColumnCount, 16, + EmitIfCount2GE RowCount, 1, ColumnCount, 32, + EmitIfCount2GE RowCount, 1, ColumnCount, 48, + EmitIfCount2GE RowCount, 2, ColumnCount, 16, + EmitIfCount2GE RowCount, 2, ColumnCount, 32, + EmitIfCount2GE RowCount, 2, ColumnCount, 48, + EmitIfCount2GE RowCount, 3, ColumnCount, 16, + EmitIfCount2GE RowCount, 3, ColumnCount, 32, + EmitIfCount2GE RowCount, 3, ColumnCount, 48, + EmitIfCount2GE RowCount, 4, ColumnCount, 16, + EmitIfCount2GE RowCount, 4, ColumnCount, 32, + EmitIfCount2GE RowCount, 4, ColumnCount, 48, + EmitIfCount2GE RowCount, 5, ColumnCount, 16, + EmitIfCount2GE RowCount, 5, ColumnCount, 32, + EmitIfCount2GE RowCount, 5, ColumnCount, 48, + EmitIfCount2GE RowCount, 6, ColumnCount, 16, + EmitIfCount2GE RowCount, 6, ColumnCount, 32, + EmitIfCount2GE RowCount, 6, ColumnCount, 48, + +AccumulatorsInitialized: + +; +; Iterate over the length of a matrix A row to produce the output accumulators. +; + +IF RowCount GT 3 + lea rbx,[r9*2+r9] + add rbx,rcx ; compute matrix A plus 3 rows +ENDIF + cmp DWORD PTR GemmU8X8KernelFrame.PreviousP1Home[rsp],0 + je ProduceWithU8S8Avx512Core + jg ProduceWithU8U8Avx512Core + ComputeBlockLoopU8S8 Avx512Vnni, ColumnCount, RowCount + jmp ExitProduceOutputBlock + +ProduceWithU8U8Avx512Core: + ComputeBlockLoopU8U8 Avx512Core, ColumnCount, RowCount + jmp ExitProduceOutputBlock + +ProduceWithU8S8Avx512Core: + ComputeBlockLoopU8S8 Avx512Core, ColumnCount, RowCount + +ExitProduceOutputBlock: +IF RowCount GT 3 + lea rbx,[rax*2+rax] + add rbx,r8 ; compute matrix C plus 3 rows +ENDIF + + ENDM + +; +; Macro Description: +; +; This macro generates code to compute matrix multiplication for a fixed set +; of rows. +; +; Arguments: +; +; RowCount - Supplies the number of rows to process. +; +; Implicit Arguments: +; +; rax - Supplies the length in bytes of a row from matrix C. +; +; rcx - Supplies the address of matrix A. +; +; rdx - Supplies the address of matrix B. +; +; r8 - Supplies the address of matrix C. +; +; rdi - Supplies the address of matrix A. +; +; rbp - Supplies the number of columns from matrix B and matrix C to iterate +; over. +; +; r9 - Supplies the length in bytes of a row from matrix A. +; +; r10b - Supplies the zero mode flag. +; +; r11 - Supplies the address of the row sum buffer. +; +; r12 - Supplies the address of the column sum buffer. +; +; r14 - Supplies the stride in bytes of between packed blocks of matrix B. +; + +ProcessCountM MACRO RowCount + + LOCAL ProcessNextColumnLoop32xN + LOCAL Output32xNBlock + LOCAL SkipAccumulateOutput32xNBlock + LOCAL Output16xNBlock + LOCAL Output16xNBlockWithMask + LOCAL SkipAccumulateOutput16xNBlockWithMask + LOCAL ProcessRemainingCountN + LOCAL ProcessNextColumnLoop48xN + LOCAL SkipAccumulateOutput48xNBlock + + cmp rbp,32 + ja ProcessNextColumnLoop48xN + cmp rbp,16 + jbe ProcessRemainingCountN + +ProcessNextColumnLoop32xN: + ProduceOutputBlock 32, RowCount + add rdx,r14 ; advance matrix B by packed block stride + +Output32xNBlock: + test r10b,r10b ; ZeroMode? + jnz SkipAccumulateOutput32xNBlock + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + +SkipAccumulateOutput32xNBlock: + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + add r8,16*4 ; advance matrix C by 16 columns +IF RowCount GT 3 + add rbx,16*4 ; advance matrix C plus 3 rows by 16 columns +ENDIF + sub rbp,16 + +Output16xNBlock: + sub rbp,16 + jae Output16xNBlockWithMask + lea ecx,[ebp+16] ; correct for over-subtract above + mov esi,1 + shl esi,cl + dec esi + kmovw k1,esi ; update mask for remaining columns + xor ebp,ebp ; no more columns remaining + +Output16xNBlockWithMask: + test r10b,r10b ; ZeroMode? + jnz SkipAccumulateOutput16xNBlockWithMask + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + +SkipAccumulateOutput16xNBlockWithMask: + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + add r8,16*4 ; advance matrix C by 16 columns + mov rcx,rdi ; reload matrix A + cmp rbp,32 + ja ProcessNextColumnLoop48xN + cmp rbp,16 + ja ProcessNextColumnLoop32xN + test rbp,rbp + jnz ProcessRemainingCountN + mov eax,RowCount + jmp ExitKernel + +ProcessRemainingCountN: + ProduceOutputBlock 16, RowCount + jmp Output16xNBlock + +ProcessNextColumnLoop48xN: + ProduceOutputBlock 48, RowCount + lea rdx,[rdx+r14*2] ; advance matrix B by packed block stride + test r10b,r10b ; ZeroMode? + jnz SkipAccumulateOutput48xNBlock + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + +SkipAccumulateOutput48xNBlock: + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + add r8,16*4 ; advance matrix C by 16 columns +IF RowCount GT 3 + add rbx,16*4 ; advance matrix C plus 3 rows by 16 columns +ENDIF + sub rbp,16 + jmp Output32xNBlock + + ENDM + +; +; Reduce code size for the various types of kernels by sharing the outer logic +; and switching on the selector codes (using sign bit to discriminate). +; + + LEAF_ENTRY MlasGemmU8S8KernelAvx512Vnni, _TEXT + + mov eax,-1 + jmp MlasGemmU8X8KernelAvx512Core + + LEAF_END MlasGemmU8S8KernelAvx512Vnni, _TEXT + + LEAF_ENTRY MlasGemmU8U8KernelAvx512Core, _TEXT + + mov eax,1 + jmp MlasGemmU8X8KernelAvx512Core + + LEAF_END MlasGemmU8U8KernelAvx512Core, _TEXT + + LEAF_ENTRY MlasGemmU8S8KernelAvx512Core, _TEXT + + xor eax,eax + jmp MlasGemmU8X8KernelAvx512Core + + LEAF_END MlasGemmU8S8KernelAvx512Core, _TEXT + +;++ +; +; Routine Description: +; +; This routine is an inner kernel to compute matrix multiplication for a +; set of rows. +; +; Arguments: +; +; A (rcx) - Supplies the address of matrix A. The matrix data has been packed +; using MlasGemmU8X8CopyPackAAvx2. +; +; B (rdx) - Supplies the address of matrix B. The matrix data has been packed +; using MlasGemmU8X8CopyPackBAvx2. +; +; C (r8) - Supplies the address of matrix C. +; +; PackedCountK (r9) - Supplies the number of packed columns from matrix A and +; the number of packed rows from matrix B to iterate over. +; +; CountM - Supplies the maximum number of rows that can be processed for +; matrix A and matrix C. The actual number of rows handled for this +; invocation depends on the kernel implementation. +; +; CountN - Supplies the number of columns from matrix B and matrix C to iterate +; over. +; +; ldc - Supplies the first dimension of matrix C. +; +; RowSumBuffer - Supplies the sum of each row from matrix A. These values have +; been pre-scaled by the zero point offset of matrix B if the offset is +; per-tensor (ZeroPointB is nullptr). Otherwise, these values must be +; scaled by the per-column zero point offsets of matrix B. These values are +; accumulated into every row of matrix C. +; +; ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied +; by the zero point offset of matrix A. These values are accumulated into +; every column of matrix C. +; +; ZeroPointB - Optionally supplies the per-column zero point offsets of matrix +; B, else nullptr if the matrix B is using per-tensor quantization. +; +; ZeroMode - Supplies true if the output matrix must be zero initialized, +; else false if the output matrix is accumulated into. +; +; Return Value: +; +; Returns the number of rows handled. +; +;-- + + NESTED_ENTRY MlasGemmU8X8KernelAvx512Core, _TEXT + + rex_push_reg rbp + push_reg rbx + push_reg rsi + push_reg rdi + push_reg r12 + push_reg r13 + push_reg r14 + alloc_stack (GemmU8X8KernelFrame.SavedR14) + save_xmm128 xmm13,GemmU8X8KernelFrame.SavedXmm13 + save_xmm128 xmm14,GemmU8X8KernelFrame.SavedXmm14 + save_xmm128 xmm15,GemmU8X8KernelFrame.SavedXmm15 + + END_PROLOGUE + + mov DWORD PTR GemmU8X8KernelFrame.PreviousP1Home[rsp],eax + mov rdi,rcx + mov rbx,GemmU8X8KernelFrame.CountM[rsp] + mov rbp,GemmU8X8KernelFrame.CountN[rsp] + mov rax,GemmU8X8KernelFrame.ldc[rsp] + shl rax,2 ; convert ldc to bytes + shl r9,2 ; convert to row length + movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp] + mov r11,GemmU8X8KernelFrame.RowSumBuffer[rsp] + mov r12,GemmU8X8KernelFrame.ColumnSumBuffer[rsp] + mov r13,GemmU8X8KernelFrame.ZeroPointB[rsp] + mov esi,-1 + kmovw k1,esi ; update mask to write all columns + neg esi + vpbroadcastw zmm13,esi ; generate 512-bit word vector [0x0001] + lea rsi,[r9*8] ; compute matrix B packed stride (U8U8) + lea r14,[rsi*2] ; compute matrix B packed stride (U8S8) + cmp DWORD PTR GemmU8X8KernelFrame.PreviousP1Home[rsp],0 + cmovg r14,rsi ; select matrix B packed stride + +; +; Process CountM rows of the matrices. +; + + cmp rbx,5 + ja ProcessCountM6 + je ProcessCountM5 + cmp rbx,3 + ja ProcessCountM4 + je ProcessCountM3 + cmp rbx,1 + ja ProcessCountM2 + +ProcessCountM1: + ProcessCountM 1 + +ProcessCountM2: + ProcessCountM 2 + +ProcessCountM3: + ProcessCountM 3 + +ProcessCountM4: + ProcessCountM 4 + +ProcessCountM5: + ProcessCountM 5 + +ProcessCountM6: + ProcessCountM 6 + +; +; Restore non-volatile registers and return. +; + +ExitKernel: + vzeroupper + movaps xmm13,GemmU8X8KernelFrame.SavedXmm13[rsp] + movaps xmm14,GemmU8X8KernelFrame.SavedXmm14[rsp] + movaps xmm15,GemmU8X8KernelFrame.SavedXmm15[rsp] + add rsp,(GemmU8X8KernelFrame.SavedR14) + + BEGIN_EPILOGUE + + pop r14 + pop r13 + pop r12 + pop rdi + pop rsi + pop rbx + pop rbp + ret + + NESTED_END MlasGemmU8X8KernelAvx512Core, _TEXT + + END diff --git a/onnxruntime/core/mlas/lib/arm64/AssembleDotProduct.h b/onnxruntime/core/mlas/lib/arm64/AssembleDotProduct.h new file mode 100644 index 0000000000..2e0e4d3d8a --- /dev/null +++ b/onnxruntime/core/mlas/lib/arm64/AssembleDotProduct.h @@ -0,0 +1,46 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + AssembleDotProduct.h + +Abstract: + + This module contains macros to build Advanced SIMD dot product instructions + for toolchains that do not natively support this newer instruction set + extension. + + This implementation uses ARM v8.4 dot product instructions. + +--*/ + +/*++ + +Macro Description: + + This macro builds a UDOT instruction of the form: + + UDOT DestReg.4s, Src1Reg.16b, Src2Reg.4b[Index] + +Arguments: + + DestReg - Specifies the destination register. + + Src1Reg - Specifies the first source register. + + Src2Reg - Specifies the second source register. + + Index - Specifies the element index of the second source register. + +--*/ + + MACRO + UdotByElement $DestReg, $Src1Reg, $Src2Reg, $Index + + DCD 0x6F80E000:OR:($DestReg):OR:($Src1Reg:SHL:5):OR:($Src2Reg:SHL:16):OR:(($Index:AND:2):SHL:10):OR:(($Index:AND:1):SHL:21) + + MEND diff --git a/onnxruntime/core/mlas/lib/arm64/QgemmU8X8KernelNeon.asm b/onnxruntime/core/mlas/lib/arm64/QgemmU8X8KernelNeon.asm index d123bd67d4..8a335517c6 100644 --- a/onnxruntime/core/mlas/lib/arm64/QgemmU8X8KernelNeon.asm +++ b/onnxruntime/core/mlas/lib/arm64/QgemmU8X8KernelNeon.asm @@ -22,7 +22,7 @@ Abstract: // #define GemmU8X8KernelFrame_ColumnSumBuffer 0 -#define GemmU8X8KernelFrame_DepthValue 8 +#define GemmU8X8KernelFrame_ZeroPointB 8 #define GemmU8X8KernelFrame_ZeroMode 16 // @@ -75,10 +75,6 @@ Arguments: by the zero point offset of matrix A. These values are accumulated into every column of matrix C. - DepthValue - Supplies the value CountK multiplied by the zero point offset - of matrix A multplied by the zero point offset of matrix B. This value is - accumulated into every element of matrix C. - ZeroMode - Supplies true if the output matrix must be zero initialized, else false if the output matrix is accumulated into. @@ -91,18 +87,16 @@ Return Value: LEAF_ENTRY MlasGemmU8X8KernelNeon ldr x8,[sp,#GemmU8X8KernelFrame_ColumnSumBuffer] - ldr s27,[sp,#GemmU8X8KernelFrame_DepthValue] + ldr x9,[sp,#GemmU8X8KernelFrame_ZeroPointB] ldrb w13,[sp,#GemmU8X8KernelFrame_ZeroMode] - dup v27.4s,v27.s[0] mov x14,x0 - ld1 {v0.4s},[x7] + ld1 {v27.4s},[x7] mov x15,x3 - add v27.4s,v27.4s,v0.4s // broadcast add DepthValue dup v24.4s,v27.s[0] // broadcast row fixups cmp x4,#1 // CountM == 1? beq ProcessNextColumnLoopM1 dup v25.4s,v27.s[1] - cmp x4,#4 // CountM < 4 ? + cmp x4,#4 // CountM < 4? blo ProcessNextColumnLoopM2 dup v26.4s,v27.s[2] dup v27.4s,v27.s[3] @@ -118,6 +112,30 @@ ProcessNextColumnLoopM4 mov x3,x15 // reload PackedCountK ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1 uxtl v0.8h,v0.8b + cbz x9,SkipScaleByZeroPointBM4 + ld1 {v28.4s},[x9],#16 // load ZeroPointB0 + ld1 {v29.4s},[x9],#16 // load ZeroPointB1 + mul v16.4s,v24.4s,v28.4s + mul v17.4s,v24.4s,v29.4s + mul v18.4s,v25.4s,v28.4s + mul v19.4s,v25.4s,v29.4s + mul v20.4s,v26.4s,v28.4s + mul v21.4s,v26.4s,v29.4s + mul v22.4s,v27.4s,v28.4s + mul v23.4s,v27.4s,v29.4s + ld1 {v4.8b},[x0],#8 // load first packed A0 + add v16.4s,v2.4s,v16.4s + add v17.4s,v3.4s,v17.4s + add v18.4s,v2.4s,v18.4s + add v19.4s,v3.4s,v19.4s + ld1 {v5.8b},[x0],#8 // load first packed A1 + add v20.4s,v2.4s,v20.4s + add v21.4s,v3.4s,v21.4s + add v22.4s,v2.4s,v22.4s + add v23.4s,v3.4s,v23.4s + b ComputeBlockLoopM4 + +SkipScaleByZeroPointBM4 ld1 {v4.8b},[x0],#8 // load first packed A0 add v16.4s,v2.4s,v24.4s add v17.4s,v3.4s,v24.4s @@ -329,6 +347,21 @@ ProcessNextColumnLoopM2 mov x3,x15 // reload PackedCountK ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1 uxtl v0.8h,v0.8b + cbz x9,SkipScaleByZeroPointBM2 + ld1 {v28.4s},[x9],#16 // load ZeroPointB0 + ld1 {v29.4s},[x9],#16 // load ZeroPointB1 + mul v16.4s,v24.4s,v28.4s + mul v17.4s,v24.4s,v29.4s + mul v18.4s,v25.4s,v28.4s + mul v19.4s,v25.4s,v29.4s + ld1 {v4.8b},[x0],#8 // load first packed A0 + add v16.4s,v2.4s,v16.4s + add v17.4s,v3.4s,v17.4s + add v18.4s,v2.4s,v18.4s + add v19.4s,v3.4s,v19.4s + b ComputeBlockLoopM2 + +SkipScaleByZeroPointBM2 ld1 {v4.8b},[x0],#8 // load first packed A0 add v16.4s,v2.4s,v24.4s add v17.4s,v3.4s,v24.4s @@ -467,6 +500,17 @@ ProcessNextColumnLoopM1 mov x3,x15 // reload PackedCountK ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1 uxtl v0.8h,v0.8b + cbz x9,SkipScaleByZeroPointBM1 + ld1 {v28.4s},[x9],#16 // load ZeroPointB0 + ld1 {v29.4s},[x9],#16 // load ZeroPointB1 + mul v16.4s,v24.4s,v28.4s + mul v17.4s,v24.4s,v29.4s + ldr s4,[x0],#4 // load first packed A0 + add v16.4s,v2.4s,v16.4s + add v17.4s,v3.4s,v17.4s + b ComputeBlockLoopM1 + +SkipScaleByZeroPointBM1 ldr s4,[x0],#4 // load first packed A0 add v16.4s,v2.4s,v24.4s add v17.4s,v3.4s,v24.4s diff --git a/onnxruntime/core/mlas/lib/arm64/QgemmU8X8KernelUdot.asm b/onnxruntime/core/mlas/lib/arm64/QgemmU8X8KernelUdot.asm new file mode 100644 index 0000000000..6f1d979bbb --- /dev/null +++ b/onnxruntime/core/mlas/lib/arm64/QgemmU8X8KernelUdot.asm @@ -0,0 +1,587 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + QgemmU8X8KernelUdot.asm + +Abstract: + + This module implements the kernels for the quantized integer matrix/matrix + multiply operation (QGEMM). + + This implementation uses ARM v8.4 dot product instructions. + +--*/ + +#include "kxarm64.h" +#include "AssembleDotProduct.h" + +// +// Stack frame layout for the U8X8 kernel. +// + +#define GemmU8XKernelFrame_SavedNeonRegisters (4 * 8) +#define GemmU8XKernelFrame_SavedRegisters GemmU8XKernelFrame_SavedNeonRegisters +#define GemmU8XKernelFrame_ColumnSumBuffer (0 + GemmU8XKernelFrame_SavedRegisters) +#define GemmU8XKernelFrame_ZeroPointB (8 + GemmU8XKernelFrame_SavedRegisters) +#define GemmU8XKernelFrame_ZeroMode (16 + GemmU8XKernelFrame_SavedRegisters) + + TEXTAREA + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (x0) - Supplies the address of matrix A. The matrix data has been packed + using MlasGemmU8X8CopyPackA. + + B (x1) - Supplies the address of matrix B. The matrix data has been packed + using MlasGemmU8X8CopyPackB. + + C (x2) - Supplies the address of matrix C. + + PackedCountK (x3) - Supplies the number of packed columns from matrix A and + the number of packed rows from matrix B to iterate over. + + CountM (x4) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (x5) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + ldc (x6) - Supplies the first dimension of matrix C. + + RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values + have been pre-scaled by the zero point offset of matrix B if the offset + is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be + scaled by the per-column zero point offsets of matrix B. These values are + accumulated into every row of matrix C. + + ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied + by the zero point offset of matrix A. These values are accumulated into + every column of matrix C. + + ZeroPointB - Optionally supplies the per-column zero point offsets of matrix + B, else nullptr if the matrix B is using per-tensor quantization. + + ZeroMode - Supplies true if the output matrix must be zero initialized, else + false if the output matrix is accumulated into. + +Return Value: + + Returns the number of rows handled. + +--*/ + + NESTED_ENTRY MlasGemmU8X8KernelUdot + + PROLOG_SAVE_REG_PAIR d8,d9,#-32! + PROLOG_SAVE_REG_PAIR d10,d11,#16 + ldr x8,[sp,#GemmU8XKernelFrame_ColumnSumBuffer] + ldr x9,[sp,#GemmU8XKernelFrame_ZeroPointB] + ldrb w13,[sp,#GemmU8XKernelFrame_ZeroMode] + mov x14,x0 + ld1 {v11.4s},[x7] + mov x15,x3 + dup v8.4s,v11.s[0] // broadcast row fixups + cmp x4,#1 // CountM == 1? + beq ProcessNextColumnLoopM1 + dup v9.4s,v11.s[1] + cmp x4,#4 // CountM < 4? + blo ProcessNextColumnLoopM2 + dup v10.4s,v11.s[2] + dup v11.4s,v11.s[3] + +// +// Process 4 rows of the matrices. +// + +ProcessNextColumnLoopM4 + ld1 {v0.16b},[x1],#16 // load packed B0 + mov x0,x14 // reload matrix A + ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer[0] + mov x3,x15 // reload PackedCountK + ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer[4] + cbz x9,SkipScaleByZeroPointBM4 + ld1 {v30.4s},[x9],#16 // load ZeroPointB[0] + mul v16.4s,v30.4s,v8.4s + mul v18.4s,v30.4s,v9.4s + ld1 {v31.4s},[x9],#16 // load ZeroPointB[4] + mul v20.4s,v30.4s,v10.4s + mul v22.4s,v30.4s,v11.4s + mul v17.4s,v31.4s,v8.4s + mul v19.4s,v31.4s,v9.4s + mul v21.4s,v31.4s,v10.4s + mul v23.4s,v31.4s,v11.4s + add v16.4s,v2.4s,v16.4s + add v18.4s,v2.4s,v18.4s + add v20.4s,v2.4s,v20.4s + add v22.4s,v2.4s,v22.4s + add v17.4s,v3.4s,v17.4s + add v19.4s,v3.4s,v19.4s + add v21.4s,v3.4s,v21.4s + add v23.4s,v3.4s,v23.4s + b ComputeBlockLoopStartM4 + +SkipScaleByZeroPointBM4 + add v16.4s,v2.4s,v8.4s + add v18.4s,v2.4s,v9.4s + add v20.4s,v2.4s,v10.4s + add v22.4s,v2.4s,v11.4s + add v17.4s,v3.4s,v8.4s + add v19.4s,v3.4s,v9.4s + add v21.4s,v3.4s,v10.4s + add v23.4s,v3.4s,v11.4s + +// +// The packing layout is setup to have a pair of four quad vectors from +// packed matrix A and a pair of eight quad vectors from packed matrix B. +// With this scheme, alternating loads from the packed matrices can be +// interleaved with the dot product instructions. +// +// One negative consequence of using four rows here is that the accumulator +// register tile is too small for processors with high out of order execution +// windows (such as the Apple M1). The dot product instructions for a given +// cell are too close to each other to avoid dependencies. To workaround this, +// the below loop uses a pair of accumulator registers that are then added +// together when the loop finishes. +// +// A55-based cores are optimized for 64-bit loads, so use 64-bit loads for +// packed matrix A. At the time of this implementation, using a wider 128-bit +// load didn't affect performance for higher end cores. +// + +ComputeBlockLoopStartM4 + ldr d4,[x0],#32 // load packed A0.l + movi v24.4s,#0 + movi v25.4s,#0 + ldur d5,[x0,#-24] // load packed A0.h + movi v26.4s,#0 + movi v27.4s,#0 + ldur d6,[x0,#-16] // load packed A1.l + movi v28.4s,#0 + movi v29.4s,#0 + movi v30.4s,#0 + movi v31.4s,#0 + +ComputeBlockLoopM4 + ld1 {v1.16b},[x1],#16 // load packed B1 + UdotByElement 16, 0, 4, 0 + UdotByElement 18, 0, 4, 1 + ldur d7,[x0,#-8] // load packed A1.h + UdotByElement 20, 0, 5, 0 + UdotByElement 22, 0, 5, 1 + ld1 {v0.16b},[x1],#16 // load packed B0 + UdotByElement 17, 1, 4, 0 + UdotByElement 19, 1, 4, 1 + sub x3,x3,#1 + cbz x3,ComputeBlockLoopFinishM4 + ldr d4,[x0],#32 // load packed A0.l + UdotByElement 21, 1, 5, 0 + UdotByElement 23, 1, 5, 1 + ld1 {v1.16b},[x1],#16 // load packed B1 + UdotByElement 24, 0, 6, 0 + UdotByElement 26, 0, 6, 1 + ldur d5,[x0,#-24] // load packed A0.h + UdotByElement 28, 0, 7, 0 + UdotByElement 30, 0, 7, 1 + ld1 {v0.16b},[x1],#16 // load packed B0 + UdotByElement 25, 1, 6, 0 + UdotByElement 27, 1, 6, 1 + ldur d6,[x0,#-16] // load packed A1.l + UdotByElement 29, 1, 7, 0 + UdotByElement 31, 1, 7, 1 + b ComputeBlockLoopM4 + +ComputeBlockLoopFinishM4 + UdotByElement 21, 1, 5, 0 + UdotByElement 23, 1, 5, 1 + ld1 {v1.16b},[x1],#16 // load packed B1 + UdotByElement 24, 0, 6, 0 + UdotByElement 26, 0, 6, 1 + UdotByElement 28, 0, 7, 0 + UdotByElement 30, 0, 7, 1 + UdotByElement 25, 1, 6, 0 + UdotByElement 27, 1, 6, 1 + UdotByElement 29, 1, 7, 0 + UdotByElement 31, 1, 7, 1 + add x10,x2,x6,lsl #2 // compute output row 2 + add v16.4s,v16.4s,v24.4s // fold high results into low results + add v18.4s,v18.4s,v26.4s + add v20.4s,v20.4s,v28.4s + add v22.4s,v22.4s,v30.4s + add x11,x10,x6,lsl #2 // compute output row 3 + add v17.4s,v17.4s,v25.4s + add v19.4s,v19.4s,v27.4s + add v21.4s,v21.4s,v29.4s + add v23.4s,v23.4s,v31.4s + add x12,x11,x6,lsl #2 // compute output row 4 + subs x5,x5,#8 // adjust CountN remaining + blo StoreOutputPartialM4 + cbnz x13,SkipAccumulateOutputM4 + ldp q0,q1,[x2] + ldp q2,q3,[x10] + add v16.4s,v16.4s,v0.4s + add v17.4s,v17.4s,v1.4s + ldp q4,q5,[x11] + add v18.4s,v18.4s,v2.4s + add v19.4s,v19.4s,v3.4s + ldp q6,q7,[x12] + add v20.4s,v20.4s,v4.4s + add v21.4s,v21.4s,v5.4s + add v22.4s,v22.4s,v6.4s + add v23.4s,v23.4s,v7.4s + +SkipAccumulateOutputM4 + stp q16,q17,[x2],#32 + stp q18,q19,[x10] + stp q20,q21,[x11] + stp q22,q23,[x12] + cbnz x5,ProcessNextColumnLoopM4 + +ExitKernelM4 + mov x0,#4 // return number of rows handled + EPILOG_RESTORE_REG_PAIR d10,d11,#16 + EPILOG_RESTORE_REG_PAIR d8,d9,#32! + EPILOG_RETURN + +// +// Store the partial 1 to 7 columns either overwriting the output matrix or +// accumulating into the existing contents of the output matrix. +// + +StoreOutputPartialM4 + cbz x13,StoreOutputPartialAddModeM4 + +StoreOutputPartialZeroModeM4 + tbz x5,#2,StoreOutputPartial2ZeroModeM4 + st1 {v16.4s},[x2],#16 + mov v16.16b,v17.16b // shift remaining elements down + st1 {v18.4s},[x10],#16 + mov v18.16b,v19.16b + st1 {v20.4s},[x11],#16 + mov v20.16b,v21.16b + st1 {v22.4s},[x12],#16 + mov v22.16b,v23.16b + +StoreOutputPartial2ZeroModeM4 + tbz x5,#1,StoreOutputPartial1ZeroModeM4 + st1 {v16.2s},[x2],#8 + dup v16.4s,v16.s[2] // shift remaining elements down + st1 {v18.2s},[x10],#8 + dup v18.4s,v18.s[2] + st1 {v20.2s},[x11],#8 + dup v20.4s,v20.s[2] + st1 {v22.2s},[x12],#8 + dup v22.4s,v22.s[2] + +StoreOutputPartial1ZeroModeM4 + tbz x5,#0,ExitKernelM4 + st1 {v16.s}[0],[x2] + st1 {v18.s}[0],[x10] + st1 {v20.s}[0],[x11] + st1 {v22.s}[0],[x12] + b ExitKernelM4 + +StoreOutputPartialAddModeM4 + tbz x5,#2,StoreOutputPartial2AddModeM4 + ld1 {v0.4s},[x2] + ld1 {v1.4s},[x10] + ld1 {v2.4s},[x11] + ld1 {v3.4s},[x12] + add v16.4s,v16.4s,v0.4s + add v18.4s,v18.4s,v1.4s + st1 {v16.4s},[x2],#16 + mov v16.16b,v17.16b // shift remaining elements down + st1 {v18.4s},[x10],#16 + mov v18.16b,v19.16b + add v20.4s,v20.4s,v2.4s + add v22.4s,v22.4s,v3.4s + st1 {v20.4s},[x11],#16 + mov v20.16b,v21.16b + st1 {v22.4s},[x12],#16 + mov v22.16b,v23.16b + +StoreOutputPartial2AddModeM4 + tbz x5,#1,StoreOutputPartial1AddModeM4 + ld1 {v0.2s},[x2] + ld1 {v1.2s},[x10] + ld1 {v2.2s},[x11] + ld1 {v3.2s},[x12] + add v16.4s,v16.4s,v0.4s + add v18.4s,v18.4s,v1.4s + st1 {v16.2s},[x2],#8 + dup v16.4s,v16.s[2] // shift remaining elements down + st1 {v18.2s},[x10],#8 + dup v18.4s,v18.s[2] + add v20.4s,v20.4s,v2.4s + add v22.4s,v22.4s,v3.4s + st1 {v20.2s},[x11],#8 + dup v20.4s,v20.s[2] + st1 {v22.2s},[x12],#8 + dup v22.4s,v22.s[2] + +StoreOutputPartial1AddModeM4 + tbz x5,#0,ExitKernelM4 + ld1 {v0.s}[0],[x2] + ld1 {v1.s}[0],[x10] + add v16.4s,v16.4s,v0.4s + ld1 {v2.s}[0],[x11] + add v18.4s,v18.4s,v1.4s + ld1 {v3.s}[0],[x12] + add v20.4s,v20.4s,v2.4s + st1 {v16.s}[0],[x2] + st1 {v18.s}[0],[x10] + add v22.4s,v22.4s,v3.4s + st1 {v20.s}[0],[x11] + st1 {v22.s}[0],[x12] + b ExitKernelM4 + +// +// Process 2 rows of the matrices. +// + +ProcessNextColumnLoopM2 + ld1 {v0.16b},[x1],#16 // load packed B0 + ld1 {v1.16b},[x1],#16 // load packed B1 + mov x0,x14 // reload matrix A + ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer[0] + mov x3,x15 // reload PackedCountK + ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer[4] + cbz x9,SkipScaleByZeroPointBM2 + ld1 {v30.4s},[x9],#16 // load ZeroPointB[0] + ld1 {v31.4s},[x9],#16 // load ZeroPointB[4] + mul v16.4s,v30.4s,v8.4s + mul v18.4s,v30.4s,v9.4s + mul v17.4s,v31.4s,v8.4s + mul v19.4s,v31.4s,v9.4s + ld1 {v4.16b},[x0],#16 // load packed A0 + add v16.4s,v2.4s,v16.4s + add v18.4s,v2.4s,v18.4s + add v17.4s,v3.4s,v17.4s + add v19.4s,v3.4s,v19.4s + b ComputeBlockLoopM2 + +SkipScaleByZeroPointBM2 + ld1 {v4.16b},[x0],#16 // load packed A0 + add v16.4s,v2.4s,v8.4s + add v18.4s,v2.4s,v9.4s + add v17.4s,v3.4s,v8.4s + add v19.4s,v3.4s,v9.4s + +ComputeBlockLoopM2 + UdotByElement 16, 0, 4, 0 + UdotByElement 17, 1, 4, 0 + UdotByElement 18, 0, 4, 1 + UdotByElement 19, 1, 4, 1 + ld1 {v0.16b},[x1],#16 // load packed B0 + ld1 {v1.16b},[x1],#16 // load packed B1 + UdotByElement 16, 0, 4, 2 + UdotByElement 17, 1, 4, 2 + UdotByElement 18, 0, 4, 3 + UdotByElement 19, 1, 4, 3 + sub x3,x3,#1 + cbz x3,ComputeBlockLoopFinishM2 + ld1 {v0.16b},[x1],#16 // load packed B0 + ld1 {v1.16b},[x1],#16 // load packed B1 + ld1 {v4.16b},[x0],#16 // load packed A0 + b ComputeBlockLoopM2 + +ComputeBlockLoopFinishM2 + add x10,x2,x6,lsl #2 // compute output row 2 + subs x5,x5,#8 // adjust CountN remaining + blo StoreOutputPartialM2 + cbnz x13,SkipAccumulateOutputM2 + ldp q0,q1,[x2] + ldp q2,q3,[x10] + add v16.4s,v16.4s,v0.4s + add v17.4s,v17.4s,v1.4s + add v18.4s,v18.4s,v2.4s + add v19.4s,v19.4s,v3.4s + +SkipAccumulateOutputM2 + stp q16,q17,[x2],#32 + stp q18,q19,[x10] + cbnz x5,ProcessNextColumnLoopM2 + +ExitKernelM2 + mov x0,#2 // return number of rows handled + EPILOG_RESTORE_REG_PAIR d10,d11,#16 + EPILOG_RESTORE_REG_PAIR d8,d9,#32! + EPILOG_RETURN + +// +// Store the partial 1 to 7 columns either overwriting the output matrix or +// accumulating into the existing contents of the output matrix. +// + +StoreOutputPartialM2 + cbz x13,StoreOutputPartialAddModeM2 + +StoreOutputPartialZeroModeM2 + tbz x5,#2,StoreOutputPartial2ZeroModeM2 + st1 {v16.4s},[x2],#16 + mov v16.16b,v17.16b // shift remaining elements down + st1 {v18.4s},[x10],#16 + mov v18.16b,v19.16b + +StoreOutputPartial2ZeroModeM2 + tbz x5,#1,StoreOutputPartial1ZeroModeM2 + st1 {v16.2s},[x2],#8 + dup v16.4s,v16.s[2] // shift remaining elements down + st1 {v18.2s},[x10],#8 + dup v18.4s,v18.s[2] + +StoreOutputPartial1ZeroModeM2 + tbz x5,#0,ExitKernelM2 + st1 {v16.s}[0],[x2] + st1 {v18.s}[0],[x10] + b ExitKernelM2 + +StoreOutputPartialAddModeM2 + tbz x5,#2,StoreOutputPartial2AddModeM2 + ld1 {v0.4s},[x2] + ld1 {v1.4s},[x10] + add v16.4s,v16.4s,v0.4s + add v18.4s,v18.4s,v1.4s + st1 {v16.4s},[x2],#16 + mov v16.16b,v17.16b // shift remaining elements down + st1 {v18.4s},[x10],#16 + mov v18.16b,v19.16b + +StoreOutputPartial2AddModeM2 + tbz x5,#1,StoreOutputPartial1AddModeM2 + ld1 {v0.2s},[x2] + ld1 {v1.2s},[x10] + add v16.4s,v16.4s,v0.4s + add v18.4s,v18.4s,v1.4s + st1 {v16.2s},[x2],#8 + dup v16.4s,v16.s[2] // shift remaining elements down + st1 {v18.2s},[x10],#8 + dup v18.4s,v18.s[2] + +StoreOutputPartial1AddModeM2 + tbz x5,#0,ExitKernelM2 + ld1 {v0.s}[0],[x2] + ld1 {v1.s}[0],[x10] + add v16.4s,v16.4s,v0.4s + add v18.4s,v18.4s,v1.4s + st1 {v16.s}[0],[x2] + st1 {v18.s}[0],[x10] + b ExitKernelM2 + +// +// Process 1 row of the matrices. +// + +ProcessNextColumnLoopM1 + ld1 {v0.16b},[x1],#16 // load packed B0 + ld1 {v1.16b},[x1],#16 // load packed B1 + mov x0,x14 // reload matrix A + ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer0 + mov x3,x15 // reload PackedCountK + ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1 + cbz x9,SkipScaleByZeroPointBM1 + ld1 {v30.4s},[x9],#16 // load ZeroPointB0 + ld1 {v31.4s},[x9],#16 // load ZeroPointB1 + mul v16.4s,v30.4s,v8.4s + mul v17.4s,v31.4s,v8.4s + ldr d4,[x0],#8 // load packed A0 + add v16.4s,v2.4s,v16.4s + add v17.4s,v3.4s,v17.4s + b ComputeBlockLoopM1 + +SkipScaleByZeroPointBM1 + ldr d4,[x0],#8 // load packed A0 + add v16.4s,v2.4s,v8.4s + add v17.4s,v3.4s,v8.4s + +ComputeBlockLoopM1 + UdotByElement 16, 0, 4, 0 + UdotByElement 17, 1, 4, 0 + ld1 {v0.16b},[x1],#16 // load packed B0 + ld1 {v1.16b},[x1],#16 // load packed B1 + UdotByElement 16, 0, 4, 1 + UdotByElement 17, 1, 4, 1 + sub x3,x3,#1 + cbz x3,ComputeBlockLoopFinishM1 + ldr d4,[x0],#8 // load packed A0 + ld1 {v0.16b},[x1],#16 // load packed B0 + ld1 {v1.16b},[x1],#16 // load packed B1 + b ComputeBlockLoopM1 + +ComputeBlockLoopFinishM1 + subs x5,x5,#8 // adjust CountN remaining + blo StoreOutputPartialM1 + cbnz x13,SkipAccumulateOutputM1 + ldp q0,q1,[x2] + add v16.4s,v16.4s,v0.4s + add v17.4s,v17.4s,v1.4s + +SkipAccumulateOutputM1 + stp q16,q17,[x2],#32 + cbnz x5,ProcessNextColumnLoopM1 + +ExitKernelM1 + mov x0,#1 // return number of rows handled + EPILOG_RESTORE_REG_PAIR d10,d11,#16 + EPILOG_RESTORE_REG_PAIR d8,d9,#32! + EPILOG_RETURN + +// +// Store the partial 1 to 7 columns either overwriting the output matrix or +// accumulating into the existing contents of the output matrix. +// + +StoreOutputPartialM1 + cbz x13,StoreOutputPartialAddModeM1 + +StoreOutputPartialZeroModeM1 + tbz x5,#2,StoreOutputPartial2ZeroModeM1 + st1 {v16.4s},[x2],#16 + mov v16.16b,v17.16b // shift remaining elements down + +StoreOutputPartial2ZeroModeM1 + tbz x5,#1,StoreOutputPartial1ZeroModeM1 + st1 {v16.2s},[x2],#8 + dup v16.4s,v16.s[2] // shift remaining elements down + +StoreOutputPartial1ZeroModeM1 + tbz x5,#0,ExitKernelM1 + st1 {v16.s}[0],[x2] + b ExitKernelM1 + +StoreOutputPartialAddModeM1 + tbz x5,#2,StoreOutputPartial2AddModeM1 + ld1 {v0.4s},[x2] + add v16.4s,v16.4s,v0.4s + st1 {v16.4s},[x2],#16 + mov v16.16b,v17.16b // shift remaining elements down + +StoreOutputPartial2AddModeM1 + tbz x5,#1,StoreOutputPartial1AddModeM1 + ld1 {v0.2s},[x2] + add v16.4s,v16.4s,v0.4s + st1 {v16.2s},[x2],#8 + dup v16.4s,v16.s[2] // shift remaining elements down + +StoreOutputPartial1AddModeM1 + tbz x5,#0,ExitKernelM1 + ld1 {v0.s}[0],[x2] + add v16.4s,v16.4s,v0.4s + st1 {v16.s}[0],[x2] + b ExitKernelM1 + + NESTED_END MlasGemmU8X8KernelUdot + + END diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 56c6d2b261..8f8727bc2c 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -245,14 +245,6 @@ void typedef MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE* PMLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE; -typedef -void -(MLASCALL MLAS_GEMM_U8X8_OPERATION)( - const struct MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock - ); - -typedef MLAS_GEMM_U8X8_OPERATION* PMLAS_GEMM_U8X8_OPERATION; - typedef size_t (MLASCALL MLAS_GEMM_U8S8_KERNEL)( @@ -265,7 +257,7 @@ size_t size_t ldc, const int32_t* RowSumVector, const int32_t* ColumnSumVector, - int32_t DepthValue, + const int32_t* ZeroPointB, bool ZeroMode ); @@ -296,7 +288,7 @@ size_t size_t ldc, const int32_t* RowSumVector, const int32_t* ColumnSumVector, - int32_t DepthValue, + const int32_t* ZeroPointB, bool ZeroMode ); @@ -697,26 +689,17 @@ MlasSgemmOperation( ); // -// Quantized integer matrix/matrix multiply operation. +// Quantized integer matrix/matrix dispatch structure. // -struct MLAS_GEMM_U8X8_KERNEL_SSE; -struct MLAS_GEMM_U8S8_KERNEL_AVX2; -struct MLAS_GEMM_U8U8_KERNEL_AVX2; +struct MLAS_GEMM_U8X8_DISPATCH; -template -void -MLASCALL -MlasGemmU8X8Operation( - const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock - ); - -template -void -MLASCALL -MlasGemmU8X8PackedOperation( - const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock - ); +extern const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8X8DispatchSse; +extern const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8S8DispatchAvx2; +extern const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8U8DispatchAvx2; +extern const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8X8DispatchNeon; +extern const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8X8DispatchUdot; +extern const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8X8DispatchDefault; // // Quantized depthwise convolution kernels. @@ -767,12 +750,10 @@ struct MLAS_PLATFORM { PMLAS_SGEMM_KERNEL_M1_ROUTINE KernelM1TransposeBRoutine; PMLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE TransposePackB16x4Routine; PMLAS_GEMM_DOUBLE_KERNEL GemmDoubleKernel; - PMLAS_GEMM_U8X8_OPERATION GemmU8S8Operation; - PMLAS_GEMM_U8X8_OPERATION GemmU8S8PackedOperation; + const MLAS_GEMM_U8X8_DISPATCH* GemmU8S8Dispatch; PMLAS_GEMM_U8S8_KERNEL GemmU8S8Kernel; PMLAS_GEMV_U8S8_KERNEL GemvU8S8Kernel; - PMLAS_GEMM_U8X8_OPERATION GemmU8U8Operation; - PMLAS_GEMM_U8X8_OPERATION GemmU8U8PackedOperation; + const MLAS_GEMM_U8X8_DISPATCH* GemmU8U8Dispatch; PMLAS_GEMM_U8U8_KERNEL GemmU8U8Kernel; PMLAS_CONV_FLOAT_KERNEL ConvNchwFloatKernel; PMLAS_CONV_FLOAT_KERNEL ConvNchwcFloatKernel; @@ -800,6 +781,10 @@ struct MLAS_PLATFORM { #else static constexpr uint32_t MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT; #endif + +#if defined(MLAS_TARGET_ARM64) + const MLAS_GEMM_U8X8_DISPATCH* GemmU8X8Dispatch; +#endif }; extern MLAS_PLATFORM MlasPlatform; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index aeeeb40126..eea76913c1 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -17,6 +17,16 @@ Abstract: #include "mlasi.h" +#if defined(MLAS_TARGET_ARM64) && defined(__linux__) +#include +#include +// N.B. Support building with older versions of asm/hwcap.h that do not define +// this capability bit. +#ifndef HWCAP_ASIMDDP +#define HWCAP_ASIMDDP (1 << 20) +#endif +#endif + // // Stores the platform information. // @@ -120,8 +130,8 @@ Return Value: this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4Sse; this->GemmDoubleKernel = MlasGemmDoubleKernelSse; - this->GemmU8S8Operation = MlasGemmU8X8Operation; - this->GemmU8U8Operation = MlasGemmU8X8Operation; + this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchSse; + this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchSse; this->ConvNchwFloatKernel = MlasConvNchwFloatKernelSse; this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelSse; this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelSse; @@ -206,12 +216,10 @@ Return Value: if (((Cpuid1[2] & 0x1000) != 0) && ((Cpuid7[1] & 0x20) != 0)) { - this->GemmU8S8Operation = MlasGemmU8X8Operation; - this->GemmU8S8PackedOperation = MlasGemmU8X8PackedOperation; + this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchAvx2; this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx2; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx2; - this->GemmU8U8Operation = MlasGemmU8X8Operation; - this->GemmU8U8PackedOperation = MlasGemmU8X8PackedOperation; + this->GemmU8U8Dispatch = &MlasGemmU8U8DispatchAvx2; this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx2; this->GemmFloatKernel = MlasGemmFloatKernelFma3; @@ -229,7 +237,7 @@ Return Value: this->ConvDepthwiseU8S8Kernel = MlasConvDepthwiseKernelAvx2; this->ConvDepthwiseU8U8Kernel = MlasConvDepthwiseKernelAvx2; this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelFma3; - + // // Check if the processor supports Hybrid core architecture. // @@ -251,8 +259,7 @@ Return Value: if ((Cpuid7_1[0] & 0x10) != 0) { - this->GemmU8U8Operation = MlasGemmU8X8Operation; - this->GemmU8U8PackedOperation = MlasGemmU8X8PackedOperation; + this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAvx2; this->GemmU8S8Kernel = MlasGemmU8S8KernelAvxVnni; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni; } @@ -304,8 +311,7 @@ Return Value: if ((Cpuid7[2] & 0x800) != 0) { - this->GemmU8U8Operation = MlasGemmU8X8Operation; - this->GemmU8U8PackedOperation = MlasGemmU8X8PackedOperation; + this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAvx2; this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx512Vnni; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Vnni; } @@ -326,6 +332,24 @@ Return Value: #endif // MLAS_TARGET_AMD64_IX86 +#if defined(MLAS_TARGET_ARM64) + + this->GemmU8X8Dispatch = &MlasGemmU8X8DispatchNeon; + +#if defined(__linux__) + + // + // Check if the processor supports ASIMD dot product instructions. + // + + if ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0) { + this->GemmU8X8Dispatch = &MlasGemmU8X8DispatchUdot; + } + +#endif + +#endif + } size_t diff --git a/onnxruntime/core/mlas/lib/qgemm.cpp b/onnxruntime/core/mlas/lib/qgemm.cpp index f2532986b3..f4a1868db9 100644 --- a/onnxruntime/core/mlas/lib/qgemm.cpp +++ b/onnxruntime/core/mlas/lib/qgemm.cpp @@ -17,6 +17,68 @@ Abstract: #include "mlasi.h" +// +// Quantized integer matrix/matrix dispatch structure. +// + +typedef +void +(MLAS_GEMM_U8X8_OPERATION)( + const MLAS_GEMM_U8X8_PARAMETERS* Parameters, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN + ); + +typedef +void +(MLAS_GEMM_U8X8_COPY_PACKB_ROUTINE)( + uint8_t* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned + ); + +struct MLAS_GEMM_U8X8_DISPATCH { + MLAS_GEMM_U8X8_OPERATION* Operation; + MLAS_GEMM_U8X8_OPERATION* PackedOperation; + MLAS_GEMM_U8X8_COPY_PACKB_ROUTINE* CopyPackBRoutine; + size_t PackedK; + size_t PackedStrideK; +}; + +const MLAS_GEMM_U8X8_DISPATCH* +MlasGemmU8X8GetDispatch( + bool BIsSigned + ) +{ + const MLAS_GEMM_U8X8_DISPATCH* GemmU8X8Dispatch; + + MLAS_UNREFERENCED_PARAMETER(BIsSigned); + +#if defined(MLAS_TARGET_AMD64) + if (BIsSigned) { + GemmU8X8Dispatch = MlasPlatform.GemmU8S8Dispatch; + } else { + GemmU8X8Dispatch = MlasPlatform.GemmU8U8Dispatch; + } +#elif defined(MLAS_SSE2_INTRINSICS) + GemmU8X8Dispatch = &MlasGemmU8X8DispatchSse; +#elif defined(MLAS_NEON64_INTRINSICS) + GemmU8X8Dispatch = MlasPlatform.GemmU8X8Dispatch; +#elif defined(MLAS_NEON32_INTRINSICS) && !defined(_MSC_VER) + GemmU8X8Dispatch = &MlasGemmU8X8DispatchNeon; +#else + GemmU8X8Dispatch = &MlasGemmU8X8DispatchDefault; +#endif + + return GemmU8X8Dispatch; +} + // // Define the parameters to execute segments of a QGEMM operation on worker // threads. @@ -25,24 +87,7 @@ Abstract: struct MLAS_GEMM_U8X8_WORK_BLOCK { int32_t ThreadCountM; int32_t ThreadCountN; - size_t RangeStartM; - size_t RangeStartN; - size_t RangeCountM; - size_t RangeCountN; - size_t M; - size_t N; - size_t K; - const uint8_t* A; - size_t lda; - const void* B; - size_t ldb; - int32_t* C; - size_t ldc; - uint8_t offa; - uint8_t offb; - bool BIsPacked; - bool BIsSigned; - const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor; + const MLAS_GEMM_U8X8_PARAMETERS* Parameters; }; // @@ -80,11 +125,116 @@ MlasGemmU8X8ScaleSumBuffer( return MlasGemmU8X8ScaleSumBuffer(SumBuffer, SumBuffer, N, Scale); } +template +MLAS_FORCEINLINE +bool +MlasGemmU8X8TryGemvKernel( + const uint8_t* A, + const uint8_t* B, + size_t ldb, + int32_t* C, + size_t CountK, + size_t CountN, + bool BIsSigned + ) +{ + MLAS_UNREFERENCED_PARAMETER(A); + MLAS_UNREFERENCED_PARAMETER(B); + MLAS_UNREFERENCED_PARAMETER(ldb); + MLAS_UNREFERENCED_PARAMETER(C); + MLAS_UNREFERENCED_PARAMETER(CountK); + MLAS_UNREFERENCED_PARAMETER(CountN); + MLAS_UNREFERENCED_PARAMETER(BIsSigned); + + return false; +} + +template +int32_t +MlasGemmU8X8FixupZeroPointB( + int32_t ZeroPointB, + bool BIsSigned + ); + +template +MLAS_FORCEINLINE +void +MlasGemmU8X8FixupZeroPointB( + const uint8_t* PackedZeroPointB, + int32_t* ZeroPointBBuffer, + size_t N, + bool BIsSigned + ) +{ + int32_t ZeroPointB; + + for (size_t n = 0; n < N; n++) { + + ZeroPointB = typename KernelType::OffsetBType(PackedZeroPointB[n]); + ZeroPointB = MlasGemmU8X8FixupZeroPointB(ZeroPointB, BIsSigned); + + ZeroPointBBuffer[n] = -ZeroPointB; + } + + // + // Fill the misaligned slots of the zero point buffer with zeroes to guard + // against tools that check for uninitialized data usage. + // + + size_t AlignedN = (N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1); + + for (size_t n = N; n < AlignedN; n++) { + ZeroPointBBuffer[n] = 0; + } +} + +template +void +MlasGemmU8X8CopyPackA( + typename KernelType::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer + ); + +template +void +MlasGemmU8X8CopyPackB( + typename KernelType::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned + ); + +template +size_t +MlasGemmU8X8Kernel( + const typename KernelType::PackedAType* A, + const typename KernelType::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode + ); + template void -MLASCALL MlasGemmU8X8Operation( - const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock + const MLAS_GEMM_U8X8_PARAMETERS* Parameters, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN ) /*++ @@ -95,7 +245,15 @@ Routine Description: Arguments: - WorkBlock - Supplies the structure containing the GEMM parameters. + Parameters - Supplies the structure containing the GEMM parameters. + + RangeStartM - Supplies the starting row index to output. + + RangeCountM - Supplies the number of rows to output. + + RangeStartN - Supplies the starting column index to output. + + RangeCountN - Supplies the number of columns to output. Return Value: @@ -110,38 +268,42 @@ Return Value: MLAS_DECLSPEC_ALIGN(int32_t RowSumBuffer[Strides.M], 64); MLAS_DECLSPEC_ALIGN(int32_t ColumnSumBuffer[Strides.N], 64); + MLAS_DECLSPEC_ALIGN(int32_t ZeroPointBBuffer[Strides.N], 64); - const size_t M = WorkBlock->RangeCountM; - const size_t N = WorkBlock->RangeCountN; - const size_t K = WorkBlock->K; + const size_t K = Parameters->K; - const size_t lda = WorkBlock->lda; - const size_t ldb = WorkBlock->ldb; - const size_t ldc = WorkBlock->ldc; + const size_t lda = Parameters->lda; + const size_t ldb = Parameters->ldb; + const size_t ldc = Parameters->ldc; - const uint8_t* A = WorkBlock->A + WorkBlock->RangeStartM * lda; - const uint8_t* B = (const uint8_t*)WorkBlock->B + WorkBlock->RangeStartN; - int32_t* C = WorkBlock->C + WorkBlock->RangeStartM * ldc + WorkBlock->RangeStartN; + const uint8_t* A = Parameters->A + RangeStartM * lda; + const uint8_t* B = (const uint8_t*)Parameters->B + RangeStartN; + int32_t* C = Parameters->C + RangeStartM * ldc + RangeStartN; + const uint8_t* PackedZeroPointB = Parameters->PerColumnZeroPoints ? + Parameters->ZeroPointB + RangeStartN : nullptr; - int32_t offa = WorkBlock->offa; - int32_t offb = typename KernelType::OffsetBType(WorkBlock->offb); + int32_t ZeroPointA = Parameters->ZeroPointA; + int32_t ZeroPointB = typename KernelType::OffsetBType(*Parameters->ZeroPointB); // // Try to use a GEMV kernel if supported by this kernel type. // - if ((M == 1) && (offa == 0) && (offb == 0) && WorkBlock->OutputProcessor == nullptr) { - if (KernelType::TryGemvKernel(A, B, ldb, C, K, N, WorkBlock->BIsSigned)) { + if ((RangeCountM == 1) && + (ZeroPointA == 0) && (PackedZeroPointB == nullptr) && (ZeroPointB == 0) && + (Parameters->OutputProcessor == nullptr)) { + if (MlasGemmU8X8TryGemvKernel(A, B, ldb, C, K, RangeCountN, Parameters->BIsSigned)) { return; } } // - // Fixup the sign bit of the zero point offset of matrix B if the data is - // the opposite format of the kernel implementation. + // Fixup the sign bit of the per-matrix zero point offset of matrix B if the + // data is the opposite format of the kernel implementation. This value is + // ignored if per-column zero point offsets are used instead. // - offb = KernelType::FixupZeroPointB(offb, WorkBlock->BIsSigned); + ZeroPointB = MlasGemmU8X8FixupZeroPointB(ZeroPointB, Parameters->BIsSigned); // // Step through each slice of matrix B along the K dimension. @@ -153,48 +315,91 @@ Return Value: CountK = std::min(K - k, Strides.K); + const size_t PackedCountK = (CountK + KernelType::PackedK - 1) / KernelType::PackedK; + // // Step through each slice of matrix B along the N dimension. // size_t CountN; - for (size_t n = 0; n < N; n += CountN) { + for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(N - n, Strides.N); + CountN = std::min(RangeCountN - n, Strides.N); + + // + // Fixup the sign bit of the per-column zero point offsets of matrix B + // if the data is the opposite format of the kernel implementation. + // + + if (PackedZeroPointB != nullptr) { + MlasGemmU8X8FixupZeroPointB( + PackedZeroPointB + n, + ZeroPointBBuffer, + CountN, + Parameters->BIsSigned); + } // // Copy a panel of matrix B to a local packed buffer. // - KernelType::CopyPackB(PanelB, B + n, ldb, CountN, CountK, - ColumnSumBuffer, WorkBlock->BIsSigned); + MlasGemmU8X8CopyPackB( + PanelB, + B + n, + ldb, + CountN, + CountK, + ColumnSumBuffer, + Parameters->BIsSigned); - MlasGemmU8X8ScaleSumBuffer(ColumnSumBuffer, CountN, -offa); + MlasGemmU8X8ScaleSumBuffer(ColumnSumBuffer, CountN, -ZeroPointA); // // Step through each slice of matrix A along the M dimension. // - const int32_t DepthValue = int32_t(CountK) * offa * offb; - const size_t PackedCountK = (CountK + KernelType::PackedK - 1) / - KernelType::PackedK; - int32_t* c = C + n; size_t CountM; - for (size_t m = 0; m < M; m += CountM) { + for (size_t m = 0; m < RangeCountM; m += CountM) { - CountM = std::min(M - m, Strides.M); + CountM = std::min(RangeCountM - m, Strides.M); // // Copy a panel of matrix A to a local packed buffer. // - KernelType::CopyPackA(PanelA, A + m * lda, lda, CountM, CountK, + MlasGemmU8X8CopyPackA( + PanelA, + A + m * lda, + lda, + CountM, + CountK, RowSumBuffer); - MlasGemmU8X8ScaleSumBuffer(RowSumBuffer, CountM, -offb); + // + // Apply the global depth value constant without the ZeroPointB scaling from: + // + // (A[i] - ZeroPointA) * (B[i] - ZeroPointB) + // ==> + // A[i] * B[i] - A[i] * ZeroPointB - B[i] * ZeroPointA + ZeroPointA * ZeroPointB + // + // The ZeroPointB term is factored out and either applied below for per-matrix + // quantization or inside the kernel for per-column quantization. + // + + for (size_t mm = 0; mm < CountM; mm++) { + RowSumBuffer[mm] -= int32_t(CountK) * ZeroPointA; + } + + // + // Scale the row sums by the per-matrix zero point offset of matrix B. + // + + if (PackedZeroPointB == nullptr) { + MlasGemmU8X8ScaleSumBuffer(RowSumBuffer, CountM, -ZeroPointB); + } // // Step through the rows of the local packed buffer. @@ -209,19 +414,27 @@ Return Value: while (RowsRemaining > 0) { - size_t RowsHandled; + size_t RowsHandled = MlasGemmU8X8Kernel( + pa, + PanelB, + c, + PackedCountK, + RowsRemaining, + CountN, + ldc, + RowSums, + ColumnSumBuffer, + (PackedZeroPointB != nullptr) ? ZeroPointBBuffer : nullptr, + ZeroMode); - RowsHandled = KernelType::GemmKernel(pa, PanelB, c, PackedCountK, - RowsRemaining, CountN, ldc, RowSums, ColumnSumBuffer, - DepthValue, ZeroMode); - - if (PostProcess && WorkBlock->OutputProcessor != nullptr) { - WorkBlock->OutputProcessor->Process(WorkBlock->C, - WorkBlock->RangeStartM + m + CountM - RowsRemaining, - WorkBlock->RangeStartN + n, - RowsHandled, - CountN, - WorkBlock->ldc); + if (PostProcess && Parameters->OutputProcessor != nullptr) { + Parameters->OutputProcessor->Process( + Parameters->C, + RangeStartM + m + CountM - RowsRemaining, + RangeStartN + n, + RowsHandled, + CountN, + Parameters->ldc); } c += ldc * RowsHandled; @@ -239,9 +452,12 @@ Return Value: template void -MLASCALL MlasGemmU8X8PackedOperation( - const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock + const MLAS_GEMM_U8X8_PARAMETERS* Parameters, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN ) /*++ @@ -252,7 +468,15 @@ Routine Description: Arguments: - WorkBlock - Supplies the structure containing the GEMM parameters. + Parameters - Supplies the structure containing the GEMM parameters. + + RangeStartM - Supplies the starting row index to output. + + RangeCountM - Supplies the number of rows to output. + + RangeStartN - Supplies the starting column index to output. + + RangeCountN - Supplies the number of columns to output. Return Value: @@ -266,37 +490,39 @@ Return Value: MLAS_DECLSPEC_ALIGN(int32_t RowSumBuffer[Strides.M], 64); MLAS_DECLSPEC_ALIGN(int32_t ColumnSumBuffer[Strides.N], 64); + MLAS_DECLSPEC_ALIGN(int32_t ZeroPointBBuffer[Strides.N], 64); - const size_t M = WorkBlock->RangeCountM; - const size_t N = WorkBlock->RangeCountN; - const size_t K = WorkBlock->K; + const size_t K = Parameters->K; - const size_t lda = WorkBlock->lda; - const size_t ldc = WorkBlock->ldc; + const size_t lda = Parameters->lda; + const size_t ldc = Parameters->ldc; - const uint8_t* A = WorkBlock->A + WorkBlock->RangeStartM * lda; - const uint8_t* PackedB = (const uint8_t*)WorkBlock->B; - int32_t* C = WorkBlock->C + WorkBlock->RangeStartM * ldc + WorkBlock->RangeStartN; + const uint8_t* A = Parameters->A + RangeStartM * lda; + const uint8_t* PackedB = (const uint8_t*)Parameters->B; + int32_t* C = Parameters->C + RangeStartM * ldc + RangeStartN; + const uint8_t* PackedZeroPointB = Parameters->PerColumnZeroPoints ? + Parameters->ZeroPointB + RangeStartN : nullptr; - int32_t offa = WorkBlock->offa; - int32_t offb = typename KernelType::OffsetBType(WorkBlock->offb); + int32_t ZeroPointA = Parameters->ZeroPointA; + int32_t ZeroPointB = typename KernelType::OffsetBType(*Parameters->ZeroPointB); // - // Flip the sign bit of the zero point offset of matrix B if the data is - // the opposite format of the kernel implementation. + // Fixup the sign bit of the per-matrix zero point offset of matrix B if the + // data is the opposite format of the kernel implementation. This value is + // ignored if per-column zero point offsets are used instead. // - offb = KernelType::FixupZeroPointB(offb, WorkBlock->BIsSigned); + ZeroPointB = MlasGemmU8X8FixupZeroPointB(ZeroPointB, Parameters->BIsSigned); // // Extract the pointer to the column sum buffer from the packed matrix. // const size_t AlignedN = - (WorkBlock->N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1); + (Parameters->N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1); const int32_t* PackedColumnSumBuffer = (const int32_t*)PackedB; PackedB = (const uint8_t*)(PackedColumnSumBuffer + AlignedN); - PackedColumnSumBuffer += WorkBlock->RangeStartN; + PackedColumnSumBuffer += RangeStartN; // // Step through each slice of matrix B along the K dimension. @@ -308,8 +534,7 @@ Return Value: CountK = std::min(K - k, Strides.K); - const size_t PackedCountK = (CountK + KernelType::PackedK - 1) / - KernelType::PackedK; + const size_t PackedCountK = (CountK + KernelType::PackedK - 1) / KernelType::PackedK; if (k > 0) { std::fill_n(ColumnSumBuffer, Strides.N, 0); @@ -321,37 +546,75 @@ Return Value: size_t CountN; - for (size_t n = 0; n < N; n += CountN) { + for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(N - n, Strides.N); + CountN = std::min(RangeCountN - n, Strides.N); if (k == 0) { MlasGemmU8X8ScaleSumBuffer(ColumnSumBuffer, PackedColumnSumBuffer + n, - CountN, -offa); + CountN, -ZeroPointA); + } + + // + // Fixup the sign bit of the per-column zero point offsets of matrix B + // if the data is the opposite format of the kernel implementation. + // + + if (PackedZeroPointB != nullptr) { + MlasGemmU8X8FixupZeroPointB( + PackedZeroPointB + n, + ZeroPointBBuffer, + CountN, + Parameters->BIsSigned); } // // Step through each slice of matrix A along the M dimension. // - const int32_t DepthValue = int32_t(CountK) * offa * offb; - const uint8_t* b = PackedB + (WorkBlock->RangeStartN + n) * + const uint8_t* b = PackedB + (RangeStartN + n) * KernelType::PackedK * PackedCountK; int32_t* c = C + n; size_t CountM; - for (size_t m = 0; m < M; m += CountM) { + for (size_t m = 0; m < RangeCountM; m += CountM) { - CountM = std::min(M - m, Strides.M); + CountM = std::min(RangeCountM - m, Strides.M); // // Copy a panel of matrix A to a local packed buffer. // - KernelType::CopyPackA(PanelA, A + m * lda, lda, CountM, CountK, + MlasGemmU8X8CopyPackA( + PanelA, + A + m * lda, + lda, + CountM, + CountK, RowSumBuffer); - MlasGemmU8X8ScaleSumBuffer(RowSumBuffer, CountM, -offb); + // + // Apply the global depth value constant without the ZeroPointB scaling from: + // + // (A[i] - ZeroPointA) * (B[i] - ZeroPointB) + // ==> + // A[i] * B[i] - A[i] * ZeroPointB - B[i] * ZeroPointA + ZeroPointA * ZeroPointB + // + // The ZeroPointB term is factored out and either applied below for per-matrix + // quantization or inside the kernel for per-column quantization. + // + + for (size_t mm = 0; mm < CountM; mm++) { + RowSumBuffer[mm] -= int32_t(CountK) * ZeroPointA; + } + + // + // Scale the row sums by the per-matrix zero point offset of matrix B. + // + + if (PackedZeroPointB == nullptr) { + MlasGemmU8X8ScaleSumBuffer(RowSumBuffer, CountM, -ZeroPointB); + } // // Step through the rows of the local packed buffer. @@ -366,20 +629,27 @@ Return Value: while (RowsRemaining > 0) { - size_t RowsHandled; + size_t RowsHandled = MlasGemmU8X8Kernel( + pa, + b, + c, + PackedCountK, + RowsRemaining, + CountN, + ldc, + RowSums, + ColumnSumBuffer, + (PackedZeroPointB != nullptr) ? ZeroPointBBuffer : nullptr, + ZeroMode); - RowsHandled = KernelType::GemmKernel(pa, b, c, PackedCountK, - RowsRemaining, CountN, ldc, RowSums, ColumnSumBuffer, - DepthValue, ZeroMode); - - if (PostProcess && WorkBlock->OutputProcessor != nullptr) { - WorkBlock->OutputProcessor->Process( - WorkBlock->C, - WorkBlock->RangeStartM + m + CountM - RowsRemaining, - WorkBlock->RangeStartN + n, + if (PostProcess && Parameters->OutputProcessor != nullptr) { + Parameters->OutputProcessor->Process( + Parameters->C, + RangeStartM + m + CountM - RowsRemaining, + RangeStartN + n, RowsHandled, CountN, - WorkBlock->ldc); + Parameters->ldc); } c += ldc * RowsHandled; @@ -397,42 +667,44 @@ Return Value: #if defined(MLAS_SSE2_INTRINSICS) +struct MLAS_GEMM_U8X8_KERNEL_SSE +{ + typedef int16_t PackedAType; + typedef int16_t PackedBType; + typedef int8_t OffsetBType; + + static constexpr size_t PackedK = 2; + static constexpr MLAS_GEMM_U8X8_STRIDES Strides{12, 128, 128}; +}; + +constexpr size_t MLAS_GEMM_U8X8_KERNEL_SSE::PackedK; +constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8X8_KERNEL_SSE::Strides; + +template<> +MLAS_FORCEINLINE +int32_t +MlasGemmU8X8FixupZeroPointB( + int32_t ZeroPointB, + bool BIsSigned + ) +{ + if (!BIsSigned) { + ZeroPointB = MLAS_GEMM_U8X8_KERNEL_SSE::OffsetBType(ZeroPointB ^ 0x80); + } + + return ZeroPointB; +} + +template<> void -MlasGemmU8X8CopyPackASse( - int16_t* D, +MlasGemmU8X8CopyPackA( + MLAS_GEMM_U8X8_KERNEL_SSE::PackedAType* D, const uint8_t* A, size_t lda, size_t CountM, size_t CountK, int32_t* RowSumBuffer ) -/*++ - -Routine Description: - - This routine copies elements from the source matrix to the destination - packed buffer. - -Arguments: - - D - Supplies the address of the destination packed buffer. - - A - Supplies the address of the source matrix. - - lda - Supplies the number of elements per row of the source matrix. - - CountM - Supplies the number of rows of the source matrix to copy. - - CountK - Supplies the number of columns of the source matrix to copy. - - RowSumBuffer - Supplies the address of the buffer to receive the sums of - the elements along each of the rows. - -Return Value: - - None. - ---*/ { const __m128i ZeroVector = _mm_setzero_si128(); const __m128i OnesWordBroadcast = _mm_set1_epi16(1); @@ -524,9 +796,10 @@ Return Value: } } +MLAS_FORCEINLINE void MlasGemmU8X8CopyPackBProcessSse( - int16_t* D, + MLAS_GEMM_U8X8_KERNEL_SSE::PackedBType* D, __m128i BytesRow0, __m128i BytesRow1, __m128i BitFlipVector, @@ -547,9 +820,10 @@ MlasGemmU8X8CopyPackBProcessSse( _mm_storeu_si128((__m128i*)&D[8], WordsInterleaved1); } +template<> void -MlasGemmU8X8CopyPackBSse( - int16_t* D, +MlasGemmU8X8CopyPackB( + MLAS_GEMM_U8X8_KERNEL_SSE::PackedBType* D, const uint8_t* B, size_t ldb, size_t CountN, @@ -557,36 +831,6 @@ MlasGemmU8X8CopyPackBSse( int32_t* ColumnSumBuffer, bool BIsSigned ) -/*++ - -Routine Description: - - This routine copies elements from the source matrix to the destination - packed buffer. - -Arguments: - - D - Supplies the address of the destination packed buffer. - - B - Supplies the address of the source matrix. - - ldb - Supplies the number of elements per row of the source matrix. - - CountN - Supplies the number of columns of the source matrix to copy. - - CountK - Supplies the number of rows of the source matrix to copy. - - ColumnSumBuffer - Supplies the address of the buffer to receive the sums of - the elements along each of the columns. - - BIsSigned - Supplies true if the source matrix is signed data, else false - if the source matrix is unsigned data. - -Return Value: - - None. - ---*/ { const __m128i OnesWordBroadcast = _mm_set1_epi16(1); const __m128i BitFlipVector = _mm_set1_epi32(BIsSigned ? 0 : 0x80808080); @@ -612,7 +856,7 @@ Return Value: // 128 to avoid overflowing these signed 16-bit accumulators. // - while (k >= 2) { + while (k >= MLAS_GEMM_U8X8_KERNEL_SSE::PackedK) { __m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&b[0]); __m128i BytesRow1 = _mm_loadl_epi64((__m128i*)&b[ldb]); @@ -633,10 +877,6 @@ Return Value: D += 16; } - // - // Reduce the partial accumulators. - // - ColumnSums[0] = _mm_madd_epi16(ColumnSums[0], OnesWordBroadcast); ColumnSums[1] = _mm_madd_epi16(ColumnSums[1], OnesWordBroadcast); @@ -669,7 +909,7 @@ Return Value: // buffer and write to the packed buffer. // - while (k >= 2) { + while (k >= MLAS_GEMM_U8X8_KERNEL_SSE::PackedK) { const uint8_t* bcopy = b; uint8_t* padded = PaddedMatrixBData; @@ -709,11 +949,6 @@ Return Value: MlasGemmU8X8CopyPackBProcessSse(D, BytesRow0, BitFlipVector, BitFlipVector, ColumnSums); } - // - // Reduce the sum for the packed columns and multiply by the zero point - // offset of the other source matrix. - // - ColumnSums[0] = _mm_madd_epi16(ColumnSums[0], OnesWordBroadcast); ColumnSums[1] = _mm_madd_epi16(ColumnSums[1], OnesWordBroadcast); @@ -737,74 +972,54 @@ MlasGemmU8X8MultiplyAccumulateRowSse( Accumulators[1] = _mm_add_epi32(Accumulators[1], _mm_madd_epi16(BElements1, ABroadcast)); } -void -MlasGemmU8X8KernelSse( - const int16_t* A, - const int16_t* B, +template<> +size_t +MlasGemmU8X8Kernel( + const MLAS_GEMM_U8X8_KERNEL_SSE::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_SSE::PackedBType* B, int32_t* C, size_t PackedCountK, + size_t CountM, size_t CountN, + size_t ldc, const int32_t* RowSumBuffer, const int32_t* ColumnSumBuffer, - int32_t DepthValue, + const int32_t* ZeroPointB, bool ZeroMode ) -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - single row. - -Arguments: - - A - Supplies the address of matrix A. The matrix data has been packed - using MlasGemmU8X8CopyPackASse. - - B - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmU8X8CopyPackBSse. - - C - Supplies the address of matrix C. - - PackedCountK - Supplies the number of packed columns from matrix A and the - number of packed rows from matrix B to iterate over. - - CountN - Supplies the number of columns from matrix B and matrix C to iterate - over. - - RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the - zero point offset of matrix B. These values are accumulated into every - row of matrix C. - - ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - - DepthValue - Supplies the value CountK multiplied by the zero point offset - of matrixA multplied by the zero point offset of matrix B. This value is - accumulated into every element of matrix C. - - ZeroMode - Supplies true if the output matrix must be zero initialized, - else false if the output matrix is accumulated into. - -Return Value: - - None. - ---*/ { + MLAS_UNREFERENCED_PARAMETER(CountM); + MLAS_UNREFERENCED_PARAMETER(ldc); + while (CountN > 0) { __m128i Accumulators[2]; // - // Initialize the accumulators with the sum of the global depth value - // constant, the column sums, and the row sums. + // Initialize the accumulators with the row and column sums. // - Accumulators[0] = _mm_set1_epi32(DepthValue); - Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_set1_epi32(RowSumBuffer[0])); - Accumulators[1] = Accumulators[0]; + int32_t RowSumValue = RowSumBuffer[0]; + + if (ZeroPointB != nullptr) { + + int32_t ScaledRowSumBuffer[8]; + + for (size_t i = 0; i < 8; i++) { + ScaledRowSumBuffer[i] = RowSumValue * ZeroPointB[i]; + } + + ZeroPointB += 8; + + Accumulators[0] = _mm_loadu_si128((__m128i*)&ScaledRowSumBuffer[0]); + Accumulators[1] = _mm_loadu_si128((__m128i*)&ScaledRowSumBuffer[4]); + + } else { + + Accumulators[0] = _mm_set1_epi32(RowSumValue); + Accumulators[1] = Accumulators[0]; + } + Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_loadu_si128((__m128i*)&ColumnSumBuffer[0])); Accumulators[1] = _mm_add_epi32(Accumulators[1], _mm_loadu_si128((__m128i*)&ColumnSumBuffer[4])); ColumnSumBuffer += 8; @@ -912,125 +1127,18 @@ Return Value: CountN = 0; } } + + return 1; } -struct MLAS_GEMM_U8X8_KERNEL_SSE -{ - typedef int16_t PackedAType; - typedef int16_t PackedBType; - typedef int8_t OffsetBType; - - static constexpr size_t PackedK = 2; - static constexpr MLAS_GEMM_U8X8_STRIDES Strides{12, 128, 128}; - - MLAS_FORCEINLINE - static - bool - TryGemvKernel( - const uint8_t* A, - const uint8_t* B, - size_t ldb, - int32_t* C, - size_t CountK, - size_t CountN, - bool BIsSigned - ) - { - MLAS_UNREFERENCED_PARAMETER(A); - MLAS_UNREFERENCED_PARAMETER(B); - MLAS_UNREFERENCED_PARAMETER(ldb); - MLAS_UNREFERENCED_PARAMETER(C); - MLAS_UNREFERENCED_PARAMETER(CountK); - MLAS_UNREFERENCED_PARAMETER(CountN); - MLAS_UNREFERENCED_PARAMETER(BIsSigned); - - return false; - } - - MLAS_FORCEINLINE - static - int32_t - FixupZeroPointB( - int32_t offb, - bool BIsSigned - ) - { - if (!BIsSigned) { - offb = OffsetBType(offb ^ 0x80); - } - - return offb; - } - - MLAS_FORCEINLINE - static - void - CopyPackA( - PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer - ) - { - MlasGemmU8X8CopyPackASse(D, A, lda, CountM, CountK, RowSumBuffer); - } - - MLAS_FORCEINLINE - static - void - CopyPackB( - PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned - ) - { - MlasGemmU8X8CopyPackBSse(D, B, ldb, CountN, CountK, ColumnSumBuffer, - BIsSigned); - } - - MLAS_FORCEINLINE - static - size_t - GemmKernel( - const PackedAType* A, - const PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - int32_t DepthValue, - bool ZeroMode - ) - { - MLAS_UNREFERENCED_PARAMETER(CountM); - MLAS_UNREFERENCED_PARAMETER(ldc); - - MlasGemmU8X8KernelSse(A, B, C, PackedCountK, CountN, RowSumBuffer, - ColumnSumBuffer, DepthValue, ZeroMode); - - return 1; - } +const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8X8DispatchSse = { + MlasGemmU8X8Operation, + nullptr, + nullptr, + MLAS_GEMM_U8X8_KERNEL_SSE::PackedK, + 0, }; -constexpr size_t MLAS_GEMM_U8X8_KERNEL_SSE::PackedK; -constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8X8_KERNEL_SSE::Strides; - -template -void -MLASCALL -MlasGemmU8X8Operation( - const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock - ); - #endif #if defined(MLAS_TARGET_AMD64) @@ -1103,112 +1211,107 @@ struct MLAS_GEMM_U8S8_KERNEL_AVX2 static constexpr size_t PackedK = 4; static constexpr MLAS_GEMM_U8X8_STRIDES Strides{24, 256, 128}; static constexpr MLAS_GEMM_U8X8_STRIDES PackedStrides{48, 256, 384}; - - MLAS_FORCEINLINE - static - bool - TryGemvKernel( - const uint8_t* A, - const uint8_t* B, - size_t ldb, - int32_t* C, - size_t CountK, - size_t CountN, - bool BIsSigned - ) - { - if (BIsSigned) { - MlasPlatform.GemvU8S8Kernel(A, B, C, CountK, CountN, ldb); - return true; - } - - return false; - } - - MLAS_FORCEINLINE - static - int32_t - FixupZeroPointB( - int32_t offb, - bool BIsSigned - ) - { - if (!BIsSigned) { - offb = OffsetBType(offb ^ 0x80); - } - - return offb; - } - - MLAS_FORCEINLINE - static - void - CopyPackA( - PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer - ) - { - MlasGemmU8S8CopyPackAAvx2(D, A, lda, CountM, CountK, RowSumBuffer); - } - - MLAS_FORCEINLINE - static - void - CopyPackB( - PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned - ) - { - MlasGemmU8S8CopyPackBAvx2(D, B, ldb, CountN, CountK, ColumnSumBuffer, - BIsSigned); - } - - MLAS_FORCEINLINE - static - size_t - GemmKernel( - const PackedAType* A, - const PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - int32_t DepthValue, - bool ZeroMode - ) - { - return MlasPlatform.GemmU8S8Kernel(A, B, C, PackedCountK, CountM, CountN, - ldc, RowSumBuffer, ColumnSumBuffer, DepthValue, ZeroMode); - } }; constexpr size_t MLAS_GEMM_U8S8_KERNEL_AVX2::PackedK; constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8S8_KERNEL_AVX2::Strides; constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8S8_KERNEL_AVX2::PackedStrides; -template -void -MlasGemmU8X8Operation( - const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock - ); +template<> +MLAS_FORCEINLINE +bool +MlasGemmU8X8TryGemvKernel( + const uint8_t* A, + const uint8_t* B, + size_t ldb, + int32_t* C, + size_t CountK, + size_t CountN, + bool BIsSigned + ) +{ + if (BIsSigned) { + MlasPlatform.GemvU8S8Kernel(A, B, C, CountK, CountN, ldb); + return true; + } -template + return false; +} + +template<> +MLAS_FORCEINLINE +int32_t +MlasGemmU8X8FixupZeroPointB( + int32_t ZeroPointB, + bool BIsSigned + ) +{ + if (!BIsSigned) { + ZeroPointB = MLAS_GEMM_U8S8_KERNEL_AVX2::OffsetBType(ZeroPointB ^ 0x80); + } + + return ZeroPointB; +} + +template<> +MLAS_FORCEINLINE void -MlasGemmU8X8PackedOperation( - const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock - ); +MlasGemmU8X8CopyPackA( + MLAS_GEMM_U8S8_KERNEL_AVX2::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer + ) +{ + MlasGemmU8S8CopyPackAAvx2(D, A, lda, CountM, CountK, RowSumBuffer); +} + +template<> +MLAS_FORCEINLINE +void +MlasGemmU8X8CopyPackB( + MLAS_GEMM_U8S8_KERNEL_AVX2::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned + ) +{ + MlasGemmU8S8CopyPackBAvx2(D, B, ldb, CountN, CountK, ColumnSumBuffer, BIsSigned); +} + +template<> +MLAS_FORCEINLINE +size_t +MlasGemmU8X8Kernel( + const MLAS_GEMM_U8S8_KERNEL_AVX2::PackedAType* A, + const MLAS_GEMM_U8S8_KERNEL_AVX2::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode + ) +{ + return MlasPlatform.GemmU8S8Kernel(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); +} + +const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8S8DispatchAvx2 = { + MlasGemmU8X8Operation, + MlasGemmU8X8PackedOperation, + MlasGemmU8X8CopyPackB, + MLAS_GEMM_U8S8_KERNEL_AVX2::PackedK, + MLAS_GEMM_U8S8_KERNEL_AVX2::PackedStrides.K, +}; struct MLAS_GEMM_U8U8_KERNEL_AVX2 { @@ -1219,115 +1322,86 @@ struct MLAS_GEMM_U8U8_KERNEL_AVX2 static constexpr size_t PackedK = 2; static constexpr MLAS_GEMM_U8X8_STRIDES Strides{24, 256, 128}; static constexpr MLAS_GEMM_U8X8_STRIDES PackedStrides{48, 256, 384}; - - MLAS_FORCEINLINE - static - bool - TryGemvKernel( - const uint8_t* A, - const uint8_t* B, - size_t ldb, - int32_t* C, - size_t CountK, - size_t CountN, - bool BIsSigned - ) - { - MLAS_UNREFERENCED_PARAMETER(A); - MLAS_UNREFERENCED_PARAMETER(B); - MLAS_UNREFERENCED_PARAMETER(ldb); - MLAS_UNREFERENCED_PARAMETER(C); - MLAS_UNREFERENCED_PARAMETER(CountK); - MLAS_UNREFERENCED_PARAMETER(CountN); - MLAS_UNREFERENCED_PARAMETER(BIsSigned); - - return false; - } - - MLAS_FORCEINLINE - static - int32_t - FixupZeroPointB( - int32_t offb, - bool BIsSigned - ) - { - MLAS_UNREFERENCED_PARAMETER(BIsSigned); - - return offb; - } - - MLAS_FORCEINLINE - static - void - CopyPackA( - PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer - ) - { - MlasGemmU8U8CopyPackAAvx2(D, A, lda, CountM, CountK, RowSumBuffer); - } - - MLAS_FORCEINLINE - static - void - CopyPackB( - PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned - ) - { - MLAS_UNREFERENCED_PARAMETER(BIsSigned); - - MlasGemmU8U8CopyPackBAvx2(D, B, ldb, CountN, CountK, ColumnSumBuffer); - } - - MLAS_FORCEINLINE - static - size_t - GemmKernel( - const PackedAType* A, - const PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - int32_t DepthValue, - bool ZeroMode - ) - { - return MlasPlatform.GemmU8U8Kernel(A, B, C, PackedCountK, CountM, CountN, - ldc, RowSumBuffer, ColumnSumBuffer, DepthValue, ZeroMode); - } }; constexpr size_t MLAS_GEMM_U8U8_KERNEL_AVX2::PackedK; constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8U8_KERNEL_AVX2::Strides; constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8U8_KERNEL_AVX2::PackedStrides; -template -void -MLASCALL -MlasGemmU8X8Operation( - const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock - ); +template<> +MLAS_FORCEINLINE +int32_t +MlasGemmU8X8FixupZeroPointB( + int32_t ZeroPointB, + bool BIsSigned + ) +{ + MLAS_UNREFERENCED_PARAMETER(BIsSigned); -template + return ZeroPointB; +} + +template<> +MLAS_FORCEINLINE void -MlasGemmU8X8PackedOperation( - const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock - ); +MlasGemmU8X8CopyPackA( + MLAS_GEMM_U8U8_KERNEL_AVX2::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer + ) +{ + MlasGemmU8U8CopyPackAAvx2(D, A, lda, CountM, CountK, RowSumBuffer); +} + +template<> +MLAS_FORCEINLINE +void +MlasGemmU8X8CopyPackB( + MLAS_GEMM_U8U8_KERNEL_AVX2::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned + ) +{ + MLAS_UNREFERENCED_PARAMETER(BIsSigned); + + MlasGemmU8U8CopyPackBAvx2(D, B, ldb, CountN, CountK, ColumnSumBuffer); +} + +template<> +MLAS_FORCEINLINE +size_t +MlasGemmU8X8Kernel( + const MLAS_GEMM_U8U8_KERNEL_AVX2::PackedAType* A, + const MLAS_GEMM_U8U8_KERNEL_AVX2::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode + ) +{ + return MlasPlatform.GemmU8U8Kernel(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); +} + +const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8U8DispatchAvx2 = { + MlasGemmU8X8Operation, + MlasGemmU8X8PackedOperation, + MlasGemmU8X8CopyPackB, + MLAS_GEMM_U8U8_KERNEL_AVX2::PackedK, + MLAS_GEMM_U8U8_KERNEL_AVX2::PackedStrides.K, +}; #endif @@ -1339,60 +1413,65 @@ MlasGemmU8X8PackedOperation( // N.B. The kernel has not been ported to build with the Windows ARM32 toolset. // -extern "C" -size_t -MLASCALL -MlasGemmU8X8KernelNeon( - const uint8_t* A, - const uint8_t* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumVector, - const int32_t* ColumnSumVector, - int32_t DepthValue, - bool ZeroMode - ); +extern "C" { + size_t + MLASCALL + MlasGemmU8X8KernelNeon( + const uint8_t* A, + const uint8_t* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB, + bool ZeroMode + ); +} + +struct MLAS_GEMM_U8X8_KERNEL_NEON +{ + typedef uint8_t PackedAType; + typedef uint8_t PackedBType; + typedef uint8_t OffsetBType; + + static constexpr size_t PackedK = 4; + static constexpr MLAS_GEMM_U8X8_STRIDES Strides{24, 128, 256}; + static constexpr MLAS_GEMM_U8X8_STRIDES PackedStrides{24, 128, 256}; +}; + +constexpr size_t MLAS_GEMM_U8X8_KERNEL_NEON::PackedK; +constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8X8_KERNEL_NEON::Strides; +constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8X8_KERNEL_NEON::PackedStrides; + +template<> +MLAS_FORCEINLINE +int32_t +MlasGemmU8X8FixupZeroPointB( + int32_t ZeroPointB, + bool BIsSigned + ) +{ + if (BIsSigned) { + ZeroPointB = MLAS_GEMM_U8X8_KERNEL_NEON::OffsetBType(ZeroPointB ^ 0x80); + } + + return ZeroPointB; +} + +template<> void -MLASCALL -MlasGemmU8X8CopyPackANeon( - uint8_t* D, +MlasGemmU8X8CopyPackA( + MLAS_GEMM_U8X8_KERNEL_NEON::PackedAType* D, const uint8_t* A, size_t lda, size_t CountM, size_t CountK, int32_t* RowSumBuffer ) -/*++ - -Routine Description: - - This routine copies elements from the source matrix to the destination - packed buffer. - -Arguments: - - D - Supplies the address of the destination packed buffer. - - A - Supplies the address of the source matrix. - - lda - Supplies the number of elements per row of the source matrix. - - CountM - Supplies the number of rows of the source matrix to copy. - - CountK - Supplies the number of columns of the source matrix to copy. - - RowSumBuffer - Supplies the address of the buffer to receive the sums of - the elements along each of the rows. - -Return Value: - - None. - ---*/ { uint8_t PaddedMatrixAData[16]; @@ -1696,10 +1775,10 @@ MlasGemmU8X8CopyPackBProcessNeon( #endif } +template<> void -MLASCALL -MlasGemmU8X8CopyPackBNeon( - uint8_t* D, +MlasGemmU8X8CopyPackB( + MLAS_GEMM_U8X8_KERNEL_NEON::PackedBType* D, const uint8_t* B, size_t ldb, size_t CountN, @@ -1707,40 +1786,11 @@ MlasGemmU8X8CopyPackBNeon( int32_t* ColumnSumBuffer, bool BIsSigned ) -/*++ - -Routine Description: - - This routine copies elements from the source matrix to the destination - packed buffer. - -Arguments: - - D - Supplies the address of the destination packed buffer. - - B - Supplies the address of the source matrix. - - ldb - Supplies the number of elements per row of the source matrix. - - CountN - Supplies the number of columns of the source matrix to copy. - - CountK - Supplies the number of rows of the source matrix to copy. - - ColumnSumBuffer - Supplies the address of the buffer to receive the sums of - the elements along each of the columns. - - BIsSigned - Supplies true if the source matrix is signed data, else false - if the source matrix is unsigned data. - -Return Value: - - None. - ---*/ { const uint8x8_t BitFlipVector = vdup_n_u8(BIsSigned ? 0x80 : 0); const uint8x8_t ZeroVector = vmov_n_u8(0); - const size_t AlignedCountK = (CountK + 3) & ~3; + const size_t AlignedCountK = + (CountK + MLAS_GEMM_U8X8_KERNEL_NEON::PackedK - 1) & ~(MLAS_GEMM_U8X8_KERNEL_NEON::PackedK - 1); // // Process 8 columns of matrix B in a loop. @@ -1821,112 +1871,607 @@ Return Value: } } -struct MLAS_GEMM_U8X8_KERNEL_NEON +template<> +MLAS_FORCEINLINE +size_t +MlasGemmU8X8Kernel( + const MLAS_GEMM_U8X8_KERNEL_NEON::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_NEON::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode + ) { - typedef uint8_t PackedAType; - typedef uint8_t PackedBType; - typedef uint8_t OffsetBType; + return MlasGemmU8X8KernelNeon(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); +} - static constexpr size_t PackedK = 4; - static constexpr MLAS_GEMM_U8X8_STRIDES Strides{24, 128, 256}; - static constexpr MLAS_GEMM_U8X8_STRIDES PackedStrides{24, 256, 128}; +const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8X8DispatchNeon = { + MlasGemmU8X8Operation, + MlasGemmU8X8PackedOperation, + MlasGemmU8X8CopyPackB, + MLAS_GEMM_U8X8_KERNEL_NEON::PackedK, + MLAS_GEMM_U8X8_KERNEL_NEON::PackedStrides.K, +}; - MLAS_FORCEINLINE - static - bool - TryGemvKernel( - const uint8_t* A, - const uint8_t* B, - size_t ldb, - int32_t* C, - size_t CountK, - size_t CountN, - bool BIsSigned - ) - { - MLAS_UNREFERENCED_PARAMETER(A); - MLAS_UNREFERENCED_PARAMETER(B); - MLAS_UNREFERENCED_PARAMETER(ldb); - MLAS_UNREFERENCED_PARAMETER(C); - MLAS_UNREFERENCED_PARAMETER(CountK); - MLAS_UNREFERENCED_PARAMETER(CountN); - MLAS_UNREFERENCED_PARAMETER(BIsSigned); +#endif - return false; - } +#if defined(MLAS_NEON64_INTRINSICS) - MLAS_FORCEINLINE - static - int32_t - FixupZeroPointB( - int32_t offb, - bool BIsSigned - ) - { - if (BIsSigned) { - offb = OffsetBType(offb ^ 0x80); - } +// +// Define the prototypes of the NEON UDOT routines written in assembly. +// - return offb; - } +extern "C" { - MLAS_FORCEINLINE - static - void - CopyPackA( - PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer - ) - { - MlasGemmU8X8CopyPackANeon(D, A, lda, CountM, CountK, RowSumBuffer); - } - - MLAS_FORCEINLINE - static - void - CopyPackB( - PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned - ) - { - MlasGemmU8X8CopyPackBNeon(D, B, ldb, CountN, CountK, ColumnSumBuffer, - BIsSigned); - } - - MLAS_FORCEINLINE - static size_t - GemmKernel( - const PackedAType* A, - const PackedBType* B, + MLASCALL + MlasGemmU8X8KernelUdot( + const uint8_t* A, + const uint8_t* B, int32_t* C, size_t PackedCountK, size_t CountM, size_t CountN, size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - int32_t DepthValue, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB, bool ZeroMode - ) - { - return MlasGemmU8X8KernelNeon(A, B, C, PackedCountK, CountM, CountN, ldc, - RowSumBuffer, ColumnSumBuffer, DepthValue, ZeroMode); - } + ); +} + +struct MLAS_GEMM_U8X8_KERNEL_UDOT +{ + typedef uint8_t PackedAType; + typedef uint8_t PackedBType; + typedef uint8_t OffsetBType; + + static constexpr size_t PackedK = 8; + static constexpr MLAS_GEMM_U8X8_STRIDES Strides{24, 128, 256}; + static constexpr MLAS_GEMM_U8X8_STRIDES PackedStrides{24, 128, 384}; }; -constexpr size_t MLAS_GEMM_U8X8_KERNEL_NEON::PackedK; -constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8X8_KERNEL_NEON::Strides; -constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8X8_KERNEL_NEON::PackedStrides; +constexpr size_t MLAS_GEMM_U8X8_KERNEL_UDOT::PackedK; +constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8X8_KERNEL_UDOT::Strides; +constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8X8_KERNEL_UDOT::PackedStrides; + +template<> +MLAS_FORCEINLINE +int32_t +MlasGemmU8X8FixupZeroPointB( + int32_t ZeroPointB, + bool BIsSigned + ) +{ + if (BIsSigned) { + ZeroPointB = MLAS_GEMM_U8X8_KERNEL_UDOT::OffsetBType(ZeroPointB ^ 0x80); + } + + return ZeroPointB; +} + +template<> +void +MlasGemmU8X8CopyPackA( + MLAS_GEMM_U8X8_KERNEL_NEON::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer + ) +{ + uint8_t PaddedMatrixAData[16]; + + // + // Process four rows of matrix A. + // + // The buffer is packed as a series of 16 byte vectors where four rows are + // interleaved with the following pattern: + // + // [ A0 A1 A2 A3 B0 B1 B2 B3 C0 C1 C2 C3 D0 D1 D2 D3 ] + // [ A4 A5 A6 A7 B4 B5 B6 B7 C4 C5 C6 C7 D4 D5 D6 D7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + while (CountM >= 4) { + + const uint8_t* a0 = A; + const uint8_t* a1 = a0 + lda; + const uint8_t* a2 = a1 + lda; + const uint8_t* a3 = a2 + lda; + + size_t k = CountK; + uint32x4_t RowSums = vmovq_n_u32(0); + + while (k >= 16) { + + uint32x4_t v0 = vld1q_u32(reinterpret_cast(a0)); + a0 += 16; + uint32x4_t v1 = vld1q_u32(reinterpret_cast(a1)); + a1 += 16; + uint32x4_t v2 = vld1q_u32(reinterpret_cast(a2)); + a2 += 16; + uint32x4_t v3 = vld1q_u32(reinterpret_cast(a3)); + a3 += 16; + + uint32x4_t z0 = vzip1q_u32(v0, v2); + uint32x4_t z1 = vzip2q_u32(v0, v2); + uint32x4_t z2 = vzip1q_u32(v1, v3); + uint32x4_t z3 = vzip2q_u32(v1, v3); + + v0 = vzip1q_u32(z0, z2); + v1 = vzip2q_u32(z0, z2); + v2 = vzip1q_u32(z1, z3); + v3 = vzip2q_u32(z1, z3); + + vst1q_u8(&D[0], vreinterpretq_u8_u32(v0)); + vst1q_u8(&D[16], vreinterpretq_u8_u32(v1)); + vst1q_u8(&D[32], vreinterpretq_u8_u32(v2)); + vst1q_u8(&D[48], vreinterpretq_u8_u32(v3)); + + RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vreinterpretq_u8_u32(v0))); + RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vreinterpretq_u8_u32(v1))); + RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vreinterpretq_u8_u32(v2))); + RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vreinterpretq_u8_u32(v3))); + + D += 64; + k -= 16; + } + + uint32x4_t GatherVector = vmovq_n_u32(0); + + while (k >= 4) { + + GatherVector = vld1q_lane_u32(reinterpret_cast(a0), GatherVector, 0); + a0 += 4; + GatherVector = vld1q_lane_u32(reinterpret_cast(a1), GatherVector, 1); + a1 += 4; + GatherVector = vld1q_lane_u32(reinterpret_cast(a2), GatherVector, 2); + a2 += 4; + GatherVector = vld1q_lane_u32(reinterpret_cast(a3), GatherVector, 3); + a3 += 4; + + uint8x16_t PackedVector = vreinterpretq_u8_u32(GatherVector); + vst1q_u8(D, PackedVector); + + RowSums = vpadalq_u16(RowSums, vpaddlq_u8(PackedVector)); + + D += 16; + k -= 4; + } + + if (k > 0) { + + // + // Copy the remaining bytes to the zero padded stack buffer. + // + + uint8_t* d = PaddedMatrixAData; + + vst1q_u8(PaddedMatrixAData, vmovq_n_u8(0)); + + while (k > 0) { + + d[0] = *a0++; + d[4] = *a1++; + d[8] = *a2++; + d[12] = *a3++; + + d += 1; + k -= 1; + } + + uint8x16_t PackedVector = vld1q_u8(PaddedMatrixAData); + vst1q_u8(D, PackedVector); + + RowSums = vpadalq_u16(RowSums, vpaddlq_u8(PackedVector)); + + D += 16; + } + + if (((CountK - 1) & 7) < 4) { + + vst1q_u8(D, vmovq_n_u8(0)); + + D += 16; + } + + vst1q_s32(RowSumBuffer, vreinterpretq_s32_u32(RowSums)); + RowSumBuffer += 4; + + A = A + lda * 4; + CountM -= 4; + } + + // + // Process two rows of matrix A. + // + // The buffer is packed as a series of 8 byte vectors where two rows are + // interleaved with the following pattern: + // + // [ A0 A1 A2 A3 B0 B1 B2 B3 ] + // [ A4 A5 A6 A7 B4 B5 B6 B7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of four, then the vector is padded + // with zeroes. + // + + if (CountM >= 2) { + + const uint8_t* a0 = A; + const uint8_t* a1 = a0 + lda; + + size_t k = CountK; + uint32x2_t RowSums = vmov_n_u32(0); + uint32x2_t GatherVector = vmov_n_u32(0); + + while (k >= 4) { + + GatherVector = vld1_lane_u32(reinterpret_cast(a0), GatherVector, 0); + a0 += 4; + GatherVector = vld1_lane_u32(reinterpret_cast(a1), GatherVector, 1); + a1 += 4; + + uint8x8_t PackedVector = vreinterpret_u8_u32(GatherVector); + vst1_u8(D, PackedVector); + + RowSums = vpadal_u16(RowSums, vpaddl_u8(PackedVector)); + + D += 8; + k -= 4; + } + + if (k > 0) { + + // + // Copy the remaining bytes to the zero padded stack buffer. + // + + uint8_t* d = PaddedMatrixAData; + + vst1_u8(PaddedMatrixAData, vmov_n_u8(0)); + + while (k > 0) { + + d[0] = *a0++; + d[4] = *a1++; + + d += 1; + k -= 1; + } + + uint8x8_t PackedVector = vld1_u8(PaddedMatrixAData); + vst1_u8(D, PackedVector); + + RowSums = vpadal_u16(RowSums, vpaddl_u8(PackedVector)); + + D += 8; + } + + if (((CountK - 1) & 7) < 4) { + + vst1_u8(D, vmov_n_u8(0)); + + D += 8; + } + + vst1_s32(RowSumBuffer, vreinterpret_s32_u32(RowSums)); + RowSumBuffer += 2; + + A = A + lda * 2; + CountM -= 2; + } + + // + // Process one row of matrix A. + // + // The buffer is packed as a series of 4 byte with the following pattern: + // + // [ A0 A1 A2 A3 ] + // [ A4 A5 A6 A7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of four, then the vector is padded + // with zeroes. + // + + if (CountM > 0) { + + const uint8_t* a = A; + size_t k = CountK; + uint32x4_t RowSums = vmovq_n_u32(0); + + while (k >= 16) { + + uint8x16_t v = vld1q_u8(a); + a += 16; + + vst1q_u8(D, v); + + RowSums = vpadalq_u16(RowSums, vpaddlq_u8(v)); + + D += 16; + k -= 16; + } + + if (k > 0) { + + // + // Copy the remaining bytes to the zero padded stack buffer. + // + + vst1q_u8(PaddedMatrixAData, vmovq_n_u8(0)); + + for (size_t kk = 0; kk < k; kk++) { + PaddedMatrixAData[kk] = a[kk]; + } + + uint8x16_t v = vld1q_u8(PaddedMatrixAData); + vst1q_u8(D, v); + + RowSums = vpadalq_u16(RowSums, vpaddlq_u8(v)); + } + +#if defined(_M_ARM64) + // N.B. The workaround of defining a local vaddvq_u32 doesn't work here + // as VS2019 added new intrinsics to make the operation work. Also, not + // all build environments using VS2019 have the up-to-date arm64_neon.h, + // so fallback to pairwise addition. + RowSums = vpaddq_u32(RowSums, RowSums); + RowSums = vpaddq_u32(RowSums, RowSums); + vst1q_lane_u32(reinterpret_cast(RowSumBuffer), RowSums, 0); +#else + *RowSumBuffer = int32_t(vaddvq_u32(RowSums)); +#endif + } +} + +MLAS_FORCEINLINE +void +MlasGemmU8X8CopyPackBProcessUdot( + MLAS_GEMM_U8X8_KERNEL_UDOT::PackedBType* D, + uint8x8_t BytesRow[4], + uint8x16_t BitFlipVector, + uint32x4_t ColumnSums[2] + ) +{ + uint8x16_t v02 = veorq_u8(vcombine_u8(BytesRow[0], BytesRow[2]), BitFlipVector); + uint8x16_t v13 = veorq_u8(vcombine_u8(BytesRow[1], BytesRow[3]), BitFlipVector); + + uint8x16x2_t zw = vzipq_u8(v02, v13); + uint16x8x2_t zd = vzipq_u16(vreinterpretq_u16_u8(zw.val[0]), vreinterpretq_u16_u8(zw.val[1])); + + vst1q_u8(&D[0], vreinterpretq_u8_u16(zd.val[0])); + vst1q_u8(&D[16], vreinterpretq_u8_u16(zd.val[1])); + + ColumnSums[0] = vpadalq_u16(ColumnSums[0], vpaddlq_u8(vreinterpretq_u8_u16(zd.val[0]))); + ColumnSums[1] = vpadalq_u16(ColumnSums[1], vpaddlq_u8(vreinterpretq_u8_u16(zd.val[1]))); +} + +template<> +void +MlasGemmU8X8CopyPackB( + MLAS_GEMM_U8X8_KERNEL_UDOT::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned + ) +{ + const uint8x16_t ZeroVector = vmovq_n_u8(0); + const uint8x16_t BitFlipVector = vdupq_n_u8(BIsSigned ? 0x80 : 0); + uint8x8_t BytesRow[4]; + + // + // Process 8 columns of matrix B in a loop. + // + // The buffer is packed as a series of 16 byte vectors where eight rows are + // interleaved with the following pattern: + // + // [ A0 A1 A2 A3 B0 B1 B2 B3 C0 C1 C2 C3 D0 D1 D2 D3 ] + // [ E0 E1 E2 E3 F0 F1 F2 F3 G0 G1 G2 G3 H0 H1 H2 H3 ] + // [ A4 A5 A6 A7 B4 B5 B6 B7 C4 C5 C6 C7 D4 D5 D6 D7 ] + // [ E4 E5 E6 E7 F4 F5 F6 F7 G4 G5 G6 G7 H4 H5 H6 H7 ] + // + // Copy columns from matrix B to the packed buffer. Signed buffers are + // converted to unsigned buffers in order to share a common kernel. + // + // If CountK is not aligned to a multiple of eight, then the packed buffer + // is padded with zero vectors. + // + // If CountN is not aligned to a multiple of four, then the extra columns + // are padded with zeroes. + // + + while (CountN >= 8) { + + const uint8_t* b = B; + size_t k = CountK; + uint32x4_t ColumnSums[2]; + + ColumnSums[0] = vmovq_n_u32(0); + ColumnSums[1] = vmovq_n_u32(0); + + // + // Interleave rows of matrix B and write to the packed buffer. + // + + while (k >= 4) { + + BytesRow[0] = vld1_u8(&b[ldb * 0]); + BytesRow[1] = vld1_u8(&b[ldb * 1]); + BytesRow[2] = vld1_u8(&b[ldb * 2]); + BytesRow[3] = vld1_u8(&b[ldb * 3]); + + MlasGemmU8X8CopyPackBProcessUdot(D, BytesRow, BitFlipVector, ColumnSums); + + b += ldb * 4; + D += 32; + k -= 4; + } + + if (k > 0) { + + BytesRow[0] = vld1_u8(&b[ldb * 0]); + BytesRow[1] = (k >= 2) ? vld1_u8(&b[ldb * 1]) : vget_low_u8(BitFlipVector); + BytesRow[2] = (k > 2) ? vld1_u8(&b[ldb * 2]) : vget_low_u8(BitFlipVector); + BytesRow[3] = vget_low_u8(BitFlipVector); + + MlasGemmU8X8CopyPackBProcessUdot(D, BytesRow, BitFlipVector, ColumnSums); + + D += 32; + } + + // + // Zero pad the output buffer to a multiple of PackedK if the above + // processed an odd number of four row bundles. + // + + if (((CountK - 1) & 7) < 4) { + + vst1q_u8(&D[0], ZeroVector); + vst1q_u8(&D[16], ZeroVector); + + D += 32; + } + + vst1q_s32(&ColumnSumBuffer[0], vreinterpretq_s32_u32(ColumnSums[0])); + vst1q_s32(&ColumnSumBuffer[4], vreinterpretq_s32_u32(ColumnSums[1])); + ColumnSumBuffer += 8; + + B += 8; + CountN -= 8; + } + + // + // Process the remaining columns of matrix B. + // + + if (CountN > 0) { + + const uint8_t* b = B; + size_t k = CountK; + uint8_t PaddedMatrixBData[32]; + uint32x4_t ColumnSums[2]; + + vst1q_u8(&PaddedMatrixBData[0], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[16], BitFlipVector); + + ColumnSums[0] = vmovq_n_u32(0); + ColumnSums[1] = vmovq_n_u32(0); + + // + // Interleave rows of matrix B using an intermediate zero padded stack + // buffer and write to the packed buffer. + // + + while (k > 0) { + + const uint8_t* bcopy0 = &b[ldb * 0]; + const uint8_t* bcopy1 = &b[ldb * 1]; + const uint8_t* bcopy2 = &b[ldb * 2]; + const uint8_t* bcopy3 = &b[ldb * 3]; + + if (k >= 4) { + + b += ldb * 4; + k -= 4; + + } else { + + vst1q_u8(&PaddedMatrixBData[0], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[16], BitFlipVector); + + bcopy1 = (k >= 2) ? bcopy1 : &PaddedMatrixBData[24]; + bcopy2 = (k > 2) ? bcopy2 : &PaddedMatrixBData[24]; + bcopy3 = &PaddedMatrixBData[24]; + + k = 0; + } + + uint8_t* padded = PaddedMatrixBData; + uint8_t* padded_end = padded + CountN; + + do { + padded[0] = *bcopy0++; + padded[8] = *bcopy1++; + padded[16] = *bcopy2++; + padded[24] = *bcopy3++; + } while (++padded < padded_end); + + BytesRow[0] = vld1_u8(&PaddedMatrixBData[0]); + BytesRow[1] = vld1_u8(&PaddedMatrixBData[8]); + BytesRow[2] = vld1_u8(&PaddedMatrixBData[16]); + BytesRow[3] = vld1_u8(&PaddedMatrixBData[24]); + + MlasGemmU8X8CopyPackBProcessUdot(D, BytesRow, BitFlipVector, ColumnSums); + + D += 32; + } + + // + // Zero pad the output buffer to a multiple of PackedK if the above + // processed an odd number of four row bundles. + // + + if (((CountK - 1) & 7) < 4) { + + vst1q_u8(&D[0], ZeroVector); + vst1q_u8(&D[16], ZeroVector); + + D += 32; + } + + vst1q_s32(&ColumnSumBuffer[0], vreinterpretq_s32_u32(ColumnSums[0])); + vst1q_s32(&ColumnSumBuffer[4], vreinterpretq_s32_u32(ColumnSums[1])); + } +} + +template<> +MLAS_FORCEINLINE +size_t +MlasGemmU8X8Kernel( + const MLAS_GEMM_U8X8_KERNEL_UDOT::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_UDOT::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode + ) +{ + return MlasGemmU8X8KernelUdot(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); +} + +const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8X8DispatchUdot = { + MlasGemmU8X8Operation, + MlasGemmU8X8PackedOperation, + MlasGemmU8X8CopyPackB, + MLAS_GEMM_U8X8_KERNEL_UDOT::PackedK, + MLAS_GEMM_U8X8_KERNEL_UDOT::PackedStrides.K, +}; #endif @@ -1939,188 +2484,177 @@ struct MLAS_GEMM_U8X8_KERNEL_DEFAULT static constexpr size_t PackedK = 4; static constexpr MLAS_GEMM_U8X8_STRIDES Strides{16, 128, 128}; static constexpr MLAS_GEMM_U8X8_STRIDES PackedStrides{16, 128, 128}; +}; - MLAS_FORCEINLINE - static - bool - TryGemvKernel( - const uint8_t* A, - const uint8_t* B, - size_t ldb, - int32_t* C, - size_t CountK, - size_t CountN, - bool BIsSigned - ) - { - MLAS_UNREFERENCED_PARAMETER(A); - MLAS_UNREFERENCED_PARAMETER(B); - MLAS_UNREFERENCED_PARAMETER(ldb); - MLAS_UNREFERENCED_PARAMETER(C); - MLAS_UNREFERENCED_PARAMETER(CountK); - MLAS_UNREFERENCED_PARAMETER(CountN); - MLAS_UNREFERENCED_PARAMETER(BIsSigned); - - return false; +template<> +MLAS_FORCEINLINE +int32_t +MlasGemmU8X8FixupZeroPointB( + int32_t ZeroPointB, + bool BIsSigned + ) +{ + if (BIsSigned) { + ZeroPointB = MLAS_GEMM_U8X8_KERNEL_DEFAULT::OffsetBType(ZeroPointB ^ 0x80); } - MLAS_FORCEINLINE - static - int32_t - FixupZeroPointB( - int32_t offb, - bool BIsSigned - ) - { - if (BIsSigned) { - offb = OffsetBType(offb ^ 0x80); + return ZeroPointB; +} + +template<> +void +MlasGemmU8X8CopyPackA( + MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer + ) +{ + const size_t AlignedCountK = + (CountK + MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedK - 1) & ~(MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedK - 1); + + // + // Process a single row of matrix A in a loop. + // + + while (CountM-- > 0) { + + int32_t RowSum = 0; + + for (size_t k = 0; k < CountK; k++) { + + uint8_t a0 = A[k]; + D[k] = a0; + + RowSum += a0; } - return offb; - } - - static - void - CopyPackA( - PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer - ) - { - const size_t AlignedCountK = - (CountK + MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedK - 1) & ~(MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedK - 1); - - // - // Process a single row of matrix A in a loop. - // - - while (CountM-- > 0) { - - int32_t RowSum = 0; - - for (size_t k = 0; k < CountK; k++) { - - uint8_t a0 = A[k]; - D[k] = a0; - - RowSum += a0; - } - - for (size_t k = CountK; k < AlignedCountK; k++) { - D[k] = 0; - } - - *RowSumBuffer++ = RowSum; - - A += lda; - D += AlignedCountK; - } - } - - static - void - CopyPackB( - PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned - ) - { - const size_t AlignedCountK = - (CountK + MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedK - 1) & ~(MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedK - 1); - const uint8_t BitFlipValue = (BIsSigned ? 0x80 : 0); - - // - // Process a single column of matrix B in a loop. - // - - while (CountN-- > 0) { - - const uint8_t* b = B; - int32_t ColumnSum = 0; - - // - // Transpose the data from matrix B to the packed buffer. - // - - for (size_t k = 0; k < CountK; k++) { - - uint8_t b0 = b[0] ^ BitFlipValue; - D[k] = b0; - - ColumnSum += b0; - - b += ldb; - } - - for (size_t k = CountK; k < AlignedCountK; k++) { - D[k] = 0; - } - - *ColumnSumBuffer++ = ColumnSum; - - B += 1; - D += AlignedCountK; - } - } - - static - size_t - GemmKernel( - const PackedAType* A, - const PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - int32_t DepthValue, - bool ZeroMode - ) - { - MLAS_UNREFERENCED_PARAMETER(CountM); - MLAS_UNREFERENCED_PARAMETER(ldc); - - // - // Process a single column of matrix B in a loop. - // - - while (CountN-- > 0) { - - int32_t Accumulator = RowSumBuffer[0] + ColumnSumBuffer[0] + DepthValue; - ColumnSumBuffer += 1; - - const PackedAType* a = A; - - for (size_t k = 0; k < PackedCountK; k++) { - - Accumulator += a[0] * B[0]; - Accumulator += a[1] * B[1]; - Accumulator += a[2] * B[2]; - Accumulator += a[3] * B[3]; - - a += 4; - B += 4; - } - - if (!ZeroMode) { - Accumulator += C[0]; - } - - C[0] = Accumulator; - C += 1; + for (size_t k = CountK; k < AlignedCountK; k++) { + D[k] = 0; } - return 1; + *RowSumBuffer++ = RowSum; + + A += lda; + D += AlignedCountK; } +} + +template<> +void +MlasGemmU8X8CopyPackB( + MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned + ) +{ + const size_t AlignedCountK = + (CountK + MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedK - 1) & ~(MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedK - 1); + const uint8_t BitFlipValue = (BIsSigned ? 0x80 : 0); + + // + // Process a single column of matrix B in a loop. + // + + while (CountN-- > 0) { + + const uint8_t* b = B; + int32_t ColumnSum = 0; + + // + // Transpose the data from matrix B to the packed buffer. + // + + for (size_t k = 0; k < CountK; k++) { + + uint8_t b0 = b[0] ^ BitFlipValue; + D[k] = b0; + + ColumnSum += b0; + + b += ldb; + } + + for (size_t k = CountK; k < AlignedCountK; k++) { + D[k] = 0; + } + + *ColumnSumBuffer++ = ColumnSum; + + B += 1; + D += AlignedCountK; + } +} + +template<> +size_t +MlasGemmU8X8Kernel( + const MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode + ) +{ + MLAS_UNREFERENCED_PARAMETER(CountM); + MLAS_UNREFERENCED_PARAMETER(ldc); + + // + // Process a single column of matrix B in a loop. + // + + while (CountN-- > 0) { + + int32_t Accumulator = *RowSumBuffer; + + if (ZeroPointB != nullptr) { + Accumulator *= *ZeroPointB++; + } + + Accumulator += *ColumnSumBuffer++; + + const auto* a = A; + + for (size_t k = 0; k < PackedCountK; k++) { + + Accumulator += a[0] * B[0]; + Accumulator += a[1] * B[1]; + Accumulator += a[2] * B[2]; + Accumulator += a[3] * B[3]; + + a += 4; + B += 4; + } + + if (!ZeroMode) { + Accumulator += C[0]; + } + + C[0] = Accumulator; + C += 1; + } + + return 1; +} + +const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8X8DispatchDefault = { + MlasGemmU8X8Operation, + nullptr, + nullptr, + MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedK, + 0, }; void @@ -2147,80 +2681,75 @@ Return Value: --*/ { - MLAS_GEMM_U8X8_WORK_BLOCK WorkBlock; + const auto* WorkBlock = (MLAS_GEMM_U8X8_WORK_BLOCK*)Context; + const auto* Parameters = WorkBlock->Parameters; - memcpy(&WorkBlock, Context, sizeof(MLAS_GEMM_U8X8_WORK_BLOCK)); - - const int32_t ThreadIdM = ThreadId / WorkBlock.ThreadCountN; - const int32_t ThreadIdN = ThreadId % WorkBlock.ThreadCountN; + const int32_t ThreadIdM = ThreadId / WorkBlock->ThreadCountN; + const int32_t ThreadIdN = ThreadId % WorkBlock->ThreadCountN; // // Partition the operation along the M dimension. // - MlasPartitionWork(ThreadIdM, WorkBlock.ThreadCountM, WorkBlock.M, - &WorkBlock.RangeStartM, &WorkBlock.RangeCountM); + size_t RangeStartM; + size_t RangeCountM; + + const size_t M = Parameters->M; + + MlasPartitionWork(ThreadIdM, WorkBlock->ThreadCountM, M, &RangeStartM, &RangeCountM); // // Partition the operation along the N dimension. // - const size_t BlockedN = (WorkBlock.N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) / + size_t RangeStartN; + size_t RangeCountN; + + const size_t N = Parameters->N; + + const size_t BlockedN = (N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_QGEMM_STRIDEN_THREAD_ALIGN; - MlasPartitionWork(ThreadIdN, WorkBlock.ThreadCountN, BlockedN, - &WorkBlock.RangeStartN, &WorkBlock.RangeCountN); + MlasPartitionWork(ThreadIdN, WorkBlock->ThreadCountN, BlockedN, + &RangeStartN, &RangeCountN); - WorkBlock.RangeStartN *= MLAS_QGEMM_STRIDEN_THREAD_ALIGN; - WorkBlock.RangeCountN *= MLAS_QGEMM_STRIDEN_THREAD_ALIGN; + RangeStartN *= MLAS_QGEMM_STRIDEN_THREAD_ALIGN; + RangeCountN *= MLAS_QGEMM_STRIDEN_THREAD_ALIGN; - WorkBlock.RangeCountN = std::min(WorkBlock.N - WorkBlock.RangeStartN, - WorkBlock.RangeCountN); + RangeCountN = std::min(N - RangeStartN, RangeCountN); // // Dispatch the partitioned operation. // -#if defined(MLAS_TARGET_AMD64) - PMLAS_GEMM_U8X8_OPERATION GemmU8X8Operation; + const auto* GemmU8X8Dispatch = MlasGemmU8X8GetDispatch(Parameters->BIsSigned); + MLAS_GEMM_U8X8_OPERATION* GemmU8X8Operation; - if (WorkBlock.BIsSigned) { - GemmU8X8Operation = WorkBlock.BIsPacked ? - MlasPlatform.GemmU8S8PackedOperation : MlasPlatform.GemmU8S8Operation; + if (Parameters->BIsPacked) { + GemmU8X8Operation = GemmU8X8Dispatch->PackedOperation; } else { - GemmU8X8Operation = WorkBlock.BIsPacked ? - MlasPlatform.GemmU8U8PackedOperation : MlasPlatform.GemmU8U8Operation; + GemmU8X8Operation = GemmU8X8Dispatch->Operation; } - GemmU8X8Operation(&WorkBlock); -#elif defined(MLAS_SSE2_INTRINSICS) - MlasGemmU8X8Operation(&WorkBlock); -#elif defined(MLAS_NEON64_INTRINSICS) || (defined(MLAS_NEON32_INTRINSICS) && !defined(_MSC_VER)) - if (WorkBlock.BIsPacked) { - MlasGemmU8X8PackedOperation(&WorkBlock); - } else { - MlasGemmU8X8Operation(&WorkBlock); - } -#else - MlasGemmU8X8Operation(&WorkBlock); -#endif + GemmU8X8Operation(Parameters, RangeStartM, RangeCountM, RangeStartN, RangeCountN); } void -MlasGemmU8X8Schedule( - MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock, +MLASCALL +MlasGemm( + const MLAS_GEMM_U8X8_PARAMETERS* Parameters, MLAS_THREADPOOL* ThreadPool ) /*++ Routine Description: - This routine schedules the quantized integer matrix/matrix multiply - operation (QGEMM) across one or more threads. + This routine implements the quantized integer matrix/matrix multiply + operation (QGEMM). Arguments: - WorkBlock - Supplies the structure containing the GEMM parameters. + Parameters - Supplies the structure containing the GEMM parameters. ThreadPool - Supplies the thread pool object to use, else nullptr if the base library threading support should be used. @@ -2231,9 +2760,9 @@ Return Value: --*/ { - const size_t M = WorkBlock->M; - const size_t N = WorkBlock->N; - const size_t K = WorkBlock->K; + const size_t M = Parameters->M; + const size_t N = Parameters->N; + const size_t K = Parameters->K; // // Compute the number of target threads given the complexity of the SGEMM @@ -2263,6 +2792,10 @@ Return Value: // works okay for operations involving skinny matrices. // + MLAS_GEMM_U8X8_WORK_BLOCK WorkBlock; + + WorkBlock.Parameters = Parameters; + if (N > M) { const size_t BlockedN = (N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) / @@ -2272,8 +2805,8 @@ Return Value: TargetThreadCount = int32_t(BlockedN); } - WorkBlock->ThreadCountM = 1; - WorkBlock->ThreadCountN = TargetThreadCount; + WorkBlock.ThreadCountM = 1; + WorkBlock.ThreadCountN = TargetThreadCount; } else { @@ -2281,196 +2814,11 @@ Return Value: TargetThreadCount = int32_t(M); } - WorkBlock->ThreadCountM = TargetThreadCount; - WorkBlock->ThreadCountN = 1; + WorkBlock.ThreadCountM = TargetThreadCount; + WorkBlock.ThreadCountN = 1; } - MlasExecuteThreaded(MlasGemmU8X8Threaded, WorkBlock, TargetThreadCount, ThreadPool); -} - -void -MLASCALL -MlasGemm( - size_t M, - size_t N, - size_t K, - const uint8_t* A, - size_t lda, - uint8_t offa, - const uint8_t* B, - size_t ldb, - uint8_t offb, - bool BIsSigned, - int32_t* C, - size_t ldc, - MLAS_THREADPOOL* ThreadPool, - const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor - ) -/*++ - -Routine Description: - - This routine implements the quantized integer matrix/matrix multiply - operation (QGEMM). - -Arguments: - - M - Supplies the number of rows of matrix A and matrix C. - - N - Supplies the number of columns of matrix B and matrix C. - - K - Supplies the number of columns of matrix A and the number of rows of - matrix B. - - A - Supplies the address of matrix A. - - lda - Supplies the first dimension of matrix A. - - offa - Supplies the zero point offset of matrix A. - - B - Supplies the address of matrix B. - - ldb - Supplies the first dimension of matrix B. - - offb - Supplies the zero point offset of matrix B. - - BIsSigned - Supplies true if matrix B is signed data, else false if matrix - B is unsigned data. - - C - Supplies the address of matrix C. - - ldc - Supplies the first dimension of matrix C. - - ThreadPool - Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - - OutputProcessor - Post Processor on C. - -Return Value: - - None. - ---*/ -{ - MLAS_GEMM_U8X8_WORK_BLOCK WorkBlock; - - // - // Capture the GEMM parameters to the work block. - // - - memset(&WorkBlock, 0, sizeof(MLAS_GEMM_U8X8_WORK_BLOCK)); - - WorkBlock.M = M; - WorkBlock.N = N; - WorkBlock.K = K; - WorkBlock.A = A; - WorkBlock.lda = lda; - WorkBlock.B = B; - WorkBlock.ldb = ldb; - WorkBlock.C = C; - WorkBlock.ldc = ldc; - WorkBlock.OutputProcessor = OutputProcessor; - WorkBlock.offa = offa; - WorkBlock.offb = offb; - WorkBlock.BIsSigned = BIsSigned; - - // - // Schedule the operation across a set of worker threads. - // - - MlasGemmU8X8Schedule(&WorkBlock, ThreadPool); -} - -#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 - -void -MLASCALL -MlasGemm( - size_t M, - size_t N, - size_t K, - const uint8_t* A, - size_t lda, - uint8_t offa, - const void* PackedB, - uint8_t offb, - bool BIsSigned, - int32_t* C, - size_t ldc, - MLAS_THREADPOOL* ThreadPool, - const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor - ) -/*++ - -Routine Description: - - This routine implements the quantized integer matrix/matrix multiply - operation (QGEMM). - -Arguments: - - M - Supplies the number of rows of matrix A and matrix C. - - N - Supplies the number of columns of matrix B and matrix C. - - K - Supplies the number of columns of matrix A and the number of rows of - matrix B. - - A - Supplies the address of matrix A. - - lda - Supplies the first dimension of matrix A. - - offa - Supplies the zero point offset of matrix A. - - PackedB - Supplies the address of packed matrix B. - - offb - Supplies the zero point offset of matrix B. - - BIsSigned - Supplies true if matrix B is signed data, else false if matrix - B is unsigned data. - - C - Supplies the address of matrix C. - - ldc - Supplies the first dimension of matrix C. - - ThreadPool - Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - - OutputProcessor - Post Processor on C - -Return Value: - - None. - ---*/ -{ - MLAS_GEMM_U8X8_WORK_BLOCK WorkBlock; - - // - // Capture the GEMM parameters to the work block. - // - - memset(&WorkBlock, 0, sizeof(MLAS_GEMM_U8X8_WORK_BLOCK)); - - WorkBlock.M = M; - WorkBlock.N = N; - WorkBlock.K = K; - WorkBlock.A = A; - WorkBlock.lda = lda; - WorkBlock.B = PackedB; - WorkBlock.C = (int32_t*)C; - WorkBlock.ldc = ldc; - WorkBlock.OutputProcessor = OutputProcessor, - WorkBlock.offa = offa; - WorkBlock.offb = offb; - WorkBlock.BIsPacked = true; - WorkBlock.BIsSigned = BIsSigned; - - // - // Schedule the operation across a set of worker threads. - // - - MlasGemmU8X8Schedule(&WorkBlock, ThreadPool); + MlasExecuteThreaded(MlasGemmU8X8Threaded, &WorkBlock, TargetThreadCount, ThreadPool); } size_t @@ -2498,34 +2846,23 @@ Arguments: Return Value: - Returns the number of bytes required to pack the matrix. + Returns the number of bytes required to pack the matrix, else zero if the + current implementation does not support packing. --*/ { // - // Retrieve the packing parameters based on the packed operation function. + // Retrieve the packing parameters. // - size_t PackedK; + const auto* GemmU8X8Dispatch = MlasGemmU8X8GetDispatch(BIsSigned); -#if defined(MLAS_TARGET_AMD64) - PMLAS_GEMM_U8X8_OPERATION GemmU8X8Operation = BIsSigned ? - MlasPlatform.GemmU8S8PackedOperation : MlasPlatform.GemmU8U8PackedOperation; + size_t PackedK = GemmU8X8Dispatch->PackedK; + size_t PackedStrideK = GemmU8X8Dispatch->PackedStrideK; - if (GemmU8X8Operation == &MlasGemmU8X8PackedOperation) { - PackedK = MLAS_GEMM_U8S8_KERNEL_AVX2::PackedK; - } else if (GemmU8X8Operation == &MlasGemmU8X8PackedOperation) { - PackedK = MLAS_GEMM_U8U8_KERNEL_AVX2::PackedK; - } else { + if (PackedStrideK == 0) { return 0; } -#elif defined(MLAS_NEON_INTRINSICS) - MLAS_UNREFERENCED_PARAMETER(BIsSigned); - - PackedK = MLAS_GEMM_U8X8_KERNEL_NEON::PackedK; -#else -#error Unknown architecture. -#endif // // Compute the number of bytes required to hold the packed buffer. @@ -2583,35 +2920,13 @@ Return Value: --*/ { // - // Retrieve the packing parameters based on the packed operation function. + // Retrieve the packing parameters. // - size_t PackedK; - size_t StrideK; + const auto* GemmU8X8Dispatch = MlasGemmU8X8GetDispatch(BIsSigned); -#if defined(MLAS_TARGET_AMD64) - PMLAS_GEMM_U8X8_OPERATION GemmU8X8Operation = BIsSigned ? - MlasPlatform.GemmU8S8PackedOperation : MlasPlatform.GemmU8U8PackedOperation; - - if (GemmU8X8Operation == &MlasGemmU8X8PackedOperation) { - PackedK = MLAS_GEMM_U8S8_KERNEL_AVX2::PackedK; - StrideK = MLAS_GEMM_U8S8_KERNEL_AVX2::PackedStrides.K; - } else if (GemmU8X8Operation == &MlasGemmU8X8PackedOperation) { - PackedK = MLAS_GEMM_U8U8_KERNEL_AVX2::PackedK; - StrideK = MLAS_GEMM_U8U8_KERNEL_AVX2::PackedStrides.K; - } else { -#ifdef MLAS_NO_EXCEPTION - abort(); -#else - throw std::runtime_error("packing unavailable"); -#endif - } -#elif defined(MLAS_NEON_INTRINSICS) - PackedK = MLAS_GEMM_U8X8_KERNEL_NEON::PackedK; - StrideK = MLAS_GEMM_U8X8_KERNEL_NEON::PackedStrides.K; -#else -#error Unknown architecture. -#endif + size_t PackedK = GemmU8X8Dispatch->PackedK; + size_t PackedStrideK = GemmU8X8Dispatch->PackedStrideK; // // Reserve and initialize storage for the column sum buffer to hold the sums @@ -2633,7 +2948,7 @@ Return Value: for (size_t k = 0; k < K; k += CountK) { - CountK = std::min(K - k, StrideK); + CountK = std::min(K - k, PackedStrideK); // // Step through each slice of matrix B along the N dimension. @@ -2650,17 +2965,7 @@ Return Value: CountN = std::min(N - n, BatchedN); -#if defined(MLAS_TARGET_AMD64) - if (GemmU8X8Operation == &MlasGemmU8X8PackedOperation) { - MLAS_GEMM_U8S8_KERNEL_AVX2::CopyPackB(pb, B + n, ldb, CountN, CountK, ColumnSumBuffer, BIsSigned); - } else { - MLAS_GEMM_U8U8_KERNEL_AVX2::CopyPackB(pb, B + n, ldb, CountN, CountK, ColumnSumBuffer, BIsSigned); - } -#elif defined(MLAS_NEON_INTRINSICS) - MLAS_GEMM_U8X8_KERNEL_NEON::CopyPackB(pb, B + n, ldb, CountN, CountK, ColumnSumBuffer, BIsSigned); -#else -#error Unknown architecture. -#endif + GemmU8X8Dispatch->CopyPackBRoutine(pb, B + n, ldb, CountN, CountK, ColumnSumBuffer, BIsSigned); // // Accumulate this batch of the column sum buffer into the packed @@ -2678,5 +2983,3 @@ Return Value: B += ldb * CountK; } } - -#endif // MLAS_SUPPORTS_PACKED_GEMM_U8X8 diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S index 6245aa97df..5068d41cea 100644 --- a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S @@ -18,7 +18,6 @@ Abstract: --*/ #include "asmmacro.h" -#include "QgemmU8X8KernelAvx2Common.h" .intel_syntax noprefix @@ -27,7 +26,7 @@ Abstract: // .equ .LGemmU8S8CopyPackAFrame_PaddedMatrixAData, -72 - .equ .LGemmU8S8CopyPackAFrame_mask, -8 + .equ .LGemmU8S8CopyPackAFrame_Padding, -8 .equ .LGemmU8S8CopyPackAFrame_SavedR13, 0 .equ .LGemmU8S8CopyPackAFrame_SavedR12, 8 .equ .LGemmU8S8CopyPackAFrame_SavedRbx, 16 @@ -76,8 +75,7 @@ Return Value: --*/ - .globl C_UNDERSCORE(MlasGemmU8S8CopyPackAAvx2) -C_UNDERSCORE(MlasGemmU8S8CopyPackAAvx2): + FUNCTION_ENTRY MlasGemmU8S8CopyPackAAvx2 push rbp push rbx @@ -101,9 +99,9 @@ C_UNDERSCORE(MlasGemmU8S8CopyPackAAvx2): and eax,15 # isolate unaligned count add eax,3 shr eax,2 # align unaligned count to quad count - mov DWORD PTR .LGemmU8S8CopyPackAFrame_mask[rsp],eax - vpbroadcastd xmm10,DWORD PTR .LGemmU8S8CopyPackAFrame_mask[rsp] - vpcmpgtd xmm10,xmm10,XMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip] + neg rax + lea rbx,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4] + vmovdqu xmm10,XMMWORD PTR [rbx+rax*4] // // Zero initialize the padded stack buffers. @@ -430,8 +428,7 @@ Return Value: --*/ - .globl C_UNDERSCORE(MlasGemmU8S8CopyPackBAvx2) -C_UNDERSCORE(MlasGemmU8S8CopyPackBAvx2): + FUNCTION_ENTRY MlasGemmU8S8CopyPackBAvx2 push rbp push rbx @@ -698,258 +695,4 @@ C_UNDERSCORE(MlasGemmU8S8CopyPackBAvx2): vmovdqu YMMWORD PTR [r9+32],ymm1 jmp .LCopyPackB.ExitRoutine -/*++ - -Macro Description: - - This macro generates code to multiply and accumulator a single row of the - output block. - -Arguments: - - ColumnCount - Supplies the number of columns to produce. - - Vec1Reg - Supplies the high block accumulator register (when ColumnCount - is 16). - - Vec2Reg - Supplies the low block accumulator register. - -Implicit Arguments: - - ymm0 - Supplies the first vector loaded from matrix B. - - ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount - is 16). - - ymm2 - Supplies the broadcast value loaded from matrix A. - - ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001. - ---*/ - - .macro MultiplyAccumulateRow ColumnCount, Vec1Reg, Vec2Reg - - vpmaddubsw ymm3,ymm2,ymm0 - vpmaddwd ymm3,ymm3,ymm12 -.if \ColumnCount\() == 16 - vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3 - vpmaddubsw ymm2,ymm2,ymm1 - vpmaddwd ymm2,ymm2,ymm12 - vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2 -.else - vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm3 -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to multiply and accumulate each row of the output - block. - -Arguments: - - ColumnCount - Supplies the number of columns to produce. - - RowCount - Supplies the number of rows to produce. - - VectorOffset - Supplies the byte offset from matrix B to fetch elements. - - BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. - -Implicit Arguments: - - rbx - Supplies the address into the matrix A data plus 3 rows. - - rcx - Supplies the address into the matrix A data. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the length in bytes of a row from matrix A. - - ymm4-ymm11 - Supplies the block accumulators. - - ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001. - ---*/ - - .macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset - -.if \RowCount\() == 1 - vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()] - vpmaddubsw ymm3,ymm2,YMMWORD PTR [rsi+\VectorOffset\()] - vpmaddwd ymm3,ymm3,ymm12 -.if \ColumnCount\() == 16 - vpaddd ymm4,ymm4,ymm3 - vpmaddubsw ymm2,ymm2,YMMWORD PTR [rsi+\VectorOffset\()+32] - vpmaddwd ymm2,ymm2,ymm12 - vpaddd ymm5,ymm5,ymm2 -.else - vpaddd ymm5,ymm5,ymm3 -.endif -.else - vmovdqu ymm0,YMMWORD PTR [rsi+\VectorOffset\()] - EmitIfCountGE \ColumnCount\(), 16, "vmovdqu ymm1,YMMWORD PTR [rsi+\VectorOffset\()+32]" - EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRow \ColumnCount\(), ymm4, ymm5" - EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRow \ColumnCount\(), ymm6, ymm7" - EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRow \ColumnCount\(), ymm8, ymm9" - EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [rbx+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRow \ColumnCount\(), ymm10, ymm11" -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to execute the block compute macro multiple - times and advancing the matrix A and matrix B data pointers. - -Arguments: - - ColumnCount - Supplies the number of columns to produce. - - RowCount - Supplies the number of rows to produce. - -Implicit Arguments: - - rbx - Supplies the address into the matrix A data plus 3 rows. - - rdi - Supplies the address into the matrix A data. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the length in bytes of a row from matrix A. - - ymm4-ymm11 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockLoop ColumnCount, RowCount - - mov rbp,rcx # reload row length remaining - -.LComputeBlockBy1Loop\@: - ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0 - add rdi,4 # advance matrix A by 1 quad -.if \RowCount\() > 3 - add rbx,4 # advance matrix A plus 3 rows by 1 quad -.endif - add rsi,64 # advance matrix B - sub rbp,4 - jnz .LComputeBlockBy1Loop\@ - - .endm - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (rdi) - Supplies the address of matrix A. The matrix data has been packed - using MlasGemmU8S8CopyPackAAvx2. - - B (rsi) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmU8S8CopyPackBAvx2. - - C (rdx) - Supplies the address of matrix C. - - PackedCountK (rcx) - Supplies the number of packed columns from matrix A - and the number of packed rows from matrix B to iterate over. - - CountM (r8) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (r9) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldc - Supplies the first dimension of matrix C. - - RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the - zero point offset of matrix B. These values are accumulated into every - row of matrix C. - - ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - - DepthValue - Supplies the value CountK multiplied by the zero point offset - of matrix A multplied by the zero point offset of matrix B. This value - is accumulated into every element of matrix C. - - ZeroMode - Supplies true if the output matrix must be zero initialized, - else false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ - - .globl C_UNDERSCORE(MlasGemmU8S8KernelAvx2) -C_UNDERSCORE(MlasGemmU8S8KernelAvx2): - - push rbp - push rbx - push r12 - push r13 - - mov rax,.LGemmU8X8KernelFrame_ldc[rsp] - shl rax,2 # convert ldc to bytes - shl rcx,2 # convert to row length - movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp] - mov r11,rdi - mov r12,.LGemmU8X8KernelFrame_RowSumBuffer[rsp] - mov r13,.LGemmU8X8KernelFrame_ColumnSumBuffer[rsp] - vpcmpeqw ymm12,ymm12,ymm12 # generate 256-bit word vector [0xFFFF] - vpsrlw ymm12,ymm12,15 # generate 256-bit word vector [0x0001] - -// -// Process CountM rows of the matrices. -// - - cmp r8,3 - ja .LProcessCountM4 - je .LProcessCountM3 - cmp r8,1 - je .LProcessCountM1 - -.LProcessCountM2: - ProcessCountM 2 - -.LProcessCountM4: - mov r8d,4 # return 4 rows handled - ProcessCountM 4, Fallthrough - -// -// Restore non-volatile registers and return. -// - -.LExitKernel: - mov eax,r8d - vzeroupper - - pop r13 - pop r12 - pop rbx - pop rbp - ret - -.LProcessCountM1: - ProcessCountM 1 - -.LProcessCountM3: - ProcessCountM 3 - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Common.h b/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Common.h deleted file mode 100644 index eb483c03fe..0000000000 --- a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Common.h +++ /dev/null @@ -1,88 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemmU8S8KernelAvx512Common.h - -Abstract: - - This module contains common kernel macros and structures for the quantized - integer matrix/matrix multiply operation (QGEMM) for the AVX512 core and - AVX512VNNI kernels. - ---*/ - -#include "QgemmU8X8KernelAvx512Common.h" - -/*++ - -Macro Description: - - This macro generates code to execute the block compute macro multiple - times and advancing the matrix A and matrix B data pointers. - -Arguments: - - ColumnCount - Supplies the number of columns to produce. - - RowCount - Supplies the number of rows to produce. - -Implicit Arguments: - - rbx - Supplies the address into the matrix A data plus 3 rows. - - rdi - Supplies the address into the matrix A data. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the length in bytes of a row from matrix A. - - r14 - Supplies the stride in bytes of between packed blocks of matrix B. - - zmm14-zmm31 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockLoop ColumnCount, RowCount - - mov rbp,rcx # reload row length remaining - -.if ((\RowCount\() & 1) == 0) - sub rbp,4*4 - jb .LProcessRemainingBlocks\@ - -.LComputeBlockBy4Loop\@: - ComputeBlock \ColumnCount\(), \RowCount\(), 0*64, 0 - ComputeBlock \ColumnCount\(), \RowCount\(), 1*64, 4 - ComputeBlock \ColumnCount\(), \RowCount\(), 2*64, 8 - ComputeBlock \ColumnCount\(), \RowCount\(), 3*64, 12 - add rdi,4*4 # advance matrix A by 1 quad -.if \RowCount\() > 3 - add rbx,4*4 # advance matrix A plus 3 rows by 1 quad -.endif - add rsi,4*64 # advance matrix B - sub rbp,4*4 # decrement quads remaining - jae .LComputeBlockBy4Loop\@ - -.LProcessRemainingBlocks\@: - add rbp,4*4 # correct for over-subtract above - jz .LComputeBlockLoopExit\@ -.endif - -.LComputeBlockBy1Loop\@: - ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0 - add rdi,4 # advance matrix A by 1 quad -.if \RowCount\() > 3 - add rbx,4 # advance matrix A plus 3 rows by 1 quad -.endif - add rsi,64 # advance matrix B - sub rbp,4 # decrement quads remaining - jnz .LComputeBlockBy1Loop\@ - -.LComputeBlockLoopExit\@: - - .endm diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Core.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Core.S deleted file mode 100644 index c1f3300425..0000000000 --- a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Core.S +++ /dev/null @@ -1,136 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemmU8S8KernelAvx512Core.s - -Abstract: - - This module implements the kernels for the quantized integer matrix/matrix - multiply operation (QGEMM). - - This implementation uses AVX512 core instructions (BW/DQ/VL). - ---*/ - -#include "asmmacro.h" -#include "QgemmU8S8KernelAvx512Common.h" - - .intel_syntax noprefix - - .text - -/*++ - -Macro Description: - - This macro generates code to multiply and accumulator a single cell of the - output block. - -Arguments: - - AccumReg - Supplies the register to accumulate into. - - Mult1Reg - Supplies the first multiplication operand register. - - Mult2Reg - Supplies the second multiplication operand register. - -Implicit Arguments: - - zmm4 - Supplies a scratch register for intermediate results. - - zmm5 - Supplies a 512-bit with the broadcasted word value 0x0001. - ---*/ - - .macro MultiplyAccumulateCell AccumReg, Mult1Reg, Mult2Reg - - vpmaddubsw zmm4,\Mult1Reg\(),\Mult2Reg\() - vpmaddwd zmm4,zmm4,zmm5 - vpaddd \AccumReg\(),\AccumReg\(),zmm4 - - .endm - -/*++ - -Macro Description: - - This macro generates code to multiply and accumulate each row of the output - block. - -Arguments: - - ColumnCount - Supplies the number of columns to produce. - - RowCount - Supplies the number of rows to produce. - - VectorOffset - Supplies the byte offset from matrix B to fetch elements. - - BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. - -Implicit Arguments: - - rbx - Supplies the address into the matrix A data plus 3 rows. - - rdi - Supplies the address into the matrix A data. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the length in bytes of a row from matrix A. - - r14 - Supplies the stride in bytes of between packed blocks of matrix B. - - zmm14-zmm31 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset - -.if \ColumnCount\() >= 48 - vmovdqu32 zmm0,ZMMWORD PTR [rsi+\VectorOffset\()] - vmovdqu32 zmm1,ZMMWORD PTR [rsi+r14+\VectorOffset\()] - vmovdqu32 zmm2,ZMMWORD PTR [rsi+r14*2+\VectorOffset\()] -.elseif \ColumnCount\() >= 32 - vmovdqu32 zmm1,ZMMWORD PTR [rsi+\VectorOffset\()] - vmovdqu32 zmm2,ZMMWORD PTR [rsi+r14+\VectorOffset\()] -.else - vmovdqu32 zmm2,ZMMWORD PTR [rsi+\VectorOffset\()] -.endif - EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm3,DWORD PTR [rdi+\BroadcastOffset\()]" - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm26,zmm3,zmm0" - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm20,zmm3,zmm1" - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm14,zmm3,zmm2" - EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx+\BroadcastOffset\()]" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm27,zmm3,zmm0" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm21,zmm3,zmm1" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm15,zmm3,zmm2" - EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm28,zmm3,zmm0" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm22,zmm3,zmm1" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm16,zmm3,zmm2" - EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm3,DWORD PTR [rbx+\BroadcastOffset\()]" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm29,zmm3,zmm0" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm23,zmm3,zmm1" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm17,zmm3,zmm2" - EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx+\BroadcastOffset\()]" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm30,zmm3,zmm0" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm24,zmm3,zmm1" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm18,zmm3,zmm2" - EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm31,zmm3,zmm0" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm25,zmm3,zmm1" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm19,zmm3,zmm2" - - .endm - -// -// Generate the GEMM kernel. -// - -GemmU8X8KernelAvx512Function U8S8, Avx512Core - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Vnni.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Vnni.S deleted file mode 100644 index bf806d9265..0000000000 --- a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Vnni.S +++ /dev/null @@ -1,106 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemmU8S8KernelAvx512Vnni.s - -Abstract: - - This module implements the kernels for the quantized integer matrix/matrix - multiply operation (QGEMM). - - This implementation uses AVX512VNNI instructions. - ---*/ - -#include "asmmacro.h" -#include "QgemmU8S8KernelAvx512Common.h" -#include "AssembleAvx512Vnni.h" - - .intel_syntax noprefix - - .text - -/*++ - -Macro Description: - - This macro generates code to multiply and accumulate each row of the output - block. - -Arguments: - - ColumnCount - Supplies the number of columns to produce. - - RowCount - Supplies the number of rows to produce. - - VectorOffset - Supplies the byte offset from matrix B to fetch elements. - - BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. - -Implicit Arguments: - - rbx - Supplies the address into the matrix A data plus 3 rows. - - rdi - Supplies the address into the matrix A data. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the length in bytes of a row from matrix A. - - r14 - Supplies the stride in bytes of between packed blocks of matrix B. - - zmm14-zmm31 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset - -.if \ColumnCount\() >= 48 - vmovdqu32 zmm0,ZMMWORD PTR [rsi+\VectorOffset\()] - vmovdqu32 zmm1,ZMMWORD PTR [rsi+r14+\VectorOffset\()] - vmovdqu32 zmm2,ZMMWORD PTR [rsi+r14*2+\VectorOffset\()] -.elseif \ColumnCount\() >= 32 - vmovdqu32 zmm1,ZMMWORD PTR [rsi+\VectorOffset\()] - vmovdqu32 zmm2,ZMMWORD PTR [rsi+r14+\VectorOffset\()] -.else - vmovdqu32 zmm2,ZMMWORD PTR [rsi+\VectorOffset\()] -.endif - EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm3,DWORD PTR [rdi+\BroadcastOffset\()]" - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm26,zmm3,zmm0" - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm20,zmm3,zmm1" - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm14,zmm3,zmm2" - EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx+\BroadcastOffset\()]" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm27,zmm3,zmm0" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm21,zmm3,zmm1" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm15,zmm3,zmm2" - EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm28,zmm3,zmm0" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm22,zmm3,zmm1" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm16,zmm3,zmm2" - EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm3,DWORD PTR [rbx+\BroadcastOffset\()]" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm29,zmm3,zmm0" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm23,zmm3,zmm1" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm17,zmm3,zmm2" - EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx+\BroadcastOffset\()]" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm30,zmm3,zmm0" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm24,zmm3,zmm1" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm18,zmm3,zmm2" - EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm31,zmm3,zmm0" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm25,zmm3,zmm1" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm19,zmm3,zmm2" - - .endm - -// -// Generate the GEMM kernel. -// - -GemmU8X8KernelAvx512Function U8S8, Avx512Vnni - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvxVnni.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvxVnni.S deleted file mode 100644 index da819a865e..0000000000 --- a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvxVnni.S +++ /dev/null @@ -1,268 +0,0 @@ -/*++ - -Copyright (c) 2020 Intel Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemmU8S8KernelAvxVnni.s - -Abstract: - - This module implements the kernels for the quantized integer matrix/matrix - multiply operation (QGEMM). - - This implementation uses AVXVNNI instructions. - ---*/ - -#include "asmmacro.h" -#include "QgemmU8X8KernelAvx2Common.h" -#include "AssembleAvxVnni.h" - - .intel_syntax noprefix - -/*++ -Macro Description: - - This macro generates code to multiply and accumulator a single row of the - output block. - -Arguments: - - ColumnCount - Supplies the number of columns to produce. - - Vec1Reg - Supplies the high block accumulator register (when ColumnCount - is 16). - - Vec2Reg - Supplies the low block accumulator register. - -Implicit Arguments: - - ymm0 - Supplies the first vector loaded from matrix B. - - ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount - is 16). - - ymm2 - Supplies the broadcast value loaded from matrix A. - ---*/ - - .macro MultiplyAccumulateRow ColumnCount, Vec1Reg, Vec2Reg - -.if \ColumnCount\() == 16 - VpdpbusdsYmmYmmYmm \Vec1Reg\(),ymm2,ymm0 - VpdpbusdsYmmYmmYmm \Vec2Reg\(),ymm2,ymm1 -.else - VpdpbusdsYmmYmmYmm \Vec2Reg\(),ymm2,ymm0 -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to multiply and accumulate each row of the output - block. - -Arguments: - - ColumnCount - Supplies the number of columns to produce. - - RowCount - Supplies the number of rows to produce. - - VectorOffset - Supplies the byte offset from matrix B to fetch elements. - - BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. - -Implicit Arguments: - - rbx - Supplies the address into the matrix A data plus 3 rows. - - rcx - Supplies the address into the matrix A data. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the length in bytes of a row from matrix A. - - ymm4-ymm15 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset - - vmovdqu ymm0,YMMWORD PTR [rsi+\VectorOffset\()] - EmitIfCountGE \ColumnCount\(), 16, "vmovdqu ymm1,YMMWORD PTR [rsi+\VectorOffset\()+32]" - EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRow \ColumnCount\(), ymm4, ymm5" - EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRow \ColumnCount\(), ymm6, ymm7" - EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRow \ColumnCount\(), ymm8, ymm9" - EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [rbx+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRow \ColumnCount\(), ymm10, ymm11" - EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm2,DWORD PTR [rbx+rcx+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRow \ColumnCount\(), ymm12, ymm13" - EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm2,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRow \ColumnCount\(), ymm14, ymm15" - - .endm - -/*++ - -Macro Description: - - This macro generates code to execute the block compute macro multiple - times and advancing the matrix A and matrix B data pointers. - -Arguments: - - ColumnCount - Supplies the number of columns to produce. - - RowCount - Supplies the number of rows to produce. - -Implicit Arguments: - - rbx - Supplies the address into the matrix A data plus 3 rows. - - rdi - Supplies the address into the matrix A data. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the length in bytes of a row from matrix A. - - ymm4-ymm15 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockLoop ColumnCount, RowCount - - mov rbp,rcx # reload row length remaining - -.LComputeBlockBy1Loop\@: - ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0 - add rdi,4 # advance matrix A by 1 quad -.if \RowCount\() > 3 - add rbx,4 # advance matrix A plus 3 rows by 1 quad -.endif - add rsi,64 # advance matrix B - sub rbp,4 - jnz .LComputeBlockBy1Loop\@ - - .endm - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (rdi) - Supplies the address of matrix A. The matrix data has been packed - using MlasGemmU8S8CopyPackAAvx2. - - B (rsi) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmU8S8CopyPackBAvx2. - - C (rdx) - Supplies the address of matrix C. - - PackedCountK (rcx) - Supplies the number of packed columns from matrix A - and the number of packed rows from matrix B to iterate over. - - CountM (r8) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (r9) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldc - Supplies the first dimension of matrix C. - - RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the - zero point offset of matrix B. These values are accumulated into every - row of matrix C. - - ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - - DepthValue - Supplies the value CountK multiplied by the zero point offset - of matrix A multplied by the zero point offset of matrix B. This value - is accumulated into every element of matrix C. - - ZeroMode - Supplies true if the output matrix must be zero initialized, - else false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ - - .globl C_UNDERSCORE(MlasGemmU8S8KernelAvxVnni) -C_UNDERSCORE(MlasGemmU8S8KernelAvxVnni): - - push rbp - push rbx - push r12 - push r13 - - mov rax,.LGemmU8X8KernelFrame_ldc[rsp] - shl rax,2 # convert ldc to bytes - shl rcx,2 # convert to row length - movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp] - mov r11,rdi - mov r12,.LGemmU8X8KernelFrame_RowSumBuffer[rsp] - mov r13,.LGemmU8X8KernelFrame_ColumnSumBuffer[rsp] - -// -// Process CountM rows of the matrices. -// - - cmp r8,5 - ja .LProcessCountM6 - je .LProcessCountM5 - cmp r8,3 - ja .LProcessCountM4 - je .LProcessCountM3 - cmp r8,1 - je .LProcessCountM1 - -.LProcessCountM2: - ProcessCountM 2 - -.LProcessCountM4: - ProcessCountM 4 - -.LProcessCountM6: - mov r8d,6 # return 6 rows handled - ProcessCountM 6, Fallthrough - -// -// Restore non-volatile registers and return. -// - -.LExitKernel: - mov eax,r8d - vzeroupper - - pop r13 - pop r12 - pop rbx - pop rbp - ret - -.LProcessCountM1: - ProcessCountM 1 - -.LProcessCountM3: - ProcessCountM 3 - -.LProcessCountM5: - ProcessCountM 5 - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx2.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx2.S index 1952578d72..2bdef12aeb 100644 --- a/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx2.S +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx2.S @@ -18,7 +18,6 @@ Abstract: --*/ #include "asmmacro.h" -#include "QgemmU8X8KernelAvx2Common.h" .intel_syntax noprefix @@ -27,7 +26,7 @@ Abstract: // .equ .LGemmU8U8CopyPackAFrame_PaddedMatrixAData, -72 - .equ .LGemmU8U8CopyPackAFrame_mask, -8 + .equ .LGemmU8U8CopyPackAFrame_Padding, -8 .equ .LGemmU8U8CopyPackAFrame_SavedR13, 0 .equ .LGemmU8U8CopyPackAFrame_SavedR12, 8 .equ .LGemmU8U8CopyPackAFrame_SavedRbx, 16 @@ -79,8 +78,7 @@ Return Value: --*/ - .globl C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2) -C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2): + FUNCTION_ENTRY MlasGemmU8U8CopyPackAAvx2 push rbp push rbx @@ -102,9 +100,9 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2): and eax,15 # isolate unaligned count inc eax shr eax,1 # align unaligned count to pair count - mov DWORD PTR .LGemmU8U8CopyPackAFrame_mask[rsp],eax - vpbroadcastd ymm9,DWORD PTR .LGemmU8U8CopyPackAFrame_mask[rsp] - vpcmpgtd ymm9,ymm9,YMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip] + neg rax + lea rbx,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4] + vmovdqu ymm9,YMMWORD PTR [rbx+rax*4] // // Zero initialize the padded stack buffers. @@ -394,8 +392,7 @@ Return Value: --*/ - .globl C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2) -C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2): + FUNCTION_ENTRY MlasGemmU8U8CopyPackBAvx2 push rbp push rbx @@ -599,273 +596,4 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2): vmovdqu YMMWORD PTR [r9+32],ymm1 jmp .LCopyPackB.ExitRoutine -/*++ - -Macro Description: - - This macro generates code to multiply and accumulator a single row of the - output block. - -Arguments: - - ColumnCount - Supplies the number of columns to produce. - - Vec1Reg - Supplies the high block accumulator register (when ColumnCount - is 16). - - Vec2Reg - Supplies the low block accumulator register. - -Implicit Arguments: - - ymm0 - Supplies the first vector loaded from matrix B. - - ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount - is 16). - - ymm2 - Supplies the broadcast value loaded from matrix A. - ---*/ - - .macro MultiplyAccumulateRow ColumnCount, Vec1Reg, Vec2Reg - - vpmaddwd ymm3,ymm2,ymm0 -.if \ColumnCount\() == 16 - vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3 - vpmaddwd ymm2,ymm2,ymm1 - vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2 -.else - vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm3 -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to multiply and accumulate each row of the output - block. - -Arguments: - - ColumnCount - Supplies the number of columns to produce. - - RowCount - Supplies the number of rows to produce. - - VectorOffset - Supplies the byte offset from matrix B to fetch elements. - - BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. - -Implicit Arguments: - - rdi - Supplies the address into the matrix A data. - - rbx - Supplies the address into the matrix A data plus 3 rows. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the length in bytes of a row from matrix A. - - ymm4-ymm15 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset - - vpmovzxbw ymm0,XMMWORD PTR [rsi+\VectorOffset\()] - EmitIfCountGE \ColumnCount\(), 16, "vpmovzxbw ymm1,XMMWORD PTR [rsi+\VectorOffset\()+16]" - EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRow \ColumnCount\(), ymm4, ymm5" - EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRow \ColumnCount\(), ymm6, ymm7" - EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRow \ColumnCount\(), ymm8, ymm9" - EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [rbx+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRow \ColumnCount\(), ymm10, ymm11" - EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm2,DWORD PTR [rbx+rcx+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRow \ColumnCount\(), ymm12, ymm13" - EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm2,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRow \ColumnCount\(), ymm14, ymm15" - - .endm - -/*++ - -Macro Description: - - This macro generates code to execute the block compute macro multiple - times and advancing the matrix A and matrix B data pointers. - -Arguments: - - ColumnCount - Supplies the number of columns to produce. - - RowCount - Supplies the number of rows to produce. - -Implicit Arguments: - - rbx - Supplies the address into the matrix A data plus 3 rows. - - rdi - Supplies the address into the matrix A data. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the length in bytes of a row from matrix A. - - ymm4-ymm15 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockLoop ColumnCount, RowCount - - mov rbp,rcx # reload row length remaining - -.if (\ColumnCount\() == 16) && ((\RowCount\() & 1) == 0) - sub rbp,2*4 - jb .LProcessRemainingBlocks\@ - -.LComputeBlockBy2Loop\@: - ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0 - ComputeBlock \ColumnCount\(), \RowCount\(), 32, 4 - add rdi,2*4 # advance matrix A by 2 pairs -.if \RowCount\() > 3 - add rbx,2*4 # advance matrix A plus 3 rows by 2 pairs -.endif - add rsi,2*32 # advance matrix B - sub rbp,2*4 - jae .LComputeBlockBy2Loop\@ - -.LProcessRemainingBlocks\@: - add rbp,2*4 # correct for over-subtract above - jz .LComputeBlockLoopExit\@ - ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0 - add rsi,32 # advance matrix B -.else -.LComputeBlockBy1Loop\@: - ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0 - add rdi,4 # advance matrix A by 1 pair -.if \RowCount\() > 3 - add rbx,4 # advance matrix A plus 3 rows by 1 pair -.endif - add rsi,32 # advance matrix B - sub rbp,4 - jnz .LComputeBlockBy1Loop\@ -.endif - -.LComputeBlockLoopExit\@: - - .endm - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (rdi) - Supplies the address of matrix A. The matrix data has been packed - using MlasGemmU8U8CopyPackAAvx2. - - B (rsi) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmU8U8CopyPackBAvx2. - - C (rdx) - Supplies the address of matrix C. - - PackedCountK (rcx) - Supplies the number of packed columns from matrix A - and the number of packed rows from matrix B to iterate over. - - CountM (r8) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (r9) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldc - Supplies the first dimension of matrix C. - - RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the - zero point offset of matrix B. These values are accumulated into every - row of matrix C. - - ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - - DepthValue - Supplies the value CountK multiplied by the zero point offset - of matrix A multplied by the zero point offset of matrix B. This value - is accumulated into every element of matrix C. - - ZeroMode - Supplies true if the output matrix must be zero initialized, - else false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ - - .globl C_UNDERSCORE(MlasGemmU8U8KernelAvx2) -C_UNDERSCORE(MlasGemmU8U8KernelAvx2): - - push rbp - push rbx - push r12 - push r13 - - mov rax,.LGemmU8X8KernelFrame_ldc[rsp] - shl rax,2 # convert ldc to bytes - shl rcx,2 # convert to row length - movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp] - mov r11,rdi - mov r12,.LGemmU8X8KernelFrame_RowSumBuffer[rsp] - mov r13,.LGemmU8X8KernelFrame_ColumnSumBuffer[rsp] - -// -// Process CountM rows of the matrices. -// - - cmp r8,5 - ja .LProcessCountM6 - je .LProcessCountM5 - cmp r8,3 - ja .LProcessCountM4 - je .LProcessCountM3 - cmp r8,1 - je .LProcessCountM1 - -.LProcessCountM2: - ProcessCountM 2 - -.LProcessCountM4: - ProcessCountM 4 - -.LProcessCountM6: - mov r8d,6 # return 6 rows handled - ProcessCountM 6, Fallthrough - -// -// Restore non-volatile registers and return. -// - -.LExitKernel: - mov eax,r8d - vzeroupper - - pop r13 - pop r12 - pop rbx - pop rbp - ret - -.LProcessCountM1: - ProcessCountM 1 - -.LProcessCountM3: - ProcessCountM 3 - -.LProcessCountM5: - ProcessCountM 5 - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Core.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Core.S deleted file mode 100644 index 8fca218e3b..0000000000 --- a/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Core.S +++ /dev/null @@ -1,178 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemmU8U8KernelAvx512Core.s - -Abstract: - - This module implements the kernels for the quantized integer matrix/matrix - multiply operation (QGEMM). - - This implementation uses AVX512 core instructions (BW/DQ/VL). - ---*/ - -#include "asmmacro.h" -#include "QgemmU8X8KernelAvx512Common.h" - - .intel_syntax noprefix - - .text - -/*++ - -Macro Description: - - This macro generates code to multiply and accumulator a single cell of the - output block. - -Arguments: - - AccumReg - Supplies the register to accumulate into. - - Mult1Reg - Supplies the first multiplication operand register. - - Mult2Reg - Supplies the second multiplication operand register. - -Implicit Arguments: - - zmm4 - Supplies a scratch register for intermediate results. - ---*/ - - .macro MultiplyAccumulateCell AccumReg, Mult1Reg, Mult2Reg - - vpmaddwd zmm4,\Mult1Reg\(),\Mult2Reg\() - vpaddd \AccumReg\(),\AccumReg\(),zmm4 - - .endm - -/*++ - -Macro Description: - - This macro generates code to multiply and accumulate each row of the output - block. - -Arguments: - - ColumnCount - Supplies the number of columns to produce. - - RowCount - Supplies the number of rows to produce. - - VectorOffset - Supplies the byte offset from matrix B to fetch elements. - - BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. - -Implicit Arguments: - - rbx - Supplies the address into the matrix A data plus 3 rows. - - rdi - Supplies the address into the matrix A data. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the length in bytes of a row from matrix A. - - r14 - Supplies the stride in bytes of between packed blocks of matrix B. - - zmm14-zmm31 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset - -.if \ColumnCount\() >= 48 - vpmovzxbw zmm0,YMMWORD PTR [rsi+\VectorOffset\()] - vpmovzxbw zmm1,YMMWORD PTR [rsi+r14+\VectorOffset\()] - vpmovzxbw zmm2,YMMWORD PTR [rsi+r14*2+\VectorOffset\()] -.elseif \ColumnCount\() >= 32 - vpmovzxbw zmm1,YMMWORD PTR [rsi+\VectorOffset\()] - vpmovzxbw zmm2,YMMWORD PTR [rsi+r14+\VectorOffset\()] -.else - vpmovzxbw zmm2,YMMWORD PTR [rsi+\VectorOffset\()] -.endif - EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm3,DWORD PTR [rdi+\BroadcastOffset\()]" - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm26,zmm3,zmm0" - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm20,zmm3,zmm1" - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm14,zmm3,zmm2" - EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx+\BroadcastOffset\()]" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm27,zmm3,zmm0" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm21,zmm3,zmm1" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm15,zmm3,zmm2" - EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm28,zmm3,zmm0" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm22,zmm3,zmm1" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm16,zmm3,zmm2" - EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm3,DWORD PTR [rbx+\BroadcastOffset\()]" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm29,zmm3,zmm0" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm23,zmm3,zmm1" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm17,zmm3,zmm2" - EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx+\BroadcastOffset\()]" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm30,zmm3,zmm0" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm24,zmm3,zmm1" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm18,zmm3,zmm2" - EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm31,zmm3,zmm0" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm25,zmm3,zmm1" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm19,zmm3,zmm2" - - .endm - -/*++ - - Macro Description: - - This macro generates code to execute the block compute macro multiple - times and advancing the matrix A and matrix B data pointers. - -Arguments: - - ColumnCount - Supplies the number of columns to produce. - - RowCount - Supplies the number of rows to produce. - -Implicit Arguments: - - rbx - Supplies the address into the matrix A data plus 3 rows. - - rdi - Supplies the address into the matrix A data. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the length in bytes of a row from matrix A. - - r14 - Supplies the stride in bytes of between packed blocks of matrix B. - - zmm14-zmm31 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockLoop ColumnCount, RowCount - - mov rbp,rcx # reload row length remaining - -.LComputeBlockBy1Loop\@: - ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0 - add rdi,4 # advance matrix A by 1 pair -.if \RowCount\() > 3 - add rbx,4 # advance matrix A plus 3 rows by 1 pair -.endif - add rsi,32 # advance matrix B - sub rbp,4 - jnz .LComputeBlockBy1Loop\@ - - .endm - -// -// Generate the GEMM kernel. -// - -GemmU8X8KernelAvx512Function U8U8, Avx512Core - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S new file mode 100644 index 0000000000..b9caba221c --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S @@ -0,0 +1,868 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + QgemmU8X8KernelAvx2.s + +Abstract: + + This module implements the kernels for the quantized integer matrix/matrix + multiply operation (QGEMM). + + This implementation uses AVX2 and AVX VNNI instructions. + +--*/ + +#include "asmmacro.h" +#include "AssembleAvxVnni.h" + + .intel_syntax noprefix + +// +// Stack frame layout for the U8X8 kernel. +// + + .equ .LGemmU8X8KernelFrame_type, -8 + .equ .LGemmU8X8KernelFrame_SavedR13, 0 + .equ .LGemmU8X8KernelFrame_SavedR12, 8 + .equ .LGemmU8X8KernelFrame_SavedRbx, 16 + .equ .LGemmU8X8KernelFrame_SavedRbp, 24 + .equ .LGemmU8X8KernelFrame_ReturnAddress, 32 + .equ .LGemmU8X8KernelFrame_ldc, 40 + .equ .LGemmU8X8KernelFrame_RowSumBuffer, 48 + .equ .LGemmU8X8KernelFrame_ColumnSumBuffer, 56 + .equ .LGemmU8X8KernelFrame_ZeroPointB, 64 + .equ .LGemmU8X8KernelFrame_ZeroMode, 72 + +/*++ + +Macro Description: + + This macro generates code to multiply and accumulator a single row of the + output block. + +Arguments: + + ColumnCount - Supplies the number of columns to produce. + + Vec1Reg - Supplies the high block accumulator register (when ColumnCount + is 16). + + Vec2Reg - Supplies the low block accumulator register. + +Implicit Arguments: + + ymm0 - Supplies the first vector loaded from matrix B. + + ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount + is 16). + + ymm2 - Supplies the broadcast value loaded from matrix A. + + ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001. + +--*/ + + .macro MultiplyAccumulateRowU8S8Avx2 ColumnCount, Vec1Reg, Vec2Reg + + vpmaddubsw ymm3,ymm2,ymm0 + vpmaddwd ymm3,ymm3,ymm12 +.if \ColumnCount\() == 16 + vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3 + vpmaddubsw ymm2,ymm2,ymm1 + vpmaddwd ymm2,ymm2,ymm12 + vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2 +.else + vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm3 +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to multiply and accumulate each row of the output + block. + +Arguments: + + ColumnCount - Supplies the number of columns to produce. + + RowCount - Supplies the number of rows to produce. + + VectorOffset - Supplies the byte offset from matrix B to fetch elements. + + BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. + +Implicit Arguments: + + rdi - Supplies the address into the matrix A data. + + r8 - Supplies the address into the matrix A data plus 3 rows. + + rsi - Supplies the address into the matrix B data. + + rcx - Supplies the length in bytes of a row from matrix A. + + ymm4-ymm11 - Supplies the block accumulators. + + ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001. + +--*/ + + .macro ComputeBlockU8S8Avx2 ColumnCount, RowCount, VectorOffset, BroadcastOffset + +.if \RowCount\() == 1 + vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()] + vpmaddubsw ymm3,ymm2,YMMWORD PTR [rsi+\VectorOffset\()] + vpmaddwd ymm3,ymm3,ymm12 +.if \ColumnCount\() == 16 + vpaddd ymm4,ymm4,ymm3 + vpmaddubsw ymm2,ymm2,YMMWORD PTR [rsi+\VectorOffset\()+32] + vpmaddwd ymm2,ymm2,ymm12 + vpaddd ymm5,ymm5,ymm2 +.else + vpaddd ymm5,ymm5,ymm3 +.endif +.else + vmovdqu ymm0,YMMWORD PTR [rsi+\VectorOffset\()] + EmitIfCountGE \ColumnCount\(), 16, "vmovdqu ymm1,YMMWORD PTR [rsi+\VectorOffset\()+32]" + EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRowU8S8Avx2 \ColumnCount\(), ymm4, ymm5" + EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRowU8S8Avx2 \ColumnCount\(), ymm6, ymm7" + EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRowU8S8Avx2 \ColumnCount\(), ymm8, ymm9" + EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [r8+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRowU8S8Avx2 \ColumnCount\(), ymm10, ymm11" +.endif + + .endm + +/*++ +Macro Description: + + This macro generates code to multiply and accumulator a single row of the + output block. + +Arguments: + + ColumnCount - Supplies the number of columns to produce. + + Vec1Reg - Supplies the high block accumulator register (when ColumnCount + is 16). + + Vec2Reg - Supplies the low block accumulator register. + +Implicit Arguments: + + ymm0 - Supplies the first vector loaded from matrix B. + + ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount + is 16). + + ymm2 - Supplies the broadcast value loaded from matrix A. + +--*/ + + .macro MultiplyAccumulateRowU8S8AvxVnni ColumnCount, Vec1Reg, Vec2Reg + +.if \ColumnCount\() == 16 + VpdpbusdsYmmYmmYmm \Vec1Reg\(),ymm2,ymm0 + VpdpbusdsYmmYmmYmm \Vec2Reg\(),ymm2,ymm1 +.else + VpdpbusdsYmmYmmYmm \Vec2Reg\(),ymm2,ymm0 +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to multiply and accumulate each row of the output + block. + +Arguments: + + ColumnCount - Supplies the number of columns to produce. + + RowCount - Supplies the number of rows to produce. + + VectorOffset - Supplies the byte offset from matrix B to fetch elements. + + BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. + +Implicit Arguments: + + rdi - Supplies the address into the matrix A data. + + r8 - Supplies the address into the matrix A data plus 3 rows. + + rsi - Supplies the address into the matrix B data. + + rcx - Supplies the length in bytes of a row from matrix A. + + ymm4-ymm15 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlockU8S8AvxVnni ColumnCount, RowCount, VectorOffset, BroadcastOffset + + vmovdqu ymm0,YMMWORD PTR [rsi+\VectorOffset\()] + EmitIfCountGE \ColumnCount\(), 16, "vmovdqu ymm1,YMMWORD PTR [rsi+\VectorOffset\()+32]" + EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRowU8S8AvxVnni \ColumnCount\(), ymm4, ymm5" + EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRowU8S8AvxVnni \ColumnCount\(), ymm6, ymm7" + EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRowU8S8AvxVnni \ColumnCount\(), ymm8, ymm9" + EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [r8+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRowU8S8AvxVnni \ColumnCount\(), ymm10, ymm11" + EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm2,DWORD PTR [r8+rcx+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRowU8S8AvxVnni \ColumnCount\(), ymm12, ymm13" + EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm2,DWORD PTR [r8+rcx*2+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRowU8S8AvxVnni \ColumnCount\(), ymm14, ymm15" + + .endm + +/*++ + +Macro Description: + + This macro generates code to execute the block compute macro multiple times + and advancing the matrix A and matrix B data pointers. + +Arguments: + + Isa - Supplies the instruction set architecture string. + + ColumnCount - Supplies the number of columns to produce. + + RowCount - Supplies the number of rows to produce. + +Implicit Arguments: + + r8 - Supplies the address into the matrix A data plus 3 rows. + + rdi - Supplies the address into the matrix A data. + + rsi - Supplies the address into the matrix B data. + + rcx - Supplies the length in bytes of a row from matrix A. + + ymm4-ymm11 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlockLoopU8S8 Isa, ColumnCount, RowCount + + mov rbp,rcx # reload row length remaining + +.LComputeBlockBy1Loop\@: + ComputeBlockU8S8\Isa\() \ColumnCount\(), \RowCount\(), 0, 0 + add rdi,4 # advance matrix A by 1 quad +.if \RowCount\() > 3 + add r8,4 # advance matrix A plus 3 rows by 1 quad +.endif + add rsi,64 # advance matrix B + sub rbp,4 + jnz .LComputeBlockBy1Loop\@ + + .endm + +/*++ + +Macro Description: + + This macro generates code to multiply and accumulator a single row of the + output block. + +Arguments: + + ColumnCount - Supplies the number of columns to produce. + + Vec1Reg - Supplies the high block accumulator register (when ColumnCount + is 16). + + Vec2Reg - Supplies the low block accumulator register. + +Implicit Arguments: + + ymm0 - Supplies the first vector loaded from matrix B. + + ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount + is 16). + + ymm2 - Supplies the broadcast value loaded from matrix A. + +--*/ + + .macro MultiplyAccumulateRowU8U8Avx2 ColumnCount, Vec1Reg, Vec2Reg + + vpmaddwd ymm3,ymm2,ymm0 +.if \ColumnCount\() == 16 + vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3 + vpmaddwd ymm2,ymm2,ymm1 + vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2 +.else + vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm3 +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to multiply and accumulate each row of the output + block. + +Arguments: + + ColumnCount - Supplies the number of columns to produce. + + RowCount - Supplies the number of rows to produce. + + VectorOffset - Supplies the byte offset from matrix B to fetch elements. + + BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. + +Implicit Arguments: + + rdi - Supplies the address into the matrix A data. + + r8 - Supplies the address into the matrix A data plus 3 rows. + + rsi - Supplies the address into the matrix B data. + + rcx - Supplies the length in bytes of a row from matrix A. + + ymm4-ymm15 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlockU8U8Avx2 ColumnCount, RowCount, VectorOffset, BroadcastOffset + + vpmovzxbw ymm0,XMMWORD PTR [rsi+\VectorOffset\()] + EmitIfCountGE \ColumnCount\(), 16, "vpmovzxbw ymm1,XMMWORD PTR [rsi+\VectorOffset\()+16]" + EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRowU8U8Avx2 \ColumnCount\(), ymm4, ymm5" + EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRowU8U8Avx2 \ColumnCount\(), ymm6, ymm7" + EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRowU8U8Avx2 \ColumnCount\(), ymm8, ymm9" + EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [r8+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRowU8U8Avx2 \ColumnCount\(), ymm10, ymm11" + EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm2,DWORD PTR [r8+rcx+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRowU8U8Avx2 \ColumnCount\(), ymm12, ymm13" + EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm2,DWORD PTR [r8+rcx*2+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRowU8U8Avx2 \ColumnCount\(), ymm14, ymm15" + + .endm + +/*++ + +Macro Description: + + This macro generates code to execute the block compute macro multiple times + and advancing the matrix A and matrix B data pointers. + +Arguments: + + Isa - Supplies the instruction set architecture string. + + ColumnCount - Supplies the number of columns to produce. + + RowCount - Supplies the number of rows to produce. + +Implicit Arguments: + + rdi - Supplies the address into the matrix A data. + + r8 - Supplies the address into the matrix A data plus 3 rows. + + rsi - Supplies the address into the matrix B data. + + rcx - Supplies the length in bytes of a row from matrix A. + + ymm4-ymm15 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlockLoopU8U8 Isa, ColumnCount, RowCount + + mov rbp,rcx # reload row length remaining + +.if (\ColumnCount\() == 16) && ((\RowCount\() & 1) == 0) + sub rbp,2*4 + jb .LProcessRemainingBlocks\@ + +.LComputeBlockBy2Loop\@: + ComputeBlockU8U8\Isa\() \ColumnCount\(), \RowCount\(), 0, 0 + ComputeBlockU8U8\Isa\() \ColumnCount\(), \RowCount\(), 32, 4 + add rdi,2*4 # advance matrix A by 2 pairs +.if \RowCount\() > 3 + add r8,2*4 # advance matrix A plus 3 rows by 2 pairs +.endif + add rsi,2*32 # advance matrix B + sub rbp,2*4 + jae .LComputeBlockBy2Loop\@ + +.LProcessRemainingBlocks\@: + add rbp,2*4 # correct for over-subtract above + jz .LComputeBlockLoopExit\@ + ComputeBlockU8U8\Isa\() \ColumnCount\(), \RowCount\(), 0, 0 + add rsi,32 # advance matrix B +.else +.LComputeBlockBy1Loop\@: + ComputeBlockU8U8\Isa\() \ColumnCount\(), \RowCount\(), 0, 0 + add rdi,4 # advance matrix A by 1 pair +.if \RowCount\() > 3 + add r8,4 # advance matrix A plus 3 rows by 1 pair +.endif + add rsi,32 # advance matrix B + sub rbp,4 + jnz .LComputeBlockBy1Loop\@ +.endif + +.LComputeBlockLoopExit\@: + + .endm + +/*++ + +Macro Description: + + This macro generates code to produce an output block for a set of columns + and rows. + +Arguments: + + ColumnCount - Supplies the number of columns to produce. + + RowCount - Supplies the number of rows to produce. + +Implicit Arguments: + + rax - Supplies the length in bytes of a row from matrix C. + + rdi - Supplies the address into the matrix A data. + + rsi - Supplies the address into the matrix B data. + + rcx - Supplies the length in bytes of a row from matrix A. + + r11 - Supplies the address of the row sum buffer. + + r12 - Supplies the address of the column sum buffer. + + ymm4-ymm15 - Supplies the block accumulators. + +--*/ + + .macro ProduceOutputBlock ColumnCount, RowCount + +// +// Initialize the accumulators with the row and column sums. +// + + EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm5,DWORD PTR [r11]" + EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm7,DWORD PTR [r11+4]" + EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm9,DWORD PTR [r11+8]" + EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm11,DWORD PTR [r11+12]" + EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm13,DWORD PTR [r11+16]" + EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm15,DWORD PTR [r11+20]" +.if \ColumnCount\() == 16 + vmovdqu ymm0,YMMWORD PTR [r12] + vmovdqu ymm1,YMMWORD PTR [r12+32] + add r12,16*4 # advance ColumnSumBuffer by 16 columns +.else + vmovdqu ymm1,YMMWORD PTR [r12] +.endif + test r13,r13 # per column zero points? + jz .LSkipScaleByZeroPointB\@ +.if \ColumnCount\() == 16 + vmovdqu ymm2,YMMWORD PTR [r13] + vmovdqu ymm3,YMMWORD PTR [r13+32] + add r13,16*4 # advance ZeroPointB by 16 columns +.else + vmovdqu ymm3,YMMWORD PTR [r13] +.endif + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpmulld ymm4,ymm5,ymm2" + EmitIfCountGE \RowCount\(), 1, "vpmulld ymm5,ymm5,ymm3" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd ymm4,ymm0,ymm4" + EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm1,ymm5" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpmulld ymm6,ymm7,ymm2" + EmitIfCountGE \RowCount\(), 2, "vpmulld ymm7,ymm7,ymm3" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd ymm6,ymm0,ymm6" + EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm1,ymm7" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpmulld ymm8,ymm9,ymm2" + EmitIfCountGE \RowCount\(), 3, "vpmulld ymm9,ymm9,ymm3" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd ymm8,ymm0,ymm8" + EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm1,ymm9" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpmulld ymm10,ymm11,ymm2" + EmitIfCountGE \RowCount\(), 4, "vpmulld ymm11,ymm11,ymm3" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd ymm10,ymm0,ymm10" + EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm1,ymm11" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpmulld ymm12,ymm13,ymm2" + EmitIfCountGE \RowCount\(), 5, "vpmulld ymm13,ymm13,ymm3" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd ymm12,ymm0,ymm12" + EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm1,ymm13" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpmulld ymm14,ymm15,ymm2" + EmitIfCountGE \RowCount\(), 6, "vpmulld ymm15,ymm15,ymm3" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd ymm14,ymm0,ymm14" + EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm1,ymm15" + jmp .LAccumulatorsInitialized\@ + +.LSkipScaleByZeroPointB\@: + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd ymm4,ymm5,ymm0" + EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm1" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd ymm6,ymm7,ymm0" + EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm1" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd ymm8,ymm9,ymm0" + EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm1" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd ymm10,ymm11,ymm0" + EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm1" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd ymm12,ymm13,ymm0" + EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm1" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd ymm14,ymm15,ymm0" + EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm1" + +.LAccumulatorsInitialized\@: + +// +// Iterate over the length of a matrix A row to produce the output accumulators. +// + +.if \RowCount\() > 3 + lea r8,[rcx*2+rcx] + add r8,rdi # compute matrix A plus 3 rows +.endif + cmp DWORD PTR .LGemmU8X8KernelFrame_type[rsp],0 + jg .LProduceWithU8U8Avx2\@ +.if \RowCount\() <= 4 + jl .LProduceWithU8S8AvxVnni\@ + ComputeBlockLoopU8S8 Avx2, \ColumnCount\(), \RowCount\() + jmp .LExitProduceOutputBlock\@ +.endif + +.LProduceWithU8S8AvxVnni\@: + ComputeBlockLoopU8S8 AvxVnni, \ColumnCount\(), \RowCount\() + jmp .LExitProduceOutputBlock\@ + +.LProduceWithU8U8Avx2\@: + ComputeBlockLoopU8U8 Avx2, \ColumnCount\(), \RowCount\() + +.LExitProduceOutputBlock\@: +.if \RowCount\() > 3 + lea r8,[rax*2+rax] + add r8,rdx # compute matrix C plus 3 rows +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to compute matrix multiplication for a fixed set + of rows. + +Arguments: + + RowCount - Supplies the number of rows to process. + +Implicit Arguments: + + rax - Supplies the length in bytes of a row from matrix C. + + rdi - Supplies the address of matrix A. + + rsi - Supplies the address of matrix B. + + rdx - Supplies the address of matrix C. + + rbx - Supplies the address of matrix A. + + r9 - Supplies the number of columns from matrix B and matrix C to iterate + over. + + rcx - Supplies the length in bytes of a row from matrix A. + + r10b - Supplies the zero mode flag. + + r11 - Supplies the address of the row sum buffer. + + r12 - Supplies the address of the column sum buffer. + +--*/ + + .macro ProcessCountM RowCount + + cmp r9,8 + jbe .LProcessRemainingCountN\@ + +.LProcessNextColumnLoop16xN\@: + ProduceOutputBlock 16, \RowCount\() + sub r9,16 + jb .LOutputMasked16xNBlock\@ + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutput16xNBlock\@ + EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]" + EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx+32]" + EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]" + EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax+32]" + EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]" + EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2+32]" + EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [r8]" + EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [r8+32]" + EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [r8+rax]" + EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [r8+rax+32]" + EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [r8+rax*2]" + EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [r8+rax*2+32]" + +.LSkipAccumulateOutput16xNBlock\@: + EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4" + EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx+32],ymm5" + EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6" + EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax+32],ymm7" + EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8" + EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2+32],ymm9" + EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [r8],ymm10" + EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [r8+32],ymm11" + EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [r8+rax],ymm12" + EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [r8+rax+32],ymm13" + EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [r8+rax*2],ymm14" + EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [r8+rax*2+32],ymm15" + add rdx,16*4 # advance matrix C by 16 columns + mov rdi,rbx # reload matrix A + cmp r9,8 + ja .LProcessNextColumnLoop16xN\@ + test r9,r9 + jnz .LProcessRemainingCountN\@ + +.LExitProcessCountM\@: + mov eax,\RowCount\() + jmp .LExitKernel + +.LProcessRemainingCountN\@: + ProduceOutputBlock 8, \RowCount\() + cmp r9,8 + jb .LOutputMasked8xNBlock\@ + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutput8xNBlock\@ + EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx]" + EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax]" + EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2]" + EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [r8]" + EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [r8+rax]" + EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [r8+rax*2]" + +.LSkipAccumulateOutput8xNBlock\@: + EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm5" + EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm7" + EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm9" + EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [r8],ymm11" + EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [r8+rax],ymm13" + EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [r8+rax*2],ymm15" + jmp .LExitProcessCountM\@ + +.LOutputMasked16xNBlock\@: + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutputMasked16xNBlock\@ + EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]" + EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]" + EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]" + EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [r8]" + EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [r8+rax]" + EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [r8+rax*2]" + +.LSkipAccumulateOutputMasked16xNBlock\@: + EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4" + EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6" + EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8" + EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [r8],ymm10" + EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [r8+rax],ymm12" + EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [r8+rax*2],ymm14" + add rdx,8*4 # advance matrix C by 8 columns +.if \RowCount\() > 3 + add r8,8*4 # advance matrix C plus 3 rows by 8 columns +.endif + add r9,8 # correct for over-subtract above + +.LOutputMasked8xNBlock\@: + neg r9 + lea rdi,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4] + vmovdqu ymm0,YMMWORD PTR [rdi+r9*4] + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutputMasked8xNBlock\@ + EmitIfCountGE \RowCount\(), 1, "vpmaskmovd ymm4,ymm0,YMMWORD PTR [rdx]" + EmitIfCountGE \RowCount\(), 2, "vpmaskmovd ymm6,ymm0,YMMWORD PTR [rdx+rax]" + EmitIfCountGE \RowCount\(), 3, "vpmaskmovd ymm8,ymm0,YMMWORD PTR [rdx+rax*2]" + EmitIfCountGE \RowCount\(), 4, "vpmaskmovd ymm10,ymm0,YMMWORD PTR [r8]" + EmitIfCountGE \RowCount\(), 5, "vpmaskmovd ymm12,ymm0,YMMWORD PTR [r8+rax]" + EmitIfCountGE \RowCount\(), 6, "vpmaskmovd ymm14,ymm0,YMMWORD PTR [r8+rax*2]" + EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm4" + EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm6" + EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm8" + EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm10" + EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm12" + EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm14" + +.LSkipAccumulateOutputMasked8xNBlock\@: + EmitIfCountGE \RowCount\(), 1, "vpmaskmovd YMMWORD PTR [rdx],ymm0,ymm5" + EmitIfCountGE \RowCount\(), 2, "vpmaskmovd YMMWORD PTR [rdx+rax],ymm0,ymm7" + EmitIfCountGE \RowCount\(), 3, "vpmaskmovd YMMWORD PTR [rdx+rax*2],ymm0,ymm9" + EmitIfCountGE \RowCount\(), 4, "vpmaskmovd YMMWORD PTR [r8],ymm0,ymm11" + EmitIfCountGE \RowCount\(), 5, "vpmaskmovd YMMWORD PTR [r8+rax],ymm0,ymm13" + EmitIfCountGE \RowCount\(), 6, "vpmaskmovd YMMWORD PTR [r8+rax*2],ymm0,ymm15" + jmp .LExitProcessCountM\@ + + .endm + +// +// Reduce code size for the various types of kernels by sharing the outer logic +// and switching on the selector codes (using sign bit to discriminate). +// + + FUNCTION_ENTRY MlasGemmU8S8KernelAvxVnni + + mov eax,-1 + jmp C_UNDERSCORE(MlasGemmU8X8KernelAvx2) + + FUNCTION_ENTRY MlasGemmU8U8KernelAvx2 + + mov eax,1 + jmp C_UNDERSCORE(MlasGemmU8X8KernelAvx2) + + FUNCTION_ENTRY MlasGemmU8S8KernelAvx2 + + xor eax,eax + jmp C_UNDERSCORE(MlasGemmU8X8KernelAvx2) + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (rdi) - Supplies the address of matrix A. The matrix data has been packed + using MlasGemmU8X8CopyPackAAvx2. + + B (rsi) - Supplies the address of matrix B. The matrix data has been packed + using MlasGemmU8X8CopyPackBAvx2. + + C (rdx) - Supplies the address of matrix C. + + PackedCountK (rcx) - Supplies the number of packed columns from matrix A + and the number of packed rows from matrix B to iterate over. + + CountM (r8) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (r9) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + ldc - Supplies the first dimension of matrix C. + + RowSumBuffer - Supplies the sum of each row from matrix A. These values have + been pre-scaled by the zero point offset of matrix B if the offset is + per-tensor (ZeroPointB is nullptr). Otherwise, these values must be + scaled by the per-column zero point offsets of matrix B. These values are + accumulated into every row of matrix C. + + ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied + by the zero point offset of matrix A. These values are accumulated into + every column of matrix C. + + ZeroPointB - Optionally supplies the per-column zero point offsets of matrix + B, else nullptr if the matrix B is using per-tensor quantization. + + ZeroMode - Supplies true if the output matrix must be zero initialized, + else false if the output matrix is accumulated into. + +Return Value: + + Returns the number of rows handled. + +--*/ + + FUNCTION_ENTRY MlasGemmU8X8KernelAvx2 + + push rbp + push rbx + push r12 + push r13 + + mov DWORD PTR .LGemmU8X8KernelFrame_type[rsp],eax + mov rbx,rdi + mov rax,.LGemmU8X8KernelFrame_ldc[rsp] + shl rax,2 # convert ldc to bytes + shl rcx,2 # convert to row length + movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp] + mov r11,.LGemmU8X8KernelFrame_RowSumBuffer[rsp] + mov r12,.LGemmU8X8KernelFrame_ColumnSumBuffer[rsp] + mov r13,.LGemmU8X8KernelFrame_ZeroPointB[rsp] + vpcmpeqw ymm12,ymm12,ymm12 # generate 256-bit word vector [0xFFFF] + vpsrlw ymm12,ymm12,15 # generate 256-bit word vector [0x0001] + cmp DWORD PTR .LGemmU8X8KernelFrame_type[rsp],0 + je .LCheckCountM4OrMore # U8S8 AVX2 kernel requires extra registers + +// +// Process CountM rows of the matrices. +// + +.LCheckCountM6OrMore: + cmp r8,5 + ja .LProcessCountM6 + je .LProcessCountM5 + +.LCheckCountM4OrMore: + cmp r8,3 + ja .LProcessCountM4 + je .LProcessCountM3 + cmp r8,1 + je .LProcessCountM1 + +.LProcessCountM2: + ProcessCountM 2 + +.LProcessCountM4: + ProcessCountM 4 + +.LProcessCountM6: + ProcessCountM 6 + +// +// Restore non-volatile registers and return. +// + +.LExitKernel: + vzeroupper + + pop r13 + pop r12 + pop rbx + pop rbp + ret + +.LProcessCountM1: + ProcessCountM 1 + +.LProcessCountM3: + ProcessCountM 3 + +.LProcessCountM5: + ProcessCountM 5 + + .end diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2Common.h b/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2Common.h deleted file mode 100644 index 8a2bf10db8..0000000000 --- a/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2Common.h +++ /dev/null @@ -1,273 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemmU8X8KernelAvx2Common.h - -Abstract: - - This module contains common kernel macros and structures for the quantized - integer matrix/matrix multiply operation (QGEMM) for the AVX2 kernels. - ---*/ - -// -// Stack frame layout for the U8S8 and U8U8 kernels. -// - - .equ .LGemmU8X8KernelFrame_mask, -8 - .equ .LGemmU8X8KernelFrame_SavedR13, 0 - .equ .LGemmU8X8KernelFrame_SavedR12, 8 - .equ .LGemmU8X8KernelFrame_SavedRbx, 16 - .equ .LGemmU8X8KernelFrame_SavedRbp, 24 - .equ .LGemmU8X8KernelFrame_ReturnAddress, 32 - .equ .LGemmU8X8KernelFrame_ldc, 40 - .equ .LGemmU8X8KernelFrame_RowSumBuffer, 48 - .equ .LGemmU8X8KernelFrame_ColumnSumBuffer, 56 - .equ .LGemmU8X8KernelFrame_DepthValue, 64 - .equ .LGemmU8X8KernelFrame_ZeroMode, 72 - -/*++ - -Macro Description: - - This macro generates code to produce an output block for a set of columns - and rows. - -Arguments: - - ColumnCount - Supplies the number of columns to produce. - - RowCount - Supplies the number of rows to produce. - -Implicit Arguments: - - rax - Supplies the length in bytes of a row from matrix C. - - rdi - Supplies the address into the matrix A data. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the length in bytes of a row from matrix A. - - r12 - Supplies the address of the row sum buffer. - - r13 - Supplies the address of the column sum buffer. - - ymm4-ymm15 - Supplies the block accumulators. - ---*/ - - .macro ProduceOutputBlock ColumnCount, RowCount - -// -// Initialize the accumulators with the sum of the global depth value constant, -// the column sums, and the row sums. -// - - vpbroadcastd ymm1,DWORD PTR .LGemmU8X8KernelFrame_DepthValue[rsp] -.if \ColumnCount\() == 16 - vpaddd ymm0,ymm1,YMMWORD PTR [r13] - vpaddd ymm1,ymm1,YMMWORD PTR [r13+32] - add r13,16*4 # advance ColumnSumBuffer by 16 columns -.else - vpaddd ymm1,ymm1,YMMWORD PTR [r13] -.endif - EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm5,DWORD PTR [r12]" - EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm7,DWORD PTR [r12+4]" - EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm9,DWORD PTR [r12+8]" - EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm11,DWORD PTR [r12+12]" - EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm13,DWORD PTR [r12+16]" - EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm15,DWORD PTR [r12+20]" - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd ymm4,ymm5,ymm0" - EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm1" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd ymm6,ymm7,ymm0" - EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm1" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd ymm8,ymm9,ymm0" - EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm1" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd ymm10,ymm11,ymm0" - EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm1" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd ymm12,ymm13,ymm0" - EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm1" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd ymm14,ymm15,ymm0" - EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm1" - -// -// Iterate over the length of a matrix A row to produce the output accumulators. -// - -.if \RowCount\() > 3 - lea rbx,[rcx*2+rcx] - add rbx,rdi # compute matrix A plus 3 rows -.endif - ComputeBlockLoop \ColumnCount\(), \RowCount\() -.if \RowCount\() > 3 - lea rbx,[rax*2+rax] - add rbx,rdx # compute matrix C plus 3 rows -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute matrix multiplication for a fixed set - of rows. - -Arguments: - - RowCount - Supplies the number of rows to process. - - Fallthrough - Supplies a non-blank value if the macro may fall through to - the ExitKernel label. - -Implicit Arguments: - - rax - Supplies the length in bytes of a row from matrix C. - - rdi - Supplies the address of matrix A. - - rsi - Supplies the address of matrix B. - - rdx - Supplies the address of matrix C. - - r11 - Supplies the address of matrix A. - - r9 - Supplies the number of columns from matrix B and matrix C to iterate - over. - - rcx - Supplies the length in bytes of a row from matrix A. - - r10b - Supplies the zero mode flag. - - r12 - Supplies the address of the row sum buffer. - - r13 - Supplies the address of the column sum buffer. - ---*/ - - .macro ProcessCountM RowCount, Fallthrough - - cmp r9,8 - jbe .LProcessRemainingCountN\@ - -.LProcessNextColumnLoop16xN\@: - ProduceOutputBlock 16, \RowCount\() - sub r9,16 - jb .LOutputMasked16xNBlock\@ - test r10b,r10b # ZeroMode? - jnz .LSkipAccumulateOutput16xNBlock\@ - EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx+32]" - EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax+32]" - EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2+32]" - EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [rbx]" - EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [rbx+32]" - EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]" - EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax+32]" - EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]" - EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2+32]" - -.LSkipAccumulateOutput16xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4" - EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx+32],ymm5" - EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6" - EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax+32],ymm7" - EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8" - EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2+32],ymm9" - EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm10" - EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx+32],ymm11" - EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm12" - EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax+32],ymm13" - EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm14" - EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2+32],ymm15" - add rdx,16*4 # advance matrix C by 16 columns - mov rdi,r11 # reload matrix A - cmp r9,8 - ja .LProcessNextColumnLoop16xN\@ - test r9,r9 - jz .LExitKernel - -.LProcessRemainingCountN\@: - ProduceOutputBlock 8, \RowCount\() - cmp r9,8 - jb .LOutputMasked8xNBlock\@ - test r10b,r10b # ZeroMode? - jnz .LSkipAccumulateOutput8xNBlock\@ - EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [rbx]" - EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax]" - EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2]" - -.LSkipAccumulateOutput8xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm5" - EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm7" - EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm9" - EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm11" - EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm13" - EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm15" - jmp .LExitKernel - -.LOutputMasked16xNBlock\@: - test r10b,r10b # ZeroMode? - jnz .LSkipAccumulateOutputMasked16xNBlock\@ - EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [rbx]" - EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]" - EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]" - -.LSkipAccumulateOutputMasked16xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4" - EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6" - EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8" - EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm10" - EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm12" - EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm14" - add rdx,8*4 # advance matrix C by 8 columns -.if \RowCount\() > 3 - add rbx,8*4 # advance matrix C plus 3 rows by 8 columns -.endif - add r9,8 # correct for over-subtract above - -.LOutputMasked8xNBlock\@: - mov DWORD PTR .LGemmU8X8KernelFrame_mask[rsp],r9d - vpbroadcastd ymm0,DWORD PTR .LGemmU8X8KernelFrame_mask[rsp] - vpcmpgtd ymm0,ymm0,YMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip] - test r10b,r10b # ZeroMode? - jnz .LSkipAccumulateOutputMasked8xNBlock\@ - EmitIfCountGE \RowCount\(), 1, "vpmaskmovd ymm4,ymm0,YMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "vpmaskmovd ymm6,ymm0,YMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 3, "vpmaskmovd ymm8,ymm0,YMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 4, "vpmaskmovd ymm10,ymm0,YMMWORD PTR [rbx]" - EmitIfCountGE \RowCount\(), 5, "vpmaskmovd ymm12,ymm0,YMMWORD PTR [rbx+rax]" - EmitIfCountGE \RowCount\(), 6, "vpmaskmovd ymm14,ymm0,YMMWORD PTR [rbx+rax*2]" - EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm4" - EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm6" - EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm8" - EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm10" - EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm12" - EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm14" - -.LSkipAccumulateOutputMasked8xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vpmaskmovd YMMWORD PTR [rdx],ymm0,ymm5" - EmitIfCountGE \RowCount\(), 2, "vpmaskmovd YMMWORD PTR [rdx+rax],ymm0,ymm7" - EmitIfCountGE \RowCount\(), 3, "vpmaskmovd YMMWORD PTR [rdx+rax*2],ymm0,ymm9" - EmitIfCountGE \RowCount\(), 4, "vpmaskmovd YMMWORD PTR [rbx],ymm0,ymm11" - EmitIfCountGE \RowCount\(), 5, "vpmaskmovd YMMWORD PTR [rbx+rax],ymm0,ymm13" - EmitIfCountGE \RowCount\(), 6, "vpmaskmovd YMMWORD PTR [rbx+rax*2],ymm0,ymm15" -.ifb \Fallthrough\() - jmp .LExitKernel -.endif - - .endm diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Common.h b/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Common.h deleted file mode 100644 index 7db55ba661..0000000000 --- a/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Common.h +++ /dev/null @@ -1,403 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemmU8X8KernelAvx512Common.h - -Abstract: - - This module contains common kernel macros and structures for the quantized - integer matrix/matrix multiply operation (QGEMM) for the AVX512 core and - AVX512VNNI kernels. - ---*/ - -// -// Stack frame layout for the U8S8 and U8U8 kernels. -// - - .equ .LGemmU8X8KernelFrame_SavedR14, 0 - .equ .LGemmU8X8KernelFrame_SavedR13, 8 - .equ .LGemmU8X8KernelFrame_SavedR12, 16 - .equ .LGemmU8X8KernelFrame_SavedRbx, 24 - .equ .LGemmU8X8KernelFrame_SavedRbp, 32 - .equ .LGemmU8X8KernelFrame_ReturnAddress, 40 - .equ .LGemmU8X8KernelFrame_ldc, 48 - .equ .LGemmU8X8KernelFrame_RowSumBuffer, 56 - .equ .LGemmU8X8KernelFrame_ColumnSumBuffer, 64 - .equ .LGemmU8X8KernelFrame_DepthValue, 72 - .equ .LGemmU8X8KernelFrame_ZeroMode, 80 - -/*++ - -Macro Description: - - This macro generates code to produce an output block for a set of columns - and rows. - -Arguments: - - ColumnCount - Supplies the number of columns to produce. - - RowCount - Supplies the number of rows to produce. - -Implicit Arguments: - - rax - Supplies the length in bytes of a row from matrix C. - - rdi - Supplies the address into the matrix A data. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the length in bytes of a row from matrix A. - - r12 - Supplies the address of the row sum buffer. - - r13 - Supplies the address of the column sum buffer. - ---*/ - - .macro ProduceOutputBlock ColumnCount, RowCount - -// -// Initialize the accumulators with the sum of the global depth value constant, -// the column sums, and the row sums. -// - - vpbroadcastd zmm3,DWORD PTR .LGemmU8X8KernelFrame_DepthValue[rsp] -.if \ColumnCount\() >= 32 -.if \ColumnCount\() >= 48 - vpaddd zmm2,zmm3,ZMMWORD PTR [r13] - vpaddd zmm1,zmm3,ZMMWORD PTR [r13+64] - vpaddd zmm0,zmm3,ZMMWORD PTR [r13+128] -.else - vpaddd zmm1,zmm3,ZMMWORD PTR [r13] - vpaddd zmm0,zmm3,ZMMWORD PTR [r13+64] -.endif - add_immed r13,\ColumnCount\()*4 # advance ColumnSumBuffer by N columns -.else - vpaddd zmm0,zmm3,ZMMWORD PTR [r13] -.endif - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd zmm14,zmm0,DWORD PTR [r12]{1to16}" - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd zmm20,zmm1,DWORD PTR [r12]{1to16}" - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "vpaddd zmm26,zmm2,DWORD PTR [r12]{1to16}" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd zmm15,zmm0,DWORD PTR [r12+4]{1to16}" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "vpaddd zmm21,zmm1,DWORD PTR [r12+4]{1to16}" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "vpaddd zmm27,zmm2,DWORD PTR [r12+4]{1to16}" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd zmm16,zmm0,DWORD PTR [r12+8]{1to16}" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "vpaddd zmm22,zmm1,DWORD PTR [r12+8]{1to16}" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "vpaddd zmm28,zmm2,DWORD PTR [r12+8]{1to16}" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd zmm17,zmm0,DWORD PTR [r12+12]{1to16}" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "vpaddd zmm23,zmm1,DWORD PTR [r12+12]{1to16}" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "vpaddd zmm29,zmm2,DWORD PTR [r12+12]{1to16}" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd zmm18,zmm0,DWORD PTR [r12+16]{1to16}" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "vpaddd zmm24,zmm1,DWORD PTR [r12+16]{1to16}" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "vpaddd zmm30,zmm2,DWORD PTR [r12+16]{1to16}" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd zmm19,zmm0,DWORD PTR [r12+20]{1to16}" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "vpaddd zmm25,zmm1,DWORD PTR [r12+20]{1to16}" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "vpaddd zmm31,zmm2,DWORD PTR [r12+20]{1to16}" - -// -// Iterate over the length of a matrix A row to produce the output accumulators. -// - -.if \RowCount\() > 3 - lea rbx,[rcx*2+rcx] - add rbx,rdi # compute matrix A plus 3 rows -.endif - ComputeBlockLoop \ColumnCount\(), \RowCount\() -.if \RowCount\() > 3 - lea rbx,[rdx+rax*2] # compute matrix C plus 3 rows - add rbx,rax -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute matrix multiplication for a fixed set - of rows. - -Arguments: - - RowCount - Supplies the number of rows to process. - -Implicit Arguments: - - rax - Supplies the length in bytes of a row from matrix C. - - rdi - Supplies the address of matrix A. - - rsi - Supplies the address of matrix B. - - rdx - Supplies the address of matrix C. - - r11 - Supplies the address of matrix A. - - r9 - Supplies the number of columns from matrix B and matrix C to iterate - over. - - rcx - Supplies the length in bytes of a row from matrix A. - - r10b - Supplies the zero mode flag. - - r12 - Supplies the address of the row sum buffer. - - r13 - Supplies the address of the column sum buffer. - - r14 - Supplies the stride in bytes of between packed blocks of matrix B. - ---*/ - - .macro ProcessCountM RowCount - - cmp r9,32 - ja .LProcessNextColumnLoop48xN\@ - cmp r9,16 - jbe .LProcessRemainingCountN\@ - -.LProcessNextColumnLoop32xN\@: - ProduceOutputBlock 32, \RowCount\() - add rsi,r14 # advance matrix B by packed block stride - -.LOutput32xNBlock\@: - test r10b,r10b # ZeroMode? - jnz .LSkipAccumulateOutput32xNBlock\@ - EmitIfCountGE \RowCount\(), 1, "vpaddd zmm20,zmm20,ZMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "vpaddd zmm21,zmm21,ZMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 3, "vpaddd zmm22,zmm22,ZMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 4, "vpaddd zmm23,zmm23,ZMMWORD PTR [rbx]" - EmitIfCountGE \RowCount\(), 5, "vpaddd zmm24,zmm24,ZMMWORD PTR [rbx+rax]" - EmitIfCountGE \RowCount\(), 6, "vpaddd zmm25,zmm25,ZMMWORD PTR [rbx+rax*2]" - -.LSkipAccumulateOutput32xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx],zmm20" - EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax],zmm21" - EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2],zmm22" - EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx],zmm23" - EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax],zmm24" - EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm25" - add rdx,16*4 # advance matrix C by 16 columns -.if \RowCount\() > 3 - add rbx,16*4 # advance matrix C plus 3 rows by 16 columns -.endif - sub r9,16 - -.LOutput16xNBlock\@: - sub r9,16 - jae .LOutput16xNBlockWithMask\@ - lea rcx,[r9+16] # correct for over-subtract above - mov ebp,1 - shl ebp,cl - dec ebp - kmovw k1,ebp # update mask for remaining columns - xor r9,r9 # no more columns remaining - -.LOutput16xNBlockWithMask\@: - test r10b,r10b # ZeroMode? - jnz .LSkipAccumulateOutput16xNBlockWithMask\@ - EmitIfCountGE \RowCount\(), 1, "vpaddd zmm14{k1},zmm14,ZMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "vpaddd zmm15{k1},zmm15,ZMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 3, "vpaddd zmm16{k1},zmm16,ZMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 4, "vpaddd zmm17{k1},zmm17,ZMMWORD PTR [rbx]" - EmitIfCountGE \RowCount\(), 5, "vpaddd zmm18{k1},zmm18,ZMMWORD PTR [rbx+rax]" - EmitIfCountGE \RowCount\(), 6, "vpaddd zmm19{k1},zmm19,ZMMWORD PTR [rbx+rax*2]" - -.LSkipAccumulateOutput16xNBlockWithMask\@: - EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx]{k1},zmm14" - EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax]{k1},zmm15" - EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2]{k1},zmm16" - EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx]{k1},zmm17" - EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax]{k1},zmm18" - EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2]{k1},zmm19" - add rdx,16*4 # advance matrix C by 16 columns - mov rdi,r11 # reload matrix A - cmp r9,32 - ja .LProcessNextColumnLoop48xN\@ - cmp r9,16 - ja .LProcessNextColumnLoop32xN\@ - test r9,r9 - jz .LExitKernel - -.LProcessRemainingCountN\@: - ProduceOutputBlock 16, \RowCount\() - jmp .LOutput16xNBlock\@ - -.LProcessNextColumnLoop48xN\@: - ProduceOutputBlock 48, \RowCount\() - lea rsi,[rsi+r14*2] # advance matrix B by packed block stride - test r10b,r10b # ZeroMode? - jnz .LSkipAccumulateOutput48xNBlock\@ - EmitIfCountGE \RowCount\(), 1, "vpaddd zmm26,zmm26,ZMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "vpaddd zmm27,zmm27,ZMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 3, "vpaddd zmm28,zmm28,ZMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 4, "vpaddd zmm29,zmm29,ZMMWORD PTR [rbx]" - EmitIfCountGE \RowCount\(), 5, "vpaddd zmm30,zmm30,ZMMWORD PTR [rbx+rax]" - EmitIfCountGE \RowCount\(), 6, "vpaddd zmm31,zmm31,ZMMWORD PTR [rbx+rax*2]" - -.LSkipAccumulateOutput48xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx],zmm26" - EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax],zmm27" - EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2],zmm28" - EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx],zmm29" - EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax],zmm30" - EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm31" - add rdx,16*4 # advance matrix C by 16 columns -.if \RowCount\() > 3 - add rbx,16*4 # advance matrix C plus 3 rows by 16 columns -.endif - sub r9,16 - jmp .LOutput32xNBlock\@ - - .endm - -/*++ - -Macro Description: - - This macro generates the common AVX512 code for the inner kernel to compute - matrix multiplication. - -Arguments: - - Type - Supplies the kernel type string for function tags. - - Isa - Supplies the instruction set architecture string for function tags. - ---*/ - - .macro GemmU8X8KernelAvx512Function Type, Isa - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (rdi) - Supplies the address of matrix A. The matrix data has been packed - using MlasGemmU8X8CopyPackAAvx2. - - B (rsi) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmU8X8CopyPackBAvx2. - - C (rdx) - Supplies the address of matrix C. - - PackedCountK (rcx) - Supplies the number of packed columns from matrix A and - the number of packed rows from matrix B to iterate over. - - CountM (r8) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (r9) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldc - Supplies the first dimension of matrix C. - - RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the - zero point offset of matrix B. These values are accumulated into every - row of matrix C. - - ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - - DepthValue - Supplies the value CountK multiplied by the zero point offset - of matrixA multplied by the zero point offset of matrix B. This value is - accumulated into every element of matrix C. - - ZeroMode - Supplies true if the output matrix must be zero initialized, - else false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ - - .globl C_UNDERSCORE(MlasGemm\Type\()Kernel\Isa\()) -C_UNDERSCORE(MlasGemm\Type\()Kernel\Isa\()): - - push rbp - push rbx - push r12 - push r13 - push r14 - - mov rax,.LGemmU8X8KernelFrame_ldc[rsp] - shl rax,2 # convert ldc to bytes - shl rcx,2 # convert to row length - movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp] - mov r11,rdi - mov r12,.LGemmU8X8KernelFrame_RowSumBuffer[rsp] - mov r13,.LGemmU8X8KernelFrame_ColumnSumBuffer[rsp] - mov ebp,-1 - kmovw k1,ebp # update mask to write all columns -.ifeqs "\Type\()", "U8S8" -.ifeqs "\Isa\()", "Avx512Core" - neg ebp - vpbroadcastw zmm5,ebp # generate 512-bit word vector [0x0001] -.endif - mov r14,rcx - shl r14,4 # compute matrix B packed stride -.else - lea r14,[rcx*8] # compute matrix B packed stride -.endif - -// -// Process CountM rows of the matrices. -// - - cmp r8,5 - ja .LProcessCountM6 - je .LProcessCountM5 - cmp r8,3 - ja .LProcessCountM4 - je .LProcessCountM3 - cmp r8,1 - je .LProcessCountM1 - -.LProcessCountM2: - ProcessCountM 2 - -.LProcessCountM4: - ProcessCountM 4 - -.LProcessCountM6: - mov r8d,6 # return 6 rows handled - ProcessCountM 6 - -// -// Restore non-volatile registers and return. -// - -.LExitKernel: - mov eax,r8d - vzeroupper - - pop r14 - pop r13 - pop r12 - pop rbx - pop rbp - ret - -.LProcessCountM1: - ProcessCountM 1 - -.LProcessCountM3: - ProcessCountM 3 - -.LProcessCountM5: - ProcessCountM 5 - - .endm diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Core.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Core.S new file mode 100644 index 0000000000..156091221e --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Core.S @@ -0,0 +1,709 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + QgemmU8X8KernelAvx512Core.s + +Abstract: + + This module implements the kernels for the quantized integer matrix/matrix + multiply operation (QGEMM). + + This implementation uses AVX512 core (BW/DQ/VL) and AVX512 VNNI instructions. + +--*/ + +#include "asmmacro.h" +#include "AssembleAvx512Vnni.h" + + .intel_syntax noprefix + +// +// Stack frame layout for the U8X8 kernel. +// + + .equ .LGemmU8X8KernelFrame_type, -8 + .equ .LGemmU8X8KernelFrame_SavedR14, 0 + .equ .LGemmU8X8KernelFrame_SavedR13, 8 + .equ .LGemmU8X8KernelFrame_SavedR12, 16 + .equ .LGemmU8X8KernelFrame_SavedRbx, 24 + .equ .LGemmU8X8KernelFrame_SavedRbp, 32 + .equ .LGemmU8X8KernelFrame_ReturnAddress, 40 + .equ .LGemmU8X8KernelFrame_ldc, 48 + .equ .LGemmU8X8KernelFrame_RowSumBuffer, 56 + .equ .LGemmU8X8KernelFrame_ColumnSumBuffer, 64 + .equ .LGemmU8X8KernelFrame_ZeroPointB, 72 + .equ .LGemmU8X8KernelFrame_ZeroMode, 80 + + .text + +/*++ + +Macro Description: + + This macro generates code to load packed data from matrix B. + +Arguments: + + VecReg - Supplies the register to load the data into. + + AddressOperand - Supplies the address operand. + +--*/ + + .macro LoadPackedMatrixBU8S8 VecReg, AddressOperand + + vmovdqu32 \VecReg\(),ZMMWORD PTR \AddressOperand\() + + .endm + + .macro LoadPackedMatrixBU8U8 VecReg, AddressOperand + + vpmovzxbw \VecReg\(),YMMWORD PTR \AddressOperand\() + + .endm + +/*++ + +Macro Description: + + This macro generates code to multiply and accumulator a single cell of the + output block. + +Arguments: + + AccumReg - Supplies the register to accumulate into. + + Mult1Reg - Supplies the first multiplication operand register. + + Mult2Reg - Supplies the second multiplication operand register. + +Implicit Arguments: + + zmm4 - Supplies a scratch register for intermediate results. + + zmm13 - Supplies a 512-bit with the broadcasted word value 0x0001. + +--*/ + + .macro MultiplyAccumulateCellU8S8Avx512Core AccumReg, Mult1Reg, Mult2Reg + + vpmaddubsw zmm4,\Mult1Reg\(),\Mult2Reg\() + vpmaddwd zmm4,zmm4,zmm13 + vpaddd \AccumReg\(),\AccumReg\(),zmm4 + + .endm + + .macro MultiplyAccumulateCellU8S8Avx512Vnni AccumReg, Mult1Reg, Mult2Reg + + VpdpbusdsZmmZmmZmm \AccumReg\(),\Mult1Reg\(),\Mult2Reg\() + + .endm + + .macro MultiplyAccumulateCellU8U8Avx512Core AccumReg, Mult1Reg, Mult2Reg + + vpmaddwd zmm4,\Mult1Reg\(),\Mult2Reg\() + vpaddd \AccumReg\(),\AccumReg\(),zmm4 + + .endm + +/*++ + +Macro Description: + + This macro generates code to multiply and accumulate each row of the output + block. + +Arguments: + + ColumnCount - Supplies the number of columns to produce. + + RowCount - Supplies the number of rows to produce. + + VectorOffset - Supplies the byte offset from matrix B to fetch elements. + + BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. + +Implicit Arguments: + + rdi - Supplies the address into the matrix A data. + + r8 - Supplies the address into the matrix A data plus 3 rows. + + rsi - Supplies the address into the matrix B data. + + rcx - Supplies the length in bytes of a row from matrix A. + + r14 - Supplies the stride in bytes of between packed blocks of matrix B. + + zmm14-zmm31 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlock Type, Isa, ColumnCount, RowCount, VectorOffset, BroadcastOffset + +.if \ColumnCount\() >= 48 + LoadPackedMatrixB\Type\() zmm0,"[rsi+\VectorOffset\()]" + LoadPackedMatrixB\Type\() zmm1,"[rsi+r14+\VectorOffset\()]" + LoadPackedMatrixB\Type\() zmm2,"[rsi+r14*2+\VectorOffset\()]" +.elseif \ColumnCount\() >= 32 + LoadPackedMatrixB\Type\() zmm1,"[rsi+\VectorOffset\()]" + LoadPackedMatrixB\Type\() zmm2,"[rsi+r14+\VectorOffset\()]" +.else + LoadPackedMatrixB\Type\() zmm2,"[rsi+\VectorOffset\()]" +.endif + EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm3,DWORD PTR [rdi+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "MultiplyAccumulateCell\Type\()\Isa\() zmm26,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "MultiplyAccumulateCell\Type\()\Isa\() zmm20,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "MultiplyAccumulateCell\Type\()\Isa\() zmm14,zmm3,zmm2" + EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "MultiplyAccumulateCell\Type\()\Isa\() zmm27,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "MultiplyAccumulateCell\Type\()\Isa\() zmm21,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "MultiplyAccumulateCell\Type\()\Isa\() zmm15,zmm3,zmm2" + EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "MultiplyAccumulateCell\Type\()\Isa\() zmm28,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "MultiplyAccumulateCell\Type\()\Isa\() zmm22,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "MultiplyAccumulateCell\Type\()\Isa\() zmm16,zmm3,zmm2" + EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm3,DWORD PTR [r8+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "MultiplyAccumulateCell\Type\()\Isa\() zmm29,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "MultiplyAccumulateCell\Type\()\Isa\() zmm23,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "MultiplyAccumulateCell\Type\()\Isa\() zmm17,zmm3,zmm2" + EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm3,DWORD PTR [r8+rcx+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "MultiplyAccumulateCell\Type\()\Isa\() zmm30,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "MultiplyAccumulateCell\Type\()\Isa\() zmm24,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "MultiplyAccumulateCell\Type\()\Isa\() zmm18,zmm3,zmm2" + EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm3,DWORD PTR [r8+rcx*2+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "MultiplyAccumulateCell\Type\()\Isa\() zmm31,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "MultiplyAccumulateCell\Type\()\Isa\() zmm25,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "MultiplyAccumulateCell\Type\()\Isa\() zmm19,zmm3,zmm2" + + .endm + +/*++ + +Macro Description: + + This macro generates code to execute the block compute macro multiple times + and advancing the matrix A and matrix B data pointers. + +Arguments: + + Isa - Supplies the instruction set architecture string. + + ColumnCount - Supplies the number of columns to produce. + + RowCount - Supplies the number of rows to produce. + +Implicit Arguments: + + rdi - Supplies the address into the matrix A data. + + r8 - Supplies the address into the matrix A data plus 3 rows. + + rsi - Supplies the address into the matrix B data. + + rcx - Supplies the length in bytes of a row from matrix A. + + r14 - Supplies the stride in bytes of between packed blocks of matrix B. + + zmm14-zmm31 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlockLoopU8S8 Isa, ColumnCount, RowCount + + mov rbp,rcx # reload row length remaining + +.if ((\RowCount\() & 1) == 0) + sub rbp,4*4 + jb .LProcessRemainingBlocks\@ + +.LComputeBlockBy4Loop\@: + ComputeBlock U8S8, \Isa\(), \ColumnCount\(), \RowCount\(), 0*64, 0 + ComputeBlock U8S8, \Isa\(), \ColumnCount\(), \RowCount\(), 1*64, 4 + ComputeBlock U8S8, \Isa\(), \ColumnCount\(), \RowCount\(), 2*64, 8 + ComputeBlock U8S8, \Isa\(), \ColumnCount\(), \RowCount\(), 3*64, 12 + add rdi,4*4 # advance matrix A by 1 quad +.if \RowCount\() > 3 + add r8,4*4 # advance matrix A plus 3 rows by 1 quad +.endif + add rsi,4*64 # advance matrix B + sub rbp,4*4 # decrement quads remaining + jae .LComputeBlockBy4Loop\@ + +.LProcessRemainingBlocks\@: + add rbp,4*4 # correct for over-subtract above + jz .LComputeBlockLoopExit\@ +.endif + +.LComputeBlockBy1Loop\@: + ComputeBlock U8S8, \Isa\(), \ColumnCount\(), \RowCount\(), 0, 0 + add rdi,4 # advance matrix A by 1 quad +.if \RowCount\() > 3 + add r8,4 # advance matrix A plus 3 rows by 1 quad +.endif + add rsi,64 # advance matrix B + sub rbp,4 # decrement quads remaining + jnz .LComputeBlockBy1Loop\@ + +.LComputeBlockLoopExit\@: + + .endm + + .macro ComputeBlockLoopU8U8 Isa, ColumnCount, RowCount + + mov rbp,rcx # reload row length remaining + +.LComputeBlockBy1Loop\@: + ComputeBlock U8U8, \Isa\(), \ColumnCount\(), \RowCount\(), 0, 0 + add rdi,4 # advance matrix A by 1 pair +.if \RowCount\() > 3 + add r8,4 # advance matrix A plus 3 rows by 1 pair +.endif + add rsi,32 # advance matrix B + sub rbp,4 + jnz .LComputeBlockBy1Loop\@ + + .endm + +/*++ + +Macro Description: + + This macro generates code to produce an output block for a set of columns + and rows. + +Arguments: + + ColumnCount - Supplies the number of columns to produce. + + RowCount - Supplies the number of rows to produce. + +Implicit Arguments: + + rax - Supplies the length in bytes of a row from matrix C. + + rdi - Supplies the address into the matrix A data. + + rsi - Supplies the address into the matrix B data. + + rcx - Supplies the length in bytes of a row from matrix A. + + r11 - Supplies the address of the row sum buffer. + + r12 - Supplies the address of the column sum buffer. + +--*/ + + .macro ProduceOutputBlock ColumnCount, RowCount + +// +// Initialize the accumulators with the row and column sums. +// + +.if \ColumnCount\() >= 32 +.if \ColumnCount\() >= 48 + vmovdqu32 zmm2,ZMMWORD PTR [r12] + vmovdqu32 zmm1,ZMMWORD PTR [r12+64] + vmovdqu32 zmm0,ZMMWORD PTR [r12+128] +.else + vmovdqu32 zmm1,ZMMWORD PTR [r12] + vmovdqu32 zmm0,ZMMWORD PTR [r12+64] +.endif + add_immed r12,\ColumnCount\()*4 # advance ColumnSumBuffer by N columns +.else + vmovdqu32 zmm0,ZMMWORD PTR [r12] +.endif + test r13,r13 # per column zero points? + jz .LSkipScaleByZeroPointB\@ +.if \ColumnCount\() >= 32 +.if \ColumnCount\() >= 48 + vmovdqu32 zmm5,ZMMWORD PTR [r13] + vmovdqu32 zmm4,ZMMWORD PTR [r13+64] + vmovdqu32 zmm3,ZMMWORD PTR [r13+128] +.else + vmovdqu32 zmm4,ZMMWORD PTR [r13] + vmovdqu32 zmm3,ZMMWORD PTR [r13+64] +.endif + add_immed r13,\ColumnCount\()*4 # advance ZeroPointB by N columns +.else + vmovdqu32 zmm3,ZMMWORD PTR [r13] +.endif + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpmulld zmm14,zmm3,DWORD PTR [r11]{1to16}" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "vpmulld zmm20,zmm4,DWORD PTR [r11]{1to16}" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "vpmulld zmm26,zmm5,DWORD PTR [r11]{1to16}" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd zmm14,zmm0,zmm14" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd zmm20,zmm1,zmm20" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "vpaddd zmm26,zmm2,zmm26" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpmulld zmm15,zmm3,DWORD PTR [r11+4]{1to16}" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "vpmulld zmm21,zmm4,DWORD PTR [r11+4]{1to16}" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "vpmulld zmm27,zmm5,DWORD PTR [r11+4]{1to16}" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd zmm15,zmm0,zmm15" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "vpaddd zmm21,zmm1,zmm21" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "vpaddd zmm27,zmm2,zmm27" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpmulld zmm16,zmm3,DWORD PTR [r11+8]{1to16}" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "vpmulld zmm22,zmm4,DWORD PTR [r11+8]{1to16}" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "vpmulld zmm28,zmm5,DWORD PTR [r11+8]{1to16}" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd zmm16,zmm0,zmm16" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "vpaddd zmm22,zmm1,zmm22" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "vpaddd zmm28,zmm2,zmm28" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpmulld zmm17,zmm3,DWORD PTR [r11+12]{1to16}" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "vpmulld zmm23,zmm4,DWORD PTR [r11+12]{1to16}" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "vpmulld zmm29,zmm5,DWORD PTR [r11+12]{1to16}" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd zmm17,zmm0,zmm17" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "vpaddd zmm23,zmm1,zmm23" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "vpaddd zmm29,zmm2,zmm29" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpmulld zmm18,zmm3,DWORD PTR [r11+16]{1to16}" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "vpmulld zmm24,zmm4,DWORD PTR [r11+16]{1to16}" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "vpmulld zmm30,zmm5,DWORD PTR [r11+16]{1to16}" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd zmm18,zmm0,zmm18" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "vpaddd zmm24,zmm1,zmm24" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "vpaddd zmm30,zmm2,zmm30" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpmulld zmm19,zmm3,DWORD PTR [r11+20]{1to16}" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "vpmulld zmm25,zmm4,DWORD PTR [r11+20]{1to16}" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "vpmulld zmm31,zmm5,DWORD PTR [r11+20]{1to16}" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd zmm19,zmm0,zmm19" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "vpaddd zmm25,zmm1,zmm25" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "vpaddd zmm31,zmm2,zmm31" + jmp .LAccumulatorsInitialized\@ + +.LSkipScaleByZeroPointB\@: + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd zmm14,zmm0,DWORD PTR [r11]{1to16}" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd zmm20,zmm1,DWORD PTR [r11]{1to16}" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "vpaddd zmm26,zmm2,DWORD PTR [r11]{1to16}" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd zmm15,zmm0,DWORD PTR [r11+4]{1to16}" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "vpaddd zmm21,zmm1,DWORD PTR [r11+4]{1to16}" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "vpaddd zmm27,zmm2,DWORD PTR [r11+4]{1to16}" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd zmm16,zmm0,DWORD PTR [r11+8]{1to16}" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "vpaddd zmm22,zmm1,DWORD PTR [r11+8]{1to16}" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "vpaddd zmm28,zmm2,DWORD PTR [r11+8]{1to16}" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd zmm17,zmm0,DWORD PTR [r11+12]{1to16}" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "vpaddd zmm23,zmm1,DWORD PTR [r11+12]{1to16}" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "vpaddd zmm29,zmm2,DWORD PTR [r11+12]{1to16}" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd zmm18,zmm0,DWORD PTR [r11+16]{1to16}" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "vpaddd zmm24,zmm1,DWORD PTR [r11+16]{1to16}" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "vpaddd zmm30,zmm2,DWORD PTR [r11+16]{1to16}" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd zmm19,zmm0,DWORD PTR [r11+20]{1to16}" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "vpaddd zmm25,zmm1,DWORD PTR [r11+20]{1to16}" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "vpaddd zmm31,zmm2,DWORD PTR [r11+20]{1to16}" + +.LAccumulatorsInitialized\@: + +// +// Iterate over the length of a matrix A row to produce the output accumulators. +// + +.if \RowCount\() > 3 + lea r8,[rcx*2+rcx] + add r8,rdi # compute matrix A plus 3 rows +.endif + cmp DWORD PTR .LGemmU8X8KernelFrame_type[rsp],0 + je .LProduceWithU8S8Avx512Core\@ + jg .LProduceWithU8U8Avx512Core\@ + ComputeBlockLoopU8S8 Avx512Vnni, \ColumnCount\(), \RowCount\() + jmp .LExitProduceOutputBlock\@ + +.LProduceWithU8U8Avx512Core\@: + ComputeBlockLoopU8U8 Avx512Core, \ColumnCount\(), \RowCount\() + jmp .LExitProduceOutputBlock\@ + +.LProduceWithU8S8Avx512Core\@: + ComputeBlockLoopU8S8 Avx512Core, \ColumnCount\(), \RowCount\() + +.LExitProduceOutputBlock\@: +.if \RowCount\() > 3 + lea r8,[rax*2+rax] + add r8,rdx # compute matrix C plus 3 rows +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to compute matrix multiplication for a fixed set + of rows. + +Arguments: + + RowCount - Supplies the number of rows to process. + +Implicit Arguments: + + rax - Supplies the length in bytes of a row from matrix C. + + rdi - Supplies the address of matrix A. + + rsi - Supplies the address of matrix B. + + rdx - Supplies the address of matrix C. + + rbx - Supplies the address of matrix A. + + r9 - Supplies the number of columns from matrix B and matrix C to iterate + over. + + rcx - Supplies the length in bytes of a row from matrix A. + + r10b - Supplies the zero mode flag. + + r11 - Supplies the address of the row sum buffer. + + r12 - Supplies the address of the column sum buffer. + + r14 - Supplies the stride in bytes of between packed blocks of matrix B. + +--*/ + + .macro ProcessCountM RowCount + + cmp r9,32 + ja .LProcessNextColumnLoop48xN\@ + cmp r9,16 + jbe .LProcessRemainingCountN\@ + +.LProcessNextColumnLoop32xN\@: + ProduceOutputBlock 32, \RowCount\() + add rsi,r14 # advance matrix B by packed block stride + +.LOutput32xNBlock\@: + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutput32xNBlock\@ + EmitIfCountGE \RowCount\(), 1, "vpaddd zmm20,zmm20,ZMMWORD PTR [rdx]" + EmitIfCountGE \RowCount\(), 2, "vpaddd zmm21,zmm21,ZMMWORD PTR [rdx+rax]" + EmitIfCountGE \RowCount\(), 3, "vpaddd zmm22,zmm22,ZMMWORD PTR [rdx+rax*2]" + EmitIfCountGE \RowCount\(), 4, "vpaddd zmm23,zmm23,ZMMWORD PTR [r8]" + EmitIfCountGE \RowCount\(), 5, "vpaddd zmm24,zmm24,ZMMWORD PTR [r8+rax]" + EmitIfCountGE \RowCount\(), 6, "vpaddd zmm25,zmm25,ZMMWORD PTR [r8+rax*2]" + +.LSkipAccumulateOutput32xNBlock\@: + EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx],zmm20" + EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax],zmm21" + EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2],zmm22" + EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [r8],zmm23" + EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [r8+rax],zmm24" + EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [r8+rax*2],zmm25" + add rdx,16*4 # advance matrix C by 16 columns +.if \RowCount\() > 3 + add r8,16*4 # advance matrix C plus 3 rows by 16 columns +.endif + sub r9,16 + +.LOutput16xNBlock\@: + sub r9,16 + jae .LOutput16xNBlockWithMask\@ + lea rcx,[r9+16] # correct for over-subtract above + mov ebp,1 + shl ebp,cl + dec ebp + kmovw k1,ebp # update mask for remaining columns + xor r9,r9 # no more columns remaining + +.LOutput16xNBlockWithMask\@: + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutput16xNBlockWithMask\@ + EmitIfCountGE \RowCount\(), 1, "vpaddd zmm14{k1},zmm14,ZMMWORD PTR [rdx]" + EmitIfCountGE \RowCount\(), 2, "vpaddd zmm15{k1},zmm15,ZMMWORD PTR [rdx+rax]" + EmitIfCountGE \RowCount\(), 3, "vpaddd zmm16{k1},zmm16,ZMMWORD PTR [rdx+rax*2]" + EmitIfCountGE \RowCount\(), 4, "vpaddd zmm17{k1},zmm17,ZMMWORD PTR [r8]" + EmitIfCountGE \RowCount\(), 5, "vpaddd zmm18{k1},zmm18,ZMMWORD PTR [r8+rax]" + EmitIfCountGE \RowCount\(), 6, "vpaddd zmm19{k1},zmm19,ZMMWORD PTR [r8+rax*2]" + +.LSkipAccumulateOutput16xNBlockWithMask\@: + EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx]{k1},zmm14" + EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax]{k1},zmm15" + EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2]{k1},zmm16" + EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [r8]{k1},zmm17" + EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [r8+rax]{k1},zmm18" + EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [r8+rax*2]{k1},zmm19" + add rdx,16*4 # advance matrix C by 16 columns + mov rdi,rbx # reload matrix A + cmp r9,32 + ja .LProcessNextColumnLoop48xN\@ + cmp r9,16 + ja .LProcessNextColumnLoop32xN\@ + test r9,r9 + jnz .LProcessRemainingCountN\@ + mov eax,\RowCount\() + jmp .LExitKernel + +.LProcessRemainingCountN\@: + ProduceOutputBlock 16, \RowCount\() + jmp .LOutput16xNBlock\@ + +.LProcessNextColumnLoop48xN\@: + ProduceOutputBlock 48, \RowCount\() + lea rsi,[rsi+r14*2] # advance matrix B by packed block stride + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutput48xNBlock\@ + EmitIfCountGE \RowCount\(), 1, "vpaddd zmm26,zmm26,ZMMWORD PTR [rdx]" + EmitIfCountGE \RowCount\(), 2, "vpaddd zmm27,zmm27,ZMMWORD PTR [rdx+rax]" + EmitIfCountGE \RowCount\(), 3, "vpaddd zmm28,zmm28,ZMMWORD PTR [rdx+rax*2]" + EmitIfCountGE \RowCount\(), 4, "vpaddd zmm29,zmm29,ZMMWORD PTR [r8]" + EmitIfCountGE \RowCount\(), 5, "vpaddd zmm30,zmm30,ZMMWORD PTR [r8+rax]" + EmitIfCountGE \RowCount\(), 6, "vpaddd zmm31,zmm31,ZMMWORD PTR [r8+rax*2]" + +.LSkipAccumulateOutput48xNBlock\@: + EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx],zmm26" + EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax],zmm27" + EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2],zmm28" + EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [r8],zmm29" + EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [r8+rax],zmm30" + EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [r8+rax*2],zmm31" + add rdx,16*4 # advance matrix C by 16 columns +.if \RowCount\() > 3 + add r8,16*4 # advance matrix C plus 3 rows by 16 columns +.endif + sub r9,16 + jmp .LOutput32xNBlock\@ + + .endm + +// +// Reduce code size for the various types of kernels by sharing the outer logic +// and switching on the selector codes (using sign bit to discriminate). +// + + FUNCTION_ENTRY MlasGemmU8U8KernelAvx512Core + + mov eax,1 + jmp C_UNDERSCORE(MlasGemmU8X8KernelAvx512Core) + + FUNCTION_ENTRY MlasGemmU8S8KernelAvx512Core + + xor eax,eax + jmp C_UNDERSCORE(MlasGemmU8X8KernelAvx512Core) + + FUNCTION_ENTRY MlasGemmU8S8KernelAvx512Vnni + + mov eax,-1 + jmp C_UNDERSCORE(MlasGemmU8X8KernelAvx512Core) + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (rdi) - Supplies the address of matrix A. The matrix data has been packed + using MlasGemmU8X8CopyPackAAvx2. + + B (rsi) - Supplies the address of matrix B. The matrix data has been packed + using MlasGemmU8X8CopyPackBAvx2. + + C (rdx) - Supplies the address of matrix C. + + PackedCountK (rcx) - Supplies the number of packed columns from matrix A and + the number of packed rows from matrix B to iterate over. + + CountM (r8) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (r9) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + ldc - Supplies the first dimension of matrix C. + + RowSumBuffer - Supplies the sum of each row from matrix A. These values have + been pre-scaled by the zero point offset of matrix B if the offset is + per-tensor (ZeroPointB is nullptr). Otherwise, these values must be + scaled by the per-column zero point offsets of matrix B. These values are + accumulated into every row of matrix C. + + ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied + by the zero point offset of matrix A. These values are accumulated into + every column of matrix C. + + ZeroPointB - Optionally supplies the per-column zero point offsets of matrix + B, else nullptr if the matrix B is using per-tensor quantization. + + ZeroMode - Supplies true if the output matrix must be zero initialized, + else false if the output matrix is accumulated into. + +Return Value: + + Returns the number of rows handled. + +--*/ + + FUNCTION_ENTRY MlasGemmU8X8KernelAvx512Core + + push rbp + push rbx + push r12 + push r13 + push r14 + + mov DWORD PTR .LGemmU8X8KernelFrame_type[rsp],eax + mov rbx,rdi + mov rax,.LGemmU8X8KernelFrame_ldc[rsp] + shl rax,2 # convert ldc to bytes + shl rcx,2 # convert to row length + movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp] + mov r11,.LGemmU8X8KernelFrame_RowSumBuffer[rsp] + mov r12,.LGemmU8X8KernelFrame_ColumnSumBuffer[rsp] + mov r13,.LGemmU8X8KernelFrame_ZeroPointB[rsp] + mov ebp,-1 + kmovw k1,ebp # update mask to write all columns + neg ebp + vpbroadcastw zmm13,ebp # generate 512-bit word vector [0x0001] + lea rbp,[rcx*8] + lea r14,[rbp*2] + cmp DWORD PTR .LGemmU8X8KernelFrame_type[rsp],0 + cmovg r14,rbp # select matrix B packed stride + +// +// Process CountM rows of the matrices. +// + + cmp r8,5 + ja .LProcessCountM6 + je .LProcessCountM5 + cmp r8,3 + ja .LProcessCountM4 + je .LProcessCountM3 + cmp r8,1 + je .LProcessCountM1 + +.LProcessCountM2: + ProcessCountM 2 + +.LProcessCountM4: + ProcessCountM 4 + +.LProcessCountM6: + ProcessCountM 6 + +// +// Restore non-volatile registers and return. +// + +.LExitKernel: + vzeroupper + + pop r14 + pop r13 + pop r12 + pop rbx + pop rbp + ret + +.LProcessCountM1: + ProcessCountM 1 + +.LProcessCountM3: + ProcessCountM 3 + +.LProcessCountM5: + ProcessCountM 5 + + .end diff --git a/onnxruntime/core/mlas/lib/x86_64/asmmacro.h b/onnxruntime/core/mlas/lib/x86_64/asmmacro.h index b4dd277825..a26754bf12 100644 --- a/onnxruntime/core/mlas/lib/x86_64/asmmacro.h +++ b/onnxruntime/core/mlas/lib/x86_64/asmmacro.h @@ -22,6 +22,32 @@ Abstract: /*++ +Macro Description: + + This macro emits the assembler directives to annotate a new function. + +Arguments: + + FunctionName - Supplies the name of the function. + +--*/ + + .macro FUNCTION_ENTRY FunctionName + + .p2align 4 +#if defined(__APPLE__) + .globl _\FunctionName\() +_\FunctionName\(): +#else + .globl \FunctionName\() + .type \FunctionName\(),@function +\FunctionName\(): +#endif + + .endm + +/*++ + Macro Description: This macro generates an optimization for "add reg,128" which can instead diff --git a/onnxruntime/core/providers/cpu/math/matmul_integer.cc b/onnxruntime/core/providers/cpu/math/matmul_integer.cc index 998b97565b..3402947bb3 100644 --- a/onnxruntime/core/providers/cpu/math/matmul_integer.cc +++ b/onnxruntime/core/providers/cpu/math/matmul_integer.cc @@ -29,8 +29,6 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( MatMulInteger); Status MatMulInteger::Compute(OpKernelContext* ctx) const { - concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); - const auto* a = ctx->Input(0); const Tensor* b = packed_b_ ? nullptr : ctx->Input(1); @@ -58,49 +56,33 @@ Status MatMulInteger::Compute(OpKernelContext* ctx) const { b_offset = *static_cast(b_zero_point->DataRaw()); } + MLAS_GEMM_U8X8_PARAMETERS gemm_params; + gemm_params.M = static_cast(helper.M()); + gemm_params.N = static_cast(helper.N()); + gemm_params.K = static_cast(helper.K()); + gemm_params.lda = gemm_params.K; + gemm_params.ZeroPointA = a_offset; + gemm_params.ldb = gemm_params.N; + gemm_params.ZeroPointB = &b_offset; + gemm_params.ldc = gemm_params.N; + const auto* a_data = a->template Data(); auto* y_data = y->template MutableData(); -#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 - if (packed_b_) { - for (size_t i = 0; i < helper.OutputOffsets().size(); i++) { - MlasGemm(static_cast(helper.M()), - static_cast(helper.N()), - static_cast(helper.K()), - a_data + helper.LeftOffsets()[i], - static_cast(helper.K()), - a_offset, - packed_b_.get(), - b_offset, - b_is_signed_, - y_data + helper.OutputOffsets()[i], - static_cast(helper.N()), - thread_pool); + for (size_t i = 0; i < helper.OutputOffsets().size(); i++) { + gemm_params.A = a_data + helper.LeftOffsets()[i]; + if (packed_b_) { + gemm_params.B = packed_b_.get(); + gemm_params.BIsPacked = true; + gemm_params.BIsSigned = b_is_signed_; + } else if (b != nullptr) { + gemm_params.B = static_cast(b->DataRaw()) + + helper.RightOffsets()[i]; + gemm_params.BIsSigned = b->IsDataType(); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input B should not be null."); } - return Status::OK(); - } -#endif - - if (b != nullptr) { - for (size_t i = 0; i < helper.OutputOffsets().size(); i++) { - const auto* b_data = static_cast(b->DataRaw()); - const bool b_is_signed = b->IsDataType(); - MlasGemm(static_cast(helper.M()), - static_cast(helper.N()), - static_cast(helper.K()), - a_data + helper.LeftOffsets()[i], - static_cast(helper.K()), - a_offset, - b_data + helper.RightOffsets()[i], - static_cast(helper.N()), - b_offset, - b_is_signed, - y_data + helper.OutputOffsets()[i], - static_cast(helper.N()), - thread_pool); - } - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input B should not be null."); + gemm_params.C = y_data + helper.OutputOffsets()[i]; + MlasGemm(&gemm_params, ctx->GetOperatorThreadPool()); } return Status::OK(); diff --git a/onnxruntime/core/providers/cpu/math/matmul_integer_base.h b/onnxruntime/core/providers/cpu/math/matmul_integer_base.h index 50c99f6efb..8c14035bbf 100644 --- a/onnxruntime/core/providers/cpu/math/matmul_integer_base.h +++ b/onnxruntime/core/providers/cpu/math/matmul_integer_base.h @@ -11,7 +11,6 @@ class MatMulIntegerBase : public OpKernel { public: MatMulIntegerBase(const OpKernelInfo& info) : OpKernel(info) {} -#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 Status PrePack(const Tensor& tensor, int input_idx, bool& is_packed) override { is_packed = false; @@ -43,7 +42,6 @@ class MatMulIntegerBase : public OpKernel { } return Status::OK(); } -#endif protected: bool b_is_signed_{true}; diff --git a/onnxruntime/core/providers/cpu/math/quantize_linear_matmul.cc b/onnxruntime/core/providers/cpu/math/quantize_linear_matmul.cc index 903c3dd8a1..91feb059d8 100644 --- a/onnxruntime/core/providers/cpu/math/quantize_linear_matmul.cc +++ b/onnxruntime/core/providers/cpu/math/quantize_linear_matmul.cc @@ -76,20 +76,22 @@ Status QLinearMatMul::Compute(OpKernelContext* ctx) const { BufferUniquePtr gemm_output_buffer(gemm_output_data, BufferDeleter(alloc)); auto* gemm_output = static_cast(gemm_output_buffer.get()); + MLAS_GEMM_U8X8_PARAMETERS gemm_params; + gemm_params.M = static_cast(helper.M()); + gemm_params.N = static_cast(helper.N()); + gemm_params.K = static_cast(helper.K()); + gemm_params.lda = gemm_params.K; + gemm_params.ZeroPointA = *a_offset->template Data(); + gemm_params.ldb = gemm_params.N; + gemm_params.ZeroPointB = static_cast(b_offset->DataRaw()); + gemm_params.BIsSigned = b->IsDataType(); + gemm_params.C = gemm_output; + gemm_params.ldc = gemm_params.N; + for (size_t i = 0; i < helper.OutputOffsets().size(); i++) { - MlasGemm(static_cast(helper.M()), - static_cast(helper.N()), - static_cast(helper.K()), - a->template Data() + helper.LeftOffsets()[i], - static_cast(helper.K()), - *a_offset->template Data(), - static_cast(b->DataRaw()) + helper.RightOffsets()[i], - static_cast(helper.N()), - *static_cast(b_offset->DataRaw()), - b->IsDataType(), - gemm_output, - static_cast(helper.N()), - ctx->GetOperatorThreadPool()); + gemm_params.A = a->template Data() + helper.LeftOffsets()[i]; + gemm_params.B = static_cast(b->DataRaw()) + helper.RightOffsets()[i]; + MlasGemm(&gemm_params, ctx->GetOperatorThreadPool()); MlasRequantizeOutput(gemm_output, y->template MutableData() + helper.OutputOffsets()[i], diff --git a/onnxruntime/core/providers/cpu/nn/conv_integer.cc b/onnxruntime/core/providers/cpu/nn/conv_integer.cc index 9260c5385e..c9089eba80 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_integer.cc +++ b/onnxruntime/core/providers/cpu/nn/conv_integer.cc @@ -149,19 +149,19 @@ Status ConvInteger::Compute(OpKernelContext* context) const { } } - MlasGemm(static_cast(M / conv_attrs_.group), - static_cast(output_image_size), - static_cast(kernel_dim), - Wdata + group_id * W_offset, - static_cast(kernel_dim), - filter_offset, - col_buffer_data == nullptr ? Xdata : col_buffer_data, - static_cast(output_image_size), - input_offset, - false, - Ydata, - static_cast(output_image_size), - thread_pool); + MLAS_GEMM_U8X8_PARAMETERS gemm_params; + gemm_params.M = static_cast(M / conv_attrs_.group); + gemm_params.N = static_cast(output_image_size); + gemm_params.K = static_cast(kernel_dim); + gemm_params.A = Wdata + group_id * W_offset; + gemm_params.lda = static_cast(kernel_dim); + gemm_params.ZeroPointA = filter_offset; + gemm_params.B = (col_buffer_data == nullptr) ? Xdata : col_buffer_data, + gemm_params.ldb = static_cast(output_image_size); + gemm_params.ZeroPointB = &input_offset; + gemm_params.C = Ydata; + gemm_params.ldc = static_cast(output_image_size); + MlasGemm(&gemm_params, thread_pool); Xdata += X_offset; Ydata += Y_offset; diff --git a/onnxruntime/core/providers/cpu/nn/qlinearconv.cc b/onnxruntime/core/providers/cpu/nn/qlinearconv.cc index cde405dfa3..bec93144ce 100644 --- a/onnxruntime/core/providers/cpu/nn/qlinearconv.cc +++ b/onnxruntime/core/providers/cpu/nn/qlinearconv.cc @@ -43,10 +43,8 @@ class QLinearConv : public OpKernel { ConvAttributes conv_attrs_; TensorShape W_shape_; -#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 BufferUniquePtr packed_W_buffer_; size_t packed_W_size_; -#endif BufferUniquePtr reordered_W_buffer_; bool is_W_signed_; bool is_W_packed_; @@ -116,7 +114,6 @@ Status QLinearConv::PrePack(const Tensor& tensor, int input_idx, bool& is_packed auto alloc = Info().GetAllocator(0, OrtMemTypeDefault); -#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 const size_t group_count = static_cast(conv_attrs_.group); const size_t group_output_channels = output_channels / group_count; const size_t kernel_dim = group_input_channels * kernel_size; @@ -151,7 +148,6 @@ Status QLinearConv::PrePack(const Tensor& tensor, int input_idx, bool& is_packed return Status::OK(); } } -#endif auto* reordered_W = static_cast(alloc->Alloc(SafeInt(sizeof(uint8_t)) * output_channels * group_input_channels * kernel_size)); reordered_W_buffer_ = BufferUniquePtr(reordered_W, BufferDeleter(alloc)); @@ -279,13 +275,7 @@ Status QLinearConv::Compute(OpKernelContext* context) const { // Handle the case of a dynamic weight filter. BufferUniquePtr reordered_W_buffer; uint8_t* reordered_W = nullptr; - bool use_reordered_W = true; -#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 - if (packed_W_buffer_) { - use_reordered_W = false; - } -#endif - if (use_reordered_W) { + if (!packed_W_buffer_) { if (W == nullptr) { // Weight was constant and reordered. reordered_W = static_cast(reordered_W_buffer_.get()); @@ -307,7 +297,7 @@ Status QLinearConv::Compute(OpKernelContext* context) const { int64_t group_output_channels = M / group_count; // Test for depthwise convolution. - const bool is_depthwise_conv = (use_reordered_W && group_input_channels == 1 && group_output_channels == 1); + const bool is_depthwise_conv = (reordered_W != nullptr && group_input_channels == 1 && group_output_channels == 1); if (is_depthwise_conv) { // Update the input and output channels to the number of groups in order to // reuse as much of the below standard convolution path. @@ -510,39 +500,25 @@ Status QLinearConv::Compute(OpKernelContext* context) const { worker_gemm_input = input_data + output_start * kernel_dim; } -#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 + MLAS_GEMM_U8X8_PARAMETERS gemm_params; + gemm_params.M = static_cast(output_count); + gemm_params.N = static_cast(group_output_channels); + gemm_params.K = static_cast(kernel_dim); + gemm_params.A = worker_gemm_input; + gemm_params.lda = static_cast(kernel_dim); + gemm_params.ZeroPointA = X_zero_point_value; if (packed_W_buffer_) { - MlasGemm( - static_cast(output_count), - static_cast(group_output_channels), - static_cast(kernel_dim), - worker_gemm_input, - static_cast(kernel_dim), - X_zero_point_value, - static_cast(packed_W_buffer_.get()) + group_id * packed_W_size_, - W_zero_point_value, - is_W_signed, - worker_gemm_output + group_id * group_output_channels, - static_cast(M), - nullptr); - } else -#endif - { - MlasGemm( - static_cast(output_count), - static_cast(group_output_channels), - static_cast(kernel_dim), - worker_gemm_input, - static_cast(kernel_dim), - X_zero_point_value, - reordered_W + group_id * group_output_channels, - static_cast(M), - W_zero_point_value, - is_W_signed, - worker_gemm_output + group_id * group_output_channels, - static_cast(M), - nullptr); + gemm_params.B = static_cast(packed_W_buffer_.get()) + group_id * packed_W_size_, + gemm_params.BIsPacked = true; + } else { + gemm_params.B = reordered_W + group_id * group_output_channels, + gemm_params.ldb = static_cast(M); } + gemm_params.ZeroPointB = &W_zero_point_value; + gemm_params.BIsSigned = is_W_signed; + gemm_params.C = worker_gemm_output + group_id * group_output_channels; + gemm_params.ldc = static_cast(M); + MlasGemm(&gemm_params, nullptr); } } diff --git a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.cc b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.cc index 0399d1193e..121a0d81f1 100644 --- a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.cc +++ b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.cc @@ -286,39 +286,23 @@ void ComputeGemm(const int M, C, ldc, scale_multiplier.data(), nullptr, beta == 1.0f ? MLAS_QGEMM_OUTPUT_MODE::AccumulateMode : MLAS_QGEMM_OUTPUT_MODE::ZeroMode, scale_multiplier.size() == 1 ? MLAS_QUANTIZATION_GRANULARITY::PerMatrix : MLAS_QUANTIZATION_GRANULARITY::PerColumn); -#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 - if (weights.is_prepacked_) { - MlasGemm(static_cast(M), - static_cast(N), - static_cast(K), - a_data_quant, - static_cast(K), - a_zero_point, - weights.buffer_, - b_zero_point, - b_is_signed, - C_buffer, - ld_C_buffer, - thread_pool, - &output_processor); - return; - } -#endif - MlasGemm(static_cast(M), - static_cast(N), - static_cast(K), - a_data_quant, - static_cast(K), - a_zero_point, - static_cast(weights.buffer_), - static_cast(N), - b_zero_point, - b_is_signed, - C_buffer, - ld_C_buffer, - thread_pool, - &output_processor); + MLAS_GEMM_U8X8_PARAMETERS gemm_params; + gemm_params.M = static_cast(M); + gemm_params.N = static_cast(N); + gemm_params.K = static_cast(K); + gemm_params.A = a_data_quant; + gemm_params.lda = static_cast(K); + gemm_params.ZeroPointA = a_zero_point; + gemm_params.B = weights.buffer_; + gemm_params.ldb = static_cast(N); + gemm_params.ZeroPointB = &b_zero_point; + gemm_params.BIsPacked = weights.is_prepacked_; + gemm_params.BIsSigned = b_is_signed; + gemm_params.C = C_buffer; + gemm_params.ldc = ld_C_buffer; + gemm_params.OutputProcessor = &output_processor; + MlasGemm(&gemm_params, thread_pool); } namespace deepcpu { diff --git a/onnxruntime/test/mlas/unittest.cpp b/onnxruntime/test/mlas/unittest.cpp index 6fe67f5b18..762175b228 100644 --- a/onnxruntime/test/mlas/unittest.cpp +++ b/onnxruntime/test/mlas/unittest.cpp @@ -537,63 +537,7 @@ public: }; template -class MlasQgemmU8X8U8X8TestBase; - -template<> -class MlasQgemmU8X8U8X8TestBase : public MlasTestBase -{ -protected: - void - TestGemm( - size_t M, - size_t N, - size_t K, - const uint8_t* A, - size_t lda, - uint8_t offa, - const uint8_t* B, - size_t ldb, - uint8_t offb, - bool BIsSigned, - int32_t* C, - size_t ldc - ) - { - MlasGemm(M, N, K, A, lda, offa, B, ldb, offb, BIsSigned, C, ldc, threadpool); - } - - void - TestGemm( - size_t M, - size_t N, - size_t K, - const uint8_t* A, - size_t lda, - uint8_t offa, - const uint8_t* B, - size_t ldb, - uint8_t offb, - bool BIsSigned, - float* C, - size_t ldc, - float CScale, - const float* Bias - ) - { - MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(C, ldc, &CScale, Bias); - MlasGemm(M, N, K, - A, lda, offa, - B, ldb, offb, BIsSigned, - reinterpret_cast(C), ldc, - threadpool, - &scale_bias_processor); - } -}; - -#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 - -template<> -class MlasQgemmU8X8U8X8TestBase : public MlasTestBase +class MlasQgemmU8X8U8X8TestBase : public MlasTestBase { private: void* @@ -628,8 +572,69 @@ protected: size_t ldc ) { - const void* PackedB = PackB(N, K, B, ldb, BIsSigned); - MlasGemm(M, N, K, A, lda, offa, PackedB, offb, BIsSigned, C, ldc, threadpool); + MLAS_GEMM_U8X8_PARAMETERS GemmParameters; + + GemmParameters.M = M; + GemmParameters.N = N; + GemmParameters.K = K; + GemmParameters.A = A; + GemmParameters.lda = lda; + GemmParameters.ZeroPointA = offa; + GemmParameters.ZeroPointB = &offb; + GemmParameters.BIsSigned = BIsSigned; + GemmParameters.C = C; + GemmParameters.ldc = ldc; + + if (Packed) { + GemmParameters.B = PackB(N, K, B, ldb, BIsSigned); + GemmParameters.BIsPacked = true; + } else { + GemmParameters.B = B; + GemmParameters.ldb = ldb; + } + + MlasGemm(&GemmParameters, threadpool); + } + + void + TestGemm( + size_t M, + size_t N, + size_t K, + const uint8_t* A, + size_t lda, + uint8_t offa, + const uint8_t* B, + size_t ldb, + const uint8_t* offb, + bool BIsSigned, + int32_t* C, + size_t ldc + ) + { + MLAS_GEMM_U8X8_PARAMETERS GemmParameters; + + GemmParameters.M = M; + GemmParameters.N = N; + GemmParameters.K = K; + GemmParameters.A = A; + GemmParameters.lda = lda; + GemmParameters.ZeroPointA = offa; + GemmParameters.ZeroPointB = offb; + GemmParameters.BIsSigned = BIsSigned; + GemmParameters.PerColumnZeroPoints = true; + GemmParameters.C = C; + GemmParameters.ldc = ldc; + + if (Packed) { + GemmParameters.B = PackB(N, K, B, ldb, BIsSigned); + GemmParameters.BIsPacked = true; + } else { + GemmParameters.B = B; + GemmParameters.ldb = ldb; + } + + MlasGemm(&GemmParameters, threadpool); } void @@ -650,22 +655,37 @@ protected: const float* Bias ) { - const void* PackedB = PackB(N, K, B, ldb, BIsSigned); - MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(C, ldc, &CScale, Bias); - MlasGemm(M, N, K, - A, lda, offa, - PackedB, offb, BIsSigned, - reinterpret_cast(C), ldc, - threadpool, - &scale_bias_processor); + MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR ScaleBiasProcessor(C, ldc, &CScale, Bias); + + MLAS_GEMM_U8X8_PARAMETERS GemmParameters; + + GemmParameters.M = M; + GemmParameters.N = N; + GemmParameters.K = K; + GemmParameters.A = A; + GemmParameters.lda = lda; + GemmParameters.ZeroPointA = offa; + GemmParameters.ZeroPointB = &offb; + GemmParameters.BIsSigned = BIsSigned; + GemmParameters.C = reinterpret_cast(C); + GemmParameters.ldc = ldc; + GemmParameters.OutputProcessor = &ScaleBiasProcessor; + + if (Packed) { + GemmParameters.B = PackB(N, K, B, ldb, BIsSigned); + GemmParameters.BIsPacked = true; + } else { + GemmParameters.B = B; + GemmParameters.ldb = ldb; + } + + MlasGemm(&GemmParameters, threadpool); } private: MatrixGuardBuffer BufferBPacked; }; -#endif - template class MlasQgemmU8X8Test; @@ -690,6 +710,23 @@ private: Test(M, N, K, A, K, offa, B, N, offb, C, CReference, N); } + void + Test( + size_t M, + size_t N, + size_t K, + uint8_t offa + ) + { + const uint8_t* A = BufferA.GetBuffer(K * M); + const uint8_t* B = BufferB.GetBuffer(N * K); + const uint8_t* ZeroPointB = BufferZeroPointB.GetBuffer(N); + int32_t* C = BufferC.GetBuffer(N * M); + int32_t* CReference = BufferCReference.GetBuffer(N * M); + + Test(M, N, K, A, K, offa, B, N, ZeroPointB, C, CReference, N); + } + void Test( size_t M, @@ -720,6 +757,36 @@ private: } } + void + Test( + size_t M, + size_t N, + size_t K, + const uint8_t* A, + size_t lda, + uint8_t offa, + const uint8_t* B, + size_t ldb, + const uint8_t *offb, + int32_t* C, + int32_t* CReference, + size_t ldc + ) + { + std::fill_n(C, M * N, -1); + std::fill_n(CReference, M * N, -1); + + this->TestGemm(M, N, K, A, lda, offa, B, ldb, offb, BIsSigned, C, ldc); + ReferenceQgemm(M, N, K, A, lda, offa, (const xint8_t*)B, ldb, (const xint8_t*)offb, CReference, ldc); + + for (size_t f = 0; f < M * N; f++) { + if (C[f] != CReference[f]) { + printf("mismatch M=%zd, N=%zd, K=%zd, offa=%d!\n", M, N, K, int(offa)); + break; + } + } + } + void ReferenceQgemm( size_t M, @@ -755,8 +822,44 @@ private: } } + void + ReferenceQgemm( + size_t M, + size_t N, + size_t K, + const uint8_t* A, + size_t lda, + uint8_t offa, + const xint8_t* B, + size_t ldb, + const xint8_t* offb, + int32_t* C, + size_t ldc + ) + { + for (size_t m = 0; m < M; m++) { + + for (size_t n = 0; n < N; n++) { + + const uint8_t* a = A + (m * lda); + const xint8_t* b = B + n; + int32_t* c = C + (m * ldc) + n; + int32_t sum = 0; + + for (size_t k = 0; k < K; k++) { + sum += ((int32_t(*b) - offb[n]) * (int32_t(*a) - offa)); + b += ldb; + a += 1; + } + + *c = sum; + } + } + } + MatrixGuardBuffer BufferA; MatrixGuardBuffer BufferB; + MatrixGuardBuffer BufferZeroPointB; MatrixGuardBuffer BufferC; MatrixGuardBuffer BufferCReference; const bool BIsSigned = std::is_signed::value; @@ -769,12 +872,15 @@ public: { for (size_t b = 1; b < 16; b++) { Test(b, b, b, 14, 211); + Test(b, b, b, 21); } for (size_t b = 1; b < 16; b++) { Test(b, b, b, 14, 211); + Test(b, b, b, 17); } for (size_t b = 16; b <= 256; b <<= 1) { Test(b, b, b, 34, 1); + Test(b, b, b, 1); } for (size_t b = 256; b < 320; b += 32) { Test(b, b, b, 85, 173); @@ -786,6 +892,7 @@ public: } Test(43, 500, 401, 183, 223); Test(1023, 1023, 1023, 5, 8); + Test(1023, 1023, 1023, 7); } void @@ -3130,7 +3237,6 @@ RunThreadedTests( printf("QGEMM U8U8=float tests.\n"); onnxruntime::make_unique>()->ExecuteShort(); -#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 if (MlasGemmPackBSize(128, 128, true) > 0) { printf("QGEMM U8S8=int32_t packed tests.\n"); onnxruntime::make_unique>()->ExecuteShort(); @@ -3143,7 +3249,6 @@ RunThreadedTests( printf("QGEMM U8U8=float packed tests.\n"); onnxruntime::make_unique>()->ExecuteShort(); } -#endif printf("Conv2D tests.\n"); onnxruntime::make_unique()->ExecuteShort();