From 28a62f772859d1d2ce43c99072dddec76a7b6adf Mon Sep 17 00:00:00 2001 From: Tracy Sharpe <42477615+tracysh@users.noreply.github.com> Date: Tue, 24 Sep 2019 18:15:11 -0700 Subject: [PATCH] MLAS: add U8S8 MatMul operation (#1895) Implement the second round of changes for quantization inside MLAS. This adds a MatMul operation for U8xS8=S32 for x86/x64 processors. --- cmake/onnxruntime_mlas.cmake | 6 + onnxruntime/core/mlas/inc/mlas.h | 19 +- .../mlas/lib/amd64/AssembleAvx512Vnni.inc | 36 +- .../mlas/lib/amd64/QgemmU8S8KernelAvx2.asm | 1053 +++++++++++++++++ .../lib/amd64/QgemmU8S8KernelAvx512BW.asm | 130 ++ .../lib/amd64/QgemmU8S8KernelAvx512Common.inc | 91 ++ .../lib/amd64/QgemmU8S8KernelAvx512Vnni.asm | 102 ++ .../mlas/lib/amd64/QgemmU8U8KernelAvx2.asm | 454 ++----- .../lib/amd64/QgemmU8U8KernelAvx512BW.asm | 91 +- .../lib/amd64/QgemmU8U8KernelAvx512Common.inc | 353 +----- .../lib/amd64/QgemmU8U8KernelAvx512Vnni.asm | 81 +- .../lib/amd64/QgemmU8X8KernelAvx2Common.inc | 302 +++++ .../lib/amd64/QgemmU8X8KernelAvx512Common.inc | 438 +++++++ onnxruntime/core/mlas/lib/mlasi.h | 59 +- onnxruntime/core/mlas/lib/platform.cpp | 9 + onnxruntime/core/mlas/lib/qgemm.cpp | 725 +++++++++++- .../core/mlas/lib/x86_64/AssembleAvx512Vnni.h | 34 +- .../mlas/lib/x86_64/QgemmU8S8KernelAvx2.S | 955 +++++++++++++++ .../mlas/lib/x86_64/QgemmU8S8KernelAvx512BW.S | 136 +++ .../lib/x86_64/QgemmU8S8KernelAvx512Common.h | 88 ++ .../lib/x86_64/QgemmU8S8KernelAvx512Vnni.S | 106 ++ .../mlas/lib/x86_64/QgemmU8U8KernelAvx2.S | 392 ++---- .../mlas/lib/x86_64/QgemmU8U8KernelAvx512BW.S | 95 +- .../lib/x86_64/QgemmU8U8KernelAvx512Common.h | 327 +---- .../lib/x86_64/QgemmU8U8KernelAvx512Vnni.S | 85 +- .../lib/x86_64/QgemmU8X8KernelAvx2Common.h | 273 +++++ .../lib/x86_64/QgemmU8X8KernelAvx512Common.h | 403 +++++++ .../core/providers/cpu/math/matmul_integer.cc | 29 +- onnxruntime/core/util/qmath.cc | 49 +- onnxruntime/core/util/qmath.h | 34 +- onnxruntime/test/mlas/unittest.cpp | 32 +- 31 files changed, 5415 insertions(+), 1572 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm create mode 100644 onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx512BW.asm create mode 100644 onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Common.inc create mode 100644 onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Vnni.asm create mode 100644 onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2Common.inc create mode 100644 onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx512Common.inc create mode 100644 onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S create mode 100644 onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512BW.S create mode 100644 onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Common.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Vnni.S create mode 100644 onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2Common.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Common.h diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 3bc52bb62b..7704a8cadd 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -44,6 +44,9 @@ if(MSVC) enable_language(ASM_MASM) set(mlas_platform_srcs + ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm + ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8S8KernelAvx512BW.asm + ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Vnni.asm ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8U8KernelAvx2.asm ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8U8KernelAvx512BW.asm ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Vnni.asm @@ -158,6 +161,7 @@ else() set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "-mavx") set(mlas_platform_srcs_avx2 + ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8U8KernelAvx2.S ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/SgemmKernelFma3.S ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/SconvKernelFma3.S @@ -175,6 +179,8 @@ else() set_source_files_properties(${mlas_platform_srcs_avx512f} PROPERTIES COMPILE_FLAGS "-mavx512f") set(mlas_platform_srcs_avx512bw + ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512BW.S + ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Vnni.S ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512BW.S ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Vnni.S ) diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 884d97042b..4002c59d60 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -135,7 +135,24 @@ MlasSgemm( void MLASCALL -MlasQgemm( +MlasGemm( + size_t M, + size_t N, + size_t K, + const uint8_t* A, + size_t lda, + uint8_t offa, + const int8_t* B, + size_t ldb, + int8_t offb, + int32_t* C, + size_t ldc, + MLAS_THREADPOOL* ThreadPool + ); + +void +MLASCALL +MlasGemm( size_t M, size_t N, size_t K, diff --git a/onnxruntime/core/mlas/lib/amd64/AssembleAvx512Vnni.inc b/onnxruntime/core/mlas/lib/amd64/AssembleAvx512Vnni.inc index 02f7d92256..ed885dcb7b 100644 --- a/onnxruntime/core/mlas/lib/amd64/AssembleAvx512Vnni.inc +++ b/onnxruntime/core/mlas/lib/amd64/AssembleAvx512Vnni.inc @@ -139,7 +139,7 @@ VpdpwssdsZmmZmmZmm MACRO DestReg, Src1Reg, Src2Reg ; ; This macro builds a VNNI instruction of the form: ; -; instr zmm1,zmm2,DWORD BCST [BaseReg+IndexReg*Scale] +; instr zmm1,zmm2,DWORD BCST [BaseReg+IndexReg*Scale+ByteOffset] ; ; Arguments: ; @@ -151,15 +151,20 @@ VpdpwssdsZmmZmmZmm MACRO DestReg, Src1Reg, Src2Reg ; ; BaseReg - Specifies the base register of the broadcast operand. ; +; ByteOffset - Specifies the DWORD aligned byte offset for the broadcast +; operand. +; ; IndexReg - Specifies the optional index register of the broadcast operand. ; ; Scale - Specifies the scaling factor of the optional index register. ; -VnniZmmZmmBroadcast MACRO Opcode, DestReg, Src1Reg, BaseReg, IndexReg, Scale +VnniZmmZmmBroadcast MACRO Opcode, DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale LOCAL Payload0, Payload1, Payload2, ModRMByte, SibByte +.errnz (ByteOffset AND 3) + Payload0 = 002h ; "0F 38" prefix Payload0 = Payload0 + ((((ZmmIndex_&DestReg& SHR 3) AND 1) XOR 1) SHL 7) IFNB @@ -183,6 +188,9 @@ IFNB ELSE ModRMByte = ModRMByte + (GprIndex_&BaseReg& AND 7) ENDIF +IF ByteOffset NE 0 + ModRMByte = ModRMByte + 040h ; indicate disp8 byte offset +ENDIF IFNB SibByte = 0 @@ -199,34 +207,36 @@ ENDIF SibByte = SibByte + (GprIndex_&BaseReg& AND 7) ENDIF -IFNB - db 062h, Payload0, Payload1, Payload2, Opcode, ModRMByte, SibByte -ELSE db 062h, Payload0, Payload1, Payload2, Opcode, ModRMByte +IFNB + db SibByte +ENDIF +IF ByteOffset NE 0 + db ByteOffset SHR 2 ENDIF ENDM -VpdpbusdZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, IndexReg, Scale +VpdpbusdZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale - VnniZmmZmmBroadcast 050h, DestReg, Src1Reg, BaseReg, IndexReg, Scale + VnniZmmZmmBroadcast 050h, DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale ENDM -VpdpbusdsZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, IndexReg, Scale +VpdpbusdsZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale - VnniZmmZmmBroadcast 051h, DestReg, Src1Reg, BaseReg, IndexReg, Scale + VnniZmmZmmBroadcast 051h, DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale ENDM -VpdpwssdZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, IndexReg, Scale +VpdpwssdZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale - VnniZmmZmmBroadcast 052h, DestReg, Src1Reg, BaseReg, IndexReg, Scale + VnniZmmZmmBroadcast 052h, DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale ENDM -VpdpwssdsZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, IndexReg, Scale +VpdpwssdsZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale - VnniZmmZmmBroadcast 053h, DestReg, Src1Reg, BaseReg, IndexReg, Scale + VnniZmmZmmBroadcast 053h, DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale ENDM diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm new file mode 100644 index 0000000000..b58cb52669 --- /dev/null +++ b/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm @@ -0,0 +1,1053 @@ +;++ +; +; Copyright (c) Microsoft Corporation. All rights reserved. +; +; Licensed under the MIT License. +; +; Module Name: +; +; QgemmU8S8KernelAvx2.asm +; +; Abstract: +; +; This module implements the kernels for the quantized integer matrix/matrix +; multiply operation (QGEMM). +; +; This implementation uses AVX2 instructions. +; +;-- + + .xlist +INCLUDE mlasi.inc +INCLUDE QgemmU8X8KernelAvx2Common.inc + .list + +; +; Stack frame layout for the U8S8 CopyPackA routine. +; + +GemmU8S8CopyPackAFrame STRUCT + + PaddedMatrixAData OWORD 4 DUP (?) + SavedXmm6 OWORD ? + SavedXmm7 OWORD ? + SavedXmm8 OWORD ? + SavedXmm9 OWORD ? + SavedXmm10 OWORD ? + Padding QWORD ? + SavedR13 QWORD ? + SavedR12 QWORD ? + SavedRdi QWORD ? + SavedRsi QWORD ? + SavedRbx QWORD ? + SavedRbp QWORD ? + ReturnAddress QWORD ? + PreviousP1Home QWORD ? + PreviousP2Home QWORD ? + PreviousP3Home QWORD ? + PreviousP4Home QWORD ? + CountK QWORD ? + RowSumVector QWORD ? + offb QWORD ? + +GemmU8S8CopyPackAFrame ENDS + +; +; Stack frame layout for the U8S8 CopyPackB routine. +; + +GemmU8S8CopyPackBFrame STRUCT + + PaddedMatrixBData OWORD 4 DUP (?) + SavedXmm6 OWORD ? + SavedXmm7 OWORD ? + SavedXmm8 OWORD ? + Padding QWORD ? + SavedRdi QWORD ? + SavedRsi QWORD ? + SavedRbx QWORD ? + SavedRbp QWORD ? + ReturnAddress QWORD ? + PreviousP1Home QWORD ? + PreviousP2Home QWORD ? + PreviousP3Home QWORD ? + PreviousP4Home QWORD ? + CountK QWORD ? + ColumnSumVector QWORD ? + offa QWORD ? + +GemmU8S8CopyPackBFrame ENDS + +;++ +; +; Routine Description: +; +; This routine copies elements from the source matrix to the destination +; packed buffer. +; +; Arguments: +; +; D (rcx) - Supplies the address of the destination packed buffer. +; +; A (rdx) - Supplies the address of the source matrix. +; +; lda (r8) - Supplies the number of elements per row of the source matrix. +; +; CountM (r9) - Supplies the number of rows of the source matrix to copy. +; +; CountK - Supplies the number of columns of the source matrix to copy. +; +; RowSumVector - Supplies the address of the buffer to receive the sums of +; the elements from each of the rows. +; +; offb - Supplies the zero point offset for the other source matrix of the +; matrix multiplication. +; +; Return Value: +; +; None. +; +;-- + + NESTED_ENTRY MlasGemmU8S8CopyPackAAvx2, _TEXT + + rex_push_reg rbp + push_reg rbx + push_reg rsi + push_reg rdi + push_reg r12 + push_reg r13 + alloc_stack (GemmU8S8CopyPackAFrame.SavedR13) + save_xmm128 xmm6,GemmU8S8CopyPackAFrame.SavedXmm6 + save_xmm128 xmm7,GemmU8S8CopyPackAFrame.SavedXmm7 + save_xmm128 xmm8,GemmU8S8CopyPackAFrame.SavedXmm8 + save_xmm128 xmm9,GemmU8S8CopyPackAFrame.SavedXmm9 + save_xmm128 xmm10,GemmU8S8CopyPackAFrame.SavedXmm10 + + END_PROLOGUE + + mov rdi,rcx + mov rsi,rdx + mov r10,GemmU8S8CopyPackAFrame.CountK[rsp] + lea r11,[r10+3] + and r11,NOT 3 ; align CountK up to quad count + mov r12,GemmU8S8CopyPackAFrame.RowSumVector[rsp] + vpbroadcastw xmm8,WORD PTR GemmU8S8CopyPackAFrame.offb[rsp] + vpcmpeqw ymm9,ymm9,ymm9 ; generate word vector [0xFFFF] + vpsrlw ymm9,ymm9,15 ; generate word vector [0x0001] + vpsllw ymm0,ymm9,8 ; generate word vector [0x0100] + vpor ymm9,ymm9,ymm0 ; generate word vector [0x0101] + +; +; Compute the conditional load/store mask for an unaligned CountK. +; + + mov eax,r10d + and eax,15 ; isolate unaligned count + add eax,3 + shr eax,2 ; align unaligned count to quad count + mov DWORD PTR GemmU8S8CopyPackAFrame.CountK[rsp],eax + vpbroadcastd xmm10,DWORD PTR GemmU8S8CopyPackAFrame.CountK[rsp] + vpcmpgtd xmm10,xmm10,XMMWORD PTR [MlasMaskMoveAvx] + +; +; Zero initialize the padded stack buffers. +; + + vpxor xmm0,xmm0,xmm0 + vmovdqu YMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp],ymm0 + vmovdqu YMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp+32],ymm0 + +; +; Process 4 rows of matrix A in a loop. +; + + sub r9,4 + jb ProcessRemainingRows + +ProcessNextRowM4: + vpxor xmm0,xmm0,xmm0 ; clear row accumulators + vpxor xmm1,xmm1,xmm1 + vpxor xmm2,xmm2,xmm2 + vpxor xmm3,xmm3,xmm3 + mov rdx,rsi + mov rcx,rdi + lea rsi,[rsi+r8*4] ; advance next matrix A by 4 rows + lea rdi,[rdi+r11*4] ; advance next matrix D by 4 rows + mov rbx,r10 ; reload columns remaining + sub rbx,32 + jb ProcessRemainingColumnsM4 + +ProcessNextColumnLoopM4: + lea rax,[rdx+r8*2] ; compute matrix A plus 2 rows + vmovdqu ymm4,YMMWORD PTR [rdx] + vmovdqu ymm5,YMMWORD PTR [rdx+r8] + vmovdqu ymm6,YMMWORD PTR [rax] + vmovdqu ymm7,YMMWORD PTR [rax+r8] + lea rax,[rcx+r11*2] ; compute matrix D plus 2 rows + vmovdqu YMMWORD PTR [rcx],ymm4 + vmovdqu YMMWORD PTR [rcx+r11],ymm5 + vmovdqu YMMWORD PTR [rax],ymm6 + vmovdqu YMMWORD PTR [rax+r11],ymm7 + vpmaddubsw ymm4,ymm4,ymm9 ; horizontal byte+byte=word per row + vpaddw ymm0,ymm0,ymm4 ; add words to row accumulators + vpmaddubsw ymm5,ymm5,ymm9 + vpaddw ymm1,ymm1,ymm5 + vpmaddubsw ymm6,ymm6,ymm9 + vpaddw ymm2,ymm2,ymm6 + vpmaddubsw ymm7,ymm7,ymm9 + vpaddw ymm3,ymm3,ymm7 + add rdx,32 ; advance matrix A by 32 bytes + add rcx,32 ; advance matrix D by 32 bytes + sub rbx,32 ; subtract columns remaining + jae ProcessNextColumnLoopM4 + +ProcessRemainingColumnsM4: + add rbx,32 ; correct for over-subtract above + jz ReduceRowSumVectorM4 + test bl,16 ; (CountK & 16) != 0? + jz CopyRemainingCountKLessThan16M4 + lea rax,[rdx+r8*2] ; compute matrix A plus 2 rows + vmovdqu xmm4,XMMWORD PTR [rdx] + vmovdqu xmm5,XMMWORD PTR [rdx+r8] + vmovdqu xmm6,XMMWORD PTR [rax] + vmovdqu xmm7,XMMWORD PTR [rax+r8] + lea rax,[rcx+r11*2] ; compute matrix D plus 2 rows + vmovdqu XMMWORD PTR [rcx],xmm4 + vmovdqu XMMWORD PTR [rcx+r11],xmm5 + vmovdqu XMMWORD PTR [rax],xmm6 + vmovdqu XMMWORD PTR [rax+r11],xmm7 + vpmaddubsw xmm4,xmm4,xmm9 ; horizontal byte+byte=word per row + vpaddw ymm0,ymm0,ymm4 ; add words to row accumulators + vpmaddubsw xmm5,xmm5,xmm9 + vpaddw ymm1,ymm1,ymm5 + vpmaddubsw xmm6,xmm6,xmm9 + vpaddw ymm2,ymm2,ymm6 + vpmaddubsw xmm7,xmm7,xmm9 + vpaddw ymm3,ymm3,ymm7 + add rdx,16 ; advance matrix A by 16 bytes + add rcx,16 ; advance matrix D by 16 bytes + test bl,15 ; test for unaligned columns + jz ReduceRowSumVectorM4 + +; +; Copy the unaligned CountK columns to a zero padded stack buffer. +; + +CopyRemainingCountKLessThan16M4: +.errnz GemmU8S8CopyPackAFrame.PaddedMatrixAData + mov rbp,rsp ; GemmU8S8CopyPackAFrame.PaddedMatrixAData + test bl,8 ; (CountK & 8) != 0? + jz CopyRemainingCountKLessThan8M4 + lea r13,[rdx+r8*2] ; compute matrix A plus 2 rows + mov rax,QWORD PTR [rdx] + mov QWORD PTR [rbp],rax + mov rax,QWORD PTR [rdx+r8] + mov QWORD PTR [rbp+16],rax + mov rax,QWORD PTR [r13] + mov QWORD PTR [rbp+32],rax + mov rax,QWORD PTR [r13+r8] + mov QWORD PTR [rbp+48],rax + add rdx,8 + add rbp,8 ; advance padded buffer destination + +CopyRemainingCountKLessThan8M4: + test bl,4 ; (CountK & 4) != 0? + jz CopyRemainingCountKLessThan4M4 + lea r13,[rdx+r8*2] ; compute matrix A plus 2 rows + mov eax,DWORD PTR [rdx] + mov DWORD PTR [rbp],eax + mov eax,DWORD PTR [rdx+r8] + mov DWORD PTR [rbp+16],eax + mov eax,DWORD PTR [r13] + mov DWORD PTR [rbp+32],eax + mov eax,DWORD PTR [r13+r8] + mov DWORD PTR [rbp+48],eax + add rdx,4 + add rbp,4 ; advance padded buffer destination + +CopyRemainingCountKLessThan4M4: + test bl,2 ; (CountK & 2) != 0? + jz CopyRemainingCountKLessThan2M4 + lea r13,[rdx+r8*2] ; compute matrix A plus 2 rows + movzx eax,WORD PTR [rdx] + mov WORD PTR [rbp],ax + movzx eax,WORD PTR [rdx+r8] + mov WORD PTR [rbp+16],ax + movzx eax,WORD PTR [r13] + mov WORD PTR [rbp+32],ax + movzx eax,WORD PTR [r13+r8] + mov WORD PTR [rbp+48],ax + add rdx,2 + add rbp,2 ; advance padded buffer destination + +CopyRemainingCountKLessThan2M4: + test bl,1 ; (CountK & 1) != 0? + jz ProcessPaddedMatrixADataM4 + lea r13,[rdx+r8*2] ; compute matrix A plus 2 rows + movzx eax,BYTE PTR [rdx] + mov BYTE PTR [rbp],al + movzx eax,BYTE PTR [rdx+r8] + mov BYTE PTR [rbp+16],al + movzx eax,BYTE PTR [r13] + mov BYTE PTR [rbp+32],al + movzx eax,BYTE PTR [r13+r8] + mov BYTE PTR [rbp+48],al + +; +; Process the remaining CountK columns using the zero padded stack buffer. +; + +ProcessPaddedMatrixADataM4: + vmovdqu xmm4,XMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp] + vmovdqu xmm5,XMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp+16] + vmovdqu xmm6,XMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp+32] + vmovdqu xmm7,XMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp+48] + lea rax,[rcx+r11*2] ; compute matrix D plus 2 rows + vpmaskmovd XMMWORD PTR [rcx],xmm10,xmm4 + vpmaskmovd XMMWORD PTR [rcx+r11],xmm10,xmm5 + vpmaskmovd XMMWORD PTR [rax],xmm10,xmm6 + vpmaskmovd XMMWORD PTR [rax+r11],xmm10,xmm7 + vpmaddubsw xmm4,xmm4,xmm9 ; horizontal byte+byte=word per row + vpaddw ymm0,ymm0,ymm4 ; add words to row accumulators + vpmaddubsw xmm5,xmm5,xmm9 + vpaddw ymm1,ymm1,ymm5 + vpmaddubsw xmm6,xmm6,xmm9 + vpaddw ymm2,ymm2,ymm6 + vpmaddubsw xmm7,xmm7,xmm9 + vpaddw ymm3,ymm3,ymm7 + +; +; Reduce the sums for the four rows of output. +; + +ReduceRowSumVectorM4: + vphaddw ymm0,ymm0,ymm1 ; reduce and interleave Sum1/Sum0 + vphaddw ymm1,ymm2,ymm3 ; reduce and interleave Sum3/Sum2 + vphaddw ymm0,ymm0,ymm1 ; reduce and interleave Sum3/Sum2/Sum1/Sum0 + vextracti128 xmm1,ymm0,1 ; extract high pairs + vpaddw xmm0,xmm0,xmm1 ; reduce low/high pairs + vpmaddwd xmm0,xmm0,xmm8 ; multiply by offset and reduce 32-bit sum + vmovdqu XMMWORD PTR [r12],xmm0 + add r12,4*4 ; advance row sum vector by 4 DWORDs + sub r9,4 ; subtract rows remaining + jae ProcessNextRowM4 + +ProcessRemainingRows: + add r9,4 ; correct for over-subtract above + jz ExitRoutine + +; +; Process a single row of matrix A in a loop. +; + +ProcessNextRowM1: + vpxor xmm0,xmm0,xmm0 ; clear row accumulator + mov rdx,rsi + mov rcx,rdi + add rsi,r8 + add rdi,r11 + mov rbx,r10 ; reload columns remaining + sub rbx,32 + jb ProcessRemainingColumnsM1 + +ProcessNextColumnLoopM1: + vmovdqu ymm4,YMMWORD PTR [rdx] + vmovdqu YMMWORD PTR [rcx],ymm4 + vpmaddubsw ymm4,ymm4,ymm9 ; horizontal byte+byte=word per row + vpaddw ymm0,ymm0,ymm4 ; add words to row accumulators + add rdx,32 ; advance matrix A by 32 bytes + add rcx,32 ; advance matrix D by 32 bytes + sub rbx,32 ; subtract columns remaining + jae ProcessNextColumnLoopM1 + +ProcessRemainingColumnsM1: + add rbx,32 ; correct for over-subtract above + jz ReduceRowSumVectorM1 + test bl,16 ; (CountK & 16) != 0? + jz CopyRemainingCountKLessThan16M1 + vmovdqu xmm4,XMMWORD PTR [rdx] + vmovdqu XMMWORD PTR [rcx],xmm4 + vpmaddubsw xmm4,xmm4,xmm9 ; horizontal byte+byte=word per row + vpaddw ymm0,ymm0,ymm4 ; add words to row accumulators + add rdx,16 ; advance matrix A by 16 bytes + add rcx,16 ; advance matrix D by 16 bytes + test bl,15 ; test for unaligned columns + jz ReduceRowSumVectorM1 + +; +; Copy the unaligned CountK columns to a zero padded stack buffer. +; + +CopyRemainingCountKLessThan16M1: +.errnz GemmU8S8CopyPackAFrame.PaddedMatrixAData + mov rbp,rsp ; GemmU8S8CopyPackAFrame.PaddedMatrixAData + test bl,8 ; (CountK & 8) != 0? + jz CopyRemainingCountKLessThan8M1 + mov rax,QWORD PTR [rdx] + mov QWORD PTR [rbp],rax + add rdx,8 + add rbp,8 ; advance padded buffer destination + +CopyRemainingCountKLessThan8M1: + test bl,4 ; (CountK & 4) != 0? + jz CopyRemainingCountKLessThan4M1 + mov eax,DWORD PTR [rdx] + mov DWORD PTR [rbp],eax + add rdx,4 + add rbp,4 ; advance padded buffer destination + +CopyRemainingCountKLessThan4M1: + test bl,2 ; (CountK & 2) != 0? + jz CopyRemainingCountKLessThan2M1 + movzx eax,WORD PTR [rdx] + mov WORD PTR [rbp],ax + add rdx,2 + add rbp,2 ; advance padded buffer destination + +CopyRemainingCountKLessThan2M1: + test bl,1 ; (CountK & 1) != 0? + jz ProcessPaddedMatrixADataM1 + movzx eax,BYTE PTR [rdx] + mov BYTE PTR [rbp],al + +; +; Process the remaining CountK columns using the zero padded stack buffer. +; + +ProcessPaddedMatrixADataM1: + vmovdqu xmm4,XMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp] + vpmaskmovd XMMWORD PTR [rcx],xmm10,xmm4 + vpmaddubsw ymm4,ymm4,ymm9 ; horizontal byte+byte=word per row + vpaddw ymm0,ymm0,ymm4 ; accumulate per row along columns + +; +; Reduce the sum for the single row of output. +; + +ReduceRowSumVectorM1: + vextracti128 xmm1,ymm0,1 ; extract high pairs + vpaddw xmm0,xmm0,xmm1 ; reduction + vphaddw xmm0,xmm0,xmm0 + vphaddw xmm0,xmm0,xmm0 + vpmaddwd xmm0,xmm0,xmm8 ; multiply by offset and reduce + vmovd DWORD PTR [r12],xmm0 + add r12,4 ; advance row sum vector by 1 DWORD + dec r9 ; decrement rows remaining + jnz ProcessNextRowM1 + +; +; Restore non-volatile registers and return. +; + +ExitRoutine: + vzeroupper + movaps xmm6,GemmU8S8CopyPackAFrame.SavedXmm6[rsp] + movaps xmm7,GemmU8S8CopyPackAFrame.SavedXmm7[rsp] + movaps xmm8,GemmU8S8CopyPackAFrame.SavedXmm8[rsp] + movaps xmm9,GemmU8S8CopyPackAFrame.SavedXmm9[rsp] + movaps xmm10,GemmU8S8CopyPackAFrame.SavedXmm10[rsp] + add rsp,(GemmU8S8CopyPackAFrame.SavedR13) + + BEGIN_EPILOGUE + + pop r13 + pop r12 + pop rdi + pop rsi + pop rbx + pop rbp + ret + + NESTED_END MlasGemmU8S8CopyPackAAvx2, _TEXT + +;++ +; +; Routine Description: +; +; This routine copies elements from the source matrix to the destination +; packed buffer. +; +; Arguments: +; +; D (rcx) - Supplies the address of the destination packed buffer. +; +; B (rdx) - Supplies the address of the source matrix. +; +; ldb (r8) - Supplies the number of elements per row of the source matrix. +; +; CountN (r9) - Supplies the number of columns of the source matrix to copy. +; +; CountK - Supplies the number of rows of the source matrix to copy. +; +; ColumnSumVector - Supplies the address of the buffer to receive the sums of +; the elements from each of the columns. Each sum has also been multiplied +; by the zero point offset. +; +; offa - Supplies the zero point offset for the other source matrix of the +; matrix multiplication. +; +; Return Value: +; +; None. +; +;-- + + NESTED_ENTRY MlasGemmU8S8CopyPackBAvx2, _TEXT + + rex_push_reg rbp + push_reg rbx + push_reg rsi + push_reg rdi + alloc_stack (GemmU8S8CopyPackBFrame.SavedRdi) + save_xmm128 xmm6,GemmU8S8CopyPackBFrame.SavedXmm6 + save_xmm128 xmm7,GemmU8S8CopyPackBFrame.SavedXmm7 + save_xmm128 xmm8,GemmU8S8CopyPackBFrame.SavedXmm8 + + END_PROLOGUE + + mov rsi,rdx + mov r10,GemmU8S8CopyPackBFrame.CountK[rsp] + mov r11,GemmU8S8CopyPackBFrame.ColumnSumVector[rsp] + vpbroadcastw ymm7,WORD PTR GemmU8S8CopyPackBFrame.offa[rsp] + vpcmpeqw ymm8,ymm8,ymm8 ; generate word vector [0xFFFF] + vpsrlw ymm8,ymm8,15 ; generate word vector [0x0001] + vpsllw ymm0,ymm8,8 ; generate word vector [0x0100] + vpor ymm8,ymm8,ymm0 ; generate word vector [0x0101] + +; +; Process 16 columns of matrix B in a loop. +; + + sub r9,16 + jb ProcessRemainingColumns + +ProcessNextColumnN16: + vpxor xmm0,xmm0,xmm0 ; clear column accumulators + vpxor xmm1,xmm1,xmm1 + mov rdx,rsi + add rsi,16 ; advance next matrix B by 16 columns + mov rbx,r10 ; reload rows remaining + sub rbx,4 + jb ProcessRemainingRowsN16 + +ProcessNextRowLoopN16: + lea rax,[rdx+r8*2] ; compute matrix B plus 2 rows + vmovdqu xmm2,XMMWORD PTR [rdx] ; load 4 rows + vmovdqu xmm3,XMMWORD PTR [rdx+r8] + vmovdqu xmm4,XMMWORD PTR [rax] + vmovdqu xmm5,XMMWORD PTR [rax+r8] + lea rdx,[rdx+r8*4] ; advance matrix B by 4 rows + +InterleaveRowDataN16: + vpunpcklbw xmm6,xmm2,xmm3 ; interleave row data + vpunpckhbw xmm3,xmm2,xmm3 + vpunpcklbw xmm2,xmm4,xmm5 + vpunpckhbw xmm5,xmm4,xmm5 + vpunpcklwd xmm4,xmm6,xmm2 + vpunpckhwd xmm6,xmm6,xmm2 + vpunpcklwd xmm2,xmm3,xmm5 + vpunpckhwd xmm3,xmm3,xmm5 + vinsertf128 ymm4,ymm4,xmm6,1 + vinsertf128 ymm2,ymm2,xmm3,1 + vmovdqu YMMWORD PTR [rcx],ymm4 ; store interleaved rows + vmovdqu YMMWORD PTR [rcx+32],ymm2 + vpmaddubsw ymm4,ymm8,ymm4 ; horizontal byte+byte=word per row + vpaddw ymm0,ymm0,ymm4 ; add words to row accumulators + vpmaddubsw ymm2,ymm8,ymm2 + vpaddw ymm1,ymm1,ymm2 + add rcx,64 ; advance matrix D by 64 bytes + sub rbx,4 ; subtract rows remaining + jae ProcessNextRowLoopN16 + +; +; Process the less than 4 remaining rows where the row has 16 columns. +; + +ProcessRemainingRowsN16: + add rbx,4 ; correct for over-subtract above + jz ReduceColumnSumVectorN16 + vmovdqu xmm2,XMMWORD PTR [rdx] + vpxor xmm3,xmm3,xmm3 + vpxor xmm4,xmm4,xmm4 + vpxor xmm5,xmm5,xmm5 + xor ebx,ebx ; no more rows remaining + test r10b,2 ; (CountK & 2) != 0? + jz InterleaveRowDataN16 + vmovdqu xmm3,XMMWORD PTR [rdx+r8] + test r10b,1 ; (CountK & 1) != 0? + jz InterleaveRowDataN16 + vmovdqu xmm4,XMMWORD PTR [rdx+r8*2] + jmp InterleaveRowDataN16 + +ReduceColumnSumVectorN16: + vpmaddwd ymm0,ymm0,ymm7 ; multiply by offset and reduce + vpmaddwd ymm1,ymm1,ymm7 ; multiply by offset and reduce + vmovdqu YMMWORD PTR [r11],ymm0 + vmovdqu YMMWORD PTR [r11+32],ymm1 + add r11,16*4 ; advance column sum vector by 16 DWORDs + sub r9,16 ; subtract columns remaining + jae ProcessNextColumnN16 + +ProcessRemainingColumns: + add r9,16 ; correct for over-subtract above + jnz ProcessColumnNUnaligned + +; +; Restore non-volatile registers and return. +; + +ExitRoutine: + vzeroupper + movaps xmm6,GemmU8S8CopyPackBFrame.SavedXmm6[rsp] + movaps xmm7,GemmU8S8CopyPackBFrame.SavedXmm7[rsp] + movaps xmm8,GemmU8S8CopyPackBFrame.SavedXmm8[rsp] + add rsp,(GemmU8S8CopyPackBFrame.SavedRdi) + + BEGIN_EPILOGUE + + pop rdi + pop rsi + pop rbx + pop rbp + ret + +; +; Process the remaining columns of matrix B. +; + +ProcessColumnNUnaligned: + vpxor xmm0,xmm0,xmm0 ; clear column accumulators + vpxor xmm1,xmm1,xmm1 + vmovdqu YMMWORD PTR GemmU8S8CopyPackBFrame.PaddedMatrixBData[rsp],ymm0 + vmovdqu YMMWORD PTR GemmU8S8CopyPackBFrame.PaddedMatrixBData[rsp+32],ymm0 + sub r10,4 + jb ProcessRemainingRowsNUnaligned + +ProcessNextRowLoopNUnaligned: + mov rdx,rsi +.errnz GemmU8S8CopyPackBFrame.PaddedMatrixBData + mov rbp,rsp ; GemmU8S8CopyPackBFrame.PaddedMatrixBData + test r9b,8 ; (CountN & 8) != 0? + jz CopyRemainingCountNLessThan8K4 + lea rdi,[rdx+r8*2] ; compute matrix B plus 2 rows + mov rax,QWORD PTR [rdx] + mov QWORD PTR [rbp],rax + mov rax,QWORD PTR [rdx+r8] + mov QWORD PTR [rbp+16],rax + mov rax,QWORD PTR [rdi] + mov QWORD PTR [rbp+32],rax + mov rax,QWORD PTR [rdi+r8] + mov QWORD PTR [rbp+48],rax + add rdx,8 ; advance matrix B + add rbp,8 ; advance padded buffer destination + +CopyRemainingCountNLessThan8K4: + test r9b,4 ; (CountN & 4) != 0? + jz CopyRemainingCountNLessThan4K4 + lea rdi,[rdx+r8*2] ; compute matrix B plus 2 rows + mov eax,DWORD PTR [rdx] + mov DWORD PTR [rbp],eax + mov eax,DWORD PTR [rdx+r8] + mov DWORD PTR [rbp+16],eax + mov eax,DWORD PTR [rdi] + mov DWORD PTR [rbp+32],eax + mov eax,DWORD PTR [rdi+r8] + mov DWORD PTR [rbp+48],eax + add rdx,4 ; advance matrix B + add rbp,4 ; advance padded buffer destination + +CopyRemainingCountNLessThan4K4: + test r9b,2 ; (CountN & 2) != 0? + jz CopyRemainingCountNLessThan2K4 + lea rdi,[rdx+r8*2] ; compute matrix B plus 2 rows + movzx eax,WORD PTR [rdx] + mov WORD PTR [rbp],ax + movzx eax,WORD PTR [rdx+r8] + mov WORD PTR [rbp+16],ax + movzx eax,WORD PTR [rdi] + mov WORD PTR [rbp+32],ax + movzx eax,WORD PTR [rdi+r8] + mov WORD PTR [rbp+48],ax + add rdx,2 ; advance matrix B + add rbp,2 ; advance padded buffer destination + +CopyRemainingCountNLessThan2K4: + test r9b,1 ; (CountN & 1) != 0? + jz ProcessPaddedMatrixBData + lea rdi,[rdx+r8*2] ; compute matrix B plus 2 rows + movzx eax,BYTE PTR [rdx] + mov BYTE PTR [rbp],al + movzx eax,BYTE PTR [rdx+r8] + mov BYTE PTR [rbp+16],al + movzx eax,BYTE PTR [rdi] + mov BYTE PTR [rbp+32],al + movzx eax,BYTE PTR [rdi+r8] + mov BYTE PTR [rbp+48],al + +ProcessPaddedMatrixBData: + vmovdqu xmm2,XMMWORD PTR GemmU8S8CopyPackBFrame.PaddedMatrixBData[rsp] + vmovdqu xmm3,XMMWORD PTR GemmU8S8CopyPackBFrame.PaddedMatrixBData[rsp+16] + vmovdqu xmm4,XMMWORD PTR GemmU8S8CopyPackBFrame.PaddedMatrixBData[rsp+32] + vmovdqu xmm5,XMMWORD PTR GemmU8S8CopyPackBFrame.PaddedMatrixBData[rsp+48] + vpunpcklbw xmm6,xmm2,xmm3 ; interleave row data + vpunpckhbw xmm3,xmm2,xmm3 + vpunpcklbw xmm2,xmm4,xmm5 + vpunpckhbw xmm5,xmm4,xmm5 + vpunpcklwd xmm4,xmm6,xmm2 + vpunpckhwd xmm6,xmm6,xmm2 + vpunpcklwd xmm2,xmm3,xmm5 + vpunpckhwd xmm3,xmm3,xmm5 + vinsertf128 ymm4,ymm4,xmm6,1 + vinsertf128 ymm2,ymm2,xmm3,1 + vmovdqu YMMWORD PTR [rcx],ymm4 ; store interleaved rows + vmovdqu YMMWORD PTR [rcx+32],ymm2 + vpmaddubsw ymm4,ymm8,ymm4 ; horizontal byte+byte=word per row + vpaddw ymm0,ymm0,ymm4 ; add words to row accumulators + vpmaddubsw ymm2,ymm8,ymm2 + vpaddw ymm1,ymm1,ymm2 + lea rsi,[rsi+r8*4] ; advance next matrix B by 4 rows + add rcx,64 ; advance matrix D by 64 bytes + sub r10,4 ; subtract rows remaining + jae ProcessNextRowLoopNUnaligned + +ProcessRemainingRowsNUnaligned: + add r10,4 + jz ReduceColumnSumVectorNUnaligned + +; +; Process the less than 4 remaining rows where the row has less than 16 columns. +; + +.errnz GemmU8S8CopyPackBFrame.PaddedMatrixBData + mov rbp,rsp ; GemmU8S8CopyPackBFrame.PaddedMatrixBData + vpxor xmm6,xmm6,xmm6 + vmovdqu YMMWORD PTR [rbp],ymm6 + vmovdqu YMMWORD PTR [rbp+32],ymm6 + +CopyUnalignedRowLoop: + lea rdi,[rbp+16] ; advance next padded buffer by 16 bytes + mov rdx,rsi + test r9b,8 ; (CountN & 8) != 0? + jz CopyRemainingCountNLessThan8KSmall + mov rax,QWORD PTR [rdx] + mov QWORD PTR [rbp],rax + add rdx,8 ; advance matrix B + add rbp,8 ; advance padded buffer destination + +CopyRemainingCountNLessThan8KSmall: + test r9b,4 ; (CountN & 4) != 0? + jz CopyRemainingCountNLessThan4KSmall + mov eax,DWORD PTR [rdx] + mov DWORD PTR [rbp],eax + add rdx,4 ; advance matrix B + add rbp,4 ; advance padded buffer destination + +CopyRemainingCountNLessThan4KSmall: + test r9b,2 ; (CountN & 2) != 0? + jz CopyRemainingCountNLessThan2KSmall + movzx eax,WORD PTR [rdx] + mov WORD PTR [rbp],ax + add rdx,2 ; advance matrix B + add rbp,2 ; advance padded buffer destination + +CopyRemainingCountNLessThan2KSmall: + test r9b,1 ; (CountN & 1) != 0? + jz DoneCopyRemainingCountNKSmall + movzx eax,BYTE PTR [rdx] + mov BYTE PTR [rbp],al + +DoneCopyRemainingCountNKSmall: + dec r10 + jz ProcessPaddedMatrixBData + add rsi,r8 ; advance next matrix B by 1 row + mov rbp,rdi + jmp CopyUnalignedRowLoop + +ReduceColumnSumVectorNUnaligned: + vpmaddwd ymm0,ymm0,ymm7 ; multiply by offset and reduce + vpmaddwd ymm1,ymm1,ymm7 ; multiply by offset and reduce + vmovdqu YMMWORD PTR [r11],ymm0 + vmovdqu YMMWORD PTR [r11+32],ymm1 + jmp ExitRoutine + + NESTED_END MlasGemmU8S8CopyPackBAvx2, _TEXT + +; +; Macro Description: +; +; This macro generates code to multiply and accumulator a single row of the +; output block. +; +; Arguments: +; +; ColumnCount - Supplies the number of columns to produce. +; +; Vec1Reg - Supplies the high block accumulator register (when ColumnCount +; is 16). +; +; Vec2Reg - Supplies the low block accumulator register. +; +; Implicit Arguments: +; +; ymm0 - Supplies the first vector loaded from matrix B. +; +; ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount +; is 16). +; +; ymm2 - Supplies the broadcast value loaded from matrix A. +; +; ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001. +; + +MultiplyAccumulateRow MACRO ColumnCount, Vec1Reg, Vec2Reg + + vpmaddubsw ymm3,ymm2,ymm0 + vpmaddwd ymm3,ymm3,ymm12 +IF ColumnCount EQ 16 + vpaddd Vec1Reg,Vec1Reg,ymm3 + vpmaddubsw ymm2,ymm2,ymm1 + vpmaddwd ymm2,ymm2,ymm12 + vpaddd Vec2Reg,Vec2Reg,ymm2 +ELSE + vpaddd Vec2Reg,Vec2Reg,ymm3 +ENDIF + + ENDM + +; +; Macro Description: +; +; This macro generates code to multiply and accumulate each row of the output +; block. +; +; Arguments: +; +; ColumnCount - Supplies the number of columns to produce. +; +; RowCount - Supplies the number of rows to produce. +; +; VectorOffset - Supplies the byte offset from matrix B to fetch elements. +; +; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. +; +; Implicit Arguments: +; +; rbx - Supplies the address into the matrix A data plus 3 rows. +; +; rcx - Supplies the address into the matrix A data. +; +; rdx - Supplies the address into the matrix B data. +; +; r9 - Supplies the length in bytes of a row from matrix A. +; +; ymm4-ymm11 - Supplies the block accumulators. +; +; ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001. +; + +ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset + +IF RowCount EQ 1 + vpbroadcastd ymm2,DWORD PTR [rcx+BroadcastOffset] + vpmaddubsw ymm3,ymm2,YMMWORD PTR [rdx+VectorOffset] + vpmaddwd ymm3,ymm3,ymm12 +IF ColumnCount EQ 16 + vpaddd ymm4,ymm4,ymm3 + vpmaddubsw ymm2,ymm2,YMMWORD PTR [rdx+VectorOffset+32] + vpmaddwd ymm2,ymm2,ymm12 + vpaddd ymm5,ymm5,ymm2 +ELSE + vpaddd ymm5,ymm5,ymm3 +ENDIF +ELSE + vmovdqu ymm0,YMMWORD PTR [rdx+VectorOffset] + EmitIfCountGE ColumnCount, 16, + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 4, +ENDIF + + ENDM + +; +; Macro Description: +; +; This macro generates code to execute the block compute macro multiple +; times and advancing the matrix A and matrix B data pointers. +; +; Arguments: +; +; ColumnCount - Supplies the number of columns to produce. +; +; RowCount - Supplies the number of rows to produce. +; +; Implicit Arguments: +; +; rbx - Supplies the address into the matrix A data plus 3 rows. +; +; rcx - Supplies the address into the matrix A data. +; +; rdx - Supplies the address into the matrix B data. +; +; r9 - Supplies the length in bytes of a row from matrix A. +; +; ymm4-ymm11 - Supplies the block accumulators. +; + +ComputeBlockLoop MACRO ColumnCount, RowCount + + LOCAL ComputeBlockBy1Loop + + mov rsi,r9 ; reload row length remaining + +ComputeBlockBy1Loop: + ComputeBlock ColumnCount, RowCount, 0, 0 + add rcx,4 ; advance matrix A by 1 quad +IF RowCount GT 3 + add rbx,4 ; advance matrix A plus 3 rows by 1 quad +ENDIF + add rdx,64 ; advance matrix B + sub rsi,4 + jnz ComputeBlockBy1Loop + + ENDM + +;++ +; +; Routine Description: +; +; This routine is an inner kernel to compute matrix multiplication for a +; set of rows. +; +; Arguments: +; +; A (rcx) - Supplies the address of matrix A. The matrix data has been packed +; using MlasGemmU8S8CopyPackAAvx2. +; +; B (rdx) - Supplies the address of matrix B. The matrix data has been packed +; using MlasGemmU8S8CopyPackBAvx2. +; +; C (r8) - Supplies the address of matrix C. +; +; QuadCountK (r9) - Supplies the number of quad columns from matrix A and the +; number of quad rows from matrix B to iterate over. +; +; CountM - Supplies the maximum number of rows that can be processed for +; matrix A and matrix C. The actual number of rows handled for this +; invocation depends on the kernel implementation. +; +; CountN - Supplies the number of columns from matrix B and matrix C to iterate +; over. +; +; ldc - Supplies the first dimension of matrix C. +; +; RowSumVector - Supplies the sum of each row from matrix A multiplied by the +; zero point offset of matrix B. These values are accumulated into every +; row of matrix C. +; +; ColumnSumVector - Supplies the sum of each column from matrix B multiplied +; by the zero point offset of matrix A. These values are accumulated into +; every column of matrix C. +; +; DepthValue - Supplies the value CountK multiplied by the zero point offset +; of matrix A multplied by the zero point offset of matrix B. This value is +; accumulated into every element of matrix C. +; +; ZeroMode - Supplies true if the output matrix must be zero initialized, +; else false if the output matrix is accumulated into. +; +; Return Value: +; +; Returns the number of rows handled. +; +;-- + + NESTED_ENTRY MlasGemmU8S8KernelAvx2, _TEXT + + rex_push_reg rbp + push_reg rbx + push_reg rsi + push_reg rdi + push_reg r12 + push_reg r13 + alloc_stack (GemmU8X8KernelFrame.SavedR13) + save_xmm128 xmm6,GemmU8X8KernelFrame.SavedXmm6 + save_xmm128 xmm7,GemmU8X8KernelFrame.SavedXmm7 + save_xmm128 xmm8,GemmU8X8KernelFrame.SavedXmm8 + save_xmm128 xmm9,GemmU8X8KernelFrame.SavedXmm9 + save_xmm128 xmm10,GemmU8X8KernelFrame.SavedXmm10 + save_xmm128 xmm11,GemmU8X8KernelFrame.SavedXmm11 + save_xmm128 xmm12,GemmU8X8KernelFrame.SavedXmm12 + + END_PROLOGUE + + mov rdi,rcx + mov rbp,GemmU8X8KernelFrame.CountN[rsp] + mov rax,GemmU8X8KernelFrame.ldc[rsp] + shl rax,2 ; convert ldc to bytes + shl r9,2 ; convert to row length + movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp] + mov r11,GemmU8X8KernelFrame.CountM[rsp] + mov r12,GemmU8X8KernelFrame.RowSumVector[rsp] + mov r13,GemmU8X8KernelFrame.ColumnSumVector[rsp] + vpcmpeqw ymm12,ymm12,ymm12 ; generate 256-bit word vector [0xFFFF] + vpsrlw ymm12,ymm12,15 ; generate 256-bit word vector [0x0001] + +; +; Process CountM rows of the matrices. +; + + cmp r11,3 + ja ProcessCountM4 + je ProcessCountM3 + cmp r11,1 + je ProcessCountM1 + +ProcessCountM2: + ProcessCountM 2 + +ProcessCountM4: + mov r11d,4 ; return 4 rows handled + ProcessCountM 4, Fallthrough + +; +; Restore non-volatile registers and return. +; + +ExitKernel: + mov eax,r11d + vzeroupper + movaps xmm6,GemmU8X8KernelFrame.SavedXmm6[rsp] + movaps xmm7,GemmU8X8KernelFrame.SavedXmm7[rsp] + movaps xmm8,GemmU8X8KernelFrame.SavedXmm8[rsp] + movaps xmm9,GemmU8X8KernelFrame.SavedXmm9[rsp] + movaps xmm10,GemmU8X8KernelFrame.SavedXmm10[rsp] + movaps xmm11,GemmU8X8KernelFrame.SavedXmm11[rsp] + movaps xmm12,GemmU8X8KernelFrame.SavedXmm12[rsp] + add rsp,(GemmU8X8KernelFrame.SavedR13) + + BEGIN_EPILOGUE + + pop r13 + pop r12 + pop rdi + pop rsi + pop rbx + pop rbp + ret + +ProcessCountM1: + ProcessCountM 1 + +ProcessCountM3: + ProcessCountM 3 + + NESTED_END MlasGemmU8S8KernelAvx2, _TEXT + + END diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx512BW.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx512BW.asm new file mode 100644 index 0000000000..cb3b819476 --- /dev/null +++ b/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx512BW.asm @@ -0,0 +1,130 @@ +;++ +; +; Copyright (c) Microsoft Corporation. All rights reserved. +; +; Licensed under the MIT License. +; +; Module Name: +; +; QgemmU8S8KernelAvx512BW.asm +; +; Abstract: +; +; This module implements the kernels for the quantized integer matrix/matrix +; multiply operation (QGEMM). +; +; This implementation uses AVX512BW instructions. +; +;-- + + .xlist +INCLUDE mlasi.inc +INCLUDE QgemmU8S8KernelAvx512Common.inc + .list + +; +; Macro Description: +; +; This macro generates code to multiply and accumulator a single cell of the +; output block. +; +; Arguments: +; +; AccumReg - Supplies the register to accumulate into. +; +; Mult1Reg - Supplies the first multiplication operand register. +; +; Mult2Reg - Supplies the second multiplication operand register. +; +; Implicit Arguments: +; +; zmm4 - Supplies a scratch register for intermediate results. +; +; zmm5 - Supplies a 512-bit with the broadcasted word value 0x0001. +; + +MultiplyAccumulateCell MACRO AccumReg, Mult1Reg, Mult2Reg + + vpmaddubsw zmm4,Mult1Reg,Mult2Reg + vpmaddwd zmm4,zmm4,zmm5 + vpaddd AccumReg,AccumReg,zmm4 + + ENDM + +; +; Macro Description: +; +; This macro generates code to multiply and accumulate each row of the output +; block. +; +; Arguments: +; +; ColumnCount - Supplies the number of columns to produce. +; +; RowCount - Supplies the number of rows to produce. +; +; VectorOffset - Supplies the byte offset from matrix B to fetch elements. +; +; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. +; +; Implicit Arguments: +; +; rbx - Supplies the address into the matrix A data plus 3 rows. +; +; rcx - Supplies the address into the matrix A data. +; +; rdx - Supplies the address into the matrix B data. +; +; r9 - Supplies the length in bytes of a row from matrix A. +; +; r14 - Supplies the stride in bytes of between packed blocks of matrix B. +; +; zmm14-zmm31 - Supplies the block accumulators. +; + +ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset + +IF ColumnCount GE 48 + vmovdqu32 zmm0,ZMMWORD PTR [rdx+VectorOffset] + vmovdqu32 zmm1,ZMMWORD PTR [rdx+r14+VectorOffset] + vmovdqu32 zmm2,ZMMWORD PTR [rdx+r14*2+VectorOffset] +ELSEIF ColumnCount GE 32 + vmovdqu32 zmm1,ZMMWORD PTR [rdx+VectorOffset] + vmovdqu32 zmm2,ZMMWORD PTR [rdx+r14+VectorOffset] +ELSE + vmovdqu32 zmm2,ZMMWORD PTR [rdx+VectorOffset] +ENDIF + EmitIfCountGE RowCount, 1, + EmitIfCount2GE RowCount, 1, ColumnCount, 48, + EmitIfCount2GE RowCount, 1, ColumnCount, 32, + EmitIfCount2GE RowCount, 1, ColumnCount, 16, + EmitIfCountGE RowCount, 2, + EmitIfCount2GE RowCount, 2, ColumnCount, 48, + EmitIfCount2GE RowCount, 2, ColumnCount, 32, + EmitIfCount2GE RowCount, 2, ColumnCount, 16, + EmitIfCountGE RowCount, 3, + EmitIfCount2GE RowCount, 3, ColumnCount, 48, + EmitIfCount2GE RowCount, 3, ColumnCount, 32, + EmitIfCount2GE RowCount, 3, ColumnCount, 16, + EmitIfCountGE RowCount, 4, + EmitIfCount2GE RowCount, 4, ColumnCount, 48, + EmitIfCount2GE RowCount, 4, ColumnCount, 32, + EmitIfCount2GE RowCount, 4, ColumnCount, 16, + EmitIfCountGE RowCount, 5, + EmitIfCount2GE RowCount, 5, ColumnCount, 48, + EmitIfCount2GE RowCount, 5, ColumnCount, 32, + EmitIfCount2GE RowCount, 5, ColumnCount, 16, + EmitIfCountGE RowCount, 6, + EmitIfCount2GE RowCount, 6, ColumnCount, 48, + EmitIfCount2GE RowCount, 6, ColumnCount, 32, + EmitIfCount2GE RowCount, 6, ColumnCount, 16, + + ENDM + +; +; Generate the GEMM kernel. +; + +GemmU8X8KernelAvx512Function U8S8, Avx512BW + + END diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Common.inc b/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Common.inc new file mode 100644 index 0000000000..7d1d4fb10f --- /dev/null +++ b/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Common.inc @@ -0,0 +1,91 @@ +;++ +; +; Copyright (c) Microsoft Corporation. All rights reserved. +; +; Licensed under the MIT License. +; +; Module Name: +; +; QgemmU8S8KernelAvx512Common.inc +; +; Abstract: +; +; This module contains common kernel macros and structures for the quantized +; integer matrix/matrix multiply operation (QGEMM) for the AVX512BW and +; AVX512VNNI kernels. +; +;-- + +INCLUDE QgemmU8X8KernelAvx512Common.inc + +; +; Macro Description: +; +; This macro generates code to execute the block compute macro multiple +; times and advancing the matrix A and matrix B data pointers. +; +; Arguments: +; +; ColumnCount - Supplies the number of columns to produce. +; +; RowCount - Supplies the number of rows to produce. +; +; Implicit Arguments: +; +; rbx - Supplies the address into the matrix A data plus 3 rows. +; +; rcx - Supplies the address into the matrix A data. +; +; rdx - Supplies the address into the matrix B data. +; +; r9 - Supplies the length in bytes of a row from matrix A. +; +; r14 - Supplies the stride in bytes of between packed blocks of matrix B. +; +; zmm14-zmm31 - Supplies the block accumulators. +; + +ComputeBlockLoop MACRO ColumnCount, RowCount + + LOCAL ComputeBlockBy4Loop + LOCAL ProcessRemainingBlocks + LOCAL ComputeBlockBy1Loop + LOCAL ComputeBlockLoopExit + + mov rsi,r9 ; reload row length remaining + +IF ((RowCount AND 1) EQ 0) + sub rsi,4*4 + jb ProcessRemainingBlocks + +ComputeBlockBy4Loop: + ComputeBlock ColumnCount, RowCount, 0*64, 0 + ComputeBlock ColumnCount, RowCount, 1*64, 4 + ComputeBlock ColumnCount, RowCount, 2*64, 8 + ComputeBlock ColumnCount, RowCount, 3*64, 12 + add rcx,4*4 ; advance matrix A by 1 quad +IF RowCount GT 3 + add rbx,4*4 ; advance matrix A plus 3 rows by 1 quad +ENDIF + add rdx,4*64 ; advance matrix B + sub rsi,4*4 ; decrement quads remaining + jae ComputeBlockBy4Loop + +ProcessRemainingBlocks: + add rsi,4*4 ; correct for over-subtract above + jz ComputeBlockLoopExit +ENDIF + +ComputeBlockBy1Loop: + ComputeBlock ColumnCount, RowCount, 0, 0 + add rcx,4 ; advance matrix A by 1 quad +IF RowCount GT 3 + add rbx,4 ; advance matrix A plus 3 rows by 1 quad +ENDIF + add rdx,64 ; advance matrix B + sub rsi,4 ; decrement quads remaining + jnz ComputeBlockBy1Loop + +ComputeBlockLoopExit: + + ENDM diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Vnni.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Vnni.asm new file mode 100644 index 0000000000..ddab81a6a4 --- /dev/null +++ b/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Vnni.asm @@ -0,0 +1,102 @@ +;++ +; +; Copyright (c) Microsoft Corporation. All rights reserved. +; +; Licensed under the MIT License. +; +; Module Name: +; +; QgemmU8S8KernelAvx512Vnni.asm +; +; Abstract: +; +; This module implements the kernels for the quantized integer matrix/matrix +; multiply operation (QGEMM). +; +; This implementation uses AVX512VNNI instructions. +; +;-- + + .xlist +INCLUDE mlasi.inc +INCLUDE QgemmU8S8KernelAvx512Common.inc +INCLUDE AssembleAvx512Vnni.inc + .list + +; +; Macro Description: +; +; This macro generates code to multiply and accumulate each row of the output +; block. +; +; Arguments: +; +; ColumnCount - Supplies the number of columns to produce. +; +; RowCount - Supplies the number of rows to produce. +; +; VectorOffset - Supplies the byte offset from matrix B to fetch elements. +; +; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. +; +; Implicit Arguments: +; +; rbx - Supplies the address into the matrix A data plus 3 rows. +; +; rcx - Supplies the address into the matrix A data. +; +; rdx - Supplies the address into the matrix B data. +; +; r9 - Supplies the length in bytes of a row from matrix A. +; +; r14 - Supplies the stride in bytes of between packed blocks of matrix B. +; +; zmm14-zmm31 - Supplies the block accumulators. +; + +ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset + +IF ColumnCount GE 48 + vmovdqu32 zmm0,ZMMWORD PTR [rdx+VectorOffset] + vmovdqu32 zmm1,ZMMWORD PTR [rdx+r14+VectorOffset] + vmovdqu32 zmm2,ZMMWORD PTR [rdx+r14*2+VectorOffset] +ELSEIF ColumnCount GE 32 + vmovdqu32 zmm1,ZMMWORD PTR [rdx+VectorOffset] + vmovdqu32 zmm2,ZMMWORD PTR [rdx+r14+VectorOffset] +ELSE + vmovdqu32 zmm2,ZMMWORD PTR [rdx+VectorOffset] +ENDIF + EmitIfCountGE RowCount, 1, + EmitIfCount2GE RowCount, 1, ColumnCount, 48, + EmitIfCount2GE RowCount, 1, ColumnCount, 32, + EmitIfCount2GE RowCount, 1, ColumnCount, 16, + EmitIfCountGE RowCount, 2, + EmitIfCount2GE RowCount, 2, ColumnCount, 48, + EmitIfCount2GE RowCount, 2, ColumnCount, 32, + EmitIfCount2GE RowCount, 2, ColumnCount, 16, + EmitIfCountGE RowCount, 3, + EmitIfCount2GE RowCount, 3, ColumnCount, 48, + EmitIfCount2GE RowCount, 3, ColumnCount, 32, + EmitIfCount2GE RowCount, 3, ColumnCount, 16, + EmitIfCountGE RowCount, 4, + EmitIfCount2GE RowCount, 4, ColumnCount, 48, + EmitIfCount2GE RowCount, 4, ColumnCount, 32, + EmitIfCount2GE RowCount, 4, ColumnCount, 16, + EmitIfCountGE RowCount, 5, + EmitIfCount2GE RowCount, 5, ColumnCount, 48, + EmitIfCount2GE RowCount, 5, ColumnCount, 32, + EmitIfCount2GE RowCount, 5, ColumnCount, 16, + EmitIfCountGE RowCount, 6, + EmitIfCount2GE RowCount, 6, ColumnCount, 48, + EmitIfCount2GE RowCount, 6, ColumnCount, 32, + EmitIfCount2GE RowCount, 6, ColumnCount, 16, + + ENDM + +; +; Generate the GEMM kernel. +; + +GemmU8X8KernelAvx512Function U8S8, Avx512Vnni + + END diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx2.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx2.asm index 60d58e6206..904ff1b60a 100644 --- a/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx2.asm +++ b/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx2.asm @@ -19,10 +19,9 @@ .xlist INCLUDE mlasi.inc +INCLUDE QgemmU8X8KernelAvx2Common.inc .list - EXTERN MlasMaskMoveAvx:NEAR - ; ; Stack frame layout for the U8U8 CopyPackA routine. ; @@ -73,44 +72,6 @@ GemmU8U8CopyPackBFrame STRUCT GemmU8U8CopyPackBFrame ENDS -; -; Stack frame layout for the U8U8 kernel. -; - -GemmU8U8KernelFrame STRUCT - - SavedXmm6 OWORD ? - SavedXmm7 OWORD ? - SavedXmm8 OWORD ? - SavedXmm9 OWORD ? - SavedXmm10 OWORD ? - SavedXmm11 OWORD ? - SavedXmm12 OWORD ? - SavedXmm13 OWORD ? - SavedXmm14 OWORD ? - SavedXmm15 OWORD ? - SavedR14 QWORD ? - SavedR13 QWORD ? - SavedR12 QWORD ? - SavedRdi QWORD ? - SavedRsi QWORD ? - SavedRbx QWORD ? - SavedRbp QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? - PreviousP2Home QWORD ? - PreviousP3Home QWORD ? - PreviousP4Home QWORD ? - CountM QWORD ? - CountN QWORD ? - ldc QWORD ? - RowSumVector QWORD ? - ColumnSumVector QWORD ? - DepthValue QWORD ? - ZeroMode QWORD ? - -GemmU8U8KernelFrame ENDS - ;++ ; ; Routine Description: @@ -157,10 +118,10 @@ GemmU8U8KernelFrame ENDS push_reg r12 push_reg r13 alloc_stack (GemmU8U8CopyPackAFrame.SavedR13) - save_xmm128_avx xmm6,GemmU8U8CopyPackAFrame.SavedXmm6 - save_xmm128_avx xmm7,GemmU8U8CopyPackAFrame.SavedXmm7 - save_xmm128_avx xmm8,GemmU8U8CopyPackAFrame.SavedXmm8 - save_xmm128_avx xmm9,GemmU8U8CopyPackAFrame.SavedXmm9 + save_xmm128 xmm6,GemmU8U8CopyPackAFrame.SavedXmm6 + save_xmm128 xmm7,GemmU8U8CopyPackAFrame.SavedXmm7 + save_xmm128 xmm8,GemmU8U8CopyPackAFrame.SavedXmm8 + save_xmm128 xmm9,GemmU8U8CopyPackAFrame.SavedXmm9 END_PROLOGUE @@ -195,13 +156,15 @@ GemmU8U8KernelFrame ENDS ; ; Process 4 rows of matrix A in a loop. ; -; For each row, zero extend the source bytes to 16-bits and write to the packed -; buffer. The packed buffer has the same data ordering as the source bytes, but -; the stride is CountK aligned up to an even number of 16-bit values. +; Zero extend the source bytes to 16-bits and write to the packed buffer. +; +; The packed buffer has the same data ordering as the source bytes, but CountK +; is aligned up to a multiple of 2 to maintain 32-bit alignment. All padding +; bytes are zero filled. ; ; These 16-bit values are also accumulated into an intermediate per-row -; accumulator. CountK cannot be greater than 256 to avoid overflowing these -; 16-bit accumulators. +; accumulator. CountK cannot be greater than 128 to avoid overflowing these +; signed 16-bit accumulators. ; sub r9,4 @@ -221,12 +184,12 @@ ProcessNextRowM4: jb ProcessRemainingColumnsM4 ProcessNextColumnLoopM4: - lea rax,[rdx+r8*2] ; compute matrix A plus two rows + lea rax,[rdx+r8*2] ; compute matrix A plus 2 rows vpmovzxbw ymm4,XMMWORD PTR [rdx] vpmovzxbw ymm5,XMMWORD PTR [rdx+r8] vpmovzxbw ymm6,XMMWORD PTR [rax] vpmovzxbw ymm7,XMMWORD PTR [rax+r8] - lea rax,[rcx+r11*4] ; compute matrix D plus two rows + lea rax,[rcx+r11*4] ; compute matrix D plus 2 rows vmovdqu YMMWORD PTR [rcx],ymm4 vmovdqu YMMWORD PTR [rcx+r11*2],ymm5 vmovdqu YMMWORD PTR [rax],ymm6 @@ -252,7 +215,7 @@ ProcessRemainingColumnsM4: mov rbp,rsp ; GemmU8U8CopyPackAFrame.PaddedMatrixAData test bl,8 ; (CountK & 8) != 0? jz CopyRemainingCountKLessThan8M4 - lea r13,[rdx+r8*2] ; compute matrix A plus two rows + lea r13,[rdx+r8*2] ; compute matrix A plus 2 rows mov rax,QWORD PTR [rdx] mov QWORD PTR [rbp],rax mov rax,QWORD PTR [rdx+r8] @@ -267,7 +230,7 @@ ProcessRemainingColumnsM4: CopyRemainingCountKLessThan8M4: test bl,4 ; (CountK & 4) != 0? jz CopyRemainingCountKLessThan4M4 - lea r13,[rdx+r8*2] ; compute matrix A plus two rows + lea r13,[rdx+r8*2] ; compute matrix A plus 2 rows mov eax,DWORD PTR [rdx] mov DWORD PTR [rbp],eax mov eax,DWORD PTR [rdx+r8] @@ -282,7 +245,7 @@ CopyRemainingCountKLessThan8M4: CopyRemainingCountKLessThan4M4: test bl,2 ; (CountK & 2) != 0? jz CopyRemainingCountKLessThan2M4 - lea r13,[rdx+r8*2] ; compute matrix A plus two rows + lea r13,[rdx+r8*2] ; compute matrix A plus 2 rows movzx eax,WORD PTR [rdx] mov WORD PTR [rbp],ax movzx eax,WORD PTR [rdx+r8] @@ -297,7 +260,7 @@ CopyRemainingCountKLessThan4M4: CopyRemainingCountKLessThan2M4: test bl,1 ; (CountK & 1) != 0? jz ProcessPaddedMatrixADataM4 - lea r13,[rdx+r8*2] ; compute matrix A plus two rows + lea r13,[rdx+r8*2] ; compute matrix A plus 2 rows movzx eax,BYTE PTR [rdx] mov BYTE PTR [rbp],al movzx eax,BYTE PTR [rdx+r8] @@ -316,7 +279,7 @@ ProcessPaddedMatrixADataM4: vpmovzxbw ymm5,XMMWORD PTR GemmU8U8CopyPackAFrame.PaddedMatrixAData[rsp+16] vpmovzxbw ymm6,XMMWORD PTR GemmU8U8CopyPackAFrame.PaddedMatrixAData[rsp+32] vpmovzxbw ymm7,XMMWORD PTR GemmU8U8CopyPackAFrame.PaddedMatrixAData[rsp+48] - lea rax,[rcx+r11*4] ; compute matrix D plus two rows + lea rax,[rcx+r11*4] ; compute matrix D plus 2 rows vpmaskmovd YMMWORD PTR [rcx],ymm9,ymm4 vpmaskmovd YMMWORD PTR [rcx+r11*2],ymm9,ymm5 vpmaskmovd YMMWORD PTR [rax],ymm9,ymm6 @@ -334,22 +297,14 @@ ProcessPaddedMatrixADataM4: ; ReduceRowSumVectorM4: - vpunpckldq ymm4,ymm0,ymm1 ; [A5 B5 A4 B4 A1 B1 A0 B0] - vpunpckhdq ymm5,ymm0,ymm1 ; [A7 B7 A6 B6 A3 B3 A2 B2] - vpunpckldq ymm6,ymm2,ymm3 ; [C5 D5 C4 D4 C1 D1 C0 D0] - vpunpckhdq ymm7,ymm2,ymm3 ; [C7 D7 C6 D6 C3 D3 C2 D2] - vpunpcklqdq ymm0,ymm4,ymm6 ; [A4 B4 C4 D4 A0 B0 C0 D0] - vpunpckhqdq ymm1,ymm4,ymm6 ; [A5 B5 C5 D5 A1 B1 C1 D1] - vpunpcklqdq ymm2,ymm5,ymm7 ; [A6 B6 C6 D6 A2 B2 C2 D2] - vpunpckhqdq ymm3,ymm5,ymm7 ; [A7 B7 C7 D7 A3 B3 C3 D3] - vpaddw ymm0,ymm0,ymm1 ; reduction - vpaddw ymm0,ymm0,ymm2 - vpaddw ymm0,ymm0,ymm3 + vphaddw ymm0,ymm0,ymm1 ; reduce and interleave Sum1/Sum0 + vphaddw ymm1,ymm2,ymm3 ; reduce and interleave Sum3/Sum2 + vphaddw ymm0,ymm0,ymm1 ; reduce and interleave Sum3/Sum2/Sum1/Sum0 vextracti128 xmm1,ymm0,1 ; extract high pairs - vpaddw xmm0,xmm0,xmm1 ; reduction - vpmaddwd xmm0,xmm0,xmm8 ; multiply by offset and reduce + vpaddw xmm0,xmm0,xmm1 ; reduce low/high pairs + vpmaddwd xmm0,xmm0,xmm8 ; multiply by offset and reduce 32-bit sum vmovdqu XMMWORD PTR [r12],xmm0 - add r12,4*4 ; advance row sum vector by 4 dwords + add r12,4*4 ; advance row sum vector by 4 DWORDs sub r9,4 ; subtract rows remaining jae ProcessNextRowM4 @@ -449,10 +404,10 @@ ReduceRowSumVectorM1: ExitRoutine: vzeroupper - vmovaps xmm6,GemmU8U8CopyPackAFrame.SavedXmm6[rsp] - vmovaps xmm7,GemmU8U8CopyPackAFrame.SavedXmm7[rsp] - vmovaps xmm8,GemmU8U8CopyPackAFrame.SavedXmm8[rsp] - vmovaps xmm9,GemmU8U8CopyPackAFrame.SavedXmm9[rsp] + movaps xmm6,GemmU8U8CopyPackAFrame.SavedXmm6[rsp] + movaps xmm7,GemmU8U8CopyPackAFrame.SavedXmm7[rsp] + movaps xmm8,GemmU8U8CopyPackAFrame.SavedXmm8[rsp] + movaps xmm9,GemmU8U8CopyPackAFrame.SavedXmm9[rsp] add rsp,(GemmU8U8CopyPackAFrame.SavedR13) BEGIN_EPILOGUE @@ -537,9 +492,9 @@ ProcessNextColumnN16: jb ProcessRemainingRowsN16 ProcessNextRowLoopN16: - vmovdqu xmm2,XMMWORD PTR [rdx] ; load two rows + vmovdqu xmm2,XMMWORD PTR [rdx] ; load 2 rows vmovdqu xmm3,XMMWORD PTR [rdx+r8] - lea rdx,[rdx+r8*2] ; advance matrix B by two rows + lea rdx,[rdx+r8*2] ; advance matrix B by 2 rows vpunpcklbw xmm4,xmm2,xmm3 ; interleave row data vpunpckhbw xmm3,xmm2,xmm3 vmovdqu XMMWORD PTR [rcx],xmm4 ; store interleaved rows @@ -549,7 +504,7 @@ ProcessNextRowLoopN16: add rcx,32 ; advance matrix D by 32 bytes vpaddw ymm0,ymm0,ymm4 ; accumulate per column vpaddw ymm1,ymm1,ymm3 - sub rbx,2 ; subtract columns remaining + sub rbx,2 ; subtract rows remaining jae ProcessNextRowLoopN16 ProcessRemainingRowsN16: @@ -569,7 +524,7 @@ ReduceColumnSumVectorN16: vpmaddwd ymm1,ymm1,ymm5 ; multiply by offset and reduce vmovdqu YMMWORD PTR [r11],ymm0 vmovdqu YMMWORD PTR [r11+32],ymm1 - add r11,64 ; advance column sum vector by 16 dwords + add r11,16*4 ; advance column sum vector by 16 DWORDs sub r9,16 ; subtract columns remaining jae ProcessNextColumnN16 @@ -654,7 +609,7 @@ ProcessPaddedMatrixBDataK2: vpmovzxbw ymm3,xmm3 vpaddw ymm0,ymm0,ymm4 ; accumulate per column vpaddw ymm1,ymm1,ymm3 - lea rsi,[rsi+r8*2] ; advance next matrix B by two rows + lea rsi,[rsi+r8*2] ; advance next matrix B by 2 rows add rcx,32 ; advance matrix D by 32 bytes sub r10,2 ; subtract columns remaining jae ProcessNextRowLoopNUnaligned @@ -739,13 +694,12 @@ ReduceColumnSumVectorNUnaligned: MultiplyAccumulateRow MACRO ColumnCount, Vec1Reg, Vec2Reg -IF ColumnCount EQ 16 vpmaddwd ymm3,ymm2,ymm0 +IF ColumnCount EQ 16 vpaddd Vec1Reg,Vec1Reg,ymm3 vpmaddwd ymm2,ymm2,ymm1 vpaddd Vec2Reg,Vec2Reg,ymm2 ELSE - vpmaddwd ymm3,ymm2,ymm0 vpaddd Vec2Reg,Vec2Reg,ymm3 ENDIF @@ -775,7 +729,7 @@ ENDIF ; ; rdx - Supplies the address into the matrix B data. ; -; r10 - Supplies the length in bytes of a row from matrix A. +; r9 - Supplies the length in bytes of a row from matrix A. ; ; ymm4-ymm15 - Supplies the block accumulators. ; @@ -786,15 +740,15 @@ ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset EmitIfCountGE ColumnCount, 16, EmitIfCountGE RowCount, 1, EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 2, EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 3, EmitIfCountGE RowCount, 3, EmitIfCountGE RowCount, 4, EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 5, EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, + EmitIfCountGE RowCount, 6, EmitIfCountGE RowCount, 6, ENDM @@ -802,8 +756,8 @@ ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset ; ; Macro Description: ; -; This macro generates code to produce an output block for a set of columns -; and rows. +; This macro generates code to execute the block compute macro multiple +; times and advancing the matrix A and matrix B data pointers. ; ; Arguments: ; @@ -813,281 +767,59 @@ ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset ; ; Implicit Arguments: ; -; rax - Supplies the length in bytes of a row from matrix C. +; rbx - Supplies the address into the matrix A data plus 3 rows. ; ; rcx - Supplies the address into the matrix A data. ; ; rdx - Supplies the address into the matrix B data. ; -; r9 - Supplies the number of paired columns from matrix A and the number of -; paired rows from matrix B to iterate over. +; r9 - Supplies the length in bytes of a row from matrix A. ; -; r10 - Supplies the length in bytes of a row from matrix A. -; -; r12 - Supplies the address of the row sum vector. -; -; r13 - Supplies the address of the column sum vector. +; ymm4-ymm15 - Supplies the block accumulators. ; -ProduceOutputBlock MACRO ColumnCount, RowCount +ComputeBlockLoop MACRO ColumnCount, RowCount - LOCAL ComputeBlockLoop + LOCAL ComputeBlockBy2Loop LOCAL ProcessRemainingBlocks + LOCAL ComputeBlockBy1Loop LOCAL ComputeBlockLoopExit -; -; Initialize the accumulators with the sum of the global depth value constant, -; the column sums, and the row sums. -; - - vpbroadcastd ymm1,DWORD PTR GemmU8U8KernelFrame.DepthValue[rsp] -IF ColumnCount EQ 16 - vpaddd ymm0,ymm1,YMMWORD PTR [r13] - vpaddd ymm1,ymm1,YMMWORD PTR [r13+32] - add r13,16*4 ; advance ColumnSumVector by 16 columns -ELSE - vpaddd ymm1,ymm1,YMMWORD PTR [r13] -ENDIF - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCount2GE RowCount, 1, ColumnCount, 16, - EmitIfCountGE RowCount, 1, - EmitIfCount2GE RowCount, 2, ColumnCount, 16, - EmitIfCountGE RowCount, 2, - EmitIfCount2GE RowCount, 3, ColumnCount, 16, - EmitIfCountGE RowCount, 3, - EmitIfCount2GE RowCount, 4, ColumnCount, 16, - EmitIfCountGE RowCount, 4, - EmitIfCount2GE RowCount, 5, ColumnCount, 16, - EmitIfCountGE RowCount, 5, - EmitIfCount2GE RowCount, 6, ColumnCount, 16, - EmitIfCountGE RowCount, 6, - -; -; Iterate over PairedCountK elements from matrix A and matrix B. -; -; Unrolling the loop to do two iterations improves performance slightly at the -; cost of larger code size. Balance this by only unrolling for the common case -; of computing 16 columns for an even number of rows. -; - - mov rsi,r9 ; reload PairedCountK -IF RowCount GT 3 - lea rbx,[r10*2+r10] - add rbx,rcx ; compute matrix A plus 3 rows -ENDIF + mov rsi,r9 ; reload row length remaining IF (ColumnCount EQ 16) AND ((RowCount AND 1) EQ 0) - sub rsi,2 + sub rsi,2*4 jb ProcessRemainingBlocks -ComputeBlockLoop: +ComputeBlockBy2Loop: ComputeBlock ColumnCount, RowCount, 0, 0 ComputeBlock ColumnCount, RowCount, 32, 4 add rcx,2*4 ; advance matrix A by 2 pairs IF RowCount GT 3 add rbx,2*4 ; advance matrix A plus 3 rows by 2 pairs ENDIF - add rdx,2*32 ; advance matrix B by 64 columns - sub rsi,2 ; subtract pairs remaining - jae ComputeBlockLoop + add rdx,2*32 ; advance matrix B + sub rsi,2*4 + jae ComputeBlockBy2Loop ProcessRemainingBlocks: - add rsi,2 ; correct for over-subtract above + add rsi,2*4 ; correct for over-subtract above jz ComputeBlockLoopExit ComputeBlock ColumnCount, RowCount, 0, 0 - add rdx,32 ; advance matrix B by 32 columns + add rdx,32 ; advance matrix B ELSE -ComputeBlockLoop: +ComputeBlockBy1Loop: ComputeBlock ColumnCount, RowCount, 0, 0 add rcx,4 ; advance matrix A by 1 pair IF RowCount GT 3 add rbx,4 ; advance matrix A plus 3 rows by 1 pair ENDIF - add rdx,32 - dec rsi ; decrement pairs remaining - jnz ComputeBlockLoop + add rdx,32 ; advance matrix B + sub rsi,4 + jnz ComputeBlockBy1Loop ENDIF ComputeBlockLoopExit: -IF RowCount GT 3 - lea rbx,[r8+rax*2] ; compute matrix C plus 3 rows - add rbx,rax -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to compute matrix multiplication for a fixed set -; of rows. -; -; Arguments: -; -; RowCount - Supplies the number of rows to process. -; -; Fallthrough - Supplies a non-blank value if the macro may fall through to -; the ExitKernel label. -; -; Implicit Arguments: -; -; rax - Supplies the length in bytes of a row from matrix C. -; -; rcx - Supplies the address of matrix A. -; -; rdx - Supplies the address of matrix B. -; -; r8 - Supplies the address of matrix C. -; -; rdi - Supplies the address of matrix A. -; -; rbp - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; r9 - Supplies the number of paired columns from matrix A and the number of -; paired rows from matrix B to iterate over. -; -; r10 - Supplies the length in bytes of a row from matrix A. -; -; r12 - Supplies the address of the row sum vector. -; -; r13 - Supplies the address of the column sum vector. -; -; r14b - Supplies the zero mode flag. -; - -ProcessCountM MACRO RowCount, Fallthrough - - LOCAL ProcessNextColumnLoop16xN - LOCAL SkipAccumulateOutput16xNBlock - LOCAL OutputMasked16xNBlock - LOCAL ProcessRemainingCountN - LOCAL SkipAccumulateOutput8xNBlock - LOCAL SkipAccumulateOutputMasked16xNBlock - LOCAL OutputMasked8xNBlock - LOCAL SkipAccumulateOutputMasked8xNBlock - - cmp rbp,8 - jbe ProcessRemainingCountN - -ProcessNextColumnLoop16xN: - ProduceOutputBlock 16, RowCount - sub rbp,16 - jb OutputMasked16xNBlock - test r14b,r14b ; ZeroMode? - jnz SkipAccumulateOutput16xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 6, - -SkipAccumulateOutput16xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 6, - add r8,16*4 ; advance matrix C by 16 columns - mov rcx,rdi ; reload matrix A - cmp rbp,8 - ja ProcessNextColumnLoop16xN - test rbp,rbp - jz ExitKernel - -ProcessRemainingCountN: - ProduceOutputBlock 8, RowCount - cmp rbp,8 - jb OutputMasked8xNBlock - test r14b,r14b ; ZeroMode? - jnz SkipAccumulateOutput8xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - -SkipAccumulateOutput8xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - jmp ExitKernel - -OutputMasked16xNBlock: - test r14b,r14b ; ZeroMode? - jnz SkipAccumulateOutputMasked16xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - -SkipAccumulateOutputMasked16xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - add r8,8*4 ; advance matrix C by 8 columns -IF RowCount GT 3 - add rbx,8*4 ; advance matrix C plus 3 rows by 8 columns -ENDIF - add rbp,8 ; correct for over-subtract above - -OutputMasked8xNBlock: - mov DWORD PTR GemmU8U8KernelFrame.CountN[rsp],ebp - vpbroadcastd ymm0,DWORD PTR GemmU8U8KernelFrame.CountN[rsp] - vpcmpgtd ymm0,ymm0,YMMWORD PTR [MlasMaskMoveAvx] - test r14b,r14b ; ZeroMode? - jnz SkipAccumulateOutputMasked8xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - -SkipAccumulateOutputMasked8xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, -IFB - jmp ExitKernel -ENDIF ENDM @@ -1108,8 +840,8 @@ ENDIF ; ; C (r8) - Supplies the address of matrix C. ; -; PairedCountK (r9) - Supplies the number of paired columns from matrix A and -; the number of paired rows from matrix B to iterate over. +; PairCountK (r9) - Supplies the number of pair columns from matrix A and the +; number of pair rows from matrix B to iterate over. ; ; CountM - Supplies the maximum number of rows that can be processed for ; matrix A and matrix C. The actual number of rows handled for this @@ -1129,7 +861,7 @@ ENDIF ; every column of matrix C. ; ; DepthValue - Supplies the value CountK multiplied by the zero point offset -; of matrixA multplied by the zero point offset of matrix B. This value is +; of matrix A multplied by the zero point offset of matrix B. This value is ; accumulated into every element of matrix C. ; ; ZeroMode - Supplies true if the output matrix must be zero initialized, @@ -1149,30 +881,29 @@ ENDIF push_reg rdi push_reg r12 push_reg r13 - push_reg r14 - alloc_stack (GemmU8U8KernelFrame.SavedR14) - save_xmm128_avx xmm6,GemmU8U8KernelFrame.SavedXmm6 - save_xmm128_avx xmm7,GemmU8U8KernelFrame.SavedXmm7 - save_xmm128_avx xmm8,GemmU8U8KernelFrame.SavedXmm8 - save_xmm128_avx xmm9,GemmU8U8KernelFrame.SavedXmm9 - save_xmm128_avx xmm10,GemmU8U8KernelFrame.SavedXmm10 - save_xmm128_avx xmm11,GemmU8U8KernelFrame.SavedXmm11 - save_xmm128_avx xmm12,GemmU8U8KernelFrame.SavedXmm12 - save_xmm128_avx xmm13,GemmU8U8KernelFrame.SavedXmm13 - save_xmm128_avx xmm14,GemmU8U8KernelFrame.SavedXmm14 - save_xmm128_avx xmm15,GemmU8U8KernelFrame.SavedXmm15 + alloc_stack (GemmU8X8KernelFrame.SavedR13) + save_xmm128 xmm6,GemmU8X8KernelFrame.SavedXmm6 + save_xmm128 xmm7,GemmU8X8KernelFrame.SavedXmm7 + save_xmm128 xmm8,GemmU8X8KernelFrame.SavedXmm8 + save_xmm128 xmm9,GemmU8X8KernelFrame.SavedXmm9 + save_xmm128 xmm10,GemmU8X8KernelFrame.SavedXmm10 + save_xmm128 xmm11,GemmU8X8KernelFrame.SavedXmm11 + save_xmm128 xmm12,GemmU8X8KernelFrame.SavedXmm12 + save_xmm128 xmm13,GemmU8X8KernelFrame.SavedXmm13 + save_xmm128 xmm14,GemmU8X8KernelFrame.SavedXmm14 + save_xmm128 xmm15,GemmU8X8KernelFrame.SavedXmm15 END_PROLOGUE mov rdi,rcx - mov rbp,GemmU8U8KernelFrame.CountN[rsp] - mov rax,GemmU8U8KernelFrame.ldc[rsp] + mov rbp,GemmU8X8KernelFrame.CountN[rsp] + mov rax,GemmU8X8KernelFrame.ldc[rsp] shl rax,2 ; convert ldc to bytes - lea r10,[r9*4] - mov r11,GemmU8U8KernelFrame.CountM[rsp] - mov r12,GemmU8U8KernelFrame.RowSumVector[rsp] - mov r13,GemmU8U8KernelFrame.ColumnSumVector[rsp] - movzx r14,BYTE PTR GemmU8U8KernelFrame.ZeroMode[rsp] + shl r9,2 ; convert to row length + movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp] + mov r11,GemmU8X8KernelFrame.CountM[rsp] + mov r12,GemmU8X8KernelFrame.RowSumVector[rsp] + mov r13,GemmU8X8KernelFrame.ColumnSumVector[rsp] ; ; Process CountM rows of the matrices. @@ -1204,21 +935,20 @@ ProcessCountM6: ExitKernel: mov eax,r11d vzeroupper - vmovaps xmm6,GemmU8U8KernelFrame.SavedXmm6[rsp] - vmovaps xmm7,GemmU8U8KernelFrame.SavedXmm7[rsp] - vmovaps xmm8,GemmU8U8KernelFrame.SavedXmm8[rsp] - vmovaps xmm9,GemmU8U8KernelFrame.SavedXmm9[rsp] - vmovaps xmm10,GemmU8U8KernelFrame.SavedXmm10[rsp] - vmovaps xmm11,GemmU8U8KernelFrame.SavedXmm11[rsp] - vmovaps xmm12,GemmU8U8KernelFrame.SavedXmm12[rsp] - vmovaps xmm13,GemmU8U8KernelFrame.SavedXmm13[rsp] - vmovaps xmm14,GemmU8U8KernelFrame.SavedXmm14[rsp] - vmovaps xmm15,GemmU8U8KernelFrame.SavedXmm15[rsp] - add rsp,(GemmU8U8KernelFrame.SavedR14) + movaps xmm6,GemmU8X8KernelFrame.SavedXmm6[rsp] + movaps xmm7,GemmU8X8KernelFrame.SavedXmm7[rsp] + movaps xmm8,GemmU8X8KernelFrame.SavedXmm8[rsp] + movaps xmm9,GemmU8X8KernelFrame.SavedXmm9[rsp] + movaps xmm10,GemmU8X8KernelFrame.SavedXmm10[rsp] + movaps xmm11,GemmU8X8KernelFrame.SavedXmm11[rsp] + movaps xmm12,GemmU8X8KernelFrame.SavedXmm12[rsp] + movaps xmm13,GemmU8X8KernelFrame.SavedXmm13[rsp] + movaps xmm14,GemmU8X8KernelFrame.SavedXmm14[rsp] + movaps xmm15,GemmU8X8KernelFrame.SavedXmm15[rsp] + add rsp,(GemmU8X8KernelFrame.SavedR13) BEGIN_EPILOGUE - pop r14 pop r13 pop r12 pop rdi diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512BW.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512BW.asm index 8f4d0fa47f..25087fbafa 100644 --- a/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512BW.asm +++ b/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512BW.asm @@ -25,39 +25,26 @@ INCLUDE QgemmU8U8KernelAvx512Common.inc ; ; Macro Description: ; -; This macro generates code to multiply and accumulator a single row of the +; This macro generates code to multiply and accumulator a single cell of the ; output block. ; ; Arguments: ; -; ColumnCount - Supplies the number of columns to produce. +; AccumReg - Supplies the register to accumulate into. ; -; Vec1Reg - Supplies the high block accumulator register (when ColumnCount -; is 32). +; Mult1Reg - Supplies the first multiplication operand register. ; -; Vec2Reg - Supplies the low block accumulator register. +; Mult2Reg - Supplies the second multiplication operand register. ; ; Implicit Arguments: ; -; zmm28 - Supplies the first vector loaded from matrix B. -; -; zmm29 - Supplies the second vector loaded from matrix B (when ColumnCount -; is 32). -; -; zmm30 - Supplies the broadcast value loaded from matrix A. +; zmm4 - Supplies a scratch register for intermediate results. ; -MultiplyAccumulateRow MACRO ColumnCount, Vec1Reg, Vec2Reg +MultiplyAccumulateCell MACRO AccumReg, Mult1Reg, Mult2Reg -IF ColumnCount EQ 32 - vpmaddwd zmm31,zmm30,zmm28 - vpaddd Vec1Reg,Vec1Reg,zmm31 - vpmaddwd zmm30,zmm30,zmm29 - vpaddd Vec2Reg,Vec2Reg,zmm30 -ELSE - vpmaddwd zmm31,zmm30,zmm28 - vpaddd Vec2Reg,Vec2Reg,zmm31 -ENDIF + vpmaddwd zmm4,Mult1Reg,Mult2Reg + vpaddd AccumReg,AccumReg,zmm4 ENDM @@ -73,6 +60,10 @@ ENDIF ; ; RowCount - Supplies the number of rows to produce. ; +; VectorOffset - Supplies the byte offset from matrix B to fetch elements. +; +; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. +; ; Implicit Arguments: ; ; rbx - Supplies the address into the matrix A data plus 3 rows. @@ -81,27 +72,49 @@ ENDIF ; ; rdx - Supplies the address into the matrix B data. ; -; r10 - Supplies the length in bytes of a row from matrix A. +; r9 - Supplies the length in bytes of a row from matrix A. ; -; zmm16-zmm27 - Supplies the block accumulators. +; r14 - Supplies the stride in bytes of between packed blocks of matrix B. +; +; zmm14-zmm31 - Supplies the block accumulators. ; -ComputeBlock MACRO ColumnCount, RowCount +ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset - vpmovzxbw zmm28,YMMWORD PTR [rdx] - EmitIfCountGE ColumnCount, 32, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 6, +IF ColumnCount GE 48 + vpmovzxbw zmm0,YMMWORD PTR [rdx+VectorOffset] + vpmovzxbw zmm1,YMMWORD PTR [rdx+r14+VectorOffset] + vpmovzxbw zmm2,YMMWORD PTR [rdx+r14*2+VectorOffset] +ELSEIF ColumnCount GE 32 + vpmovzxbw zmm1,YMMWORD PTR [rdx+VectorOffset] + vpmovzxbw zmm2,YMMWORD PTR [rdx+r14+VectorOffset] +ELSE + vpmovzxbw zmm2,YMMWORD PTR [rdx+VectorOffset] +ENDIF + EmitIfCountGE RowCount, 1, + EmitIfCount2GE RowCount, 1, ColumnCount, 48, + EmitIfCount2GE RowCount, 1, ColumnCount, 32, + EmitIfCount2GE RowCount, 1, ColumnCount, 16, + EmitIfCountGE RowCount, 2, + EmitIfCount2GE RowCount, 2, ColumnCount, 48, + EmitIfCount2GE RowCount, 2, ColumnCount, 32, + EmitIfCount2GE RowCount, 2, ColumnCount, 16, + EmitIfCountGE RowCount, 3, + EmitIfCount2GE RowCount, 3, ColumnCount, 48, + EmitIfCount2GE RowCount, 3, ColumnCount, 32, + EmitIfCount2GE RowCount, 3, ColumnCount, 16, + EmitIfCountGE RowCount, 4, + EmitIfCount2GE RowCount, 4, ColumnCount, 48, + EmitIfCount2GE RowCount, 4, ColumnCount, 32, + EmitIfCount2GE RowCount, 4, ColumnCount, 16, + EmitIfCountGE RowCount, 5, + EmitIfCount2GE RowCount, 5, ColumnCount, 48, + EmitIfCount2GE RowCount, 5, ColumnCount, 32, + EmitIfCount2GE RowCount, 5, ColumnCount, 16, + EmitIfCountGE RowCount, 6, + EmitIfCount2GE RowCount, 6, ColumnCount, 48, + EmitIfCount2GE RowCount, 6, ColumnCount, 32, + EmitIfCount2GE RowCount, 6, ColumnCount, 16, ENDM @@ -109,6 +122,6 @@ ComputeBlock MACRO ColumnCount, RowCount ; Generate the GEMM kernel. ; -GemmU8U8KernelAvx512Function Avx512BW +GemmU8X8KernelAvx512Function U8U8, Avx512BW END diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Common.inc b/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Common.inc index 1cd5cdc732..b46cf32b9b 100644 --- a/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Common.inc +++ b/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Common.inc @@ -16,39 +16,13 @@ ; ;-- -; -; Stack frame layout for the U8U8 kernel. -; - -GemmU8U8KernelFrame STRUCT - - SavedR14 QWORD ? - SavedR13 QWORD ? - SavedR12 QWORD ? - SavedRdi QWORD ? - SavedRsi QWORD ? - SavedRbx QWORD ? - SavedRbp QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? - PreviousP2Home QWORD ? - PreviousP3Home QWORD ? - PreviousP4Home QWORD ? - CountM QWORD ? - CountN QWORD ? - ldc QWORD ? - RowSumVector QWORD ? - ColumnSumVector QWORD ? - DepthValue QWORD ? - ZeroMode QWORD ? - -GemmU8U8KernelFrame ENDS +INCLUDE QgemmU8X8KernelAvx512Common.inc ; ; Macro Description: ; -; This macro generates code to produce an output block for a set of columns -; and rows. +; This macro generates code to execute the block compute macro multiple +; times and advancing the matrix A and matrix B data pointers. ; ; Arguments: ; @@ -58,328 +32,33 @@ GemmU8U8KernelFrame ENDS ; ; Implicit Arguments: ; +; rbx - Supplies the address into the matrix A data plus 3 rows. +; ; rcx - Supplies the address into the matrix A data. ; ; rdx - Supplies the address into the matrix B data. ; -; r9 - Supplies the number of paired columns from matrix A and the number of -; paired rows from matrix B to iterate over. +; r9 - Supplies the length in bytes of a row from matrix A. ; -; r10 - Supplies the length in bytes of a row from matrix A. +; r14 - Supplies the stride in bytes of between packed blocks of matrix B. ; -; r12 - Supplies the address of the row sum vector. -; -; r13 - Supplies the address of the column sum vector. +; zmm14-zmm31 - Supplies the block accumulators. ; -ProduceOutputBlock MACRO ColumnCount, RowCount +ComputeBlockLoop MACRO ColumnCount, RowCount - LOCAL ComputeBlockLoop + LOCAL ComputeBlockBy1Loop -; -; Initialize the accumulators with the sum of the global depth value constant, -; the column sums, and the row sums. -; + mov rsi,r9 ; reload row length remaining - vpbroadcastd zmm31,DWORD PTR GemmU8U8KernelFrame.DepthValue[rsp] -IF ColumnCount EQ 32 - vpaddd zmm30,zmm31,ZMMWORD PTR [r13] - vpaddd zmm31,zmm31,ZMMWORD PTR [r13+64] - add r13,32*4 ; advance ColumnSumVector by 32 columns -ELSE - vpaddd zmm31,zmm31,ZMMWORD PTR [r13] -ENDIF - EmitIfCount2GE RowCount, 1, ColumnCount, 32, - EmitIfCountGE RowCount, 1, - EmitIfCount2GE RowCount, 2, ColumnCount, 32, - EmitIfCountGE RowCount, 2, - EmitIfCount2GE RowCount, 3, ColumnCount, 32, - EmitIfCountGE RowCount, 3, - EmitIfCount2GE RowCount, 4, ColumnCount, 32, - EmitIfCountGE RowCount, 4, - EmitIfCount2GE RowCount, 5, ColumnCount, 32, - EmitIfCountGE RowCount, 5, - EmitIfCount2GE RowCount, 6, ColumnCount, 32, - EmitIfCountGE RowCount, 6, - -; -; Iterate over PairedCountK elements from matrix A and matrix B. -; - - mov rsi,r9 ; reload PairedCountK -IF RowCount GT 3 - lea rbx,[r10*2+r10] - add rbx,rcx ; compute matrix A plus 3 rows -ENDIF - -ComputeBlockLoop: - ComputeBlock ColumnCount, RowCount +ComputeBlockBy1Loop: + ComputeBlock ColumnCount, RowCount, 0, 0 add rcx,4 ; advance matrix A by 1 pair IF RowCount GT 3 add rbx,4 ; advance matrix A plus 3 rows by 1 pair ENDIF - add rdx,32 - dec rsi ; decrement pairs remaining - jnz ComputeBlockLoop - -IF RowCount GT 3 - lea rbx,[r8+rax*2] ; compute matrix C plus 3 rows - add rbx,rax -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to compute matrix multiplication for a fixed set -; of rows. -; -; Arguments: -; -; RowCount - Supplies the number of rows to process. -; -; Implicit Arguments: -; -; rax - Supplies the length in bytes of a row from matrix C. -; -; rcx - Supplies the address of matrix A. -; -; rdx - Supplies the address of matrix B. -; -; r8 - Supplies the address of matrix C. -; -; rdi - Supplies the address of matrix A. -; -; rbp - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; r9 - Supplies the number of paired columns from matrix A and the number of -; paired rows from matrix B to iterate over. -; -; r10 - Supplies the length in bytes of a row from matrix A. -; -; r12 - Supplies the address of the row sum vector. -; -; r13 - Supplies the address of the column sum vector. -; -; r14b - Supplies the zero mode flag. -; - -ProcessCountM MACRO RowCount - - LOCAL ProcessNextColumnLoop32xN - LOCAL SkipAccumulateOutput32xNBlock - LOCAL Output16xNBlock - LOCAL Output16xNBlockWithMask - LOCAL SkipAccumulateOutput16xNBlockWithMask - LOCAL ProcessRemainingCountN - - cmp rbp,16 - jbe ProcessRemainingCountN - -ProcessNextColumnLoop32xN: - ProduceOutputBlock 32, RowCount - lea rdx,[rdx+r10*8] ; advance matrix B by 8*PairedCountK - test r14b,r14b ; ZeroMode? - jnz SkipAccumulateOutput32xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - -SkipAccumulateOutput32xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - add r8,16*4 ; advance matrix C by 16 columns -IF RowCount GT 3 - add rbx,16*4 ; advance matrix C plus 3 rows by 16 columns -ENDIF - sub rbp,16 - -Output16xNBlock: - sub rbp,16 - jae Output16xNBlockWithMask - lea ecx,[ebp+16] ; correct for over-subtract above - mov esi,1 - shl esi,cl - dec esi - kmovw k1,esi ; update mask for remaining columns - xor ebp,ebp ; no more columns remaining - -Output16xNBlockWithMask: - test r14b,r14b ; ZeroMode? - jnz SkipAccumulateOutput16xNBlockWithMask - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - -SkipAccumulateOutput16xNBlockWithMask: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - add r8,16*4 ; advance matrix C by 16 columns - mov rcx,rdi ; reload matrix A - cmp rbp,16 - ja ProcessNextColumnLoop32xN - test rbp,rbp - jz ExitKernel - -ProcessRemainingCountN: - ProduceOutputBlock 16, RowCount - jmp Output16xNBlock - - ENDM - -; -; Macro Description: -; -; This macro generates the common AVX512 code for the inner kernel to compute -; matrix multiplication. -; -; Arguments: -; -; Isa - Supplies the instruction set architecture string for function tags. -; - -GemmU8U8KernelAvx512Function MACRO Isa - -;++ -; -; Routine Description: -; -; This routine is an inner kernel to compute matrix multiplication for a -; set of rows. -; -; Arguments: -; -; A (rcx) - Supplies the address of matrix A. The matrix data has been packed -; using MlasGemmU8U8CopyPackAAvx2. -; -; B (rdx) - Supplies the address of matrix B. The matrix data has been packed -; using MlasGemmU8U8CopyPackBAvx2. -; -; C (r8) - Supplies the address of matrix C. -; -; PairedCountK (r9) - Supplies the number of paired columns from matrix A and -; the number of paired rows from matrix B to iterate over. -; -; CountM - Supplies the maximum number of rows that can be processed for -; matrix A and matrix C. The actual number of rows handled for this -; invocation depends on the kernel implementation. -; -; CountN - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; ldc - Supplies the first dimension of matrix C. -; -; RowSumVector - Supplies the sum of each row from matrix A multiplied by the -; zero point offset of matrix B. These values are accumulated into every -; row of matrix C. -; -; ColumnSumVector - Supplies the sum of each column from matrix B multiplied -; by the zero point offset of matrix A. These values are accumulated into -; every column of matrix C. -; -; DepthValue - Supplies the value CountK multiplied by the zero point offset -; of matrixA multplied by the zero point offset of matrix B. This value is -; accumulated into every element of matrix C. -; -; ZeroMode - Supplies true if the output matrix must be zero initialized, -; else false if the output matrix is accumulated into. -; -; Return Value: -; -; Returns the number of rows handled. -; -;-- - - NESTED_ENTRY MlasGemmU8U8Kernel&Isa&, _TEXT - - rex_push_reg rbp - push_reg rbx - push_reg rsi - push_reg rdi - push_reg r12 - push_reg r13 - push_reg r14 - - END_PROLOGUE - - mov rdi,rcx - mov rbp,GemmU8U8KernelFrame.CountN[rsp] - mov rax,GemmU8U8KernelFrame.ldc[rsp] - shl rax,2 ; convert ldc to bytes - lea r10,[r9*4] - mov r11,GemmU8U8KernelFrame.CountM[rsp] - mov r12,GemmU8U8KernelFrame.RowSumVector[rsp] - mov r13,GemmU8U8KernelFrame.ColumnSumVector[rsp] - movzx r14,BYTE PTR GemmU8U8KernelFrame.ZeroMode[rsp] - mov esi,-1 - kmovw k1,esi ; update mask to write all columns - -; -; Process CountM rows of the matrices. -; - - cmp r11,5 - ja ProcessCountM6 - je ProcessCountM5 - cmp r11,3 - ja ProcessCountM4 - je ProcessCountM3 - cmp r11,1 - je ProcessCountM1 - -ProcessCountM2: - ProcessCountM 2 - -ProcessCountM4: - ProcessCountM 4 - -ProcessCountM6: - mov r11d,6 ; return 6 rows handled - ProcessCountM 6 - -; -; Restore non-volatile registers and return. -; - -ExitKernel: - mov eax,r11d - - BEGIN_EPILOGUE - - pop r14 - pop r13 - pop r12 - pop rdi - pop rsi - pop rbx - pop rbp - ret - -ProcessCountM1: - ProcessCountM 1 - -ProcessCountM3: - ProcessCountM 3 - -ProcessCountM5: - ProcessCountM 5 - - NESTED_END MlasGemmU8U8Kernel&Isa&, _TEXT + add rdx,32 ; advance matrix B + sub rsi,4 + jnz ComputeBlockBy1Loop ENDM diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Vnni.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Vnni.asm index d2b6b69632..99c72a4f2c 100644 --- a/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Vnni.asm +++ b/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Vnni.asm @@ -35,6 +35,10 @@ INCLUDE AssembleAvx512Vnni.inc ; ; RowCount - Supplies the number of rows to produce. ; +; VectorOffset - Supplies the byte offset from matrix B to fetch elements. +; +; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. +; ; Implicit Arguments: ; ; rbx - Supplies the address into the matrix A data plus 3 rows. @@ -43,41 +47,56 @@ INCLUDE AssembleAvx512Vnni.inc ; ; rdx - Supplies the address into the matrix B data. ; -; r10 - Supplies the length in bytes of a row from matrix A. +; r9 - Supplies the length in bytes of a row from matrix A. ; -; zmm16-zmm27 - Supplies the block accumulators. +; r14 - Supplies the stride in bytes of between packed blocks of matrix B. +; +; zmm14-zmm31 - Supplies the block accumulators. ; -ComputeBlock MACRO ColumnCount, RowCount +ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset - vpmovzxbw zmm28,YMMWORD PTR [rdx] -IF ColumnCount EQ 32 - vpmovzxbw zmm29,YMMWORD PTR [rdx+r10*8] - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 6, +IF ColumnCount GE 32 +IF ColumnCount GE 48 + vpmovzxbw zmm0,YMMWORD PTR [rdx+VectorOffset] + vpmovzxbw zmm1,YMMWORD PTR [rdx+r14+VectorOffset] + vpmovzxbw zmm2,YMMWORD PTR [rdx+r14*2+VectorOffset] ELSE - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, + vpmovzxbw zmm1,YMMWORD PTR [rdx+VectorOffset] + vpmovzxbw zmm2,YMMWORD PTR [rdx+r14+VectorOffset] +ENDIF + EmitIfCountGE RowCount, 1, + EmitIfCount2GE RowCount, 1, ColumnCount, 48, + EmitIfCount2GE RowCount, 1, ColumnCount, 32, + EmitIfCount2GE RowCount, 1, ColumnCount, 16, + EmitIfCountGE RowCount, 2, + EmitIfCount2GE RowCount, 2, ColumnCount, 48, + EmitIfCount2GE RowCount, 2, ColumnCount, 32, + EmitIfCount2GE RowCount, 2, ColumnCount, 16, + EmitIfCountGE RowCount, 3, + EmitIfCount2GE RowCount, 3, ColumnCount, 48, + EmitIfCount2GE RowCount, 3, ColumnCount, 32, + EmitIfCount2GE RowCount, 3, ColumnCount, 16, + EmitIfCountGE RowCount, 4, + EmitIfCount2GE RowCount, 4, ColumnCount, 48, + EmitIfCount2GE RowCount, 4, ColumnCount, 32, + EmitIfCount2GE RowCount, 4, ColumnCount, 16, + EmitIfCountGE RowCount, 5, + EmitIfCount2GE RowCount, 5, ColumnCount, 48, + EmitIfCount2GE RowCount, 5, ColumnCount, 32, + EmitIfCount2GE RowCount, 5, ColumnCount, 16, + EmitIfCountGE RowCount, 6, + EmitIfCount2GE RowCount, 6, ColumnCount, 48, + EmitIfCount2GE RowCount, 6, ColumnCount, 32, + EmitIfCount2GE RowCount, 6, ColumnCount, 16, +ELSE + vpmovzxbw zmm2,YMMWORD PTR [rdx+VectorOffset] + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, ENDIF ENDM @@ -86,6 +105,6 @@ ENDIF ; Generate the GEMM kernel. ; -GemmU8U8KernelAvx512Function Avx512Vnni +GemmU8X8KernelAvx512Function U8U8, Avx512Vnni END diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2Common.inc b/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2Common.inc new file mode 100644 index 0000000000..86142a75f7 --- /dev/null +++ b/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2Common.inc @@ -0,0 +1,302 @@ +;++ +; +; Copyright (c) Microsoft Corporation. All rights reserved. +; +; Licensed under the MIT License. +; +; Module Name: +; +; QgemmU8X8KernelAvx2Common.inc +; +; Abstract: +; +; This module contains common kernel macros and structures for the quantized +; integer matrix/matrix multiply operation (QGEMM) for the AVX2 kernels. +; +;-- + + EXTERN MlasMaskMoveAvx:NEAR + +; +; Stack frame layout for the U8S8 and U8U8 kernels. +; + +GemmU8X8KernelFrame STRUCT + + SavedXmm6 OWORD ? + SavedXmm7 OWORD ? + SavedXmm8 OWORD ? + SavedXmm9 OWORD ? + SavedXmm10 OWORD ? + SavedXmm11 OWORD ? + SavedXmm12 OWORD ? + SavedXmm13 OWORD ? + SavedXmm14 OWORD ? + SavedXmm15 OWORD ? + Padding QWORD ? + SavedR13 QWORD ? + SavedR12 QWORD ? + SavedRdi QWORD ? + SavedRsi QWORD ? + SavedRbx QWORD ? + SavedRbp QWORD ? + ReturnAddress QWORD ? + PreviousP1Home QWORD ? + PreviousP2Home QWORD ? + PreviousP3Home QWORD ? + PreviousP4Home QWORD ? + CountM QWORD ? + CountN QWORD ? + ldc QWORD ? + RowSumVector QWORD ? + ColumnSumVector QWORD ? + DepthValue QWORD ? + ZeroMode QWORD ? + +GemmU8X8KernelFrame ENDS + +; +; Macro Description: +; +; This macro generates code to produce an output block for a set of columns +; and rows. +; +; Arguments: +; +; ColumnCount - Supplies the number of columns to produce. +; +; RowCount - Supplies the number of rows to produce. +; +; Implicit Arguments: +; +; rax - Supplies the length in bytes of a row from matrix C. +; +; rcx - Supplies the address into the matrix A data. +; +; rdx - Supplies the address into the matrix B data. +; +; r9 - Supplies the length in bytes of a row from matrix A. +; +; r12 - Supplies the address of the row sum vector. +; +; r13 - Supplies the address of the column sum vector. +; +; ymm4-ymm15 - Supplies the block accumulators. +; + +ProduceOutputBlock MACRO ColumnCount, RowCount + +; +; Initialize the accumulators with the sum of the global depth value constant, +; the column sums, and the row sums. +; + + vpbroadcastd ymm1,DWORD PTR GemmU8X8KernelFrame.DepthValue[rsp] +IF ColumnCount EQ 16 + vpaddd ymm0,ymm1,YMMWORD PTR [r13] + vpaddd ymm1,ymm1,YMMWORD PTR [r13+32] + add r13,16*4 ; advance ColumnSumVector by 16 columns +ELSE + vpaddd ymm1,ymm1,YMMWORD PTR [r13] +ENDIF + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + EmitIfCount2GE RowCount, 1, ColumnCount, 16, + EmitIfCountGE RowCount, 1, + EmitIfCount2GE RowCount, 2, ColumnCount, 16, + EmitIfCountGE RowCount, 2, + EmitIfCount2GE RowCount, 3, ColumnCount, 16, + EmitIfCountGE RowCount, 3, + EmitIfCount2GE RowCount, 4, ColumnCount, 16, + EmitIfCountGE RowCount, 4, + EmitIfCount2GE RowCount, 5, ColumnCount, 16, + EmitIfCountGE RowCount, 5, + EmitIfCount2GE RowCount, 6, ColumnCount, 16, + EmitIfCountGE RowCount, 6, + +; +; Iterate over the length of a matrix A row to produce the output accumulators. +; + +IF RowCount GT 3 + lea rbx,[r9*2+r9] + add rbx,rcx ; compute matrix A plus 3 rows +ENDIF + ComputeBlockLoop ColumnCount, RowCount +IF RowCount GT 3 + lea rbx,[r8+rax*2] ; compute matrix C plus 3 rows + add rbx,rax +ENDIF + + ENDM + +; +; Macro Description: +; +; This macro generates code to compute matrix multiplication for a fixed set +; of rows. +; +; Arguments: +; +; RowCount - Supplies the number of rows to process. +; +; Fallthrough - Supplies a non-blank value if the macro may fall through to +; the ExitKernel label. +; +; Implicit Arguments: +; +; rax - Supplies the length in bytes of a row from matrix C. +; +; rcx - Supplies the address of matrix A. +; +; rdx - Supplies the address of matrix B. +; +; r8 - Supplies the address of matrix C. +; +; rdi - Supplies the address of matrix A. +; +; rbp - Supplies the number of columns from matrix B and matrix C to iterate +; over. +; +; r9 - Supplies the length in bytes of a row from matrix A. +; +; r10b - Supplies the zero mode flag. +; +; r12 - Supplies the address of the row sum vector. +; +; r13 - Supplies the address of the column sum vector. +; + +ProcessCountM MACRO RowCount, Fallthrough + + LOCAL ProcessNextColumnLoop16xN + LOCAL SkipAccumulateOutput16xNBlock + LOCAL OutputMasked16xNBlock + LOCAL ProcessRemainingCountN + LOCAL SkipAccumulateOutput8xNBlock + LOCAL SkipAccumulateOutputMasked16xNBlock + LOCAL OutputMasked8xNBlock + LOCAL SkipAccumulateOutputMasked8xNBlock + + cmp rbp,8 + jbe ProcessRemainingCountN + +ProcessNextColumnLoop16xN: + ProduceOutputBlock 16, RowCount + sub rbp,16 + jb OutputMasked16xNBlock + test r10b,r10b ; ZeroMode? + jnz SkipAccumulateOutput16xNBlock + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + EmitIfCountGE RowCount, 6, + +SkipAccumulateOutput16xNBlock: + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + EmitIfCountGE RowCount, 6, + add r8,16*4 ; advance matrix C by 16 columns + mov rcx,rdi ; reload matrix A + cmp rbp,8 + ja ProcessNextColumnLoop16xN + test rbp,rbp + jz ExitKernel + +ProcessRemainingCountN: + ProduceOutputBlock 8, RowCount + cmp rbp,8 + jb OutputMasked8xNBlock + test r10b,r10b ; ZeroMode? + jnz SkipAccumulateOutput8xNBlock + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + +SkipAccumulateOutput8xNBlock: + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + jmp ExitKernel + +OutputMasked16xNBlock: + test r10b,r10b ; ZeroMode? + jnz SkipAccumulateOutputMasked16xNBlock + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + +SkipAccumulateOutputMasked16xNBlock: + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + add r8,8*4 ; advance matrix C by 8 columns +IF RowCount GT 3 + add rbx,8*4 ; advance matrix C plus 3 rows by 8 columns +ENDIF + add rbp,8 ; correct for over-subtract above + +OutputMasked8xNBlock: + mov DWORD PTR GemmU8X8KernelFrame.CountN[rsp],ebp + vpbroadcastd ymm0,DWORD PTR GemmU8X8KernelFrame.CountN[rsp] + vpcmpgtd ymm0,ymm0,YMMWORD PTR [MlasMaskMoveAvx] + test r10b,r10b ; ZeroMode? + jnz SkipAccumulateOutputMasked8xNBlock + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + +SkipAccumulateOutputMasked8xNBlock: + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, +IFB + jmp ExitKernel +ENDIF + + ENDM diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx512Common.inc b/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx512Common.inc new file mode 100644 index 0000000000..5b364facb8 --- /dev/null +++ b/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx512Common.inc @@ -0,0 +1,438 @@ +;++ +; +; Copyright (c) Microsoft Corporation. All rights reserved. +; +; Licensed under the MIT License. +; +; Module Name: +; +; QgemmU8X8KernelAvx512Common.inc +; +; Abstract: +; +; This module contains common kernel macros and structures for the quantized +; integer matrix/matrix multiply operation (QGEMM) for the AVX512BW and +; AVX512VNNI kernels. +; +;-- + +; +; Stack frame layout for the U8S8 and U8U8 kernels. +; + +GemmU8X8KernelFrame STRUCT + + SavedXmm14 OWORD ? + SavedXmm15 OWORD ? + SavedR14 QWORD ? + SavedR13 QWORD ? + SavedR12 QWORD ? + SavedRdi QWORD ? + SavedRsi QWORD ? + SavedRbx QWORD ? + SavedRbp QWORD ? + ReturnAddress QWORD ? + PreviousP1Home QWORD ? + PreviousP2Home QWORD ? + PreviousP3Home QWORD ? + PreviousP4Home QWORD ? + CountM QWORD ? + CountN QWORD ? + ldc QWORD ? + RowSumVector QWORD ? + ColumnSumVector QWORD ? + DepthValue QWORD ? + ZeroMode QWORD ? + +GemmU8X8KernelFrame ENDS + +; +; Macro Description: +; +; This macro generates code to produce an output block for a set of columns +; and rows. +; +; Arguments: +; +; ColumnCount - Supplies the number of columns to produce. +; +; RowCount - Supplies the number of rows to produce. +; +; Implicit Arguments: +; +; rax - Supplies the length in bytes of a row from matrix C. +; +; rcx - Supplies the address into the matrix A data. +; +; rdx - Supplies the address into the matrix B data. +; +; r9 - Supplies the length in bytes of a row from matrix A. +; +; r12 - Supplies the address of the row sum vector. +; +; r13 - Supplies the address of the column sum vector. +; + +ProduceOutputBlock MACRO ColumnCount, RowCount + +; +; Initialize the accumulators with the sum of the global depth value constant, +; the column sums, and the row sums. +; + + vpbroadcastd zmm3,DWORD PTR GemmU8X8KernelFrame.DepthValue[rsp] +IF ColumnCount GE 32 +IF ColumnCount GE 48 + vpaddd zmm2,zmm3,ZMMWORD PTR [r13] + vpaddd zmm1,zmm3,ZMMWORD PTR [r13+64] + vpaddd zmm0,zmm3,ZMMWORD PTR [r13+128] +ELSE + vpaddd zmm1,zmm3,ZMMWORD PTR [r13] + vpaddd zmm0,zmm3,ZMMWORD PTR [r13+64] +ENDIF + add_immed r13,ColumnCount*4 ; advance ColumnSumVector by N columns +ELSE + vpaddd zmm0,zmm3,ZMMWORD PTR [r13] +ENDIF + EmitIfCount2GE RowCount, 1, ColumnCount, 16, + EmitIfCount2GE RowCount, 1, ColumnCount, 32, + EmitIfCount2GE RowCount, 1, ColumnCount, 48, + EmitIfCount2GE RowCount, 2, ColumnCount, 16, + EmitIfCount2GE RowCount, 2, ColumnCount, 32, + EmitIfCount2GE RowCount, 2, ColumnCount, 48, + EmitIfCount2GE RowCount, 3, ColumnCount, 16, + EmitIfCount2GE RowCount, 3, ColumnCount, 32, + EmitIfCount2GE RowCount, 3, ColumnCount, 48, + EmitIfCount2GE RowCount, 4, ColumnCount, 16, + EmitIfCount2GE RowCount, 4, ColumnCount, 32, + EmitIfCount2GE RowCount, 4, ColumnCount, 48, + EmitIfCount2GE RowCount, 5, ColumnCount, 16, + EmitIfCount2GE RowCount, 5, ColumnCount, 32, + EmitIfCount2GE RowCount, 5, ColumnCount, 48, + EmitIfCount2GE RowCount, 6, ColumnCount, 16, + EmitIfCount2GE RowCount, 6, ColumnCount, 32, + EmitIfCount2GE RowCount, 6, ColumnCount, 48, + +; +; Iterate over the length of a matrix A row to produce the output accumulators. +; + +IF RowCount GT 3 + lea rbx,[r9*2+r9] + add rbx,rcx ; compute matrix A plus 3 rows +ENDIF + ComputeBlockLoop ColumnCount, RowCount +IF RowCount GT 3 + lea rbx,[r8+rax*2] ; compute matrix C plus 3 rows + add rbx,rax +ENDIF + + ENDM + +; +; Macro Description: +; +; This macro generates code to compute matrix multiplication for a fixed set +; of rows. +; +; Arguments: +; +; RowCount - Supplies the number of rows to process. +; +; Implicit Arguments: +; +; rax - Supplies the length in bytes of a row from matrix C. +; +; rcx - Supplies the address of matrix A. +; +; rdx - Supplies the address of matrix B. +; +; r8 - Supplies the address of matrix C. +; +; rdi - Supplies the address of matrix A. +; +; rbp - Supplies the number of columns from matrix B and matrix C to iterate +; over. +; +; r9 - Supplies the length in bytes of a row from matrix A. +; +; r10b - Supplies the zero mode flag. +; +; r12 - Supplies the address of the row sum vector. +; +; r13 - Supplies the address of the column sum vector. +; +; r14 - Supplies the stride in bytes of between packed blocks of matrix B. +; + +ProcessCountM MACRO RowCount + + LOCAL ProcessNextColumnLoop32xN + LOCAL Output32xNBlock + LOCAL SkipAccumulateOutput32xNBlock + LOCAL Output16xNBlock + LOCAL Output16xNBlockWithMask + LOCAL SkipAccumulateOutput16xNBlockWithMask + LOCAL ProcessRemainingCountN + LOCAL ProcessNextColumnLoop48xN + LOCAL SkipAccumulateOutput48xNBlock + + cmp rbp,32 + ja ProcessNextColumnLoop48xN + cmp rbp,16 + jbe ProcessRemainingCountN + +ProcessNextColumnLoop32xN: + ProduceOutputBlock 32, RowCount + add rdx,r14 ; advance matrix B by packed block stride + +Output32xNBlock: + test r10b,r10b ; ZeroMode? + jnz SkipAccumulateOutput32xNBlock + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + +SkipAccumulateOutput32xNBlock: + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + add r8,16*4 ; advance matrix C by 16 columns +IF RowCount GT 3 + add rbx,16*4 ; advance matrix C plus 3 rows by 16 columns +ENDIF + sub rbp,16 + +Output16xNBlock: + sub rbp,16 + jae Output16xNBlockWithMask + lea ecx,[ebp+16] ; correct for over-subtract above + mov esi,1 + shl esi,cl + dec esi + kmovw k1,esi ; update mask for remaining columns + xor ebp,ebp ; no more columns remaining + +Output16xNBlockWithMask: + test r10b,r10b ; ZeroMode? + jnz SkipAccumulateOutput16xNBlockWithMask + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + +SkipAccumulateOutput16xNBlockWithMask: + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + add r8,16*4 ; advance matrix C by 16 columns + mov rcx,rdi ; reload matrix A + cmp rbp,32 + ja ProcessNextColumnLoop48xN + cmp rbp,16 + ja ProcessNextColumnLoop32xN + test rbp,rbp + jz ExitKernel + +ProcessRemainingCountN: + ProduceOutputBlock 16, RowCount + jmp Output16xNBlock + +ProcessNextColumnLoop48xN: + ProduceOutputBlock 48, RowCount + lea rdx,[rdx+r14*2] ; advance matrix B by packed block stride + test r10b,r10b ; ZeroMode? + jnz SkipAccumulateOutput48xNBlock + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + +SkipAccumulateOutput48xNBlock: + EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 6, + add r8,16*4 ; advance matrix C by 16 columns +IF RowCount GT 3 + add rbx,16*4 ; advance matrix C plus 3 rows by 16 columns +ENDIF + sub rbp,16 + jmp Output32xNBlock + + ENDM + +; +; Macro Description: +; +; This macro generates the common AVX512 code for the inner kernel to compute +; matrix multiplication. +; +; Arguments: +; +; Type - Supplies the kernel type string for function tags. +; +; Isa - Supplies the instruction set architecture string for function tags. +; + +GemmU8X8KernelAvx512Function MACRO Type, Isa + +;++ +; +; Routine Description: +; +; This routine is an inner kernel to compute matrix multiplication for a +; set of rows. +; +; Arguments: +; +; A (rcx) - Supplies the address of matrix A. The matrix data has been packed +; using MlasGemmU8X8CopyPackAAvx2. +; +; B (rdx) - Supplies the address of matrix B. The matrix data has been packed +; using MlasGemmU8X8CopyPackBAvx2. +; +; C (r8) - Supplies the address of matrix C. +; +; QuadCountK (r9) - Supplies the number of quad columns from matrix A and the +; number of quad rows from matrix B to iterate over. +; +; CountM - Supplies the maximum number of rows that can be processed for +; matrix A and matrix C. The actual number of rows handled for this +; invocation depends on the kernel implementation. +; +; CountN - Supplies the number of columns from matrix B and matrix C to iterate +; over. +; +; ldc - Supplies the first dimension of matrix C. +; +; RowSumVector - Supplies the sum of each row from matrix A multiplied by the +; zero point offset of matrix B. These values are accumulated into every +; row of matrix C. +; +; ColumnSumVector - Supplies the sum of each column from matrix B multiplied +; by the zero point offset of matrix A. These values are accumulated into +; every column of matrix C. +; +; DepthValue - Supplies the value CountK multiplied by the zero point offset +; of matrixA multplied by the zero point offset of matrix B. This value is +; accumulated into every element of matrix C. +; +; ZeroMode - Supplies true if the output matrix must be zero initialized, +; else false if the output matrix is accumulated into. +; +; Return Value: +; +; Returns the number of rows handled. +; +;-- + + NESTED_ENTRY MlasGemm&Type&Kernel&Isa&, _TEXT + + rex_push_reg rbp + push_reg rbx + push_reg rsi + push_reg rdi + push_reg r12 + push_reg r13 + push_reg r14 + alloc_stack (GemmU8X8KernelFrame.SavedR14) + save_xmm128 xmm14,GemmU8X8KernelFrame.SavedXmm14 + save_xmm128 xmm15,GemmU8X8KernelFrame.SavedXmm15 + + END_PROLOGUE + + mov rdi,rcx + mov rbp,GemmU8X8KernelFrame.CountN[rsp] + mov rax,GemmU8X8KernelFrame.ldc[rsp] + shl rax,2 ; convert ldc to bytes + shl r9,2 ; convert to row length + movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp] + mov r11,GemmU8X8KernelFrame.CountM[rsp] + mov r12,GemmU8X8KernelFrame.RowSumVector[rsp] + mov r13,GemmU8X8KernelFrame.ColumnSumVector[rsp] + mov esi,-1 + kmovw k1,esi ; update mask to write all columns +IFIDNI , +IFIDNI , + neg esi + vpbroadcastw zmm5,esi ; generate 512-bit word vector [0x0001] +ENDIF + mov r14,r9 + shl r14,4 ; compute matrix B packed stride +ELSE + lea r14,[r9*8] ; compute matrix B packed stride +ENDIF + +; +; Process CountM rows of the matrices. +; + + cmp r11,5 + ja ProcessCountM6 + je ProcessCountM5 + cmp r11,3 + ja ProcessCountM4 + je ProcessCountM3 + cmp r11,1 + je ProcessCountM1 + +ProcessCountM2: + ProcessCountM 2 + +ProcessCountM4: + ProcessCountM 4 + +ProcessCountM6: + mov r11d,6 ; return 6 rows handled + ProcessCountM 6 + +; +; Restore non-volatile registers and return. +; + +ExitKernel: + mov eax,r11d + vzeroupper + movaps xmm14,GemmU8X8KernelFrame.SavedXmm14[rsp] + movaps xmm15,GemmU8X8KernelFrame.SavedXmm15[rsp] + add rsp,(GemmU8X8KernelFrame.SavedR14) + + BEGIN_EPILOGUE + + pop r14 + pop r13 + pop r12 + pop rdi + pop rsi + pop rbx + pop rbp + ret + +ProcessCountM1: + ProcessCountM 1 + +ProcessCountM3: + ProcessCountM 3 + +ProcessCountM5: + ProcessCountM 5 + + NESTED_END MlasGemm&Type&Kernel&Isa&, _TEXT + + ENDM diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 6ba13c728d..193172d6ef 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -194,6 +194,52 @@ void typedef MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE* PMLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE; +typedef +void +(MLASCALL MLAS_GEMM_U8S8_COPY_PACKA_ROUTINE)( + uint8_t* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumVector, + int16_t offb + ); + +typedef MLAS_GEMM_U8S8_COPY_PACKA_ROUTINE* PMLAS_GEMM_U8S8_COPY_PACKA_ROUTINE; + +typedef +void +(MLASCALL MLAS_GEMM_U8S8_COPY_PACKB_ROUTINE)( + int8_t* D, + const int8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumVector, + int16_t offa + ); + +typedef MLAS_GEMM_U8S8_COPY_PACKB_ROUTINE* PMLAS_GEMM_U8S8_COPY_PACKB_ROUTINE; + +typedef +size_t +(MLASCALL MLAS_GEMM_U8S8_KERNEL)( + const uint8_t* A, + const int8_t* B, + int32_t* C, + size_t QuadCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + int32_t DepthValue, + bool ZeroMode + ); + +typedef MLAS_GEMM_U8S8_KERNEL* PMLAS_GEMM_U8S8_KERNEL; + typedef void (MLASCALL MLAS_GEMM_U8U8_COPY_PACKA_ROUTINE)( @@ -228,7 +274,7 @@ size_t const int16_t* A, const uint8_t* B, int32_t* C, - size_t PairedCountK, + size_t PairCountK, size_t CountM, size_t CountN, size_t ldc, @@ -364,10 +410,18 @@ extern "C" { #endif #if defined(MLAS_TARGET_AMD64_IX86) + MLAS_GEMM_U8S8_COPY_PACKA_ROUTINE MlasGemmU8S8CopyPackASse; + MLAS_GEMM_U8S8_COPY_PACKB_ROUTINE MlasGemmU8S8CopyPackBSse; + MLAS_GEMM_U8S8_KERNEL MlasGemmU8S8KernelSse; MLAS_GEMM_U8U8_COPY_PACKA_ROUTINE MlasGemmU8U8CopyPackASse; MLAS_GEMM_U8U8_COPY_PACKB_ROUTINE MlasGemmU8U8CopyPackBSse; MLAS_GEMM_U8U8_KERNEL MlasGemmU8U8KernelSse; #if defined(MLAS_TARGET_AMD64) + MLAS_GEMM_U8S8_COPY_PACKA_ROUTINE MlasGemmU8S8CopyPackAAvx2; + MLAS_GEMM_U8S8_COPY_PACKB_ROUTINE MlasGemmU8S8CopyPackBAvx2; + MLAS_GEMM_U8S8_KERNEL MlasGemmU8S8KernelAvx2; + MLAS_GEMM_U8S8_KERNEL MlasGemmU8S8KernelAvx512BW; + MLAS_GEMM_U8S8_KERNEL MlasGemmU8S8KernelAvx512Vnni; MLAS_GEMM_U8U8_COPY_PACKA_ROUTINE MlasGemmU8U8CopyPackAAvx2; MLAS_GEMM_U8U8_COPY_PACKB_ROUTINE MlasGemmU8U8CopyPackBAvx2; MLAS_GEMM_U8U8_KERNEL MlasGemmU8U8KernelAvx2; @@ -490,6 +544,9 @@ struct MLAS_PLATFORM { #if defined(MLAS_TARGET_AMD64_IX86) PMLAS_GEMM_FLOAT_KERNEL GemmFloatKernel; + PMLAS_GEMM_U8S8_COPY_PACKA_ROUTINE GemmU8S8CopyPackARoutine; + PMLAS_GEMM_U8S8_COPY_PACKB_ROUTINE GemmU8S8CopyPackBRoutine; + PMLAS_GEMM_U8S8_KERNEL GemmU8S8Kernel; PMLAS_GEMM_U8U8_COPY_PACKA_ROUTINE GemmU8U8CopyPackARoutine; PMLAS_GEMM_U8U8_COPY_PACKB_ROUTINE GemmU8U8CopyPackBRoutine; PMLAS_GEMM_U8U8_KERNEL GemmU8U8Kernel; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index e3634ebd37..2e74ebd3ba 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -85,6 +85,9 @@ Return Value: // this->GemmFloatKernel = MlasGemmFloatKernelSse; + this->GemmU8S8CopyPackARoutine = MlasGemmU8S8CopyPackASse; + this->GemmU8S8CopyPackBRoutine = MlasGemmU8S8CopyPackBSse; + this->GemmU8S8Kernel = MlasGemmU8S8KernelSse; this->GemmU8U8CopyPackARoutine = MlasGemmU8U8CopyPackASse; this->GemmU8U8CopyPackBRoutine = MlasGemmU8U8CopyPackBSse; this->GemmU8U8Kernel = MlasGemmU8U8KernelSse; @@ -157,6 +160,9 @@ Return Value: if (((Cpuid1[2] & 0x1000) != 0) && ((Cpuid7[1] & 0x20) != 0)) { + this->GemmU8S8CopyPackARoutine = MlasGemmU8S8CopyPackAAvx2; + this->GemmU8S8CopyPackBRoutine = MlasGemmU8S8CopyPackBAvx2; + this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx2; this->GemmU8U8CopyPackARoutine = MlasGemmU8U8CopyPackAAvx2; this->GemmU8U8CopyPackBRoutine = MlasGemmU8U8CopyPackBAvx2; this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx2; @@ -180,6 +186,7 @@ Return Value: if ((Cpuid7[1] & 0x40000000) != 0) { + this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx512BW; this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx512BW; // @@ -187,6 +194,8 @@ Return Value: // if ((Cpuid7[2] & 0x800) != 0) { + + this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx512Vnni; this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx512Vnni; } } diff --git a/onnxruntime/core/mlas/lib/qgemm.cpp b/onnxruntime/core/mlas/lib/qgemm.cpp index c5d07da984..9b6b998663 100644 --- a/onnxruntime/core/mlas/lib/qgemm.cpp +++ b/onnxruntime/core/mlas/lib/qgemm.cpp @@ -21,12 +21,577 @@ Abstract: // Define the default strides to step through slices of the input matrices. // -#define MLAS_GEMM_U8U8_STRIDEM 12 -#define MLAS_GEMM_U8U8_STRIDEN 128 +#define MLAS_GEMM_U8S8_STRIDEM 24 +#define MLAS_GEMM_U8S8_STRIDEN 256 +#define MLAS_GEMM_U8S8_STRIDEK 128 + +#define MLAS_GEMM_U8U8_STRIDEM 24 +#define MLAS_GEMM_U8U8_STRIDEN 256 #define MLAS_GEMM_U8U8_STRIDEK 128 #ifdef MLAS_TARGET_AMD64_IX86 +// +// U8S8 implementation using SSE2 intrinsics. +// + +void +MLASCALL +MlasGemmU8S8CopyPackASse( + uint8_t* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumVector, + int16_t offb + ) +/*++ + +Routine Description: + + This routine copies elements from the source matrix to the destination + packed buffer. + +Arguments: + + D - Supplies the address of the destination packed buffer. + + A - Supplies the address of the source matrix. + + lda - Supplies the number of elements per row of the source matrix. + + CountM - Supplies the number of rows of the source matrix to copy. + + CountK - Supplies the number of columns of the source matrix to copy. + + RowSumVector - Supplies the address of the buffer to receive the sums of + the elements from each of the rows. Each sum has also been multiplied + by the zero point offset. + + offb - Supplies the zero point offset for the other source matrix of the + matrix multiplication. + +Return Value: + + None. + +--*/ +{ + const __m128i ZeroVector = _mm_setzero_si128(); + const __m128i OffsetBroadcast = _mm_set1_epi16(offb); + uint8_t PaddedMatrixAData[8] = { 0 }; + + // + // Process a single row of matrix A in a loop. + // + + while (CountM > 0) { + + const uint8_t* a = A; + size_t k = CountK; + __m128i RowSum = ZeroVector; + + // + // Copy the source bytes to the packed buffer. + // + // The packed buffer has the same data ordering as the source bytes, + // but CountK is aligned up to a multiple of 4 to maintain 32-bit + // alignment. All extra bytes are zero-padded. + // + // These values are also zero-extended and accumulated into an + // intermediate per-row accumulator. CountK cannot be greater than 128 + // to avoid overflowing these signed 16-bit accumulators. + // + + while (k >= 8) { + + __m128i Bytes = _mm_loadl_epi64((__m128i*)&a[0]); + _mm_storel_epi64((__m128i*)&D[0], Bytes); + + RowSum = _mm_add_epi16(RowSum, _mm_unpacklo_epi8(Bytes, ZeroVector)); + + D += 8; + a += 8; + k -= 8; + } + + if (k > 0) { + + // + // Copy the remaining bytes to the zero padded stack buffer. + // + + uint8_t* padded = PaddedMatrixAData; + uint8_t* padded_end = padded + k; + + do { + padded[0] = a[0]; + padded++; + a++; + } while (padded < padded_end); + + __m128i Bytes = _mm_loadl_epi64((__m128i*)PaddedMatrixAData); + _mm_storel_epi64((__m128i*)&D[0], Bytes); + + RowSum = _mm_add_epi16(RowSum, _mm_unpacklo_epi8(Bytes, ZeroVector)); + + // + // Copy quads of 8-bit values from the vector to the packed + // buffer and rotate the vector for the next iteration. + // + + for (size_t quads = (k + 3) / 4; quads > 0; quads--) { + *((int32_t*)D) = _mm_cvtsi128_si32(Bytes); + D += 4; + Bytes = _mm_shuffle_epi32(Bytes, _MM_SHUFFLE(0, 3, 2, 1)); + } + } + + // + // Reduce the sum for the single row of output and multiply by the + // zero point offset of the other source matrix. + // + + RowSum = _mm_madd_epi16(RowSum, OffsetBroadcast); + RowSum = _mm_add_epi32(RowSum, _mm_shuffle_epi32(RowSum, _MM_SHUFFLE(3, 2, 3, 2))); + RowSum = _mm_add_epi32(RowSum, _mm_shuffle_epi32(RowSum, _MM_SHUFFLE(0, 1, 0, 1))); + + *RowSumVector++ = _mm_cvtsi128_si32(RowSum); + + A += lda; + CountM -= 1; + } +} + +void +MLASCALL +MlasGemmU8S8CopyPackBSse( + int8_t* D, + const int8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumVector, + int16_t offa + ) +/*++ + +Routine Description: + + This routine copies elements from the source matrix to the destination + packed buffer. + +Arguments: + + D - Supplies the address of the destination packed buffer. + + B - Supplies the address of the source matrix. + + ldb - Supplies the number of elements per row of the source matrix. + + CountN - Supplies the number of columns of the source matrix to copy. + + CountK - Supplies the number of rows of the source matrix to copy. + + ColumnSumVector - Supplies the address of the buffer to receive the sums of + the elements from each of the columns. Each sum has also been multiplied + by the zero point offset. + + offa - Supplies the zero point offset for the other source matrix of the + matrix multiplication. + +Return Value: + + None. + +--*/ +{ + const __m128i ZeroVector = _mm_setzero_si128(); + const __m128i OffsetBroadcast = _mm_set1_epi16(offa); + int8_t PaddedMatrixBData[16] = { 0 }; + + // + // Process 8 columns of matrix B in a loop. + // + + while (CountN >= 8) { + + const int8_t* b = B; + size_t k = CountK; + __m128i ColumnSum0 = ZeroVector; + __m128i ColumnSum1 = ZeroVector; + + // + // Interleave 2 rows of matrix B and write to the packed buffer. + // + // These values are also sign-extended and accumulated into an + // intermediate per-column accumulator. CountK cannot be greater than + // 128 to avoid overflowing these signed 16-bit accumulators. + // + + while (k >= 2) { + + __m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&b[0]); + __m128i BytesRow1 = _mm_loadl_epi64((__m128i*)&b[ldb]); + __m128i BytesInterleaved = _mm_unpacklo_epi8(BytesRow0, BytesRow1); + + _mm_storeu_si128((__m128i*)&D[0], BytesInterleaved); + + __m128i WordsLow = _mm_srai_epi16(_mm_unpacklo_epi8(ZeroVector, BytesInterleaved), 8); + ColumnSum0 = _mm_add_epi16(ColumnSum0, WordsLow); + __m128i WordsHigh = _mm_srai_epi16(_mm_unpackhi_epi8(ZeroVector, BytesInterleaved), 8); + ColumnSum1 = _mm_add_epi16(ColumnSum1, WordsHigh); + + b += ldb * 2; + D += 16; + k -= 2; + } + + if (k > 0) { + + // + // Process the remaining row of matrix B. + // + + __m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&b[0]); + __m128i BytesInterleaved = _mm_unpacklo_epi8(BytesRow0, ZeroVector); + + _mm_storeu_si128((__m128i*)&D[0], BytesInterleaved); + + __m128i WordsLow = _mm_srai_epi16(_mm_unpacklo_epi8(ZeroVector, BytesInterleaved), 8); + ColumnSum0 = _mm_add_epi16(ColumnSum0, WordsLow); + __m128i WordsHigh = _mm_srai_epi16(_mm_unpackhi_epi8(ZeroVector, BytesInterleaved), 8); + ColumnSum1 = _mm_add_epi16(ColumnSum1, WordsHigh); + + D += 16; + } + + // + // The number of rows written to the packed buffer should be a multiple + // of 4. Zero pad the packed buffer if the block is not complete. + // + + if (((CountK - 1) & 2) == 0) { + + _mm_storeu_si128((__m128i*)&D[0], ZeroVector); + + D += 16; + } + + ColumnSum0 = _mm_madd_epi16(ColumnSum0, OffsetBroadcast); + ColumnSum1 = _mm_madd_epi16(ColumnSum1, OffsetBroadcast); + + _mm_storeu_si128((__m128i*)&ColumnSumVector[0], ColumnSum0); + _mm_storeu_si128((__m128i*)&ColumnSumVector[4], ColumnSum1); + + ColumnSumVector += 8; + + B += 8; + CountN -= 8; + } + + // + // Process the remaining columns of matrix B. + // + + if (CountN > 0) { + + const int8_t* b = B; + size_t k = CountK; + __m128i ColumnSum0 = ZeroVector; + __m128i ColumnSum1 = ZeroVector; + + // + // Interleave 2 rows of matrix B and write to the packed buffer. + // + // These values are also sign-extended and accumulated into an + // intermediate per-column accumulator. CountK cannot be greater than + // 128 to avoid overflowing these signed 16-bit accumulators. + // + + while (k >= 2) { + + // + // Copy the remaining columns to the zero padded stack buffer. + // + + const int8_t* bcopy = b; + int8_t* padded = PaddedMatrixBData; + int8_t* padded_end = padded + CountN; + + do { + padded[0] = bcopy[0]; + padded[8] = bcopy[ldb]; + padded++; + bcopy++; + } while (padded < padded_end); + + __m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&PaddedMatrixBData[0]); + __m128i BytesRow1 = _mm_loadl_epi64((__m128i*)&PaddedMatrixBData[8]); + __m128i BytesInterleaved = _mm_unpacklo_epi8(BytesRow0, BytesRow1); + + _mm_storeu_si128((__m128i*)&D[0], BytesInterleaved); + + __m128i WordsLow = _mm_srai_epi16(_mm_unpacklo_epi8(ZeroVector, BytesInterleaved), 8); + ColumnSum0 = _mm_add_epi16(ColumnSum0, WordsLow); + __m128i WordsHigh = _mm_srai_epi16(_mm_unpackhi_epi8(ZeroVector, BytesInterleaved), 8); + ColumnSum1 = _mm_add_epi16(ColumnSum1, WordsHigh); + + b += ldb * 2; + D += 16; + k -= 2; + } + + if (k > 0) { + + // + // Copy the remaining columns to the zero padded stack buffer. + // + + const int8_t* bcopy = b; + int8_t* padded = PaddedMatrixBData; + int8_t* padded_end = padded + CountN; + + do { + padded[0] = bcopy[0]; + padded++; + bcopy++; + } while (padded < padded_end); + + __m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&PaddedMatrixBData[0]); + __m128i BytesInterleaved = _mm_unpacklo_epi8(BytesRow0, ZeroVector); + + _mm_storeu_si128((__m128i*)&D[0], BytesInterleaved); + + __m128i WordsLow = _mm_srai_epi16(_mm_unpacklo_epi8(ZeroVector, BytesInterleaved), 8); + ColumnSum0 = _mm_add_epi16(ColumnSum0, WordsLow); + __m128i WordsHigh = _mm_srai_epi16(_mm_unpackhi_epi8(ZeroVector, BytesInterleaved), 8); + ColumnSum1 = _mm_add_epi16(ColumnSum1, WordsHigh); + + D += 16; + } + + // + // The number of rows written to the packed buffer should be a multiple + // of 4. Zero pad the packed buffer if the block is not complete. + // + + if (((CountK - 1) & 2) == 0) { + + _mm_storeu_si128((__m128i*)&D[0], ZeroVector); + + D += 16; + } + + ColumnSum0 = _mm_madd_epi16(ColumnSum0, OffsetBroadcast); + ColumnSum1 = _mm_madd_epi16(ColumnSum1, OffsetBroadcast); + + _mm_storeu_si128((__m128i*)&ColumnSumVector[0], ColumnSum0); + _mm_storeu_si128((__m128i*)&ColumnSumVector[4], ColumnSum1); + } +} + +size_t +MLASCALL +MlasGemmU8S8KernelSse( + const uint8_t* A, + const int8_t* B, + int32_t* C, + size_t PairCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + int32_t DepthValue, + bool ZeroMode + ) +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A - Supplies the address of matrix A. The matrix data has been packed + using MlasGemmU8S8CopyPackASse. + + B - Supplies the address of matrix B. The matrix data has been packed + using MlasGemmU8S8CopyPackBSse. + + C - Supplies the address of matrix C. + + PairCountK - Supplies the number of paired columns from matrix A and the + number of paired rows from matrix B to iterate over. + + CountM - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN - Supplies the number of columns from matrix B and matrix C to iterate + over. + + ldc - Supplies the first dimension of matrix C. + + RowSumVector - Supplies the sum of each row from matrix A multiplied by the + zero point offset of matrix B. These values are accumulated into every + row of matrix C. + + ColumnSumVector - Supplies the sum of each column from matrix B multiplied + by the zero point offset of matrix A. These values are accumulated into + every column of matrix C. + + DepthValue - Supplies the value CountK multiplied by the zero point offset + of matrixA multplied by the zero point offset of matrix B. This value is + accumulated into every element of matrix C. + + ZeroMode - Supplies true if the output matrix must be zero initialized, + else false if the output matrix is accumulated into. + +Return Value: + + Returns the number of rows handled. + +--*/ +{ + const __m128i ZeroVector = _mm_setzero_si128(); + + MLAS_UNREFERENCED_PARAMETER(CountM); + MLAS_UNREFERENCED_PARAMETER(ldc); + + while (CountN > 0) { + + // + // Initialize the accumulators with the sum of the global depth value + // constant, the column sums, and the row sums. + // + + __m128i Accumulator0 = _mm_set1_epi32(DepthValue); + Accumulator0 = _mm_add_epi32(Accumulator0, _mm_set1_epi32(RowSumVector[0])); + __m128i Accumulator1 = Accumulator0; + Accumulator0 = _mm_add_epi32(Accumulator0, _mm_loadu_si128((__m128i*)&ColumnSumVector[0])); + Accumulator1 = _mm_add_epi32(Accumulator1, _mm_loadu_si128((__m128i*)&ColumnSumVector[4])); + ColumnSumVector += 8; + + // + // Broadcast each pair of 16-bit values from the matrix A and multiply + // with the zero-extended pair of 16-bit values from matrix B, and add + // the 32-bit intermediate into the accumulator registers. + // + + const uint8_t* a = A; + size_t k = PairCountK; + + while (k > 0) { + + __m128i AElements = _mm_unpacklo_epi8(_mm_cvtsi32_si128(*((int32_t*)a)), ZeroVector); + + __m128i BElements; + __m128i Intermediate0; + __m128i Intermediate1; + + BElements = _mm_loadu_si128((__m128i*)&B[0]); + Intermediate0 = _mm_srai_epi16(_mm_unpacklo_epi8(ZeroVector, BElements), 8); + Intermediate1 = _mm_srai_epi16(_mm_unpackhi_epi8(ZeroVector, BElements), 8); + + __m128i AElements0 = _mm_shuffle_epi32(AElements, _MM_SHUFFLE(0, 0, 0, 0)); + + Intermediate0 = _mm_madd_epi16(Intermediate0, AElements0); + Intermediate1 = _mm_madd_epi16(Intermediate1, AElements0); + + Accumulator0 = _mm_add_epi32(Accumulator0, Intermediate0); + Accumulator1 = _mm_add_epi32(Accumulator1, Intermediate1); + + BElements = _mm_loadu_si128((__m128i*)&B[16]); + Intermediate0 = _mm_srai_epi16(_mm_unpacklo_epi8(ZeroVector, BElements), 8); + Intermediate1 = _mm_srai_epi16(_mm_unpackhi_epi8(ZeroVector, BElements), 8); + + __m128i AElements1 = _mm_shuffle_epi32(AElements, _MM_SHUFFLE(1, 1, 1, 1)); + + Intermediate0 = _mm_madd_epi16(Intermediate0, AElements1); + Intermediate1 = _mm_madd_epi16(Intermediate1, AElements1); + + Accumulator0 = _mm_add_epi32(Accumulator0, Intermediate0); + Accumulator1 = _mm_add_epi32(Accumulator1, Intermediate1); + + a += 4; + B += 32; + k -= 1; + } + + // + // Output the accumulator block after optionally accumulating the values + // from matrix C. + // + + if (CountN >= 8) { + + if (!ZeroMode) { + Accumulator0 = _mm_add_epi32(Accumulator0, _mm_loadu_si128((__m128i*)&C[0])); + Accumulator1 = _mm_add_epi32(Accumulator1, _mm_loadu_si128((__m128i*)&C[4])); + } + + _mm_storeu_si128((__m128i*)&C[0], Accumulator0); + _mm_storeu_si128((__m128i*)&C[4], Accumulator1); + + C += 8; + CountN -= 8; + + } else { + + // + // Output the remaining partial output block. + // + + if ((CountN & 4) != 0) { + + if (!ZeroMode) { + Accumulator0 = _mm_add_epi32(Accumulator0, _mm_loadu_si128((__m128i*)&C[0])); + } + + _mm_storeu_si128((__m128i*)&C[0], Accumulator0); + C += 4; + + Accumulator0 = Accumulator1; + } + + if ((CountN & 2) != 0) { + + if (!ZeroMode) { + Accumulator0 = _mm_add_epi32(Accumulator0, _mm_loadl_epi64((__m128i*)&C[0])); + } + + _mm_storel_epi64((__m128i*)&C[0], Accumulator0); + C += 2; + + Accumulator0 = _mm_shuffle_epi32(Accumulator0, _MM_SHUFFLE(1, 0, 3, 2)); + } + + if ((CountN & 1) != 0) { + + int32_t AccumulatorValue = _mm_cvtsi128_si32(Accumulator0); + + if (!ZeroMode) { + AccumulatorValue += C[0]; + } + + C[0] = AccumulatorValue; + } + + CountN = 0; + } + } + + return 1; +} + +// +// U8U8 implementation using SSE2 intrinsics. +// + void MLASCALL MlasGemmU8U8CopyPackASse( @@ -86,13 +651,15 @@ Return Value: // // Zero extend the source bytes to 16-bits and write to the packed - // buffer. The packed buffer has the same data ordering as the source - // bytes, but the stride is CountK aligned up to a multiple of 8 - // values. + // buffer. + // + // The packed buffer has the same data ordering as the source bytes, + // but CountK is aligned up to a multiple of 2 to maintain 32-bit + // alignment. All extra bytes are zero-padded. // // These 16-bit values are also accumulated into an intermediate per-row - // accumulator. CountK cannot be greater than 256 to avoid overflowing - // these 16-bit accumulators. + // accumulator. CountK cannot be greater than 128 to avoid overflowing + // these signed 16-bit accumulators. // while (k >= 8) { @@ -130,8 +697,8 @@ Return Value: RowSum = _mm_add_epi16(RowSum, Words); // - // Copy the 16-bit pairs from the vector to the destination packed - // buffer. Rotate the vector at each iteration. + // Copy pairs of 16-bit values from the vector to the packed + // buffer and rotate the vector for the next iteration. // for (size_t pairs = (k + 1) / 2; pairs > 0; pairs--) { @@ -142,7 +709,8 @@ Return Value: } // - // Reduce the sum for the single row of output. + // Reduce the sum for the single row of output and multiply by the + // zero point offset of the other source matrix. // RowSum = _mm_madd_epi16(RowSum, OffsetBroadcast); @@ -176,13 +744,13 @@ Routine Description: Arguments: - D (rcx) - Supplies the address of the destination packed buffer. + D - Supplies the address of the destination packed buffer. - B (rdx) - Supplies the address of the source matrix. + B - Supplies the address of the source matrix. - ldb (r8) - Supplies the number of elements per row of the source matrix. + ldb - Supplies the number of elements per row of the source matrix. - CountN (r9) - Supplies the number of columns of the source matrix to copy. + CountN - Supplies the number of columns of the source matrix to copy. CountK - Supplies the number of rows of the source matrix to copy. @@ -214,6 +782,14 @@ Return Value: __m128i ColumnSum0 = ZeroVector; __m128i ColumnSum1 = ZeroVector; + // + // Interleave 2 rows of matrix B and write to the packed buffer. + // + // These values are also zero-extended and accumulated into an + // intermediate per-column accumulator. CountK cannot be greater than + // 128 to avoid overflowing these signed 16-bit accumulators. + // + while (k >= 2) { __m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&b[0]); @@ -232,6 +808,10 @@ Return Value: if (k > 0) { + // + // Process the remaining row of matrix B. + // + __m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&b[0]); __m128i BytesInterleaved = _mm_unpacklo_epi8(BytesRow0, ZeroVector); @@ -240,7 +820,6 @@ Return Value: ColumnSum0 = _mm_add_epi16(ColumnSum0, _mm_unpacklo_epi8(BytesInterleaved, ZeroVector)); ColumnSum1 = _mm_add_epi16(ColumnSum1, _mm_unpackhi_epi8(BytesInterleaved, ZeroVector)); - b += ldb * 2; D += 16; } @@ -323,6 +902,11 @@ Return Value: ColumnSum1 = _mm_add_epi16(ColumnSum1, _mm_unpackhi_epi8(BytesInterleaved, ZeroVector)); } + // + // Reduce the sum for the packed columns and multiply by the zero point + // offset of the other source matrix. + // + ColumnSum0 = _mm_madd_epi16(ColumnSum0, OffsetBroadcast); ColumnSum1 = _mm_madd_epi16(ColumnSum1, OffsetBroadcast); @@ -337,7 +921,7 @@ MlasGemmU8U8KernelSse( const int16_t* A, const uint8_t* B, int32_t* C, - size_t PairedCountK, + size_t PairCountK, size_t CountM, size_t CountN, size_t ldc, @@ -363,8 +947,8 @@ Arguments: C - Supplies the address of matrix C. - PairedCountK - Supplies the number of paired columns from matrix A and - the number of paired rows from matrix B to iterate over. + PairCountK - Supplies the number of paired columns from matrix A and the + number of paired rows from matrix B to iterate over. CountM - Supplies the maximum number of rows that can be processed for matrix A and matrix C. The actual number of rows handled for this @@ -422,18 +1006,18 @@ Return Value: // const int16_t* a = A; - size_t k = PairedCountK; + size_t k = PairCountK; while (k > 0) { - __m128i AElements0 = _mm_set1_epi32(*((int32_t*)a)); + __m128i AElements = _mm_set1_epi32(*((int32_t*)a)); __m128i BElements0 = _mm_loadu_si128((__m128i*)&B[0]); __m128i Intermediate0 = _mm_unpacklo_epi8(BElements0, ZeroVector); __m128i Intermediate1 = _mm_unpackhi_epi8(BElements0, ZeroVector); - Intermediate0 = _mm_madd_epi16(Intermediate0, AElements0); - Intermediate1 = _mm_madd_epi16(Intermediate1, AElements0); + Intermediate0 = _mm_madd_epi16(Intermediate0, AElements); + Intermediate1 = _mm_madd_epi16(Intermediate1, AElements); Accumulator0 = _mm_add_epi32(Accumulator0, Intermediate0); Accumulator1 = _mm_add_epi32(Accumulator1, Intermediate1); @@ -502,7 +1086,7 @@ Return Value: C[0] = AccumulatorValue; } - break; + CountN = 0; } } @@ -511,7 +1095,94 @@ Return Value: void MLASCALL -MlasQgemm( +MlasGemm( + size_t M, + size_t N, + size_t K, + const uint8_t* A, + size_t lda, + uint8_t offa, + const int8_t* B, + size_t ldb, + int8_t offb, + int32_t* C, + size_t ldc, + MLAS_THREADPOOL* ThreadPool + ) +{ + MLAS_DECLSPEC_ALIGN(uint8_t PanelA[MLAS_GEMM_U8S8_STRIDEM * MLAS_GEMM_U8S8_STRIDEK], 64); + MLAS_DECLSPEC_ALIGN(int8_t PanelB[MLAS_GEMM_U8S8_STRIDEN * MLAS_GEMM_U8S8_STRIDEK], 64); + + MLAS_DECLSPEC_ALIGN(int32_t RowSumVector[MLAS_GEMM_U8S8_STRIDEM], 16); + MLAS_DECLSPEC_ALIGN(int32_t ColumnSumVector[MLAS_GEMM_U8S8_STRIDEN], 16); + + size_t StrideM = MLAS_GEMM_U8S8_STRIDEM; + size_t StrideN = MLAS_GEMM_U8S8_STRIDEN; + size_t StrideK = MLAS_GEMM_U8S8_STRIDEK; + + MLAS_UNREFERENCED_PARAMETER(ThreadPool); + + size_t CountK; + + for (size_t k = 0; k < K; k += CountK) { + + CountK = StrideK; + + if (CountK > (K - k)) { + CountK = K - k; + } + + size_t CountN; + + for (size_t n = 0; n < N; n += CountN) { + + CountN = StrideN; + + if (CountN > (N - n)) { + CountN = N - n; + } + + MlasPlatform.GemmU8S8CopyPackBRoutine(PanelB, B + n + k * ldb, ldb, CountN, CountK, ColumnSumVector, -int16_t(offa)); + + size_t CountM; + + for (size_t m = 0; m < M; m += CountM) { + + CountM = StrideM; + + if (CountM > (M - m)) { + CountM = M - m; + } + + MlasPlatform.GemmU8S8CopyPackARoutine(PanelA, A + k + m * lda, lda, CountM, CountK, RowSumVector, -int16_t(offb)); + + uint8_t* pa = PanelA; + int32_t* c = C + n + m * ldc; + + int32_t* RowSums = RowSumVector; + + size_t RowsRemaining = CountM; + size_t RowsHandled; + + size_t QuadCountK = (CountK + 3) / 4; + + while (RowsRemaining > 0) { + + RowsHandled = MlasPlatform.GemmU8S8Kernel(pa, PanelB, c, QuadCountK, RowsRemaining, CountN, ldc, RowSums, ColumnSumVector, int32_t(CountK) * offa * offb, k == 0); + + RowsRemaining -= RowsHandled; + c += ldc * RowsHandled; + pa += 4 * QuadCountK * RowsHandled; + RowSums += RowsHandled; + } + } + } + } +} + +void +MLASCALL +MlasGemm( size_t M, size_t N, size_t K, @@ -580,15 +1251,15 @@ MlasQgemm( size_t RowsRemaining = CountM; size_t RowsHandled; - size_t PairedCountK = (CountK + 1) / 2; + size_t PairCountK = (CountK + 1) / 2; while (RowsRemaining > 0) { - RowsHandled = MlasPlatform.GemmU8U8Kernel(pa, PanelB, c, PairedCountK, RowsRemaining, CountN, ldc, RowSums, ColumnSumVector, int32_t(CountK) * offa * offb, k == 0); + RowsHandled = MlasPlatform.GemmU8U8Kernel(pa, PanelB, c, PairCountK, RowsRemaining, CountN, ldc, RowSums, ColumnSumVector, int32_t(CountK) * offa * offb, k == 0); RowsRemaining -= RowsHandled; c += ldc * RowsHandled; - pa += 2 * PairedCountK * RowsHandled; + pa += 2 * PairCountK * RowsHandled; RowSums += RowsHandled; } } diff --git a/onnxruntime/core/mlas/lib/x86_64/AssembleAvx512Vnni.h b/onnxruntime/core/mlas/lib/x86_64/AssembleAvx512Vnni.h index bd3112bd9c..f02fc3bca4 100644 --- a/onnxruntime/core/mlas/lib/x86_64/AssembleAvx512Vnni.h +++ b/onnxruntime/core/mlas/lib/x86_64/AssembleAvx512Vnni.h @@ -140,7 +140,7 @@ Macro Description: This macro builds a VNNI instruction of the form: - instr zmm1,zmm2,DWORD PTR [BaseReg+IndexReg*Scale]{1to16} + instr zmm1,zmm2,DWORD PTR [BaseReg+IndexReg*Scale+ByteOffset]{1to16} Arguments: @@ -152,13 +152,16 @@ Arguments: BaseReg - Specifies the base register of the broadcast operand. + ByteOffset - Specifies the DWORD aligned byte offset for the broadcast + operand. + IndexReg - Specifies the optional index register of the broadcast operand. Scale - Specifies the scaling factor of the optional index register. --*/ - .macro VnniZmmZmmBroadcast Opcode, DestReg, Src1Reg, BaseReg, IndexReg, Scale + .macro VnniZmmZmmBroadcast Opcode, DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale .set Payload0, 0x02 # "0F 38" prefix .set Payload0, Payload0 + ((((.LZmmIndex_\DestReg\() >> 3) & 1) ^ 1) << 7) @@ -183,6 +186,9 @@ Arguments: .else .set ModRMByte, ModRMByte + (.LGprIndex_\BaseReg\() & 7) .endif +.if \ByteOffset\() != 0 + .set ModRMByte, ModRMByte + 0x40 # indicate disp8 byte offset +.endif .ifnes "\IndexReg\()", "" .set SibByte, 0 @@ -205,34 +211,36 @@ Arguments: .set SibByte, SibByte + (.LGprIndex_\BaseReg\() & 7) .endif -.ifnes "\IndexReg\()", "" - .byte 0x62, Payload0, Payload1, Payload2, \Opcode\(), ModRMByte, SibByte -.else .byte 0x62, Payload0, Payload1, Payload2, \Opcode\(), ModRMByte +.ifnes "\IndexReg\()", "" + .byte SibByte +.endif +.if \ByteOffset\() != 0 + .byte (\ByteOffset\() >> 2) .endif .endm - .macro VpdpbusdZmmZmmBroadcast DestReg, Src1Reg, BaseReg, IndexReg, Scale + .macro VpdpbusdZmmZmmBroadcast DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale - VnniZmmZmmBroadcast 0x50, \DestReg\(), \Src1Reg\(), \BaseReg\(), \IndexReg\(), \Scale\() + VnniZmmZmmBroadcast 0x50, \DestReg\(), \Src1Reg\(), \BaseReg\(), \ByteOffset\(), \IndexReg\(), \Scale\() .endm - .macro VpdpbusdsZmmZmmBroadcast DestReg, Src1Reg, BaseReg, IndexReg, Scale + .macro VpdpbusdsZmmZmmBroadcast DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale - VnniZmmZmmBroadcast 0x51, \DestReg\(), \Src1Reg\(), \BaseReg\(), \IndexReg\(), \Scale\() + VnniZmmZmmBroadcast 0x51, \DestReg\(), \Src1Reg\(), \BaseReg\(), \ByteOffset\(), \IndexReg\(), \Scale\() .endm - .macro VpdpwssdZmmZmmBroadcast DestReg, Src1Reg, BaseReg, IndexReg, Scale + .macro VpdpwssdZmmZmmBroadcast DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale - VnniZmmZmmBroadcast 0x52, \DestReg\(), \Src1Reg\(), \BaseReg\(), \IndexReg\(), \Scale\() + VnniZmmZmmBroadcast 0x52, \DestReg\(), \Src1Reg\(), \BaseReg\(), \ByteOffset\(), \IndexReg\(), \Scale\() .endm - .macro VpdpwssdsZmmZmmBroadcast DestReg, Src1Reg, BaseReg, IndexReg, Scale + .macro VpdpwssdsZmmZmmBroadcast DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale - VnniZmmZmmBroadcast 0x53, \DestReg\(), \Src1Reg\(), \BaseReg\(), \IndexReg\(), \Scale\() + VnniZmmZmmBroadcast 0x53, \DestReg\(), \Src1Reg\(), \BaseReg\(), \ByteOffset\(), \IndexReg\(), \Scale\() .endm diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S new file mode 100644 index 0000000000..b11363c118 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S @@ -0,0 +1,955 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + QgemmU8S8KernelAvx2.s + +Abstract: + + This module implements the kernels for the quantized integer matrix/matrix + multiply operation (QGEMM). + + This implementation uses AVX2 instructions. + +--*/ + +#include "asmmacro.h" +#include "QgemmU8X8KernelAvx2Common.h" + + .intel_syntax noprefix + +// +// Stack frame layout for the U8S8 CopyPackA routine. +// + + .equ .LGemmU8S8CopyPackAFrame_PaddedMatrixAData, -72 + .equ .LGemmU8S8CopyPackAFrame_mask, -8 + .equ .LGemmU8S8CopyPackAFrame_SavedR13, 0 + .equ .LGemmU8S8CopyPackAFrame_SavedR12, 8 + .equ .LGemmU8S8CopyPackAFrame_SavedRbx, 16 + .equ .LGemmU8S8CopyPackAFrame_SavedRbp, 24 + .equ .LGemmU8S8CopyPackAFrame_ReturnAddress, 32 + .equ .LGemmU8S8CopyPackAFrame_offb, 40 + +// +// Stack frame layout for the U8S8 CopyPackB routine. +// + + .equ .LGemmU8S8CopyPackBFrame_PaddedMatrixBData, -72 + .equ .LGemmU8S8CopyPackBFrame_Padding, -8 + .equ .LGemmU8S8CopyPackBFrame_SavedRbx, 0 + .equ .LGemmU8S8CopyPackBFrame_SavedRbp, 8 + .equ .LGemmU8S8CopyPackBFrame_ReturnAddress, 16 + .equ .LGemmU8S8CopyPackBFrame_offa, 24 + + .text + +/*++ + +Routine Description: + + This routine copies elements from the source matrix to the destination + packed buffer. + +Arguments: + + D (rdi) - Supplies the address of the destination packed buffer. + + A (rsi) - Supplies the address of the source matrix. + + lda (rdx) - Supplies the number of elements per row of the source matrix. + + CountM (rcx) - Supplies the number of rows of the source matrix to copy. + + CountK (r8) - Supplies the number of columns of the source matrix to copy. + + RowSumVector (r9) - Supplies the address of the buffer to receive the sums + of the elements from each of the rows. Each sum has also been multiplied + by the zero point offset. + + offb - Supplies the zero point offset for the other source matrix of the + matrix multiplication. + +Return Value: + + None. + +--*/ + + .globl C_UNDERSCORE(MlasGemmU8S8CopyPackAAvx2) +C_UNDERSCORE(MlasGemmU8S8CopyPackAAvx2): + + push rbp + push rbx + push r12 + push r13 + + mov r10,rdx + mov r11,rcx + lea r12,[r8+3] + and r12,NOT 3 # align CountK up to quad count + vpbroadcastw xmm8,WORD PTR .LGemmU8S8CopyPackAFrame_offb[rsp] + vpcmpeqw ymm9,ymm9,ymm9 # generate word vector [0xFFFF] + vpsrlw ymm9,ymm9,15 # generate word vector [0x0001] + vpsllw ymm0,ymm9,8 # generate word vector [0x0100] + vpor ymm9,ymm9,ymm0 # generate word vector [0x0101] + +// +// Compute the conditional load/store mask for an unaligned CountK. +// + + mov eax,r8d + and eax,15 # isolate unaligned count + add eax,3 + shr eax,2 # align unaligned count to quad count + mov DWORD PTR .LGemmU8S8CopyPackAFrame_mask[rsp],eax + vpbroadcastd xmm10,DWORD PTR .LGemmU8S8CopyPackAFrame_mask[rsp] + vpcmpgtd xmm10,xmm10,XMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip] + +// +// Zero initialize the padded stack buffers. +// + + vpxor xmm0,xmm0,xmm0 + vmovdqu YMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp],ymm0 + vmovdqu YMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp+32],ymm0 + +// +// Process 4 rows of matrix A in a loop. +// + + sub r11,4 + jb .LCopyPackA.ProcessRemainingRows + +.LCopyPackA.ProcessNextRowM4: + vpxor xmm0,xmm0,xmm0 # clear row accumulators + vpxor xmm1,xmm1,xmm1 + vpxor xmm2,xmm2,xmm2 + vpxor xmm3,xmm3,xmm3 + mov rdx,rsi + mov rcx,rdi + lea rsi,[rsi+r10*4] # advance next matrix A by 4 rows + lea rdi,[rdi+r12*4] # advance next matrix D by 4 rows + mov rbx,r8 # reload columns remaining + sub rbx,32 + jb .LCopyPackA.ProcessRemainingColumnsM4 + +.LCopyPackA.ProcessNextColumnLoopM4: + lea rax,[rdx+r10*2] # compute matrix A plus 2 rows + vmovdqu ymm4,YMMWORD PTR [rdx] + vmovdqu ymm5,YMMWORD PTR [rdx+r10] + vmovdqu ymm6,YMMWORD PTR [rax] + vmovdqu ymm7,YMMWORD PTR [rax+r10] + lea rax,[rcx+r12*2] # compute matrix D plus 2 rows + vmovdqu YMMWORD PTR [rcx],ymm4 + vmovdqu YMMWORD PTR [rcx+r12],ymm5 + vmovdqu YMMWORD PTR [rax],ymm6 + vmovdqu YMMWORD PTR [rax+r12],ymm7 + vpmaddubsw ymm4,ymm4,ymm9 # horizontal byte+byte=word per row + vpaddw ymm0,ymm0,ymm4 # add words to row accumulators + vpmaddubsw ymm5,ymm5,ymm9 + vpaddw ymm1,ymm1,ymm5 + vpmaddubsw ymm6,ymm6,ymm9 + vpaddw ymm2,ymm2,ymm6 + vpmaddubsw ymm7,ymm7,ymm9 + vpaddw ymm3,ymm3,ymm7 + add rdx,32 # advance matrix A by 32 bytes + add rcx,32 # advance matrix D by 32 bytes + sub rbx,32 # subtract columns remaining + jae .LCopyPackA.ProcessNextColumnLoopM4 + +.LCopyPackA.ProcessRemainingColumnsM4: + add rbx,32 # correct for over-subtract above + jz .LCopyPackA.ReduceRowSumVectorM4 + test bl,16 # (CountK & 16) != 0? + jz .LCopyPackA.CopyRemainingCountKLessThan16M4 + lea rax,[rdx+r10*2] # compute matrix A plus 2 rows + vmovdqu xmm4,XMMWORD PTR [rdx] + vmovdqu xmm5,XMMWORD PTR [rdx+r10] + vmovdqu xmm6,XMMWORD PTR [rax] + vmovdqu xmm7,XMMWORD PTR [rax+r10] + lea rax,[rcx+r12*2] # compute matrix D plus 2 rows + vmovdqu XMMWORD PTR [rcx],xmm4 + vmovdqu XMMWORD PTR [rcx+r12],xmm5 + vmovdqu XMMWORD PTR [rax],xmm6 + vmovdqu XMMWORD PTR [rax+r12],xmm7 + vpmaddubsw xmm4,xmm4,xmm9 # horizontal byte+byte=word per row + vpaddw ymm0,ymm0,ymm4 # add words to row accumulators + vpmaddubsw xmm5,xmm5,xmm9 + vpaddw ymm1,ymm1,ymm5 + vpmaddubsw xmm6,xmm6,xmm9 + vpaddw ymm2,ymm2,ymm6 + vpmaddubsw xmm7,xmm7,xmm9 + vpaddw ymm3,ymm3,ymm7 + add rdx,16 # advance matrix A by 16 bytes + add rcx,16 # advance matrix D by 16 bytes + test bl,15 # test for unaligned columns + jz .LCopyPackA.ReduceRowSumVectorM4 + +// +// Copy the unaligned CountK columns to a zero padded stack buffer. +// + +.LCopyPackA.CopyRemainingCountKLessThan16M4: + lea rbp,.LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp] + test bl,8 # (CountK & 8) != 0? + jz .LCopyPackA.CopyRemainingCountKLessThan8M4 + lea r13,[rdx+r10*2] # compute matrix A plus 2 rows + mov rax,QWORD PTR [rdx] + mov QWORD PTR [rbp],rax + mov rax,QWORD PTR [rdx+r10] + mov QWORD PTR [rbp+16],rax + mov rax,QWORD PTR [r13] + mov QWORD PTR [rbp+32],rax + mov rax,QWORD PTR [r13+r10] + mov QWORD PTR [rbp+48],rax + add rdx,8 + add rbp,8 # advance padded buffer destination + +.LCopyPackA.CopyRemainingCountKLessThan8M4: + test bl,4 # (CountK & 4) != 0? + jz .LCopyPackA.CopyRemainingCountKLessThan4M4 + lea r13,[rdx+r10*2] # compute matrix A plus 2 rows + mov eax,DWORD PTR [rdx] + mov DWORD PTR [rbp],eax + mov eax,DWORD PTR [rdx+r10] + mov DWORD PTR [rbp+16],eax + mov eax,DWORD PTR [r13] + mov DWORD PTR [rbp+32],eax + mov eax,DWORD PTR [r13+r10] + mov DWORD PTR [rbp+48],eax + add rdx,4 + add rbp,4 # advance padded buffer destination + +.LCopyPackA.CopyRemainingCountKLessThan4M4: + test bl,2 # (CountK & 2) != 0? + jz .LCopyPackA.CopyRemainingCountKLessThan2M4 + lea r13,[rdx+r10*2] # compute matrix A plus 2 rows + movzx eax,WORD PTR [rdx] + mov WORD PTR [rbp],ax + movzx eax,WORD PTR [rdx+r10] + mov WORD PTR [rbp+16],ax + movzx eax,WORD PTR [r13] + mov WORD PTR [rbp+32],ax + movzx eax,WORD PTR [r13+r10] + mov WORD PTR [rbp+48],ax + add rdx,2 + add rbp,2 # advance padded buffer destination + +.LCopyPackA.CopyRemainingCountKLessThan2M4: + test bl,1 # (CountK & 1) != 0? + jz .LCopyPackA.ProcessPaddedMatrixADataM4 + lea r13,[rdx+r10*2] # compute matrix A plus 2 rows + movzx eax,BYTE PTR [rdx] + mov BYTE PTR [rbp],al + movzx eax,BYTE PTR [rdx+r10] + mov BYTE PTR [rbp+16],al + movzx eax,BYTE PTR [r13] + mov BYTE PTR [rbp+32],al + movzx eax,BYTE PTR [r13+r10] + mov BYTE PTR [rbp+48],al + +// +// Process the remaining CountK columns using the zero padded stack buffer. +// + +.LCopyPackA.ProcessPaddedMatrixADataM4: + vmovdqu xmm4,XMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp] + vmovdqu xmm5,XMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp+16] + vmovdqu xmm6,XMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp+32] + vmovdqu xmm7,XMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp+48] + lea rax,[rcx+r12*2] # compute matrix D plus 2 rows + vpmaskmovd XMMWORD PTR [rcx],xmm10,xmm4 + vpmaskmovd XMMWORD PTR [rcx+r12],xmm10,xmm5 + vpmaskmovd XMMWORD PTR [rax],xmm10,xmm6 + vpmaskmovd XMMWORD PTR [rax+r12],xmm10,xmm7 + vpmaddubsw xmm4,xmm4,xmm9 # horizontal byte+byte=word per row + vpaddw ymm0,ymm0,ymm4 # add words to row accumulators + vpmaddubsw xmm5,xmm5,xmm9 + vpaddw ymm1,ymm1,ymm5 + vpmaddubsw xmm6,xmm6,xmm9 + vpaddw ymm2,ymm2,ymm6 + vpmaddubsw xmm7,xmm7,xmm9 + vpaddw ymm3,ymm3,ymm7 + +// +// Reduce the sums for the four rows of output. +// + +.LCopyPackA.ReduceRowSumVectorM4: + vphaddw ymm0,ymm0,ymm1 # reduce and interleave Sum1/Sum0 + vphaddw ymm1,ymm2,ymm3 # reduce and interleave Sum3/Sum2 + vphaddw ymm0,ymm0,ymm1 # reduce and interleave Sum3/Sum2/Sum1/Sum0 + vextracti128 xmm1,ymm0,1 # extract high pairs + vpaddw xmm0,xmm0,xmm1 # reduce low/high pairs + vpmaddwd xmm0,xmm0,xmm8 # multiply by offset and reduce 32-bit sum + vmovdqu XMMWORD PTR [r9],xmm0 + add r9,4*4 # advance row sum vector by 4 DWORDs + sub r11,4 # subtract rows remaining + jae .LCopyPackA.ProcessNextRowM4 + +.LCopyPackA.ProcessRemainingRows: + add r11,4 # correct for over-subtract above + jz .LCopyPackA.ExitRoutine + +// +// Process a single row of matrix A in a loop. +// + +.LCopyPackA.ProcessNextRowM1: + vpxor xmm0,xmm0,xmm0 # clear row accumulator + mov rdx,rsi + mov rcx,rdi + add rsi,r10 + add rdi,r12 + mov rbx,r8 # reload columns remaining + sub rbx,32 + jb .LCopyPackA.ProcessRemainingColumnsM1 + +.LCopyPackA.ProcessNextColumnLoopM1: + vmovdqu ymm4,YMMWORD PTR [rdx] + vmovdqu YMMWORD PTR [rcx],ymm4 + vpmaddubsw ymm4,ymm4,ymm9 # horizontal byte+byte=word per row + vpaddw ymm0,ymm0,ymm4 # add words to row accumulators + add rdx,32 # advance matrix A by 32 bytes + add rcx,32 # advance matrix D by 32 bytes + sub rbx,32 # subtract columns remaining + jae .LCopyPackA.ProcessNextColumnLoopM1 + +.LCopyPackA.ProcessRemainingColumnsM1: + add rbx,32 # correct for over-subtract above + jz .LCopyPackA.ReduceRowSumVectorM1 + test bl,16 # (CountK & 16) != 0? + jz .LCopyPackA.CopyRemainingCountKLessThan16M1 + vmovdqu xmm4,XMMWORD PTR [rdx] + vmovdqu XMMWORD PTR [rcx],xmm4 + vpmaddubsw xmm4,xmm4,xmm9 # horizontal byte+byte=word per row + vpaddw ymm0,ymm0,ymm4 # add words to row accumulators + add rdx,16 # advance matrix A by 16 bytes + add rcx,16 # advance matrix D by 16 bytes + test bl,15 # test for unaligned columns + jz .LCopyPackA.ReduceRowSumVectorM1 + +// +// Copy the unaligned CountK columns to a zero padded stack buffer. +// + +.LCopyPackA.CopyRemainingCountKLessThan16M1: + lea rbp,.LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp] + test bl,8 # (CountK & 8) != 0? + jz .LCopyPackA.CopyRemainingCountKLessThan8M1 + mov rax,QWORD PTR [rdx] + mov QWORD PTR [rbp],rax + add rdx,8 + add rbp,8 # advance padded buffer destination + +.LCopyPackA.CopyRemainingCountKLessThan8M1: + test bl,4 # (CountK & 4) != 0? + jz .LCopyPackA.CopyRemainingCountKLessThan4M1 + mov eax,DWORD PTR [rdx] + mov DWORD PTR [rbp],eax + add rdx,4 + add rbp,4 # advance padded buffer destination + +.LCopyPackA.CopyRemainingCountKLessThan4M1: + test bl,2 # (CountK & 2) != 0? + jz .LCopyPackA.CopyRemainingCountKLessThan2M1 + movzx eax,WORD PTR [rdx] + mov WORD PTR [rbp],ax + add rdx,2 + add rbp,2 # advance padded buffer destination + +.LCopyPackA.CopyRemainingCountKLessThan2M1: + test bl,1 # (CountK & 1) != 0? + jz .LCopyPackA.ProcessPaddedMatrixADataM1 + movzx eax,BYTE PTR [rdx] + mov BYTE PTR [rbp],al + +// +// Process the remaining CountK columns using the zero padded stack buffer. +// + +.LCopyPackA.ProcessPaddedMatrixADataM1: + vmovdqu xmm4,XMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp] + vpmaskmovd XMMWORD PTR [rcx],xmm10,xmm4 + vpmaddubsw ymm4,ymm4,ymm9 # horizontal byte+byte=word per row + vpaddw ymm0,ymm0,ymm4 # accumulate per row along columns + +// +// Reduce the sum for the single row of output. +// + +.LCopyPackA.ReduceRowSumVectorM1: + vextracti128 xmm1,ymm0,1 # extract high pairs + vpaddw xmm0,xmm0,xmm1 # reduction + vphaddw xmm0,xmm0,xmm0 + vphaddw xmm0,xmm0,xmm0 + vpmaddwd xmm0,xmm0,xmm8 # multiply by offset and reduce + vmovd DWORD PTR [r9],xmm0 + add r9,4 # advance row sum vector by 1 DWORD + dec r11 # decrement rows remaining + jnz .LCopyPackA.ProcessNextRowM1 + +// +// Restore non-volatile registers and return. +// + +.LCopyPackA.ExitRoutine: + vzeroupper + + pop r13 + pop r12 + pop rbx + pop rbp + ret + +/*++ + +Routine Description: + + This routine copies elements from the source matrix to the destination + packed buffer. + +Arguments: + + D (rdi) - Supplies the address of the destination packed buffer. + + B (rsi) - Supplies the address of the source matrix. + + ldb (rdx) - Supplies the number of elements per row of the source matrix. + + CountN (rcx) - Supplies the number of columns of the source matrix to copy. + + CountK (r8) - Supplies the number of rows of the source matrix to copy. + + ColumnSumVector (r9) - Supplies the address of the buffer to receive the sums + of the elements from each of the columns. Each sum has also been + multiplied by the zero point offset. + + offa - Supplies the zero point offset for the other source matrix of the + matrix multiplication. + +Return Value: + + None. + +--*/ + + .globl C_UNDERSCORE(MlasGemmU8S8CopyPackBAvx2) +C_UNDERSCORE(MlasGemmU8S8CopyPackBAvx2): + + push rbp + push rbx + + mov r10,rdx + vpbroadcastw ymm7,WORD PTR .LGemmU8S8CopyPackBFrame_offa[rsp] + vpcmpeqw ymm8,ymm8,ymm8 # generate word vector [0xFFFF] + vpsrlw ymm8,ymm8,15 # generate word vector [0x0001] + vpsllw ymm0,ymm8,8 # generate word vector [0x0100] + vpor ymm8,ymm8,ymm0 # generate word vector [0x0101] + +// +// Process 16 columns of matrix B in a loop. +// + + sub rcx,16 + jb .LCopyPackB.ProcessRemainingColumns + +.LCopyPackB.ProcessNextColumnN16: + vpxor xmm0,xmm0,xmm0 # clear column accumulators + vpxor xmm1,xmm1,xmm1 + mov rdx,rsi + add rsi,16 # advance next matrix B by 16 columns + mov rbx,r8 # reload rows remaining + sub rbx,4 + jb .LCopyPackB.ProcessRemainingRowsN16 + +.LCopyPackB.ProcessNextRowLoopN16: + lea rax,[rdx+r10*2] # compute matrix B plus 2 rows + vmovdqu xmm2,XMMWORD PTR [rdx] # load 4 rows + vmovdqu xmm3,XMMWORD PTR [rdx+r10] + vmovdqu xmm4,XMMWORD PTR [rax] + vmovdqu xmm5,XMMWORD PTR [rax+r10] + lea rdx,[rdx+r10*4] # advance matrix B by 4 rows + +.LCopyPackB.InterleaveRowDataN16: + vpunpcklbw xmm6,xmm2,xmm3 # interleave row data + vpunpckhbw xmm3,xmm2,xmm3 + vpunpcklbw xmm2,xmm4,xmm5 + vpunpckhbw xmm5,xmm4,xmm5 + vpunpcklwd xmm4,xmm6,xmm2 + vpunpckhwd xmm6,xmm6,xmm2 + vpunpcklwd xmm2,xmm3,xmm5 + vpunpckhwd xmm3,xmm3,xmm5 + vinsertf128 ymm4,ymm4,xmm6,1 + vinsertf128 ymm2,ymm2,xmm3,1 + vmovdqu YMMWORD PTR [rdi],ymm4 # store interleaved rows + vmovdqu YMMWORD PTR [rdi+32],ymm2 + vpmaddubsw ymm4,ymm8,ymm4 # horizontal byte+byte=word per row + vpaddw ymm0,ymm0,ymm4 # add words to row accumulators + vpmaddubsw ymm2,ymm8,ymm2 + vpaddw ymm1,ymm1,ymm2 + add rdi,64 # advance matrix D by 64 bytes + sub rbx,4 # subtract rows remaining + jae .LCopyPackB.ProcessNextRowLoopN16 + +// +// Process the less than 4 remaining rows where the row has 16 columns. +// + +.LCopyPackB.ProcessRemainingRowsN16: + add rbx,4 # correct for over-subtract above + jz .LCopyPackB.ReduceColumnSumVectorN16 + vmovdqu xmm2,XMMWORD PTR [rdx] + vpxor xmm3,xmm3,xmm3 + vpxor xmm4,xmm4,xmm4 + vpxor xmm5,xmm5,xmm5 + xor ebx,ebx # no more rows remaining + test r8b,2 # (CountK & 2) != 0? + jz .LCopyPackB.InterleaveRowDataN16 + vmovdqu xmm3,XMMWORD PTR [rdx+r10] + test r8b,1 # (CountK & 1) != 0? + jz .LCopyPackB.InterleaveRowDataN16 + vmovdqu xmm4,XMMWORD PTR [rdx+r10*2] + jmp .LCopyPackB.InterleaveRowDataN16 + +.LCopyPackB.ReduceColumnSumVectorN16: + vpmaddwd ymm0,ymm0,ymm7 # multiply by offset and reduce + vpmaddwd ymm1,ymm1,ymm7 # multiply by offset and reduce + vmovdqu YMMWORD PTR [r9],ymm0 + vmovdqu YMMWORD PTR [r9+32],ymm1 + add r9,16*4 # advance column sum vector by 16 DWORDs + sub rcx,16 # subtract columns remaining + jae .LCopyPackB.ProcessNextColumnN16 + +.LCopyPackB.ProcessRemainingColumns: + add rcx,16 # correct for over-subtract above + jnz .LCopyPackB.ProcessColumnNUnaligned + +// +// Restore non-volatile registers and return. +// + +.LCopyPackB.ExitRoutine: + vzeroupper + + pop rbx + pop rbp + ret + +// +// Process the remaining columns of matrix B. +// + +.LCopyPackB.ProcessColumnNUnaligned: + vpxor xmm0,xmm0,xmm0 # clear column accumulators + vpxor xmm1,xmm1,xmm1 + vmovdqu YMMWORD PTR .LGemmU8S8CopyPackBFrame_PaddedMatrixBData[rsp],ymm0 + vmovdqu YMMWORD PTR .LGemmU8S8CopyPackBFrame_PaddedMatrixBData[rsp+32],ymm0 + sub r8,4 + jb .LCopyPackB.ProcessRemainingRowsNUnaligned + +.LCopyPackB.ProcessNextRowLoopNUnaligned: + mov rdx,rsi + lea rbp,.LGemmU8S8CopyPackBFrame_PaddedMatrixBData[rsp] + test cl,8 # (CountN & 8) != 0? + jz .LCopyPackB.CopyRemainingCountNLessThan8K4 + lea r11,[rdx+r10*2] # compute matrix B plus 2 rows + mov rax,QWORD PTR [rdx] + mov QWORD PTR [rbp],rax + mov rax,QWORD PTR [rdx+r10] + mov QWORD PTR [rbp+16],rax + mov rax,QWORD PTR [r11] + mov QWORD PTR [rbp+32],rax + mov rax,QWORD PTR [r11+r10] + mov QWORD PTR [rbp+48],rax + add rdx,8 # advance matrix B + add rbp,8 # advance padded buffer destination + +.LCopyPackB.CopyRemainingCountNLessThan8K4: + test cl,4 # (CountN & 4) != 0? + jz .LCopyPackB.CopyRemainingCountNLessThan4K4 + lea r11,[rdx+r10*2] # compute matrix B plus 2 rows + mov eax,DWORD PTR [rdx] + mov DWORD PTR [rbp],eax + mov eax,DWORD PTR [rdx+r10] + mov DWORD PTR [rbp+16],eax + mov eax,DWORD PTR [r11] + mov DWORD PTR [rbp+32],eax + mov eax,DWORD PTR [r11+r10] + mov DWORD PTR [rbp+48],eax + add rdx,4 # advance matrix B + add rbp,4 # advance padded buffer destination + +.LCopyPackB.CopyRemainingCountNLessThan4K4: + test cl,2 # (CountN & 2) != 0? + jz .LCopyPackB.CopyRemainingCountNLessThan2K4 + lea r11,[rdx+r10*2] # compute matrix B plus 2 rows + movzx eax,WORD PTR [rdx] + mov WORD PTR [rbp],ax + movzx eax,WORD PTR [rdx+r10] + mov WORD PTR [rbp+16],ax + movzx eax,WORD PTR [r11] + mov WORD PTR [rbp+32],ax + movzx eax,WORD PTR [r11+r10] + mov WORD PTR [rbp+48],ax + add rdx,2 # advance matrix B + add rbp,2 # advance padded buffer destination + +.LCopyPackB.CopyRemainingCountNLessThan2K4: + test cl,1 # (CountN & 1) != 0? + jz .LCopyPackB.ProcessPaddedMatrixBData + lea r11,[rdx+r10*2] # compute matrix B plus 2 rows + movzx eax,BYTE PTR [rdx] + mov BYTE PTR [rbp],al + movzx eax,BYTE PTR [rdx+r10] + mov BYTE PTR [rbp+16],al + movzx eax,BYTE PTR [r11] + mov BYTE PTR [rbp+32],al + movzx eax,BYTE PTR [r11+r10] + mov BYTE PTR [rbp+48],al + +.LCopyPackB.ProcessPaddedMatrixBData: + vmovdqu xmm2,XMMWORD PTR .LGemmU8S8CopyPackBFrame_PaddedMatrixBData[rsp] + vmovdqu xmm3,XMMWORD PTR .LGemmU8S8CopyPackBFrame_PaddedMatrixBData[rsp+16] + vmovdqu xmm4,XMMWORD PTR .LGemmU8S8CopyPackBFrame_PaddedMatrixBData[rsp+32] + vmovdqu xmm5,XMMWORD PTR .LGemmU8S8CopyPackBFrame_PaddedMatrixBData[rsp+48] + vpunpcklbw xmm6,xmm2,xmm3 # interleave row data + vpunpckhbw xmm3,xmm2,xmm3 + vpunpcklbw xmm2,xmm4,xmm5 + vpunpckhbw xmm5,xmm4,xmm5 + vpunpcklwd xmm4,xmm6,xmm2 + vpunpckhwd xmm6,xmm6,xmm2 + vpunpcklwd xmm2,xmm3,xmm5 + vpunpckhwd xmm3,xmm3,xmm5 + vinsertf128 ymm4,ymm4,xmm6,1 + vinsertf128 ymm2,ymm2,xmm3,1 + vmovdqu YMMWORD PTR [rdi],ymm4 # store interleaved rows + vmovdqu YMMWORD PTR [rdi+32],ymm2 + vpmaddubsw ymm4,ymm8,ymm4 # horizontal byte+byte=word per row + vpaddw ymm0,ymm0,ymm4 # add words to row accumulators + vpmaddubsw ymm2,ymm8,ymm2 + vpaddw ymm1,ymm1,ymm2 + lea rsi,[rsi+r10*4] # advance next matrix B by 4 rows + add rdi,64 # advance matrix D by 64 bytes + sub r8,4 # subtract rows remaining + jae .LCopyPackB.ProcessNextRowLoopNUnaligned + +.LCopyPackB.ProcessRemainingRowsNUnaligned: + add r8,4 + jz .LCopyPackB.ReduceColumnSumVectorNUnaligned + +// +// Process the less than 4 remaining rows where the row has less than 16 columns. +// + + lea rbp,.LGemmU8S8CopyPackBFrame_PaddedMatrixBData[rsp] + vpxor xmm6,xmm6,xmm6 + vmovdqu YMMWORD PTR [rbp],ymm6 + vmovdqu YMMWORD PTR [rbp+32],ymm6 + +.LCopyPackB.CopyUnalignedRowLoop: + lea r11,[rbp+16] # advance next padded buffer by 16 bytes + mov rdx,rsi + test cl,8 # (CountN & 8) != 0? + jz .LCopyPackB.CopyRemainingCountNLessThan8KSmall + mov rax,QWORD PTR [rdx] + mov QWORD PTR [rbp],rax + add rdx,8 # advance matrix B + add rbp,8 # advance padded buffer destination + +.LCopyPackB.CopyRemainingCountNLessThan8KSmall: + test cl,4 # (CountN & 4) != 0? + jz .LCopyPackB.CopyRemainingCountNLessThan4KSmall + mov eax,DWORD PTR [rdx] + mov DWORD PTR [rbp],eax + add rdx,4 # advance matrix B + add rbp,4 # advance padded buffer destination + +.LCopyPackB.CopyRemainingCountNLessThan4KSmall: + test cl,2 # (CountN & 2) != 0? + jz .LCopyPackB.CopyRemainingCountNLessThan2KSmall + movzx eax,WORD PTR [rdx] + mov WORD PTR [rbp],ax + add rdx,2 # advance matrix B + add rbp,2 # advance padded buffer destination + +.LCopyPackB.CopyRemainingCountNLessThan2KSmall: + test cl,1 # (CountN & 1) != 0? + jz .LCopyPackB.DoneCopyRemainingCountNKSmall + movzx eax,BYTE PTR [rdx] + mov BYTE PTR [rbp],al + +.LCopyPackB.DoneCopyRemainingCountNKSmall: + dec r8 + jz .LCopyPackB.ProcessPaddedMatrixBData + add rsi,r10 # advance next matrix B by 1 row + mov rbp,r11 + jmp .LCopyPackB.CopyUnalignedRowLoop + +.LCopyPackB.ReduceColumnSumVectorNUnaligned: + vpmaddwd ymm0,ymm0,ymm7 # multiply by offset and reduce + vpmaddwd ymm1,ymm1,ymm7 # multiply by offset and reduce + vmovdqu YMMWORD PTR [r9],ymm0 + vmovdqu YMMWORD PTR [r9+32],ymm1 + jmp .LCopyPackB.ExitRoutine + +/*++ + +Macro Description: + + This macro generates code to multiply and accumulator a single row of the + output block. + +Arguments: + + ColumnCount - Supplies the number of columns to produce. + + Vec1Reg - Supplies the high block accumulator register (when ColumnCount + is 16). + + Vec2Reg - Supplies the low block accumulator register. + +Implicit Arguments: + + ymm0 - Supplies the first vector loaded from matrix B. + + ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount + is 16). + + ymm2 - Supplies the broadcast value loaded from matrix A. + + ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001. + +--*/ + + .macro MultiplyAccumulateRow ColumnCount, Vec1Reg, Vec2Reg + + vpmaddubsw ymm3,ymm2,ymm0 + vpmaddwd ymm3,ymm3,ymm12 +.if \ColumnCount\() == 16 + vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3 + vpmaddubsw ymm2,ymm2,ymm1 + vpmaddwd ymm2,ymm2,ymm12 + vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2 +.else + vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm3 +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to multiply and accumulate each row of the output + block. + +Arguments: + + ColumnCount - Supplies the number of columns to produce. + + RowCount - Supplies the number of rows to produce. + + VectorOffset - Supplies the byte offset from matrix B to fetch elements. + + BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. + +Implicit Arguments: + + rbx - Supplies the address into the matrix A data plus 3 rows. + + rcx - Supplies the address into the matrix A data. + + rsi - Supplies the address into the matrix B data. + + rcx - Supplies the length in bytes of a row from matrix A. + + ymm4-ymm11 - Supplies the block accumulators. + + ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001. + +--*/ + + .macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset + +.if \RowCount\() == 1 + vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()] + vpmaddubsw ymm3,ymm2,YMMWORD PTR [rsi+\VectorOffset\()] + vpmaddwd ymm3,ymm3,ymm12 +.if \ColumnCount\() == 16 + vpaddd ymm4,ymm4,ymm3 + vpmaddubsw ymm2,ymm2,YMMWORD PTR [rsi+\VectorOffset\()+32] + vpmaddwd ymm2,ymm2,ymm12 + vpaddd ymm5,ymm5,ymm2 +.else + vpaddd ymm5,ymm5,ymm3 +.endif +.else + vmovdqu ymm0,YMMWORD PTR [rsi+\VectorOffset\()] + EmitIfCountGE \ColumnCount\(), 16, "vmovdqu ymm1,YMMWORD PTR [rsi+\VectorOffset\()+32]" + EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRow \ColumnCount\(), ymm4, ymm5" + EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRow \ColumnCount\(), ymm6, ymm7" + EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRow \ColumnCount\(), ymm8, ymm9" + EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [rbx+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRow \ColumnCount\(), ymm10, ymm11" +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to execute the block compute macro multiple + times and advancing the matrix A and matrix B data pointers. + +Arguments: + + ColumnCount - Supplies the number of columns to produce. + + RowCount - Supplies the number of rows to produce. + +Implicit Arguments: + + rbx - Supplies the address into the matrix A data plus 3 rows. + + rdi - Supplies the address into the matrix A data. + + rsi - Supplies the address into the matrix B data. + + rcx - Supplies the length in bytes of a row from matrix A. + + ymm4-ymm11 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlockLoop ColumnCount, RowCount + + mov rbp,rcx # reload row length remaining + +.LComputeBlockBy1Loop\@: + ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0 + add rdi,4 # advance matrix A by 1 quad +.if \RowCount\() > 3 + add rbx,4 # advance matrix A plus 3 rows by 1 quad +.endif + add rsi,64 # advance matrix B + sub rbp,4 + jnz .LComputeBlockBy1Loop\@ + + .endm + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (rdi) - Supplies the address of matrix A. The matrix data has been packed + using MlasGemmU8S8CopyPackAAvx2. + + B (rsi) - Supplies the address of matrix B. The matrix data has been packed + using MlasGemmU8S8CopyPackBAvx2. + + C (rdx) - Supplies the address of matrix C. + + QuadCountK (rcx) - Supplies the number of quad columns from matrix A and + the number of quad rows from matrix B to iterate over. + + CountM (r8) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (r9) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + ldc - Supplies the first dimension of matrix C. + + RowSumVector - Supplies the sum of each row from matrix A multiplied by the + zero point offset of matrix B. These values are accumulated into every + row of matrix C. + + ColumnSumVector - Supplies the sum of each column from matrix B multiplied + by the zero point offset of matrix A. These values are accumulated into + every column of matrix C. + + DepthValue - Supplies the value CountK multiplied by the zero point offset + of matrix A multplied by the zero point offset of matrix B. This value + is accumulated into every element of matrix C. + + ZeroMode - Supplies true if the output matrix must be zero initialized, + else false if the output matrix is accumulated into. + +Return Value: + + Returns the number of rows handled. + +--*/ + + .globl C_UNDERSCORE(MlasGemmU8S8KernelAvx2) +C_UNDERSCORE(MlasGemmU8S8KernelAvx2): + + push rbp + push rbx + push r12 + push r13 + + mov rax,.LGemmU8X8KernelFrame_ldc[rsp] + shl rax,2 # convert ldc to bytes + shl rcx,2 # convert to row length + movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp] + mov r11,rdi + mov r12,.LGemmU8X8KernelFrame_RowSumVector[rsp] + mov r13,.LGemmU8X8KernelFrame_ColumnSumVector[rsp] + vpcmpeqw ymm12,ymm12,ymm12 # generate 256-bit word vector [0xFFFF] + vpsrlw ymm12,ymm12,15 # generate 256-bit word vector [0x0001] + +// +// Process CountM rows of the matrices. +// + + cmp r8,3 + ja .LProcessCountM4 + je .LProcessCountM3 + cmp r8,1 + je .LProcessCountM1 + +.LProcessCountM2: + ProcessCountM 2 + +.LProcessCountM4: + mov r8d,4 # return 4 rows handled + ProcessCountM 4, Fallthrough + +// +// Restore non-volatile registers and return. +// + +.LExitKernel: + mov eax,r8d + vzeroupper + + pop r13 + pop r12 + pop rbx + pop rbp + ret + +.LProcessCountM1: + ProcessCountM 1 + +.LProcessCountM3: + ProcessCountM 3 + + .end diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512BW.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512BW.S new file mode 100644 index 0000000000..c994545362 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512BW.S @@ -0,0 +1,136 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + QgemmU8S8KernelAvx512BW.s + +Abstract: + + This module implements the kernels for the quantized integer matrix/matrix + multiply operation (QGEMM). + + This implementation uses AVX512BW instructions. + +--*/ + +#include "asmmacro.h" +#include "QgemmU8S8KernelAvx512Common.h" + + .intel_syntax noprefix + + .text + +/*++ + +Macro Description: + + This macro generates code to multiply and accumulator a single cell of the + output block. + +Arguments: + + AccumReg - Supplies the register to accumulate into. + + Mult1Reg - Supplies the first multiplication operand register. + + Mult2Reg - Supplies the second multiplication operand register. + +Implicit Arguments: + + zmm4 - Supplies a scratch register for intermediate results. + + zmm5 - Supplies a 512-bit with the broadcasted word value 0x0001. + +--*/ + + .macro MultiplyAccumulateCell AccumReg, Mult1Reg, Mult2Reg + + vpmaddubsw zmm4,\Mult1Reg\(),\Mult2Reg\() + vpmaddwd zmm4,zmm4,zmm5 + vpaddd \AccumReg\(),\AccumReg\(),zmm4 + + .endm + +/*++ + +Macro Description: + + This macro generates code to multiply and accumulate each row of the output + block. + +Arguments: + + ColumnCount - Supplies the number of columns to produce. + + RowCount - Supplies the number of rows to produce. + + VectorOffset - Supplies the byte offset from matrix B to fetch elements. + + BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. + +Implicit Arguments: + + rbx - Supplies the address into the matrix A data plus 3 rows. + + rdi - Supplies the address into the matrix A data. + + rsi - Supplies the address into the matrix B data. + + rcx - Supplies the length in bytes of a row from matrix A. + + r14 - Supplies the stride in bytes of between packed blocks of matrix B. + + zmm14-zmm31 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset + +.if \ColumnCount\() >= 48 + vmovdqu32 zmm0,ZMMWORD PTR [rsi+\VectorOffset\()] + vmovdqu32 zmm1,ZMMWORD PTR [rsi+r14+\VectorOffset\()] + vmovdqu32 zmm2,ZMMWORD PTR [rsi+r14*2+\VectorOffset\()] +.elseif \ColumnCount\() >= 32 + vmovdqu32 zmm1,ZMMWORD PTR [rsi+\VectorOffset\()] + vmovdqu32 zmm2,ZMMWORD PTR [rsi+r14+\VectorOffset\()] +.else + vmovdqu32 zmm2,ZMMWORD PTR [rsi+\VectorOffset\()] +.endif + EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm3,DWORD PTR [rdi+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm26,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm20,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm14,zmm3,zmm2" + EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm27,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm21,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm15,zmm3,zmm2" + EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm28,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm22,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm16,zmm3,zmm2" + EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm3,DWORD PTR [rbx+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm29,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm23,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm17,zmm3,zmm2" + EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm30,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm24,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm18,zmm3,zmm2" + EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm31,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm25,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm19,zmm3,zmm2" + + .endm + +// +// Generate the GEMM kernel. +// + +GemmU8X8KernelAvx512Function U8S8, Avx512BW + + .end diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Common.h b/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Common.h new file mode 100644 index 0000000000..41437286a5 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Common.h @@ -0,0 +1,88 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + QgemmU8S8KernelAvx512Common.h + +Abstract: + + This module contains common kernel macros and structures for the quantized + integer matrix/matrix multiply operation (QGEMM) for the AVX512BW and + AVX512VNNI kernels. + +--*/ + +#include "QgemmU8X8KernelAvx512Common.h" + +/*++ + +Macro Description: + + This macro generates code to execute the block compute macro multiple + times and advancing the matrix A and matrix B data pointers. + +Arguments: + + ColumnCount - Supplies the number of columns to produce. + + RowCount - Supplies the number of rows to produce. + +Implicit Arguments: + + rbx - Supplies the address into the matrix A data plus 3 rows. + + rdi - Supplies the address into the matrix A data. + + rsi - Supplies the address into the matrix B data. + + rcx - Supplies the length in bytes of a row from matrix A. + + r14 - Supplies the stride in bytes of between packed blocks of matrix B. + + zmm14-zmm31 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlockLoop ColumnCount, RowCount + + mov rbp,rcx # reload row length remaining + +.if ((\RowCount\() & 1) == 0) + sub rbp,4*4 + jb .LProcessRemainingBlocks\@ + +.LComputeBlockBy4Loop\@: + ComputeBlock \ColumnCount\(), \RowCount\(), 0*64, 0 + ComputeBlock \ColumnCount\(), \RowCount\(), 1*64, 4 + ComputeBlock \ColumnCount\(), \RowCount\(), 2*64, 8 + ComputeBlock \ColumnCount\(), \RowCount\(), 3*64, 12 + add rdi,4*4 # advance matrix A by 1 quad +.if \RowCount\() > 3 + add rbx,4*4 # advance matrix A plus 3 rows by 1 quad +.endif + add rsi,4*64 # advance matrix B + sub rbp,4*4 # decrement quads remaining + jae .LComputeBlockBy4Loop\@ + +.LProcessRemainingBlocks\@: + add rbp,4*4 # correct for over-subtract above + jz .LComputeBlockLoopExit\@ +.endif + +.LComputeBlockBy1Loop\@: + ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0 + add rdi,4 # advance matrix A by 1 quad +.if \RowCount\() > 3 + add rbx,4 # advance matrix A plus 3 rows by 1 quad +.endif + add rsi,64 # advance matrix B + sub rbp,4 # decrement quads remaining + jnz .LComputeBlockBy1Loop\@ + +.LComputeBlockLoopExit\@: + + .endm diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Vnni.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Vnni.S new file mode 100644 index 0000000000..bf806d9265 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Vnni.S @@ -0,0 +1,106 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + QgemmU8S8KernelAvx512Vnni.s + +Abstract: + + This module implements the kernels for the quantized integer matrix/matrix + multiply operation (QGEMM). + + This implementation uses AVX512VNNI instructions. + +--*/ + +#include "asmmacro.h" +#include "QgemmU8S8KernelAvx512Common.h" +#include "AssembleAvx512Vnni.h" + + .intel_syntax noprefix + + .text + +/*++ + +Macro Description: + + This macro generates code to multiply and accumulate each row of the output + block. + +Arguments: + + ColumnCount - Supplies the number of columns to produce. + + RowCount - Supplies the number of rows to produce. + + VectorOffset - Supplies the byte offset from matrix B to fetch elements. + + BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. + +Implicit Arguments: + + rbx - Supplies the address into the matrix A data plus 3 rows. + + rdi - Supplies the address into the matrix A data. + + rsi - Supplies the address into the matrix B data. + + rcx - Supplies the length in bytes of a row from matrix A. + + r14 - Supplies the stride in bytes of between packed blocks of matrix B. + + zmm14-zmm31 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset + +.if \ColumnCount\() >= 48 + vmovdqu32 zmm0,ZMMWORD PTR [rsi+\VectorOffset\()] + vmovdqu32 zmm1,ZMMWORD PTR [rsi+r14+\VectorOffset\()] + vmovdqu32 zmm2,ZMMWORD PTR [rsi+r14*2+\VectorOffset\()] +.elseif \ColumnCount\() >= 32 + vmovdqu32 zmm1,ZMMWORD PTR [rsi+\VectorOffset\()] + vmovdqu32 zmm2,ZMMWORD PTR [rsi+r14+\VectorOffset\()] +.else + vmovdqu32 zmm2,ZMMWORD PTR [rsi+\VectorOffset\()] +.endif + EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm3,DWORD PTR [rdi+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm26,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm20,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm14,zmm3,zmm2" + EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm27,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm21,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm15,zmm3,zmm2" + EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm28,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm22,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm16,zmm3,zmm2" + EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm3,DWORD PTR [rbx+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm29,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm23,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm17,zmm3,zmm2" + EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm30,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm24,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm18,zmm3,zmm2" + EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm31,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm25,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm19,zmm3,zmm2" + + .endm + +// +// Generate the GEMM kernel. +// + +GemmU8X8KernelAvx512Function U8S8, Avx512Vnni + + .end diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx2.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx2.S index f53059e313..d443561f62 100644 --- a/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx2.S +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx2.S @@ -18,11 +18,10 @@ Abstract: --*/ #include "asmmacro.h" +#include "QgemmU8X8KernelAvx2Common.h" .intel_syntax noprefix - .text - // // Stack frame layout for the U8U8 CopyPackA routine. // @@ -47,22 +46,7 @@ Abstract: .equ .LGemmU8U8CopyPackBFrame_ReturnAddress, 16 .equ .LGemmU8U8CopyPackBFrame_offa, 24 -// -// Stack frame layout for the U8U8 kernel. -// - - .equ .LGemmU8U8KernelFrame_mask, -8 - .equ .LGemmU8U8KernelFrame_SavedR14, 0 - .equ .LGemmU8U8KernelFrame_SavedR13, 8 - .equ .LGemmU8U8KernelFrame_SavedR12, 16 - .equ .LGemmU8U8KernelFrame_SavedRbx, 24 - .equ .LGemmU8U8KernelFrame_SavedRbp, 32 - .equ .LGemmU8U8KernelFrame_ReturnAddress, 40 - .equ .LGemmU8U8KernelFrame_ldc, 48 - .equ .LGemmU8U8KernelFrame_RowSumVector, 56 - .equ .LGemmU8U8KernelFrame_ColumnSumVector, 64 - .equ .LGemmU8U8KernelFrame_DepthValue, 72 - .equ .LGemmU8U8KernelFrame_ZeroMode, 80 + .text /*++ @@ -138,13 +122,15 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2): // // Process 4 rows of matrix A in a loop. // -// For each row, zero extend the source bytes to 16-bits and write to the packed -// buffer. The packed buffer has the same data ordering as the source bytes, but -// the stride is CountK aligned up to an even number of 16-bit values. +// Zero extend the source bytes to 16-bits and write to the packed buffer. +// +// The packed buffer has the same data ordering as the source bytes, but CountK +// is aligned up to a multiple of 2 to maintain 32-bit alignment. All padding +// bytes are zero filled. // // These 16-bit values are also accumulated into an intermediate per-row -// accumulator. CountK cannot be greater than 256 to avoid overflowing these -// 16-bit accumulators. +// accumulator. CountK cannot be greater than 128 to avoid overflowing these +// signed 16-bit accumulators. // sub r11,4 @@ -164,12 +150,12 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2): jb .LCopyPackA.ProcessRemainingColumnsM4 .LCopyPackA.ProcessNextColumnLoopM4: - lea rax,[rdx+r10*2] # compute matrix A plus two rows + lea rax,[rdx+r10*2] # compute matrix A plus 2 rows vpmovzxbw ymm4,XMMWORD PTR [rdx] vpmovzxbw ymm5,XMMWORD PTR [rdx+r10] vpmovzxbw ymm6,XMMWORD PTR [rax] vpmovzxbw ymm7,XMMWORD PTR [rax+r10] - lea rax,[rcx+r12*4] # compute matrix D plus two rows + lea rax,[rcx+r12*4] # compute matrix D plus 2 rows vmovdqu YMMWORD PTR [rcx],ymm4 vmovdqu YMMWORD PTR [rcx+r12*2],ymm5 vmovdqu YMMWORD PTR [rax],ymm6 @@ -194,7 +180,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2): lea rbp,.LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp] test bl,8 # (CountK & 8) != 0? jz .LCopyPackA.CopyRemainingCountKLessThan8M4 - lea r13,[rdx+r10*2] # compute matrix A plus two rows + lea r13,[rdx+r10*2] # compute matrix A plus 2 rows mov rax,QWORD PTR [rdx] mov QWORD PTR [rbp],rax mov rax,QWORD PTR [rdx+r10] @@ -209,7 +195,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2): .LCopyPackA.CopyRemainingCountKLessThan8M4: test bl,4 # (CountK & 4) != 0? jz .LCopyPackA.CopyRemainingCountKLessThan4M4 - lea r13,[rdx+r10*2] # compute matrix A plus two rows + lea r13,[rdx+r10*2] # compute matrix A plus 2 rows mov eax,DWORD PTR [rdx] mov DWORD PTR [rbp],eax mov eax,DWORD PTR [rdx+r10] @@ -224,7 +210,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2): .LCopyPackA.CopyRemainingCountKLessThan4M4: test bl,2 # (CountK & 2) != 0? jz .LCopyPackA.CopyRemainingCountKLessThan2M4 - lea r13,[rdx+r10*2] # compute matrix A plus two rows + lea r13,[rdx+r10*2] # compute matrix A plus 2 rows movzx eax,WORD PTR [rdx] mov WORD PTR [rbp],ax movzx eax,WORD PTR [rdx+r10] @@ -239,7 +225,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2): .LCopyPackA.CopyRemainingCountKLessThan2M4: test bl,1 # (CountK & 1) != 0? jz .LCopyPackA.ProcessPaddedMatrixADataM4 - lea r13,[rdx+r10*2] # compute matrix A plus two rows + lea r13,[rdx+r10*2] # compute matrix A plus 2 rows movzx eax,BYTE PTR [rdx] mov BYTE PTR [rbp],al movzx eax,BYTE PTR [rdx+r10] @@ -258,7 +244,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2): vpmovzxbw ymm5,XMMWORD PTR .LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp+16] vpmovzxbw ymm6,XMMWORD PTR .LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp+32] vpmovzxbw ymm7,XMMWORD PTR .LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp+48] - lea rax,[rcx+r12*4] # compute matrix D plus two rows + lea rax,[rcx+r12*4] # compute matrix D plus 2 rows vpmaskmovd YMMWORD PTR [rcx],ymm9,ymm4 vpmaskmovd YMMWORD PTR [rcx+r12*2],ymm9,ymm5 vpmaskmovd YMMWORD PTR [rax],ymm9,ymm6 @@ -276,20 +262,12 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2): // .LCopyPackA.ReduceRowSumVectorM4: - vpunpckldq ymm4,ymm0,ymm1 # [A5 B5 A4 B4 A1 B1 A0 B0] - vpunpckhdq ymm5,ymm0,ymm1 # [A7 B7 A6 B6 A3 B3 A2 B2] - vpunpckldq ymm6,ymm2,ymm3 # [C5 D5 C4 D4 C1 D1 C0 D0] - vpunpckhdq ymm7,ymm2,ymm3 # [C7 D7 C6 D6 C3 D3 C2 D2] - vpunpcklqdq ymm0,ymm4,ymm6 # [A4 B4 C4 D4 A0 B0 C0 D0] - vpunpckhqdq ymm1,ymm4,ymm6 # [A5 B5 C5 D5 A1 B1 C1 D1] - vpunpcklqdq ymm2,ymm5,ymm7 # [A6 B6 C6 D6 A2 B2 C2 D2] - vpunpckhqdq ymm3,ymm5,ymm7 # [A7 B7 C7 D7 A3 B3 C3 D3] - vpaddw ymm0,ymm0,ymm1 # reduction - vpaddw ymm0,ymm0,ymm2 - vpaddw ymm0,ymm0,ymm3 + vphaddw ymm0,ymm0,ymm1 # reduce and interleave Sum1/Sum0 + vphaddw ymm1,ymm2,ymm3 # reduce and interleave Sum3/Sum2 + vphaddw ymm0,ymm0,ymm1 # reduce and interleave Sum3/Sum2/Sum1/Sum0 vextracti128 xmm1,ymm0,1 # extract high pairs - vpaddw xmm0,xmm0,xmm1 # reduction - vpmaddwd xmm0,xmm0,xmm8 # multiply by offset and reduce + vpaddw xmm0,xmm0,xmm1 # reduce low/high pairs + vpmaddwd xmm0,xmm0,xmm8 # multiply by offset and reduce 32-bit sum vmovdqu XMMWORD PTR [r9],xmm0 add r9,4*4 # advance row sum vector by 4 dwords sub r11,4 # subtract rows remaining @@ -436,7 +414,6 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2): push rbx mov r10,rdx - mov r11,rcx vpbroadcastw ymm5,WORD PTR .LGemmU8U8CopyPackBFrame_offa[rsp] // @@ -450,7 +427,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2): // Process 16 columns of matrix B in a loop. // - sub r11,16 + sub rcx,16 jb .LCopyPackB.ProcessRemainingColumns .LCopyPackB.ProcessNextColumnN16: @@ -463,9 +440,9 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2): jb .LCopyPackB.ProcessRemainingRowsN16 .LCopyPackB.ProcessNextRowLoopN16: - vmovdqu xmm2,XMMWORD PTR [rdx] # load two rows + vmovdqu xmm2,XMMWORD PTR [rdx] # load 2 rows vmovdqu xmm3,XMMWORD PTR [rdx+r10] - lea rdx,[rdx+r10*2] # advance matrix B by two rows + lea rdx,[rdx+r10*2] # advance matrix B by 2 rows vpunpcklbw xmm4,xmm2,xmm3 # interleave row data vpunpckhbw xmm3,xmm2,xmm3 vmovdqu XMMWORD PTR [rdi],xmm4 # store interleaved rows @@ -475,7 +452,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2): add rdi,32 # advance matrix D by 32 bytes vpaddw ymm0,ymm0,ymm4 # accumulate per column vpaddw ymm1,ymm1,ymm3 - sub rbx,2 # subtract columns remaining + sub rbx,2 # subtract rows remaining jae .LCopyPackB.ProcessNextRowLoopN16 .LCopyPackB.ProcessRemainingRowsN16: @@ -495,12 +472,12 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2): vpmaddwd ymm1,ymm1,ymm5 # multiply by offset and reduce vmovdqu YMMWORD PTR [r9],ymm0 vmovdqu YMMWORD PTR [r9+32],ymm1 - add r9,64 # advance column sum vector by 16 dwords - sub r11,16 # subtract columns remaining + add r9,64 # advance column sum vector by 16 DWORDs + sub rcx,16 # subtract columns remaining jae .LCopyPackB.ProcessNextColumnN16 .LCopyPackB.ProcessRemainingColumns: - add r11,16 # correct for over-subtract above + add rcx,16 # correct for over-subtract above jnz .LCopyPackB.ProcessColumnNUnaligned // @@ -527,7 +504,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2): .LCopyPackB.ProcessNextRowLoopNUnaligned: mov rdx,rsi lea rbp,.LGemmU8U8CopyPackBFrame_PaddedMatrixBData[rsp] - test r11b,8 # (CountN & 8) != 0? + test cl,8 # (CountN & 8) != 0? jz .LCopyPackB.CopyRemainingCountNLessThan8K2 mov rax,QWORD PTR [rdx] mov QWORD PTR [rbp],rax @@ -537,7 +514,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2): add rbp,8 # advance padded buffer destination .LCopyPackB.CopyRemainingCountNLessThan8K2: - test r11b,4 # (CountN & 4) != 0? + test cl,4 # (CountN & 4) != 0? jz .LCopyPackB.CopyRemainingCountNLessThan4K2 mov eax,DWORD PTR [rdx] mov DWORD PTR [rbp],eax @@ -547,7 +524,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2): add rbp,4 # advance padded buffer destination .LCopyPackB.CopyRemainingCountNLessThan4K2: - test r11b,2 # (CountN & 2) != 0? + test cl,2 # (CountN & 2) != 0? jz .LCopyPackB.CopyRemainingCountNLessThan2K2 movzx eax,WORD PTR [rdx] mov WORD PTR [rbp],ax @@ -557,7 +534,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2): add rbp,2 # advance padded buffer destination .LCopyPackB.CopyRemainingCountNLessThan2K2: - test r11b,1 # (CountN & 1) != 0? + test cl,1 # (CountN & 1) != 0? jz .LCopyPackB.ProcessPaddedMatrixBDataK2 movzx eax,BYTE PTR [rdx] mov BYTE PTR [rbp],al @@ -575,9 +552,9 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2): vpmovzxbw ymm3,xmm3 vpaddw ymm0,ymm0,ymm4 # accumulate per column vpaddw ymm1,ymm1,ymm3 - lea rsi,[rsi+r10*2] # advance next matrix B by two rows + lea rsi,[rsi+r10*2] # advance next matrix B by 2 rows add rdi,32 # advance matrix D by 32 bytes - sub r8,2 # subtract columns remaining + sub r8,2 # subtract rows remaining jae .LCopyPackB.ProcessNextRowLoopNUnaligned .LCopyPackB.ProcessRemainingRowsNUnaligned: @@ -585,7 +562,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2): jz .LCopyPackB.ReduceColumnSumVectorNUnaligned mov rdx,rsi lea rbp,.LGemmU8U8CopyPackBFrame_PaddedMatrixBData[rsp] - test r11b,8 # (CountN & 8) != 0? + test cl,8 # (CountN & 8) != 0? jz .LCopyPackB.CopyRemainingCountNLessThan8K1 mov rax,QWORD PTR [rdx] mov QWORD PTR [rbp],rax @@ -593,7 +570,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2): add rbp,8 # advance padded buffer destination .LCopyPackB.CopyRemainingCountNLessThan8K1: - test r11b,4 # (CountN & 4) != 0? + test cl,4 # (CountN & 4) != 0? jz .LCopyPackB.CopyRemainingCountNLessThan4K1 mov eax,DWORD PTR [rdx] mov DWORD PTR [rbp],eax @@ -601,7 +578,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2): add rbp,4 # advance padded buffer destination .LCopyPackB.CopyRemainingCountNLessThan4K1: - test r11b,2 # (CountN & 2) != 0? + test cl,2 # (CountN & 2) != 0? jz .LCopyPackB.CopyRemainingCountNLessThan2K1 movzx eax,WORD PTR [rdx] mov WORD PTR [rbp],ax @@ -609,7 +586,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2): add rbp,2 # advance padded buffer destination .LCopyPackB.CopyRemainingCountNLessThan2K1: - test r11b,1 # (CountN & 1) != 0? + test cl,1 # (CountN & 1) != 0? jz .LCopyPackB.ProcessPaddedMatrixBDataK1 movzx eax,BYTE PTR [rdx] mov BYTE PTR [rbp],al @@ -659,13 +636,12 @@ Implicit Arguments: .macro MultiplyAccumulateRow ColumnCount, Vec1Reg, Vec2Reg -.if \ColumnCount\() == 16 vpmaddwd ymm3,ymm2,ymm0 +.if \ColumnCount\() == 16 vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3 vpmaddwd ymm2,ymm2,ymm1 vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2 .else - vpmaddwd ymm3,ymm2,ymm0 vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm3 .endif @@ -696,7 +672,7 @@ Implicit Arguments: rsi - Supplies the address into the matrix B data. - r10 - Supplies the length in bytes of a row from matrix A. + rcx - Supplies the length in bytes of a row from matrix A. ymm4-ymm15 - Supplies the block accumulators. @@ -708,15 +684,15 @@ Implicit Arguments: EmitIfCountGE \ColumnCount\(), 16, "vpmovzxbw ymm1,XMMWORD PTR [rsi+\VectorOffset\()+16]" EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]" EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRow \ColumnCount\(), ymm4, ymm5" - EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+r10+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]" EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRow \ColumnCount\(), ymm6, ymm7" - EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+r10*2+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]" EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRow \ColumnCount\(), ymm8, ymm9" EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [rbx+\BroadcastOffset\()]" EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRow \ColumnCount\(), ymm10, ymm11" - EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm2,DWORD PTR [rbx+r10+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm2,DWORD PTR [rbx+rcx+\BroadcastOffset\()]" EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRow \ColumnCount\(), ymm12, ymm13" - EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm2,DWORD PTR [rbx+r10*2+\BroadcastOffset\()]" + EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm2,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]" EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRow \ColumnCount\(), ymm14, ymm15" .endm @@ -725,8 +701,8 @@ Implicit Arguments: Macro Description: - This macro generates code to produce an output block for a set of columns - and rows. + This macro generates code to execute the block compute macro multiple + times and advancing the matrix A and matrix B data pointers. Arguments: @@ -736,271 +712,55 @@ Arguments: Implicit Arguments: - rax - Supplies the length in bytes of a row from matrix C. + rbx - Supplies the address into the matrix A data plus 3 rows. rdi - Supplies the address into the matrix A data. rsi - Supplies the address into the matrix B data. - rcx - Supplies the number of paired columns from matrix A and the number of - paired rows from matrix B to iterate over. + rcx - Supplies the length in bytes of a row from matrix A. - r10 - Supplies the length in bytes of a row from matrix A. - - r12 - Supplies the address of the row sum vector. - - r13 - Supplies the address of the column sum vector. + ymm4-ymm15 - Supplies the block accumulators. --*/ - .macro ProduceOutputBlock ColumnCount, RowCount + .macro ComputeBlockLoop ColumnCount, RowCount -// -// Initialize the accumulators with the sum of the global depth value constant, -// the column sums, and the row sums. -// - - vpbroadcastd ymm1,DWORD PTR .LGemmU8U8KernelFrame_DepthValue[rsp] -.if \ColumnCount\() == 16 - vpaddd ymm0,ymm1,YMMWORD PTR [r13] - vpaddd ymm1,ymm1,YMMWORD PTR [r13+32] - add r13,16*4 # advance ColumnSumVector by 16 columns -.else - vpaddd ymm1,ymm1,YMMWORD PTR [r13] -.endif - EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm5,DWORD PTR [r12]" - EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm7,DWORD PTR [r12+4]" - EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm9,DWORD PTR [r12+8]" - EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm11,DWORD PTR [r12+12]" - EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm13,DWORD PTR [r12+16]" - EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm15,DWORD PTR [r12+20]" - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd ymm4,ymm5,ymm0" - EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm1" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd ymm6,ymm7,ymm0" - EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm1" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd ymm8,ymm9,ymm0" - EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm1" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd ymm10,ymm11,ymm0" - EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm1" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd ymm12,ymm13,ymm0" - EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm1" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd ymm14,ymm15,ymm0" - EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm1" - -// -// Iterate over PairedCountK elements from matrix A and matrix B. -// -// Unrolling the loop to do two iterations improves performance slightly at the -// cost of larger code size. Balance this by only unrolling for the common case -// of computing 16 columns for an even number of rows. -// - - mov rbp,rcx # reload PairedCountK -.if \RowCount\() > 3 - lea rbx,[r10*2+r10] - add rbx,rdi # compute matrix A plus 3 rows -.endif + mov rbp,rcx # reload row length remaining .if (\ColumnCount\() == 16) && ((\RowCount\() & 1) == 0) - sub rbp,2 - jb .LProcessRemainingBlocks.\ColumnCount\().\RowCount\() + sub rbp,2*4 + jb .LProcessRemainingBlocks\@ -.LComputeBlockLoop.\ColumnCount\().\RowCount\(): +.LComputeBlockBy2Loop\@: ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0 ComputeBlock \ColumnCount\(), \RowCount\(), 32, 4 add rdi,2*4 # advance matrix A by 2 pairs .if \RowCount\() > 3 add rbx,2*4 # advance matrix A plus 3 rows by 2 pairs .endif - add rsi,2*32 # advance matrix B by 64 columns - sub rbp,2 # subtract pairs remaining - jae .LComputeBlockLoop.\ColumnCount\().\RowCount\() + add rsi,2*32 # advance matrix B + sub rbp,2*4 + jae .LComputeBlockBy2Loop\@ -.LProcessRemainingBlocks.\ColumnCount\().\RowCount\(): - add rbp,2 # correct for over-subtract above - jz .LComputeBlockLoopExit.\ColumnCount\().\RowCount\() +.LProcessRemainingBlocks\@: + add rbp,2*4 # correct for over-subtract above + jz .LComputeBlockLoopExit\@ ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0 - add rsi,32 # advance matrix B by 32 columns + add rsi,32 # advance matrix B .else -.LComputeBlockLoop.\ColumnCount\().\RowCount\(): +.LComputeBlockBy1Loop\@: ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0 add rdi,4 # advance matrix A by 1 pair .if \RowCount\() > 3 add rbx,4 # advance matrix A plus 3 rows by 1 pair .endif - add rsi,32 - dec rbp # decrement pairs remaining - jnz .LComputeBlockLoop.\ColumnCount\().\RowCount\() + add rsi,32 # advance matrix B + sub rbp,4 + jnz .LComputeBlockBy1Loop\@ .endif -.LComputeBlockLoopExit.\ColumnCount\().\RowCount\(): -.if \RowCount\() > 3 - lea rbx,[rdx+rax*2] # compute matrix C plus 3 rows - add rbx,rax -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute matrix multiplication for a fixed set - of rows. - -Arguments: - - RowCount - Supplies the number of rows to process. - - Fallthrough - Supplies a non-blank value if the macro may fall through to - the ExitKernel label. - -Implicit Arguments: - - rax - Supplies the length in bytes of a row from matrix C. - - rdi - Supplies the address of matrix A. - - rsi - Supplies the address of matrix B. - - rdx - Supplies the address of matrix C. - - r11 - Supplies the address of matrix A. - - r9 - Supplies the number of columns from matrix B and matrix C to iterate - over. - - rcx - Supplies the number of paired columns from matrix A and the number of - paired rows from matrix B to iterate over. - - r10 - Supplies the length in bytes of a row from matrix A. - - r12 - Supplies the address of the row sum vector. - - r13 - Supplies the address of the column sum vector. - - r14b - Supplies the zero mode flag. - ---*/ - - .macro ProcessCountM RowCount, Fallthrough - - cmp r9,8 - jbe .LProcessRemainingCountN.\RowCount\() - -.LProcessNextColumnLoop16xN.\RowCount\(): - ProduceOutputBlock 16, \RowCount\() - sub r9,16 - jb .LOutputMasked16xNBlock.\RowCount\() - test r14b,r14b # ZeroMode? - jnz .LSkipAccumulateOutput16xNBlock.\RowCount\() - EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx+32]" - EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax+32]" - EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2+32]" - EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [rbx]" - EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [rbx+32]" - EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]" - EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax+32]" - EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]" - EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2+32]" - -.LSkipAccumulateOutput16xNBlock.\RowCount\(): - EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4" - EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx+32],ymm5" - EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6" - EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax+32],ymm7" - EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8" - EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2+32],ymm9" - EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm10" - EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx+32],ymm11" - EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm12" - EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax+32],ymm13" - EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm14" - EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2+32],ymm15" - add rdx,16*4 # advance matrix C by 16 columns - mov rdi,r11 # reload matrix A - cmp r9,8 - ja .LProcessNextColumnLoop16xN.\RowCount\() - test r9,r9 - jz .LExitKernel - -.LProcessRemainingCountN.\RowCount\(): - ProduceOutputBlock 8, \RowCount\() - cmp r9,8 - jb .LOutputMasked8xNBlock.\RowCount\() - test r14b,r14b # ZeroMode? - jnz .LSkipAccumulateOutput8xNBlock.\RowCount\() - EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [rbx]" - EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax]" - EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2]" - -.LSkipAccumulateOutput8xNBlock.\RowCount\(): - EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm5" - EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm7" - EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm9" - EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm11" - EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm13" - EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm15" - jmp .LExitKernel - -.LOutputMasked16xNBlock.\RowCount\(): - test r14b,r14b # ZeroMode? - jnz .LSkipAccumulateOutputMasked16xNBlock.\RowCount\() - EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [rbx]" - EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]" - EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]" - -.LSkipAccumulateOutputMasked16xNBlock.\RowCount\(): - EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4" - EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6" - EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8" - EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm10" - EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm12" - EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm14" - add rdx,8*4 # advance matrix C by 8 columns -.if \RowCount\() > 3 - add rbx,8*4 # advance matrix C plus 3 rows by 8 columns -.endif - add r9,8 # correct for over-subtract above - -.LOutputMasked8xNBlock.\RowCount\(): - mov DWORD PTR .LGemmU8U8KernelFrame_mask[rsp],r9d - vpbroadcastd ymm0,DWORD PTR .LGemmU8U8KernelFrame_mask[rsp] - vpcmpgtd ymm0,ymm0,YMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip] - test r14b,r14b # ZeroMode? - jnz .LSkipAccumulateOutputMasked8xNBlock.\RowCount\() - EmitIfCountGE \RowCount\(), 1, "vpmaskmovd ymm4,ymm0,YMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "vpmaskmovd ymm6,ymm0,YMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 3, "vpmaskmovd ymm8,ymm0,YMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 4, "vpmaskmovd ymm10,ymm0,YMMWORD PTR [rbx]" - EmitIfCountGE \RowCount\(), 5, "vpmaskmovd ymm12,ymm0,YMMWORD PTR [rbx+rax]" - EmitIfCountGE \RowCount\(), 6, "vpmaskmovd ymm14,ymm0,YMMWORD PTR [rbx+rax*2]" - EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm4" - EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm6" - EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm8" - EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm10" - EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm12" - EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm14" - -.LSkipAccumulateOutputMasked8xNBlock.\RowCount\(): - EmitIfCountGE \RowCount\(), 1, "vpmaskmovd YMMWORD PTR [rdx],ymm0,ymm5" - EmitIfCountGE \RowCount\(), 2, "vpmaskmovd YMMWORD PTR [rdx+rax],ymm0,ymm7" - EmitIfCountGE \RowCount\(), 3, "vpmaskmovd YMMWORD PTR [rdx+rax*2],ymm0,ymm9" - EmitIfCountGE \RowCount\(), 4, "vpmaskmovd YMMWORD PTR [rbx],ymm0,ymm11" - EmitIfCountGE \RowCount\(), 5, "vpmaskmovd YMMWORD PTR [rbx+rax],ymm0,ymm13" - EmitIfCountGE \RowCount\(), 6, "vpmaskmovd YMMWORD PTR [rbx+rax*2],ymm0,ymm15" -.ifb \Fallthrough\() - jmp .LExitKernel -.endif +.LComputeBlockLoopExit\@: .endm @@ -1021,8 +781,8 @@ Arguments: C (rdx) - Supplies the address of matrix C. - PairedCountK (rcx) - Supplies the number of paired columns from matrix A and - the number of paired rows from matrix B to iterate over. + PairCountK (rcx) - Supplies the number of pair columns from matrix A and + the number of pair rows from matrix B to iterate over. CountM (r8) - Supplies the maximum number of rows that can be processed for matrix A and matrix C. The actual number of rows handled for this @@ -1042,8 +802,8 @@ Arguments: every column of matrix C. DepthValue - Supplies the value CountK multiplied by the zero point offset - of matrixA multplied by the zero point offset of matrix B. This value is - accumulated into every element of matrix C. + of matrix A multplied by the zero point offset of matrix B. This value + is accumulated into every element of matrix C. ZeroMode - Supplies true if the output matrix must be zero initialized, else false if the output matrix is accumulated into. @@ -1061,15 +821,14 @@ C_UNDERSCORE(MlasGemmU8U8KernelAvx2): push rbx push r12 push r13 - push r14 - mov rax,.LGemmU8U8KernelFrame_ldc[rsp] + mov rax,.LGemmU8X8KernelFrame_ldc[rsp] shl rax,2 # convert ldc to bytes - lea r10,[rcx*4] + shl rcx,2 # convert to row length + movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp] mov r11,rdi - mov r12,.LGemmU8U8KernelFrame_RowSumVector[rsp] - mov r13,.LGemmU8U8KernelFrame_ColumnSumVector[rsp] - movzx r14,BYTE PTR .LGemmU8U8KernelFrame_ZeroMode[rsp] + mov r12,.LGemmU8X8KernelFrame_RowSumVector[rsp] + mov r13,.LGemmU8X8KernelFrame_ColumnSumVector[rsp] // // Process CountM rows of the matrices. @@ -1102,7 +861,6 @@ C_UNDERSCORE(MlasGemmU8U8KernelAvx2): mov eax,r8d vzeroupper - pop r14 pop r13 pop r12 pop rbx diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512BW.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512BW.S index bacb29a9a1..1e251d94ab 100644 --- a/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512BW.S +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512BW.S @@ -28,40 +28,27 @@ Abstract: Macro Description: - This macro generates code to multiply and accumulator a single row of the + This macro generates code to multiply and accumulator a single cell of the output block. Arguments: - ColumnCount - Supplies the number of columns to produce. + AccumReg - Supplies the register to accumulate into. - Vec1Reg - Supplies the high block accumulator register (when ColumnCount - is 32). + Mult1Reg - Supplies the first multiplication operand register. - Vec2Reg - Supplies the low block accumulator register. + Mult2Reg - Supplies the second multiplication operand register. Implicit Arguments: - zmm28 - Supplies the first vector loaded from matrix B. - - zmm29 - Supplies the second vector loaded from matrix B (when ColumnCount - is 32). - - zmm30 - Supplies the broadcast value loaded from matrix A. + zmm4 - Supplies a scratch register for intermediate results. --*/ - .macro MultiplyAccumulateRow ColumnCount, Vec1Reg, Vec2Reg + .macro MultiplyAccumulateCell AccumReg, Mult1Reg, Mult2Reg -.if \ColumnCount\() == 32 - vpmaddwd zmm31,zmm30,zmm28 - vpaddd \Vec1Reg\(),\Vec1Reg\(),zmm31 - vpmaddwd zmm30,zmm30,zmm29 - vpaddd \Vec2Reg\(),\Vec2Reg\(),zmm30 -.else - vpmaddwd zmm31,zmm30,zmm28 - vpaddd \Vec2Reg\(),\Vec2Reg\(),zmm31 -.endif + vpmaddwd zmm4,\Mult1Reg\(),\Mult2Reg\() + vpaddd \AccumReg\(),\AccumReg\(),zmm4 .endm @@ -78,36 +65,62 @@ Arguments: RowCount - Supplies the number of rows to produce. -Implicit Arguments: + VectorOffset - Supplies the byte offset from matrix B to fetch elements. - rdi - Supplies the address into the matrix A data. + BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. + +Implicit Arguments: rbx - Supplies the address into the matrix A data plus 3 rows. + rdi - Supplies the address into the matrix A data. + rsi - Supplies the address into the matrix B data. - r10 - Supplies the length in bytes of a row from matrix A. + rcx - Supplies the length in bytes of a row from matrix A. - zmm16-zmm27 - Supplies the block accumulators. + r14 - Supplies the stride in bytes of between packed blocks of matrix B. + + zmm14-zmm31 - Supplies the block accumulators. --*/ - .macro ComputeBlock ColumnCount, RowCount + .macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset - vpmovzxbw zmm28,YMMWORD PTR [rsi] - EmitIfCountGE \ColumnCount\(), 32, "vpmovzxbw zmm29,YMMWORD PTR [rsi+r10*8]" - EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm30,DWORD PTR [rdi]" - EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRow \ColumnCount\(), zmm16, zmm17" - EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm30,DWORD PTR [rdi+r10]" - EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRow \ColumnCount\(), zmm18, zmm19" - EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm30,DWORD PTR [rdi+r10*2]" - EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRow \ColumnCount\(), zmm20, zmm21" - EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm30,DWORD PTR [rbx]" - EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRow \ColumnCount\(), zmm22, zmm23" - EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm30,DWORD PTR [rbx+r10]" - EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRow \ColumnCount\(), zmm24, zmm25" - EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm30,DWORD PTR [rbx+r10*2]" - EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRow \ColumnCount\(), zmm26, zmm27" +.if \ColumnCount\() >= 48 + vpmovzxbw zmm0,YMMWORD PTR [rsi+\VectorOffset\()] + vpmovzxbw zmm1,YMMWORD PTR [rsi+r14+\VectorOffset\()] + vpmovzxbw zmm2,YMMWORD PTR [rsi+r14*2+\VectorOffset\()] +.elseif \ColumnCount\() >= 32 + vpmovzxbw zmm1,YMMWORD PTR [rsi+\VectorOffset\()] + vpmovzxbw zmm2,YMMWORD PTR [rsi+r14+\VectorOffset\()] +.else + vpmovzxbw zmm2,YMMWORD PTR [rsi+\VectorOffset\()] +.endif + EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm3,DWORD PTR [rdi+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm26,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm20,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm14,zmm3,zmm2" + EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm27,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm21,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm15,zmm3,zmm2" + EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm28,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm22,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm16,zmm3,zmm2" + EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm3,DWORD PTR [rbx+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm29,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm23,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm17,zmm3,zmm2" + EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm30,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm24,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm18,zmm3,zmm2" + EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm31,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm25,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm19,zmm3,zmm2" .endm @@ -115,6 +128,6 @@ Implicit Arguments: // Generate the GEMM kernel. // -GemmU8U8KernelAvx512Function Avx512BW +GemmU8X8KernelAvx512Function U8U8, Avx512BW .end diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Common.h b/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Common.h index 3abd87b7ce..486dd5667b 100644 --- a/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Common.h +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Common.h @@ -16,28 +16,14 @@ Abstract: --*/ -// -// Stack frame layout for the U8U8 kernel. -// - - .equ .LGemmU8U8KernelFrame_SavedR14, 0 - .equ .LGemmU8U8KernelFrame_SavedR13, 8 - .equ .LGemmU8U8KernelFrame_SavedR12, 16 - .equ .LGemmU8U8KernelFrame_SavedRbx, 24 - .equ .LGemmU8U8KernelFrame_SavedRbp, 32 - .equ .LGemmU8U8KernelFrame_ReturnAddress, 40 - .equ .LGemmU8U8KernelFrame_ldc, 48 - .equ .LGemmU8U8KernelFrame_RowSumVector, 56 - .equ .LGemmU8U8KernelFrame_ColumnSumVector, 64 - .equ .LGemmU8U8KernelFrame_DepthValue, 72 - .equ .LGemmU8U8KernelFrame_ZeroMode, 80 +#include "QgemmU8X8KernelAvx512Common.h" /*++ -Macro Description: + Macro Description: - This macro generates code to produce an output block for a set of columns - and rows. + This macro generates code to execute the block compute macro multiple + times and advancing the matrix A and matrix B data pointers. Arguments: @@ -47,315 +33,32 @@ Arguments: Implicit Arguments: - rax - Supplies the length in bytes of a row from matrix C. + rbx - Supplies the address into the matrix A data plus 3 rows. rdi - Supplies the address into the matrix A data. rsi - Supplies the address into the matrix B data. - rcx - Supplies the number of paired columns from matrix A and the number of - paired rows from matrix B to iterate over. + rcx - Supplies the length in bytes of a row from matrix A. - r10 - Supplies the length in bytes of a row from matrix A. + r14 - Supplies the stride in bytes of between packed blocks of matrix B. - r12 - Supplies the address of the row sum vector. - - r13 - Supplies the address of the column sum vector. + zmm14-zmm31 - Supplies the block accumulators. --*/ - .macro ProduceOutputBlock ColumnCount, RowCount + .macro ComputeBlockLoop ColumnCount, RowCount -// -// Initialize the accumulators with the sum of the global depth value constant, -// the column sums, and the row sums. -// + mov rbp,rcx # reload row length remaining - vpbroadcastd zmm31,DWORD PTR .LGemmU8U8KernelFrame_DepthValue[rsp] -.if \ColumnCount\() == 32 - vpaddd zmm30,zmm31,ZMMWORD PTR [r13] - vpaddd zmm31,zmm31,ZMMWORD PTR [r13+64] - add r13,32*4 # advance ColumnSumVector by 32 columns -.else - vpaddd zmm31,zmm31,ZMMWORD PTR [r13] -.endif - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd zmm16,zmm30,DWORD PTR [r12]{1to16}" - EmitIfCountGE \RowCount\(), 1, "vpaddd zmm17,zmm31,DWORD PTR [r12]{1to16}" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "vpaddd zmm18,zmm30,DWORD PTR [r12+4]{1to16}" - EmitIfCountGE \RowCount\(), 2, "vpaddd zmm19,zmm31,DWORD PTR [r12+4]{1to16}" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "vpaddd zmm20,zmm30,DWORD PTR [r12+8]{1to16}" - EmitIfCountGE \RowCount\(), 3, "vpaddd zmm21,zmm31,DWORD PTR [r12+8]{1to16}" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "vpaddd zmm22,zmm30,DWORD PTR [r12+12]{1to16}" - EmitIfCountGE \RowCount\(), 4, "vpaddd zmm23,zmm31,DWORD PTR [r12+12]{1to16}" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "vpaddd zmm24,zmm30,DWORD PTR [r12+16]{1to16}" - EmitIfCountGE \RowCount\(), 5, "vpaddd zmm25,zmm31,DWORD PTR [r12+16]{1to16}" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "vpaddd zmm26,zmm30,DWORD PTR [r12+20]{1to16}" - EmitIfCountGE \RowCount\(), 6, "vpaddd zmm27,zmm31,DWORD PTR [r12+20]{1to16}" - -// -// Iterate over PairedCountK elements from matrix A and matrix B. -// - - mov rbp,rcx # reload PairedCountK -.if \RowCount\() > 3 - lea rbx,[r10*2+r10] - add rbx,rdi # compute matrix A plus 3 rows -.endif - -.LComputeBlockLoop.\ColumnCount\().\RowCount\(): - ComputeBlock \ColumnCount\(), \RowCount\() +.LComputeBlockBy1Loop\@: + ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0 add rdi,4 # advance matrix A by 1 pair .if \RowCount\() > 3 add rbx,4 # advance matrix A plus 3 rows by 1 pair .endif - add rsi,32 - dec rbp # decrement pairs remaining - jnz .LComputeBlockLoop.\ColumnCount\().\RowCount\() - -.if \RowCount\() > 3 - lea rbx,[rdx+rax*2] # compute matrix C plus 3 rows - add rbx,rax -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute matrix multiplication for a fixed set - of rows. - -Arguments: - - RowCount - Supplies the number of rows to process. - -Implicit Arguments: - - rax - Supplies the length in bytes of a row from matrix C. - - rdi - Supplies the address of matrix A. - - rsi - Supplies the address of matrix B. - - rdx - Supplies the address of matrix C. - - r11 - Supplies the address of matrix A. - - r9 - Supplies the number of columns from matrix B and matrix C to iterate - over. - - rcx - Supplies the number of paired columns from matrix A and the number of - paired rows from matrix B to iterate over. - - r10 - Supplies the length in bytes of a row from matrix A. - - r12 - Supplies the address of the row sum vector. - - r13 - Supplies the address of the column sum vector. - - r14b - Supplies the zero mode flag. - ---*/ - - .macro ProcessCountM RowCount - - cmp r9,16 - jbe .LProcessRemainingCountN.\RowCount\() - -.LProcessNextColumnLoop32xN.\RowCount\(): - ProduceOutputBlock 32, \RowCount\() - lea rsi,[rsi+r10*8] # advance matrix B by 8*PairedCountK - test r14b,r14b # ZeroMode? - jnz .LSkipAccumulateOutput32xNBlock.\RowCount\() - EmitIfCountGE \RowCount\(), 1, "vpaddd zmm16,zmm16,ZMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "vpaddd zmm18,zmm18,ZMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 3, "vpaddd zmm20,zmm20,ZMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 4, "vpaddd zmm22,zmm22,ZMMWORD PTR [rbx]" - EmitIfCountGE \RowCount\(), 5, "vpaddd zmm24,zmm24,ZMMWORD PTR [rbx+rax]" - EmitIfCountGE \RowCount\(), 6, "vpaddd zmm26,zmm26,ZMMWORD PTR [rbx+rax*2]" - -.LSkipAccumulateOutput32xNBlock.\RowCount\(): - EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx],zmm16" - EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax],zmm18" - EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2],zmm20" - EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx],zmm22" - EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax],zmm24" - EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm26" - add rdx,16*4 # advance matrix C by 16 columns -.if \RowCount\() > 3 - add rbx,16*4 # advance matrix C plus 3 rows by 16 columns -.endif - sub r9,16 - -.LOutput16xNBlock.\RowCount\(): - sub r9,16 - jae .LOutput16xNBlockWithMask.\RowCount\() - lea rcx,[r9+16] # correct for over-subtract above - mov ebp,1 - shl ebp,cl - dec ebp - kmovw k1,ebp # update mask for remaining columns - xor r9,r9 # no more columns remaining - -.LOutput16xNBlockWithMask.\RowCount\(): - test r14b,r14b # ZeroMode? - jnz .LSkipAccumulateOutput16xNBlockWithMask.\RowCount\() - EmitIfCountGE \RowCount\(), 1, "vpaddd zmm17{k1},zmm17,ZMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "vpaddd zmm19{k1},zmm19,ZMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 3, "vpaddd zmm21{k1},zmm21,ZMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 4, "vpaddd zmm23{k1},zmm23,ZMMWORD PTR [rbx]" - EmitIfCountGE \RowCount\(), 5, "vpaddd zmm25{k1},zmm25,ZMMWORD PTR [rbx+rax]" - EmitIfCountGE \RowCount\(), 6, "vpaddd zmm27{k1},zmm27,ZMMWORD PTR [rbx+rax*2]" - -.LSkipAccumulateOutput16xNBlockWithMask.\RowCount\(): - EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx]{k1},zmm17" - EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax]{k1},zmm19" - EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2]{k1},zmm21" - EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx]{k1},zmm23" - EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax]{k1},zmm25" - EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2]{k1},zmm27" - add rdx,16*4 # advance matrix C by 16 columns - mov rdi,r11 # reload matrix A - cmp r9,16 - ja .LProcessNextColumnLoop32xN.\RowCount\() - test r9,r9 - jz .LExitKernel - -.LProcessRemainingCountN.\RowCount\(): - ProduceOutputBlock 16, \RowCount\() - jmp .LOutput16xNBlock.\RowCount\() - - .endm - -/*++ - -Macro Description: - - This macro generates the common AVX512 code for the inner kernel to compute - matrix multiplication. - -Arguments: - - Isa - Supplies the instruction set architecture string for function tags. - ---*/ - - .macro GemmU8U8KernelAvx512Function Isa - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (rdi) - Supplies the address of matrix A. The matrix data has been packed - using MlasGemmU8U8CopyPackAAvx2. - - B (rsi) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmU8U8CopyPackBAvx2. - - C (rdx) - Supplies the address of matrix C. - - PairedCountK (rcx) - Supplies the number of paired columns from matrix A and - the number of paired rows from matrix B to iterate over. - - CountM (r8) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (r9) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldc - Supplies the first dimension of matrix C. - - RowSumVector - Supplies the sum of each row from matrix A multiplied by the - zero point offset of matrix B. These values are accumulated into every - row of matrix C. - - ColumnSumVector - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - - DepthValue - Supplies the value CountK multiplied by the zero point offset - of matrixA multplied by the zero point offset of matrix B. This value is - accumulated into every element of matrix C. - - ZeroMode - Supplies true if the output matrix must be zero initialized, - else false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ - - .globl C_UNDERSCORE(MlasGemmU8U8Kernel\Isa\()) -C_UNDERSCORE(MlasGemmU8U8Kernel\Isa\()): - - push rbp - push rbx - push r12 - push r13 - push r14 - - mov rax,.LGemmU8U8KernelFrame_ldc[rsp] - shl rax,2 # convert ldc to bytes - lea r10,[rcx*4] - mov r11,rdi - mov r12,.LGemmU8U8KernelFrame_RowSumVector[rsp] - mov r13,.LGemmU8U8KernelFrame_ColumnSumVector[rsp] - movzx r14,BYTE PTR .LGemmU8U8KernelFrame_ZeroMode[rsp] - mov ebp,-1 - kmovw k1,ebp # update mask to write all columns - -// -// Process CountM rows of the matrices. -// - - cmp r8,5 - ja .LProcessCountM6 - je .LProcessCountM5 - cmp r8,3 - ja .LProcessCountM4 - je .LProcessCountM3 - cmp r8,1 - je .LProcessCountM1 - -.LProcessCountM2: - ProcessCountM 2 - -.LProcessCountM4: - ProcessCountM 4 - -.LProcessCountM6: - mov r8d,6 # return 6 rows handled - ProcessCountM 6 - -// -// Restore non-volatile registers and return. -// - -.LExitKernel: - mov eax,r8d - - pop r14 - pop r13 - pop r12 - pop rbx - pop rbp - ret - -.LProcessCountM1: - ProcessCountM 1 - -.LProcessCountM3: - ProcessCountM 3 - -.LProcessCountM5: - ProcessCountM 5 + add rsi,32 # advance matrix B + sub rbp,4 + jnz .LComputeBlockBy1Loop\@ .endm diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Vnni.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Vnni.S index 76a85427d5..c946f50bee 100644 --- a/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Vnni.S +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Vnni.S @@ -38,50 +38,69 @@ Arguments: RowCount - Supplies the number of rows to produce. -Implicit Arguments: + VectorOffset - Supplies the byte offset from matrix B to fetch elements. - rdi - Supplies the address into the matrix A data. + BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. + +Implicit Arguments: rbx - Supplies the address into the matrix A data plus 3 rows. + rdi - Supplies the address into the matrix A data. + rsi - Supplies the address into the matrix B data. - r10 - Supplies the length in bytes of a row from matrix A. + rcx - Supplies the length in bytes of a row from matrix A. - zmm16-zmm27 - Supplies the block accumulators. + r14 - Supplies the stride in bytes of between packed blocks of matrix B. + + zmm14-zmm31 - Supplies the block accumulators. --*/ - .macro ComputeBlock ColumnCount, RowCount + .macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset - vpmovzxbw zmm28,YMMWORD PTR [rsi] -.if \ColumnCount\() == 32 - vpmovzxbw zmm29,YMMWORD PTR [rsi+r10*8] - EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm30,DWORD PTR [rdi]" - EmitIfCountGE \RowCount\(), 1, "VpdpwssdZmmZmmZmm zmm16,zmm28,zmm30" - EmitIfCountGE \RowCount\(), 1, "VpdpwssdZmmZmmZmm zmm17,zmm29,zmm30" - EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm30,DWORD PTR [rdi+r10]" - EmitIfCountGE \RowCount\(), 2, "VpdpwssdZmmZmmZmm zmm18,zmm28,zmm30" - EmitIfCountGE \RowCount\(), 2, "VpdpwssdZmmZmmZmm zmm19,zmm29,zmm30" - EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm30,DWORD PTR [rdi+r10*2]" - EmitIfCountGE \RowCount\(), 3, "VpdpwssdZmmZmmZmm zmm20,zmm28,zmm30" - EmitIfCountGE \RowCount\(), 3, "VpdpwssdZmmZmmZmm zmm21,zmm29,zmm30" - EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm30,DWORD PTR [rbx]" - EmitIfCountGE \RowCount\(), 4, "VpdpwssdZmmZmmZmm zmm22,zmm28,zmm30" - EmitIfCountGE \RowCount\(), 4, "VpdpwssdZmmZmmZmm zmm23,zmm29,zmm30" - EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm30,DWORD PTR [rbx+r10]" - EmitIfCountGE \RowCount\(), 5, "VpdpwssdZmmZmmZmm zmm24,zmm28,zmm30" - EmitIfCountGE \RowCount\(), 5, "VpdpwssdZmmZmmZmm zmm25,zmm29,zmm30" - EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm30,DWORD PTR [rbx+r10*2]" - EmitIfCountGE \RowCount\(), 6, "VpdpwssdZmmZmmZmm zmm26,zmm28,zmm30" - EmitIfCountGE \RowCount\(), 6, "VpdpwssdZmmZmmZmm zmm27,zmm29,zmm30" +.if \ColumnCount\() >= 32 +.if \ColumnCount\() >= 48 + vpmovzxbw zmm0,YMMWORD PTR [rsi+\VectorOffset\()] + vpmovzxbw zmm1,YMMWORD PTR [rsi+r14+\VectorOffset\()] + vpmovzxbw zmm2,YMMWORD PTR [rsi+r14*2+\VectorOffset\()] .else - EmitIfCountGE \RowCount\(), 1, "VpdpwssdZmmZmmBroadcast zmm17,zmm28,rdi" - EmitIfCountGE \RowCount\(), 2, "VpdpwssdZmmZmmBroadcast zmm19,zmm28,rdi,r10,1" - EmitIfCountGE \RowCount\(), 3, "VpdpwssdZmmZmmBroadcast zmm21,zmm28,rdi,r10,2" - EmitIfCountGE \RowCount\(), 4, "VpdpwssdZmmZmmBroadcast zmm23,zmm28,rbx" - EmitIfCountGE \RowCount\(), 5, "VpdpwssdZmmZmmBroadcast zmm25,zmm28,rbx,r10,1" - EmitIfCountGE \RowCount\(), 6, "VpdpwssdZmmZmmBroadcast zmm27,zmm28,rbx,r10,2" + vpmovzxbw zmm1,YMMWORD PTR [rsi+\VectorOffset\()] + vpmovzxbw zmm2,YMMWORD PTR [rsi+r14+\VectorOffset\()] +.endif + EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm3,DWORD PTR [rdi+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "VpdpwssdZmmZmmZmm zmm26,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "VpdpwssdZmmZmmZmm zmm20,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "VpdpwssdZmmZmmZmm zmm14,zmm3,zmm2" + EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "VpdpwssdZmmZmmZmm zmm27,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "VpdpwssdZmmZmmZmm zmm21,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "VpdpwssdZmmZmmZmm zmm15,zmm3,zmm2" + EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "VpdpwssdZmmZmmZmm zmm28,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "VpdpwssdZmmZmmZmm zmm22,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "VpdpwssdZmmZmmZmm zmm16,zmm3,zmm2" + EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm3,DWORD PTR [rbx+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "VpdpwssdZmmZmmZmm zmm29,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "VpdpwssdZmmZmmZmm zmm23,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "VpdpwssdZmmZmmZmm zmm17,zmm3,zmm2" + EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "VpdpwssdZmmZmmZmm zmm30,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "VpdpwssdZmmZmmZmm zmm24,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "VpdpwssdZmmZmmZmm zmm18,zmm3,zmm2" + EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "VpdpwssdZmmZmmZmm zmm31,zmm3,zmm0" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "VpdpwssdZmmZmmZmm zmm25,zmm3,zmm1" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "VpdpwssdZmmZmmZmm zmm19,zmm3,zmm2" +.else + vpmovzxbw zmm2,YMMWORD PTR [rsi+\VectorOffset\()] + EmitIfCountGE \RowCount\(), 1, "VpdpwssdZmmZmmBroadcast zmm14,zmm2,rdi,\BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 2, "VpdpwssdZmmZmmBroadcast zmm15,zmm2,rdi,\BroadcastOffset\(),rcx,1" + EmitIfCountGE \RowCount\(), 3, "VpdpwssdZmmZmmBroadcast zmm16,zmm2,rdi,\BroadcastOffset\(),rcx,2" + EmitIfCountGE \RowCount\(), 4, "VpdpwssdZmmZmmBroadcast zmm17,zmm2,rbx,\BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 5, "VpdpwssdZmmZmmBroadcast zmm18,zmm2,rbx,\BroadcastOffset\(),rcx,1" + EmitIfCountGE \RowCount\(), 6, "VpdpwssdZmmZmmBroadcast zmm19,zmm2,rbx,\BroadcastOffset\(),rcx,2" .endif .endm @@ -90,6 +109,6 @@ Implicit Arguments: // Generate the GEMM kernel. // -GemmU8U8KernelAvx512Function Avx512Vnni +GemmU8X8KernelAvx512Function U8U8, Avx512Vnni .end diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2Common.h b/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2Common.h new file mode 100644 index 0000000000..172bd9fab1 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2Common.h @@ -0,0 +1,273 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + QgemmU8X8KernelAvx2Common.h + +Abstract: + + This module contains common kernel macros and structures for the quantized + integer matrix/matrix multiply operation (QGEMM) for the AVX2 kernels. + +--*/ + +// +// Stack frame layout for the U8S8 and U8U8 kernels. +// + + .equ .LGemmU8X8KernelFrame_mask, -8 + .equ .LGemmU8X8KernelFrame_SavedR13, 0 + .equ .LGemmU8X8KernelFrame_SavedR12, 8 + .equ .LGemmU8X8KernelFrame_SavedRbx, 16 + .equ .LGemmU8X8KernelFrame_SavedRbp, 24 + .equ .LGemmU8X8KernelFrame_ReturnAddress, 32 + .equ .LGemmU8X8KernelFrame_ldc, 40 + .equ .LGemmU8X8KernelFrame_RowSumVector, 48 + .equ .LGemmU8X8KernelFrame_ColumnSumVector, 56 + .equ .LGemmU8X8KernelFrame_DepthValue, 64 + .equ .LGemmU8X8KernelFrame_ZeroMode, 72 + +/*++ + +Macro Description: + + This macro generates code to produce an output block for a set of columns + and rows. + +Arguments: + + ColumnCount - Supplies the number of columns to produce. + + RowCount - Supplies the number of rows to produce. + +Implicit Arguments: + + rax - Supplies the length in bytes of a row from matrix C. + + rdi - Supplies the address into the matrix A data. + + rsi - Supplies the address into the matrix B data. + + rcx - Supplies the length in bytes of a row from matrix A. + + r12 - Supplies the address of the row sum vector. + + r13 - Supplies the address of the column sum vector. + + ymm4-ymm15 - Supplies the block accumulators. + +--*/ + + .macro ProduceOutputBlock ColumnCount, RowCount + +// +// Initialize the accumulators with the sum of the global depth value constant, +// the column sums, and the row sums. +// + + vpbroadcastd ymm1,DWORD PTR .LGemmU8X8KernelFrame_DepthValue[rsp] +.if \ColumnCount\() == 16 + vpaddd ymm0,ymm1,YMMWORD PTR [r13] + vpaddd ymm1,ymm1,YMMWORD PTR [r13+32] + add r13,16*4 # advance ColumnSumVector by 16 columns +.else + vpaddd ymm1,ymm1,YMMWORD PTR [r13] +.endif + EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm5,DWORD PTR [r12]" + EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm7,DWORD PTR [r12+4]" + EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm9,DWORD PTR [r12+8]" + EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm11,DWORD PTR [r12+12]" + EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm13,DWORD PTR [r12+16]" + EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm15,DWORD PTR [r12+20]" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd ymm4,ymm5,ymm0" + EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm1" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd ymm6,ymm7,ymm0" + EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm1" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd ymm8,ymm9,ymm0" + EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm1" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd ymm10,ymm11,ymm0" + EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm1" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd ymm12,ymm13,ymm0" + EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm1" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd ymm14,ymm15,ymm0" + EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm1" + +// +// Iterate over the length of a matrix A row to produce the output accumulators. +// + +.if \RowCount\() > 3 + lea rbx,[rcx*2+rcx] + add rbx,rdi # compute matrix A plus 3 rows +.endif + ComputeBlockLoop \ColumnCount\(), \RowCount\() +.if \RowCount\() > 3 + lea rbx,[rdx+rax*2] # compute matrix C plus 3 rows + add rbx,rax +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to compute matrix multiplication for a fixed set + of rows. + +Arguments: + + RowCount - Supplies the number of rows to process. + + Fallthrough - Supplies a non-blank value if the macro may fall through to + the ExitKernel label. + +Implicit Arguments: + + rax - Supplies the length in bytes of a row from matrix C. + + rdi - Supplies the address of matrix A. + + rsi - Supplies the address of matrix B. + + rdx - Supplies the address of matrix C. + + r11 - Supplies the address of matrix A. + + r9 - Supplies the number of columns from matrix B and matrix C to iterate + over. + + rcx - Supplies the length in bytes of a row from matrix A. + + r10b - Supplies the zero mode flag. + + r12 - Supplies the address of the row sum vector. + + r13 - Supplies the address of the column sum vector. + +--*/ + + .macro ProcessCountM RowCount, Fallthrough + + cmp r9,8 + jbe .LProcessRemainingCountN\@ + +.LProcessNextColumnLoop16xN\@: + ProduceOutputBlock 16, \RowCount\() + sub r9,16 + jb .LOutputMasked16xNBlock\@ + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutput16xNBlock\@ + EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]" + EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx+32]" + EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]" + EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax+32]" + EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]" + EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2+32]" + EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [rbx]" + EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [rbx+32]" + EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]" + EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax+32]" + EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]" + EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2+32]" + +.LSkipAccumulateOutput16xNBlock\@: + EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4" + EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx+32],ymm5" + EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6" + EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax+32],ymm7" + EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8" + EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2+32],ymm9" + EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm10" + EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx+32],ymm11" + EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm12" + EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax+32],ymm13" + EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm14" + EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2+32],ymm15" + add rdx,16*4 # advance matrix C by 16 columns + mov rdi,r11 # reload matrix A + cmp r9,8 + ja .LProcessNextColumnLoop16xN\@ + test r9,r9 + jz .LExitKernel + +.LProcessRemainingCountN\@: + ProduceOutputBlock 8, \RowCount\() + cmp r9,8 + jb .LOutputMasked8xNBlock\@ + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutput8xNBlock\@ + EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx]" + EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax]" + EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2]" + EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [rbx]" + EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax]" + EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2]" + +.LSkipAccumulateOutput8xNBlock\@: + EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm5" + EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm7" + EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm9" + EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm11" + EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm13" + EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm15" + jmp .LExitKernel + +.LOutputMasked16xNBlock\@: + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutputMasked16xNBlock\@ + EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]" + EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]" + EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]" + EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [rbx]" + EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]" + EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]" + +.LSkipAccumulateOutputMasked16xNBlock\@: + EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4" + EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6" + EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8" + EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm10" + EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm12" + EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm14" + add rdx,8*4 # advance matrix C by 8 columns +.if \RowCount\() > 3 + add rbx,8*4 # advance matrix C plus 3 rows by 8 columns +.endif + add r9,8 # correct for over-subtract above + +.LOutputMasked8xNBlock\@: + mov DWORD PTR .LGemmU8X8KernelFrame_mask[rsp],r9d + vpbroadcastd ymm0,DWORD PTR .LGemmU8X8KernelFrame_mask[rsp] + vpcmpgtd ymm0,ymm0,YMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip] + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutputMasked8xNBlock\@ + EmitIfCountGE \RowCount\(), 1, "vpmaskmovd ymm4,ymm0,YMMWORD PTR [rdx]" + EmitIfCountGE \RowCount\(), 2, "vpmaskmovd ymm6,ymm0,YMMWORD PTR [rdx+rax]" + EmitIfCountGE \RowCount\(), 3, "vpmaskmovd ymm8,ymm0,YMMWORD PTR [rdx+rax*2]" + EmitIfCountGE \RowCount\(), 4, "vpmaskmovd ymm10,ymm0,YMMWORD PTR [rbx]" + EmitIfCountGE \RowCount\(), 5, "vpmaskmovd ymm12,ymm0,YMMWORD PTR [rbx+rax]" + EmitIfCountGE \RowCount\(), 6, "vpmaskmovd ymm14,ymm0,YMMWORD PTR [rbx+rax*2]" + EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm4" + EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm6" + EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm8" + EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm10" + EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm12" + EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm14" + +.LSkipAccumulateOutputMasked8xNBlock\@: + EmitIfCountGE \RowCount\(), 1, "vpmaskmovd YMMWORD PTR [rdx],ymm0,ymm5" + EmitIfCountGE \RowCount\(), 2, "vpmaskmovd YMMWORD PTR [rdx+rax],ymm0,ymm7" + EmitIfCountGE \RowCount\(), 3, "vpmaskmovd YMMWORD PTR [rdx+rax*2],ymm0,ymm9" + EmitIfCountGE \RowCount\(), 4, "vpmaskmovd YMMWORD PTR [rbx],ymm0,ymm11" + EmitIfCountGE \RowCount\(), 5, "vpmaskmovd YMMWORD PTR [rbx+rax],ymm0,ymm13" + EmitIfCountGE \RowCount\(), 6, "vpmaskmovd YMMWORD PTR [rbx+rax*2],ymm0,ymm15" +.ifb \Fallthrough\() + jmp .LExitKernel +.endif + + .endm diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Common.h b/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Common.h new file mode 100644 index 0000000000..18f82b15ad --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Common.h @@ -0,0 +1,403 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + QgemmU8X8KernelAvx512Common.h + +Abstract: + + This module contains common kernel macros and structures for the quantized + integer matrix/matrix multiply operation (QGEMM) for the AVX512BW and + AVX512VNNI kernels. + +--*/ + +// +// Stack frame layout for the U8S8 and U8U8 kernels. +// + + .equ .LGemmU8X8KernelFrame_SavedR14, 0 + .equ .LGemmU8X8KernelFrame_SavedR13, 8 + .equ .LGemmU8X8KernelFrame_SavedR12, 16 + .equ .LGemmU8X8KernelFrame_SavedRbx, 24 + .equ .LGemmU8X8KernelFrame_SavedRbp, 32 + .equ .LGemmU8X8KernelFrame_ReturnAddress, 40 + .equ .LGemmU8X8KernelFrame_ldc, 48 + .equ .LGemmU8X8KernelFrame_RowSumVector, 56 + .equ .LGemmU8X8KernelFrame_ColumnSumVector, 64 + .equ .LGemmU8X8KernelFrame_DepthValue, 72 + .equ .LGemmU8X8KernelFrame_ZeroMode, 80 + +/*++ + +Macro Description: + + This macro generates code to produce an output block for a set of columns + and rows. + +Arguments: + + ColumnCount - Supplies the number of columns to produce. + + RowCount - Supplies the number of rows to produce. + +Implicit Arguments: + + rax - Supplies the length in bytes of a row from matrix C. + + rdi - Supplies the address into the matrix A data. + + rsi - Supplies the address into the matrix B data. + + rcx - Supplies the length in bytes of a row from matrix A. + + r12 - Supplies the address of the row sum vector. + + r13 - Supplies the address of the column sum vector. + +--*/ + + .macro ProduceOutputBlock ColumnCount, RowCount + +// +// Initialize the accumulators with the sum of the global depth value constant, +// the column sums, and the row sums. +// + + vpbroadcastd zmm3,DWORD PTR .LGemmU8X8KernelFrame_DepthValue[rsp] +.if \ColumnCount\() >= 32 +.if \ColumnCount\() >= 48 + vpaddd zmm2,zmm3,ZMMWORD PTR [r13] + vpaddd zmm1,zmm3,ZMMWORD PTR [r13+64] + vpaddd zmm0,zmm3,ZMMWORD PTR [r13+128] +.else + vpaddd zmm1,zmm3,ZMMWORD PTR [r13] + vpaddd zmm0,zmm3,ZMMWORD PTR [r13+64] +.endif + add_immed r13,\ColumnCount\()*4 # advance ColumnSumVector by N columns +.else + vpaddd zmm0,zmm3,ZMMWORD PTR [r13] +.endif + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd zmm14,zmm0,DWORD PTR [r12]{1to16}" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd zmm20,zmm1,DWORD PTR [r12]{1to16}" + EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "vpaddd zmm26,zmm2,DWORD PTR [r12]{1to16}" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd zmm15,zmm0,DWORD PTR [r12+4]{1to16}" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "vpaddd zmm21,zmm1,DWORD PTR [r12+4]{1to16}" + EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "vpaddd zmm27,zmm2,DWORD PTR [r12+4]{1to16}" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd zmm16,zmm0,DWORD PTR [r12+8]{1to16}" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "vpaddd zmm22,zmm1,DWORD PTR [r12+8]{1to16}" + EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "vpaddd zmm28,zmm2,DWORD PTR [r12+8]{1to16}" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd zmm17,zmm0,DWORD PTR [r12+12]{1to16}" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "vpaddd zmm23,zmm1,DWORD PTR [r12+12]{1to16}" + EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "vpaddd zmm29,zmm2,DWORD PTR [r12+12]{1to16}" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd zmm18,zmm0,DWORD PTR [r12+16]{1to16}" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "vpaddd zmm24,zmm1,DWORD PTR [r12+16]{1to16}" + EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "vpaddd zmm30,zmm2,DWORD PTR [r12+16]{1to16}" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd zmm19,zmm0,DWORD PTR [r12+20]{1to16}" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "vpaddd zmm25,zmm1,DWORD PTR [r12+20]{1to16}" + EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "vpaddd zmm31,zmm2,DWORD PTR [r12+20]{1to16}" + +// +// Iterate over the length of a matrix A row to produce the output accumulators. +// + +.if \RowCount\() > 3 + lea rbx,[rcx*2+rcx] + add rbx,rdi # compute matrix A plus 3 rows +.endif + ComputeBlockLoop \ColumnCount\(), \RowCount\() +.if \RowCount\() > 3 + lea rbx,[rdx+rax*2] # compute matrix C plus 3 rows + add rbx,rax +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to compute matrix multiplication for a fixed set + of rows. + +Arguments: + + RowCount - Supplies the number of rows to process. + +Implicit Arguments: + + rax - Supplies the length in bytes of a row from matrix C. + + rdi - Supplies the address of matrix A. + + rsi - Supplies the address of matrix B. + + rdx - Supplies the address of matrix C. + + r11 - Supplies the address of matrix A. + + r9 - Supplies the number of columns from matrix B and matrix C to iterate + over. + + rcx - Supplies the length in bytes of a row from matrix A. + + r10b - Supplies the zero mode flag. + + r12 - Supplies the address of the row sum vector. + + r13 - Supplies the address of the column sum vector. + + r14 - Supplies the stride in bytes of between packed blocks of matrix B. + +--*/ + + .macro ProcessCountM RowCount + + cmp r9,32 + ja .LProcessNextColumnLoop48xN\@ + cmp r9,16 + jbe .LProcessRemainingCountN\@ + +.LProcessNextColumnLoop32xN\@: + ProduceOutputBlock 32, \RowCount\() + add rsi,r14 # advance matrix B by packed block stride + +.LOutput32xNBlock\@: + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutput32xNBlock\@ + EmitIfCountGE \RowCount\(), 1, "vpaddd zmm20,zmm20,ZMMWORD PTR [rdx]" + EmitIfCountGE \RowCount\(), 2, "vpaddd zmm21,zmm21,ZMMWORD PTR [rdx+rax]" + EmitIfCountGE \RowCount\(), 3, "vpaddd zmm22,zmm22,ZMMWORD PTR [rdx+rax*2]" + EmitIfCountGE \RowCount\(), 4, "vpaddd zmm23,zmm23,ZMMWORD PTR [rbx]" + EmitIfCountGE \RowCount\(), 5, "vpaddd zmm24,zmm24,ZMMWORD PTR [rbx+rax]" + EmitIfCountGE \RowCount\(), 6, "vpaddd zmm25,zmm25,ZMMWORD PTR [rbx+rax*2]" + +.LSkipAccumulateOutput32xNBlock\@: + EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx],zmm20" + EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax],zmm21" + EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2],zmm22" + EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx],zmm23" + EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax],zmm24" + EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm25" + add rdx,16*4 # advance matrix C by 16 columns +.if \RowCount\() > 3 + add rbx,16*4 # advance matrix C plus 3 rows by 16 columns +.endif + sub r9,16 + +.LOutput16xNBlock\@: + sub r9,16 + jae .LOutput16xNBlockWithMask\@ + lea rcx,[r9+16] # correct for over-subtract above + mov ebp,1 + shl ebp,cl + dec ebp + kmovw k1,ebp # update mask for remaining columns + xor r9,r9 # no more columns remaining + +.LOutput16xNBlockWithMask\@: + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutput16xNBlockWithMask\@ + EmitIfCountGE \RowCount\(), 1, "vpaddd zmm14{k1},zmm14,ZMMWORD PTR [rdx]" + EmitIfCountGE \RowCount\(), 2, "vpaddd zmm15{k1},zmm15,ZMMWORD PTR [rdx+rax]" + EmitIfCountGE \RowCount\(), 3, "vpaddd zmm16{k1},zmm16,ZMMWORD PTR [rdx+rax*2]" + EmitIfCountGE \RowCount\(), 4, "vpaddd zmm17{k1},zmm17,ZMMWORD PTR [rbx]" + EmitIfCountGE \RowCount\(), 5, "vpaddd zmm18{k1},zmm18,ZMMWORD PTR [rbx+rax]" + EmitIfCountGE \RowCount\(), 6, "vpaddd zmm19{k1},zmm19,ZMMWORD PTR [rbx+rax*2]" + +.LSkipAccumulateOutput16xNBlockWithMask\@: + EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx]{k1},zmm14" + EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax]{k1},zmm15" + EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2]{k1},zmm16" + EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx]{k1},zmm17" + EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax]{k1},zmm18" + EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2]{k1},zmm19" + add rdx,16*4 # advance matrix C by 16 columns + mov rdi,r11 # reload matrix A + cmp r9,32 + ja .LProcessNextColumnLoop48xN\@ + cmp r9,16 + ja .LProcessNextColumnLoop32xN\@ + test r9,r9 + jz .LExitKernel + +.LProcessRemainingCountN\@: + ProduceOutputBlock 16, \RowCount\() + jmp .LOutput16xNBlock\@ + +.LProcessNextColumnLoop48xN\@: + ProduceOutputBlock 48, \RowCount\() + lea rsi,[rsi+r14*2] # advance matrix B by packed block stride + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutput48xNBlock\@ + EmitIfCountGE \RowCount\(), 1, "vpaddd zmm26,zmm26,ZMMWORD PTR [rdx]" + EmitIfCountGE \RowCount\(), 2, "vpaddd zmm27,zmm27,ZMMWORD PTR [rdx+rax]" + EmitIfCountGE \RowCount\(), 3, "vpaddd zmm28,zmm28,ZMMWORD PTR [rdx+rax*2]" + EmitIfCountGE \RowCount\(), 4, "vpaddd zmm29,zmm29,ZMMWORD PTR [rbx]" + EmitIfCountGE \RowCount\(), 5, "vpaddd zmm30,zmm30,ZMMWORD PTR [rbx+rax]" + EmitIfCountGE \RowCount\(), 6, "vpaddd zmm31,zmm31,ZMMWORD PTR [rbx+rax*2]" + +.LSkipAccumulateOutput48xNBlock\@: + EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx],zmm26" + EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax],zmm27" + EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2],zmm28" + EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx],zmm29" + EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax],zmm30" + EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm31" + add rdx,16*4 # advance matrix C by 16 columns +.if \RowCount\() > 3 + add rbx,16*4 # advance matrix C plus 3 rows by 16 columns +.endif + sub r9,16 + jmp .LOutput32xNBlock\@ + + .endm + +/*++ + +Macro Description: + + This macro generates the common AVX512 code for the inner kernel to compute + matrix multiplication. + +Arguments: + + Type - Supplies the kernel type string for function tags. + + Isa - Supplies the instruction set architecture string for function tags. + +--*/ + + .macro GemmU8X8KernelAvx512Function Type, Isa + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (rdi) - Supplies the address of matrix A. The matrix data has been packed + using MlasGemmU8X8CopyPackAAvx2. + + B (rsi) - Supplies the address of matrix B. The matrix data has been packed + using MlasGemmU8X8CopyPackBAvx2. + + C (rdx) - Supplies the address of matrix C. + + PairedCountK (rcx) - Supplies the number of paired columns from matrix A and + the number of paired rows from matrix B to iterate over. + + CountM (r8) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (r9) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + ldc - Supplies the first dimension of matrix C. + + RowSumVector - Supplies the sum of each row from matrix A multiplied by the + zero point offset of matrix B. These values are accumulated into every + row of matrix C. + + ColumnSumVector - Supplies the sum of each column from matrix B multiplied + by the zero point offset of matrix A. These values are accumulated into + every column of matrix C. + + DepthValue - Supplies the value CountK multiplied by the zero point offset + of matrixA multplied by the zero point offset of matrix B. This value is + accumulated into every element of matrix C. + + ZeroMode - Supplies true if the output matrix must be zero initialized, + else false if the output matrix is accumulated into. + +Return Value: + + Returns the number of rows handled. + +--*/ + + .globl C_UNDERSCORE(MlasGemm\Type\()Kernel\Isa\()) +C_UNDERSCORE(MlasGemm\Type\()Kernel\Isa\()): + + push rbp + push rbx + push r12 + push r13 + push r14 + + mov rax,.LGemmU8X8KernelFrame_ldc[rsp] + shl rax,2 # convert ldc to bytes + shl rcx,2 # convert to row length + movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp] + mov r11,rdi + mov r12,.LGemmU8X8KernelFrame_RowSumVector[rsp] + mov r13,.LGemmU8X8KernelFrame_ColumnSumVector[rsp] + mov ebp,-1 + kmovw k1,ebp # update mask to write all columns +.ifeqs "\Type\()", "U8S8" +.ifeqs "\Isa\()", "Avx512BW" + neg ebp + vpbroadcastw zmm5,ebp # generate 512-bit word vector [0x0001] +.endif + mov r14,rcx + shl r14,4 # compute matrix B packed stride +.else + lea r14,[rcx*8] # compute matrix B packed stride +.endif + +// +// Process CountM rows of the matrices. +// + + cmp r8,5 + ja .LProcessCountM6 + je .LProcessCountM5 + cmp r8,3 + ja .LProcessCountM4 + je .LProcessCountM3 + cmp r8,1 + je .LProcessCountM1 + +.LProcessCountM2: + ProcessCountM 2 + +.LProcessCountM4: + ProcessCountM 4 + +.LProcessCountM6: + mov r8d,6 # return 6 rows handled + ProcessCountM 6 + +// +// Restore non-volatile registers and return. +// + +.LExitKernel: + mov eax,r8d + vzeroupper + + pop r14 + pop r13 + pop r12 + pop rbx + pop rbp + ret + +.LProcessCountM1: + ProcessCountM 1 + +.LProcessCountM3: + ProcessCountM 3 + +.LProcessCountM5: + ProcessCountM 5 + + .endm diff --git a/onnxruntime/core/providers/cpu/math/matmul_integer.cc b/onnxruntime/core/providers/cpu/math/matmul_integer.cc index eab5434d24..68ddf891d3 100644 --- a/onnxruntime/core/providers/cpu/math/matmul_integer.cc +++ b/onnxruntime/core/providers/cpu/math/matmul_integer.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/framework/op_kernel_context_internal.h" #include "core/providers/cpu/math/matmul_integer.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/util/qmath.h" @@ -35,6 +36,9 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( template <> Status MatMulInteger::Compute(OpKernelContext* ctx) const { + auto ctx_internal = static_cast(ctx); + concurrency::ThreadPool* thread_pool = ctx_internal->GetOperatorThreadPool(); + auto a = ctx->Input(0); auto b = ctx->Input(1); ORT_ENFORCE(a != nullptr && b != nullptr); @@ -71,13 +75,16 @@ Status MatMulInteger::Compute(OpKernelContext* ctx) const { b_offset, y->template MutableData() + helper.OutputOffsets()[i], static_cast(helper.N()), - nullptr); + thread_pool); } return Status::OK(); } template <> Status MatMulInteger::Compute(OpKernelContext* ctx) const { + auto ctx_internal = static_cast(ctx); + concurrency::ThreadPool* thread_pool = ctx_internal->GetOperatorThreadPool(); + auto a = ctx->Input(0); auto b = ctx->Input(1); ORT_ENFORCE(a != nullptr && b != nullptr); @@ -107,15 +114,19 @@ Status MatMulInteger::Compute(OpKernelContext* ctx) const { } } - // NOTE: Eigen based implementation is a reference implementation for accuracy only for (int i = 0; i < static_cast(helper.OutputOffsets().size()); i++) { - EigenCastGEMM( - a->template Data() + helper.LeftOffsets()[i], - b->template Data() + helper.RightOffsets()[i], - y->template MutableData() + helper.OutputOffsets()[i], - static_cast(helper.M()), - static_cast(helper.N()), - static_cast(helper.K())); + QGemmu8s8_s32(static_cast(helper.M()), + static_cast(helper.N()), + static_cast(helper.K()), + a->template Data() + helper.LeftOffsets()[i], + static_cast(helper.K()), + 0, + b->template Data() + helper.RightOffsets()[i], + static_cast(helper.N()), + 0, + y->template MutableData() + helper.OutputOffsets()[i], + static_cast(helper.N()), + thread_pool); } return Status::OK(); } diff --git a/onnxruntime/core/util/qmath.cc b/onnxruntime/core/util/qmath.cc index 9372ad29d8..08f69a3649 100644 --- a/onnxruntime/core/util/qmath.cc +++ b/onnxruntime/core/util/qmath.cc @@ -3,9 +3,50 @@ #include "core/util/qmath.h" #include "core/common/common.h" +#include "core/util/math_cpuonly.h" +#include "core/mlas/inc/mlas.h" + +#if defined(_M_AMD64) || defined(__x86_64__) || defined(_M_IX86) || defined(__i386__) +#define MLAS_SUPPORTS_GEMM_U8X8 +#else +// default to gemmlowp when building for arm devices +#ifndef USE_GEMMLOWP +#define USE_GEMMLOWP +#endif +#endif + +#ifdef USE_GEMMLOWP +#include "core/util/gemmlowp_common.h" +#endif namespace onnxruntime { +void QGemmu8s8_s32( + int M, + int N, + int K, + const uint8_t* lhs_data, + int lda, + const uint8_t lhs_offset, + const int8_t* rhs_data, + int ldb, + const int8_t rhs_offset, + int32_t* result_data, + int ldc, + concurrency::ThreadPool* thread_pool) { +#ifdef MLAS_SUPPORTS_GEMM_U8X8 + + MlasGemm(M, N, K, lhs_data, lda, lhs_offset, rhs_data, ldb, rhs_offset, result_data, ldc, thread_pool); + +#else + ORT_ENFORCE(lhs_offset == 0 && rhs_offset == 0, "For Eigen, zero point must be zero"); + ORT_ENFORCE(lda == K && ldb == N && ldc == N, "For Eigen only RowMajor*RowMajor=RowMajor format is supported"); + + EigenCastGEMM(lhs_data, rhs_data, result_data, M, N, K); + +#endif +} + void QGemmu8u8_s32( int M, int N, @@ -21,13 +62,13 @@ void QGemmu8u8_s32( concurrency::ThreadPool* thread_pool) { #ifdef USE_GEMMLOWP - ORT_ENFORCE(lda == K && ldb == N && ldc == N, "For gemmlowp only RowMajor*RowMajor=RowMajor format is supported"); + ORT_ENFORCE(lda == K && ldb == N && ldc == N, "For gemmlowp only RowMajor*RowMajor=RowMajor format is supported"); - GemmlowpMultiplyu8u8_s32(lhs_data, rhs_data, result_data, lhs_offset, rhs_offset, M, N, K, thread_pool); + GemmlowpMultiplyu8u8_s32(lhs_data, rhs_data, result_data, lhs_offset, rhs_offset, M, N, K, thread_pool); #else - MlasQgemm(M, N, K, lhs_data, lda, lhs_offset, rhs_data, ldb, rhs_offset, result_data, ldc, thread_pool); + MlasGemm(M, N, K, lhs_data, lda, lhs_offset, rhs_data, ldb, rhs_offset, result_data, ldc, thread_pool); #endif } -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/util/qmath.h b/onnxruntime/core/util/qmath.h index f6ce7bcaa2..cd519b90c0 100644 --- a/onnxruntime/core/util/qmath.h +++ b/onnxruntime/core/util/qmath.h @@ -1,26 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #pragma once -// default to gemmlowp when building for arm devices -#ifndef USE_GEMMLOWP -#if defined(_M_ARM64) || defined(__aarch64__) -#define USE_GEMMLOWP -#endif -#if defined(_M_ARM) || defined(__arm__) -#define USE_GEMMLOWP -#endif -#endif - -#ifdef USE_GEMMLOWP -#include "core/util/gemmlowp_common.h" -#else -#include "core/mlas/inc/mlas.h" -#endif #include "core/platform/threadpool.h" -#include -#include namespace onnxruntime { +void QGemmu8s8_s32( + int M, + int N, + int K, + const uint8_t* lhs_data, + int lda, + const uint8_t lhs_offset, + const int8_t* rhs_data, + int ldb, + const int8_t rhs_offset, + int32_t* result_data, + int ldc, + concurrency::ThreadPool* thread_pool); + void QGemmu8u8_s32( int M, int N, diff --git a/onnxruntime/test/mlas/unittest.cpp b/onnxruntime/test/mlas/unittest.cpp index 64690d8ee7..8f2ea7063b 100644 --- a/onnxruntime/test/mlas/unittest.cpp +++ b/onnxruntime/test/mlas/unittest.cpp @@ -35,7 +35,7 @@ Abstract: #endif #if defined(_M_IX86) || defined(__i386__) || defined(_M_AMD64) || defined(__x86_64__) -#define MLAS_HAS_QGEMM_U8U8 +#define MLAS_HAS_QGEMM_U8X8 #endif MLAS_THREADPOOL* threadpool = nullptr; @@ -452,9 +452,10 @@ public: } }; -#ifdef MLAS_HAS_QGEMM_U8U8 +#ifdef MLAS_HAS_QGEMM_U8X8 -class MlasQgemmU8U8Test : public MlasTestBase +template +class MlasQgemmU8X8Test : public MlasTestBase { private: void @@ -467,11 +468,11 @@ private: ) { const uint8_t* A = BufferA.GetBuffer(K * M); - const uint8_t* B = BufferB.GetBuffer(N * K); + const xint8_t* B = BufferB.GetBuffer(N * K); int32_t* C = BufferC.GetBuffer(N * M); int32_t* CReference = BufferCReference.GetBuffer(N * M); - Test(M, N, K, A, K, offa, B, N, offb, C, CReference, N); + Test(M, N, K, A, K, offa, B, N, xint8_t(offb), C, CReference, N); } void @@ -482,9 +483,9 @@ private: const uint8_t* A, size_t lda, uint8_t offa, - const uint8_t* B, + const xint8_t* B, size_t ldb, - uint8_t offb, + xint8_t offb, int32_t* C, int32_t* CReference, size_t ldc @@ -493,7 +494,7 @@ private: std::fill_n(C, M * N, -1); std::fill_n(CReference, M * N, -1); - MlasQgemm(M, N, K, A, lda, offa, B, ldb, offb, C, ldc, threadpool); + MlasGemm(M, N, K, A, lda, offa, B, ldb, offb, C, ldc, threadpool); ReferenceQgemm(M, N, K, A, lda, offa, B, ldb, offb, CReference, ldc); for (size_t f = 0; f < M * N; f++) { @@ -511,9 +512,9 @@ private: const uint8_t* A, size_t lda, uint8_t offa, - const uint8_t* B, + const xint8_t* B, size_t ldb, - uint8_t offb, + xint8_t offb, int32_t* C, size_t ldc ) @@ -523,7 +524,7 @@ private: for (size_t n = 0; n < N; n++) { const uint8_t* a = A + (m * lda); - const uint8_t* b = B + n; + const xint8_t* b = B + n; int32_t* c = C + (m * ldc) + n; int32_t sum = 0; @@ -539,7 +540,7 @@ private: } MatrixGuardBuffer BufferA; - MatrixGuardBuffer BufferB; + MatrixGuardBuffer BufferB; MatrixGuardBuffer BufferC; MatrixGuardBuffer BufferCReference; @@ -565,7 +566,7 @@ public: void ) override { - static const uint8_t zero_points[] = { 0, 18, 128, 157, 231, 255 }; + static const uint8_t zero_points[] = { 0, 18, 75, 128, 157, 231, 255 }; for (size_t a = 0; a < _countof(zero_points); a++) { uint8_t offa = zero_points[a]; @@ -1970,9 +1971,10 @@ main( printf("SGEMM tests.\n"); std::make_unique()->ExecuteShort(); -#ifdef MLAS_HAS_QGEMM_U8U8 +#ifdef MLAS_HAS_QGEMM_U8X8 printf("QGEMM tests.\n"); - std::make_unique()->ExecuteShort(); + std::make_unique>()->ExecuteShort(); + std::make_unique>()->ExecuteShort(); #endif printf("Conv2D tests.\n");