MLAS: add U8S8 MatMul operation (#1895)

Implement the second round of changes for quantization inside MLAS. This adds a MatMul operation for U8xS8=S32 for x86/x64 processors.
This commit is contained in:
Tracy Sharpe 2019-09-24 18:15:11 -07:00 committed by GitHub
parent d3cb2a5572
commit 28a62f7728
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 5415 additions and 1572 deletions

View file

@ -44,6 +44,9 @@ if(MSVC)
enable_language(ASM_MASM)
set(mlas_platform_srcs
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8S8KernelAvx512BW.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Vnni.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8U8KernelAvx2.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8U8KernelAvx512BW.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Vnni.asm
@ -158,6 +161,7 @@ else()
set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "-mavx")
set(mlas_platform_srcs_avx2
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8U8KernelAvx2.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/SgemmKernelFma3.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/SconvKernelFma3.S
@ -175,6 +179,8 @@ else()
set_source_files_properties(${mlas_platform_srcs_avx512f} PROPERTIES COMPILE_FLAGS "-mavx512f")
set(mlas_platform_srcs_avx512bw
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512BW.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Vnni.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512BW.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Vnni.S
)

View file

@ -135,7 +135,24 @@ MlasSgemm(
void
MLASCALL
MlasQgemm(
MlasGemm(
size_t M,
size_t N,
size_t K,
const uint8_t* A,
size_t lda,
uint8_t offa,
const int8_t* B,
size_t ldb,
int8_t offb,
int32_t* C,
size_t ldc,
MLAS_THREADPOOL* ThreadPool
);
void
MLASCALL
MlasGemm(
size_t M,
size_t N,
size_t K,

View file

@ -139,7 +139,7 @@ VpdpwssdsZmmZmmZmm MACRO DestReg, Src1Reg, Src2Reg
;
; This macro builds a VNNI instruction of the form:
;
; instr zmm1,zmm2,DWORD BCST [BaseReg+IndexReg*Scale]
; instr zmm1,zmm2,DWORD BCST [BaseReg+IndexReg*Scale+ByteOffset]
;
; Arguments:
;
@ -151,15 +151,20 @@ VpdpwssdsZmmZmmZmm MACRO DestReg, Src1Reg, Src2Reg
;
; BaseReg - Specifies the base register of the broadcast operand.
;
; ByteOffset - Specifies the DWORD aligned byte offset for the broadcast
; operand.
;
; IndexReg - Specifies the optional index register of the broadcast operand.
;
; Scale - Specifies the scaling factor of the optional index register.
;
VnniZmmZmmBroadcast MACRO Opcode, DestReg, Src1Reg, BaseReg, IndexReg, Scale
VnniZmmZmmBroadcast MACRO Opcode, DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
LOCAL Payload0, Payload1, Payload2, ModRMByte, SibByte
.errnz (ByteOffset AND 3)
Payload0 = 002h ; "0F 38" prefix
Payload0 = Payload0 + ((((ZmmIndex_&DestReg& SHR 3) AND 1) XOR 1) SHL 7)
IFNB <IndexReg>
@ -183,6 +188,9 @@ IFNB <IndexReg>
ELSE
ModRMByte = ModRMByte + (GprIndex_&BaseReg& AND 7)
ENDIF
IF ByteOffset NE 0
ModRMByte = ModRMByte + 040h ; indicate disp8 byte offset
ENDIF
IFNB <IndexReg>
SibByte = 0
@ -199,34 +207,36 @@ ENDIF
SibByte = SibByte + (GprIndex_&BaseReg& AND 7)
ENDIF
IFNB <IndexReg>
db 062h, Payload0, Payload1, Payload2, Opcode, ModRMByte, SibByte
ELSE
db 062h, Payload0, Payload1, Payload2, Opcode, ModRMByte
IFNB <IndexReg>
db SibByte
ENDIF
IF ByteOffset NE 0
db ByteOffset SHR 2
ENDIF
ENDM
VpdpbusdZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, IndexReg, Scale
VpdpbusdZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
VnniZmmZmmBroadcast 050h, DestReg, Src1Reg, BaseReg, IndexReg, Scale
VnniZmmZmmBroadcast 050h, DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
ENDM
VpdpbusdsZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, IndexReg, Scale
VpdpbusdsZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
VnniZmmZmmBroadcast 051h, DestReg, Src1Reg, BaseReg, IndexReg, Scale
VnniZmmZmmBroadcast 051h, DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
ENDM
VpdpwssdZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, IndexReg, Scale
VpdpwssdZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
VnniZmmZmmBroadcast 052h, DestReg, Src1Reg, BaseReg, IndexReg, Scale
VnniZmmZmmBroadcast 052h, DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
ENDM
VpdpwssdsZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, IndexReg, Scale
VpdpwssdsZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
VnniZmmZmmBroadcast 053h, DestReg, Src1Reg, BaseReg, IndexReg, Scale
VnniZmmZmmBroadcast 053h, DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
ENDM

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,130 @@
;++
;
; Copyright (c) Microsoft Corporation. All rights reserved.
;
; Licensed under the MIT License.
;
; Module Name:
;
; QgemmU8S8KernelAvx512BW.asm
;
; Abstract:
;
; This module implements the kernels for the quantized integer matrix/matrix
; multiply operation (QGEMM).
;
; This implementation uses AVX512BW instructions.
;
;--
.xlist
INCLUDE mlasi.inc
INCLUDE QgemmU8S8KernelAvx512Common.inc
.list
;
; Macro Description:
;
; This macro generates code to multiply and accumulator a single cell of the
; output block.
;
; Arguments:
;
; AccumReg - Supplies the register to accumulate into.
;
; Mult1Reg - Supplies the first multiplication operand register.
;
; Mult2Reg - Supplies the second multiplication operand register.
;
; Implicit Arguments:
;
; zmm4 - Supplies a scratch register for intermediate results.
;
; zmm5 - Supplies a 512-bit with the broadcasted word value 0x0001.
;
MultiplyAccumulateCell MACRO AccumReg, Mult1Reg, Mult2Reg
vpmaddubsw zmm4,Mult1Reg,Mult2Reg
vpmaddwd zmm4,zmm4,zmm5
vpaddd AccumReg,AccumReg,zmm4
ENDM
;
; Macro Description:
;
; This macro generates code to multiply and accumulate each row of the output
; block.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
;
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
;
; Implicit Arguments:
;
; rbx - Supplies the address into the matrix A data plus 3 rows.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
;
; zmm14-zmm31 - Supplies the block accumulators.
;
ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
IF ColumnCount GE 48
vmovdqu32 zmm0,ZMMWORD PTR [rdx+VectorOffset]
vmovdqu32 zmm1,ZMMWORD PTR [rdx+r14+VectorOffset]
vmovdqu32 zmm2,ZMMWORD PTR [rdx+r14*2+VectorOffset]
ELSEIF ColumnCount GE 32
vmovdqu32 zmm1,ZMMWORD PTR [rdx+VectorOffset]
vmovdqu32 zmm2,ZMMWORD PTR [rdx+r14+VectorOffset]
ELSE
vmovdqu32 zmm2,ZMMWORD PTR [rdx+VectorOffset]
ENDIF
EmitIfCountGE RowCount, 1, <vpbroadcastd zmm3,DWORD PTR [rcx+BroadcastOffset]>
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <MultiplyAccumulateCell zmm26,zmm3,zmm0>
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <MultiplyAccumulateCell zmm20,zmm3,zmm1>
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <MultiplyAccumulateCell zmm14,zmm3,zmm2>
EmitIfCountGE RowCount, 2, <vpbroadcastd zmm3,DWORD PTR [rcx+r9+BroadcastOffset]>
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <MultiplyAccumulateCell zmm27,zmm3,zmm0>
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <MultiplyAccumulateCell zmm21,zmm3,zmm1>
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <MultiplyAccumulateCell zmm15,zmm3,zmm2>
EmitIfCountGE RowCount, 3, <vpbroadcastd zmm3,DWORD PTR [rcx+r9*2+BroadcastOffset]>
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <MultiplyAccumulateCell zmm28,zmm3,zmm0>
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <MultiplyAccumulateCell zmm22,zmm3,zmm1>
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <MultiplyAccumulateCell zmm16,zmm3,zmm2>
EmitIfCountGE RowCount, 4, <vpbroadcastd zmm3,DWORD PTR [rbx+BroadcastOffset]>
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <MultiplyAccumulateCell zmm29,zmm3,zmm0>
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <MultiplyAccumulateCell zmm23,zmm3,zmm1>
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <MultiplyAccumulateCell zmm17,zmm3,zmm2>
EmitIfCountGE RowCount, 5, <vpbroadcastd zmm3,DWORD PTR [rbx+r9+BroadcastOffset]>
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <MultiplyAccumulateCell zmm30,zmm3,zmm0>
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <MultiplyAccumulateCell zmm24,zmm3,zmm1>
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <MultiplyAccumulateCell zmm18,zmm3,zmm2>
EmitIfCountGE RowCount, 6, <vpbroadcastd zmm3,DWORD PTR [rbx+r9*2+BroadcastOffset]>
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <MultiplyAccumulateCell zmm31,zmm3,zmm0>
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <MultiplyAccumulateCell zmm25,zmm3,zmm1>
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <MultiplyAccumulateCell zmm19,zmm3,zmm2>
ENDM
;
; Generate the GEMM kernel.
;
GemmU8X8KernelAvx512Function U8S8, Avx512BW
END

View file

@ -0,0 +1,91 @@
;++
;
; Copyright (c) Microsoft Corporation. All rights reserved.
;
; Licensed under the MIT License.
;
; Module Name:
;
; QgemmU8S8KernelAvx512Common.inc
;
; Abstract:
;
; This module contains common kernel macros and structures for the quantized
; integer matrix/matrix multiply operation (QGEMM) for the AVX512BW and
; AVX512VNNI kernels.
;
;--
INCLUDE QgemmU8X8KernelAvx512Common.inc
;
; Macro Description:
;
; This macro generates code to execute the block compute macro multiple
; times and advancing the matrix A and matrix B data pointers.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; Implicit Arguments:
;
; rbx - Supplies the address into the matrix A data plus 3 rows.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
;
; zmm14-zmm31 - Supplies the block accumulators.
;
ComputeBlockLoop MACRO ColumnCount, RowCount
LOCAL ComputeBlockBy4Loop
LOCAL ProcessRemainingBlocks
LOCAL ComputeBlockBy1Loop
LOCAL ComputeBlockLoopExit
mov rsi,r9 ; reload row length remaining
IF ((RowCount AND 1) EQ 0)
sub rsi,4*4
jb ProcessRemainingBlocks
ComputeBlockBy4Loop:
ComputeBlock ColumnCount, RowCount, 0*64, 0
ComputeBlock ColumnCount, RowCount, 1*64, 4
ComputeBlock ColumnCount, RowCount, 2*64, 8
ComputeBlock ColumnCount, RowCount, 3*64, 12
add rcx,4*4 ; advance matrix A by 1 quad
IF RowCount GT 3
add rbx,4*4 ; advance matrix A plus 3 rows by 1 quad
ENDIF
add rdx,4*64 ; advance matrix B
sub rsi,4*4 ; decrement quads remaining
jae ComputeBlockBy4Loop
ProcessRemainingBlocks:
add rsi,4*4 ; correct for over-subtract above
jz ComputeBlockLoopExit
ENDIF
ComputeBlockBy1Loop:
ComputeBlock ColumnCount, RowCount, 0, 0
add rcx,4 ; advance matrix A by 1 quad
IF RowCount GT 3
add rbx,4 ; advance matrix A plus 3 rows by 1 quad
ENDIF
add rdx,64 ; advance matrix B
sub rsi,4 ; decrement quads remaining
jnz ComputeBlockBy1Loop
ComputeBlockLoopExit:
ENDM

View file

@ -0,0 +1,102 @@
;++
;
; Copyright (c) Microsoft Corporation. All rights reserved.
;
; Licensed under the MIT License.
;
; Module Name:
;
; QgemmU8S8KernelAvx512Vnni.asm
;
; Abstract:
;
; This module implements the kernels for the quantized integer matrix/matrix
; multiply operation (QGEMM).
;
; This implementation uses AVX512VNNI instructions.
;
;--
.xlist
INCLUDE mlasi.inc
INCLUDE QgemmU8S8KernelAvx512Common.inc
INCLUDE AssembleAvx512Vnni.inc
.list
;
; Macro Description:
;
; This macro generates code to multiply and accumulate each row of the output
; block.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
;
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
;
; Implicit Arguments:
;
; rbx - Supplies the address into the matrix A data plus 3 rows.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
;
; zmm14-zmm31 - Supplies the block accumulators.
;
ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
IF ColumnCount GE 48
vmovdqu32 zmm0,ZMMWORD PTR [rdx+VectorOffset]
vmovdqu32 zmm1,ZMMWORD PTR [rdx+r14+VectorOffset]
vmovdqu32 zmm2,ZMMWORD PTR [rdx+r14*2+VectorOffset]
ELSEIF ColumnCount GE 32
vmovdqu32 zmm1,ZMMWORD PTR [rdx+VectorOffset]
vmovdqu32 zmm2,ZMMWORD PTR [rdx+r14+VectorOffset]
ELSE
vmovdqu32 zmm2,ZMMWORD PTR [rdx+VectorOffset]
ENDIF
EmitIfCountGE RowCount, 1, <vpbroadcastd zmm3,DWORD PTR [rcx+BroadcastOffset]>
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <VpdpbusdsZmmZmmZmm zmm26,zmm3,zmm0>
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <VpdpbusdsZmmZmmZmm zmm20,zmm3,zmm1>
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <VpdpbusdsZmmZmmZmm zmm14,zmm3,zmm2>
EmitIfCountGE RowCount, 2, <vpbroadcastd zmm3,DWORD PTR [rcx+r9+BroadcastOffset]>
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <VpdpbusdsZmmZmmZmm zmm27,zmm3,zmm0>
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <VpdpbusdsZmmZmmZmm zmm21,zmm3,zmm1>
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <VpdpbusdsZmmZmmZmm zmm15,zmm3,zmm2>
EmitIfCountGE RowCount, 3, <vpbroadcastd zmm3,DWORD PTR [rcx+r9*2+BroadcastOffset]>
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <VpdpbusdsZmmZmmZmm zmm28,zmm3,zmm0>
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <VpdpbusdsZmmZmmZmm zmm22,zmm3,zmm1>
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <VpdpbusdsZmmZmmZmm zmm16,zmm3,zmm2>
EmitIfCountGE RowCount, 4, <vpbroadcastd zmm3,DWORD PTR [rbx+BroadcastOffset]>
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <VpdpbusdsZmmZmmZmm zmm29,zmm3,zmm0>
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <VpdpbusdsZmmZmmZmm zmm23,zmm3,zmm1>
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <VpdpbusdsZmmZmmZmm zmm17,zmm3,zmm2>
EmitIfCountGE RowCount, 5, <vpbroadcastd zmm3,DWORD PTR [rbx+r9+BroadcastOffset]>
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <VpdpbusdsZmmZmmZmm zmm30,zmm3,zmm0>
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <VpdpbusdsZmmZmmZmm zmm24,zmm3,zmm1>
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <VpdpbusdsZmmZmmZmm zmm18,zmm3,zmm2>
EmitIfCountGE RowCount, 6, <vpbroadcastd zmm3,DWORD PTR [rbx+r9*2+BroadcastOffset]>
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <VpdpbusdsZmmZmmZmm zmm31,zmm3,zmm0>
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <VpdpbusdsZmmZmmZmm zmm25,zmm3,zmm1>
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <VpdpbusdsZmmZmmZmm zmm19,zmm3,zmm2>
ENDM
;
; Generate the GEMM kernel.
;
GemmU8X8KernelAvx512Function U8S8, Avx512Vnni
END

View file

@ -19,10 +19,9 @@
.xlist
INCLUDE mlasi.inc
INCLUDE QgemmU8X8KernelAvx2Common.inc
.list
EXTERN MlasMaskMoveAvx:NEAR
;
; Stack frame layout for the U8U8 CopyPackA routine.
;
@ -73,44 +72,6 @@ GemmU8U8CopyPackBFrame STRUCT
GemmU8U8CopyPackBFrame ENDS
;
; Stack frame layout for the U8U8 kernel.
;
GemmU8U8KernelFrame STRUCT
SavedXmm6 OWORD ?
SavedXmm7 OWORD ?
SavedXmm8 OWORD ?
SavedXmm9 OWORD ?
SavedXmm10 OWORD ?
SavedXmm11 OWORD ?
SavedXmm12 OWORD ?
SavedXmm13 OWORD ?
SavedXmm14 OWORD ?
SavedXmm15 OWORD ?
SavedR14 QWORD ?
SavedR13 QWORD ?
SavedR12 QWORD ?
SavedRdi QWORD ?
SavedRsi QWORD ?
SavedRbx QWORD ?
SavedRbp QWORD ?
ReturnAddress QWORD ?
PreviousP1Home QWORD ?
PreviousP2Home QWORD ?
PreviousP3Home QWORD ?
PreviousP4Home QWORD ?
CountM QWORD ?
CountN QWORD ?
ldc QWORD ?
RowSumVector QWORD ?
ColumnSumVector QWORD ?
DepthValue QWORD ?
ZeroMode QWORD ?
GemmU8U8KernelFrame ENDS
;++
;
; Routine Description:
@ -157,10 +118,10 @@ GemmU8U8KernelFrame ENDS
push_reg r12
push_reg r13
alloc_stack (GemmU8U8CopyPackAFrame.SavedR13)
save_xmm128_avx xmm6,GemmU8U8CopyPackAFrame.SavedXmm6
save_xmm128_avx xmm7,GemmU8U8CopyPackAFrame.SavedXmm7
save_xmm128_avx xmm8,GemmU8U8CopyPackAFrame.SavedXmm8
save_xmm128_avx xmm9,GemmU8U8CopyPackAFrame.SavedXmm9
save_xmm128 xmm6,GemmU8U8CopyPackAFrame.SavedXmm6
save_xmm128 xmm7,GemmU8U8CopyPackAFrame.SavedXmm7
save_xmm128 xmm8,GemmU8U8CopyPackAFrame.SavedXmm8
save_xmm128 xmm9,GemmU8U8CopyPackAFrame.SavedXmm9
END_PROLOGUE
@ -195,13 +156,15 @@ GemmU8U8KernelFrame ENDS
;
; Process 4 rows of matrix A in a loop.
;
; For each row, zero extend the source bytes to 16-bits and write to the packed
; buffer. The packed buffer has the same data ordering as the source bytes, but
; the stride is CountK aligned up to an even number of 16-bit values.
; Zero extend the source bytes to 16-bits and write to the packed buffer.
;
; The packed buffer has the same data ordering as the source bytes, but CountK
; is aligned up to a multiple of 2 to maintain 32-bit alignment. All padding
; bytes are zero filled.
;
; These 16-bit values are also accumulated into an intermediate per-row
; accumulator. CountK cannot be greater than 256 to avoid overflowing these
; 16-bit accumulators.
; accumulator. CountK cannot be greater than 128 to avoid overflowing these
; signed 16-bit accumulators.
;
sub r9,4
@ -221,12 +184,12 @@ ProcessNextRowM4:
jb ProcessRemainingColumnsM4
ProcessNextColumnLoopM4:
lea rax,[rdx+r8*2] ; compute matrix A plus two rows
lea rax,[rdx+r8*2] ; compute matrix A plus 2 rows
vpmovzxbw ymm4,XMMWORD PTR [rdx]
vpmovzxbw ymm5,XMMWORD PTR [rdx+r8]
vpmovzxbw ymm6,XMMWORD PTR [rax]
vpmovzxbw ymm7,XMMWORD PTR [rax+r8]
lea rax,[rcx+r11*4] ; compute matrix D plus two rows
lea rax,[rcx+r11*4] ; compute matrix D plus 2 rows
vmovdqu YMMWORD PTR [rcx],ymm4
vmovdqu YMMWORD PTR [rcx+r11*2],ymm5
vmovdqu YMMWORD PTR [rax],ymm6
@ -252,7 +215,7 @@ ProcessRemainingColumnsM4:
mov rbp,rsp ; GemmU8U8CopyPackAFrame.PaddedMatrixAData
test bl,8 ; (CountK & 8) != 0?
jz CopyRemainingCountKLessThan8M4
lea r13,[rdx+r8*2] ; compute matrix A plus two rows
lea r13,[rdx+r8*2] ; compute matrix A plus 2 rows
mov rax,QWORD PTR [rdx]
mov QWORD PTR [rbp],rax
mov rax,QWORD PTR [rdx+r8]
@ -267,7 +230,7 @@ ProcessRemainingColumnsM4:
CopyRemainingCountKLessThan8M4:
test bl,4 ; (CountK & 4) != 0?
jz CopyRemainingCountKLessThan4M4
lea r13,[rdx+r8*2] ; compute matrix A plus two rows
lea r13,[rdx+r8*2] ; compute matrix A plus 2 rows
mov eax,DWORD PTR [rdx]
mov DWORD PTR [rbp],eax
mov eax,DWORD PTR [rdx+r8]
@ -282,7 +245,7 @@ CopyRemainingCountKLessThan8M4:
CopyRemainingCountKLessThan4M4:
test bl,2 ; (CountK & 2) != 0?
jz CopyRemainingCountKLessThan2M4
lea r13,[rdx+r8*2] ; compute matrix A plus two rows
lea r13,[rdx+r8*2] ; compute matrix A plus 2 rows
movzx eax,WORD PTR [rdx]
mov WORD PTR [rbp],ax
movzx eax,WORD PTR [rdx+r8]
@ -297,7 +260,7 @@ CopyRemainingCountKLessThan4M4:
CopyRemainingCountKLessThan2M4:
test bl,1 ; (CountK & 1) != 0?
jz ProcessPaddedMatrixADataM4
lea r13,[rdx+r8*2] ; compute matrix A plus two rows
lea r13,[rdx+r8*2] ; compute matrix A plus 2 rows
movzx eax,BYTE PTR [rdx]
mov BYTE PTR [rbp],al
movzx eax,BYTE PTR [rdx+r8]
@ -316,7 +279,7 @@ ProcessPaddedMatrixADataM4:
vpmovzxbw ymm5,XMMWORD PTR GemmU8U8CopyPackAFrame.PaddedMatrixAData[rsp+16]
vpmovzxbw ymm6,XMMWORD PTR GemmU8U8CopyPackAFrame.PaddedMatrixAData[rsp+32]
vpmovzxbw ymm7,XMMWORD PTR GemmU8U8CopyPackAFrame.PaddedMatrixAData[rsp+48]
lea rax,[rcx+r11*4] ; compute matrix D plus two rows
lea rax,[rcx+r11*4] ; compute matrix D plus 2 rows
vpmaskmovd YMMWORD PTR [rcx],ymm9,ymm4
vpmaskmovd YMMWORD PTR [rcx+r11*2],ymm9,ymm5
vpmaskmovd YMMWORD PTR [rax],ymm9,ymm6
@ -334,22 +297,14 @@ ProcessPaddedMatrixADataM4:
;
ReduceRowSumVectorM4:
vpunpckldq ymm4,ymm0,ymm1 ; [A5 B5 A4 B4 A1 B1 A0 B0]
vpunpckhdq ymm5,ymm0,ymm1 ; [A7 B7 A6 B6 A3 B3 A2 B2]
vpunpckldq ymm6,ymm2,ymm3 ; [C5 D5 C4 D4 C1 D1 C0 D0]
vpunpckhdq ymm7,ymm2,ymm3 ; [C7 D7 C6 D6 C3 D3 C2 D2]
vpunpcklqdq ymm0,ymm4,ymm6 ; [A4 B4 C4 D4 A0 B0 C0 D0]
vpunpckhqdq ymm1,ymm4,ymm6 ; [A5 B5 C5 D5 A1 B1 C1 D1]
vpunpcklqdq ymm2,ymm5,ymm7 ; [A6 B6 C6 D6 A2 B2 C2 D2]
vpunpckhqdq ymm3,ymm5,ymm7 ; [A7 B7 C7 D7 A3 B3 C3 D3]
vpaddw ymm0,ymm0,ymm1 ; reduction
vpaddw ymm0,ymm0,ymm2
vpaddw ymm0,ymm0,ymm3
vphaddw ymm0,ymm0,ymm1 ; reduce and interleave Sum1/Sum0
vphaddw ymm1,ymm2,ymm3 ; reduce and interleave Sum3/Sum2
vphaddw ymm0,ymm0,ymm1 ; reduce and interleave Sum3/Sum2/Sum1/Sum0
vextracti128 xmm1,ymm0,1 ; extract high pairs
vpaddw xmm0,xmm0,xmm1 ; reduction
vpmaddwd xmm0,xmm0,xmm8 ; multiply by offset and reduce
vpaddw xmm0,xmm0,xmm1 ; reduce low/high pairs
vpmaddwd xmm0,xmm0,xmm8 ; multiply by offset and reduce 32-bit sum
vmovdqu XMMWORD PTR [r12],xmm0
add r12,4*4 ; advance row sum vector by 4 dwords
add r12,4*4 ; advance row sum vector by 4 DWORDs
sub r9,4 ; subtract rows remaining
jae ProcessNextRowM4
@ -449,10 +404,10 @@ ReduceRowSumVectorM1:
ExitRoutine:
vzeroupper
vmovaps xmm6,GemmU8U8CopyPackAFrame.SavedXmm6[rsp]
vmovaps xmm7,GemmU8U8CopyPackAFrame.SavedXmm7[rsp]
vmovaps xmm8,GemmU8U8CopyPackAFrame.SavedXmm8[rsp]
vmovaps xmm9,GemmU8U8CopyPackAFrame.SavedXmm9[rsp]
movaps xmm6,GemmU8U8CopyPackAFrame.SavedXmm6[rsp]
movaps xmm7,GemmU8U8CopyPackAFrame.SavedXmm7[rsp]
movaps xmm8,GemmU8U8CopyPackAFrame.SavedXmm8[rsp]
movaps xmm9,GemmU8U8CopyPackAFrame.SavedXmm9[rsp]
add rsp,(GemmU8U8CopyPackAFrame.SavedR13)
BEGIN_EPILOGUE
@ -537,9 +492,9 @@ ProcessNextColumnN16:
jb ProcessRemainingRowsN16
ProcessNextRowLoopN16:
vmovdqu xmm2,XMMWORD PTR [rdx] ; load two rows
vmovdqu xmm2,XMMWORD PTR [rdx] ; load 2 rows
vmovdqu xmm3,XMMWORD PTR [rdx+r8]
lea rdx,[rdx+r8*2] ; advance matrix B by two rows
lea rdx,[rdx+r8*2] ; advance matrix B by 2 rows
vpunpcklbw xmm4,xmm2,xmm3 ; interleave row data
vpunpckhbw xmm3,xmm2,xmm3
vmovdqu XMMWORD PTR [rcx],xmm4 ; store interleaved rows
@ -549,7 +504,7 @@ ProcessNextRowLoopN16:
add rcx,32 ; advance matrix D by 32 bytes
vpaddw ymm0,ymm0,ymm4 ; accumulate per column
vpaddw ymm1,ymm1,ymm3
sub rbx,2 ; subtract columns remaining
sub rbx,2 ; subtract rows remaining
jae ProcessNextRowLoopN16
ProcessRemainingRowsN16:
@ -569,7 +524,7 @@ ReduceColumnSumVectorN16:
vpmaddwd ymm1,ymm1,ymm5 ; multiply by offset and reduce
vmovdqu YMMWORD PTR [r11],ymm0
vmovdqu YMMWORD PTR [r11+32],ymm1
add r11,64 ; advance column sum vector by 16 dwords
add r11,16*4 ; advance column sum vector by 16 DWORDs
sub r9,16 ; subtract columns remaining
jae ProcessNextColumnN16
@ -654,7 +609,7 @@ ProcessPaddedMatrixBDataK2:
vpmovzxbw ymm3,xmm3
vpaddw ymm0,ymm0,ymm4 ; accumulate per column
vpaddw ymm1,ymm1,ymm3
lea rsi,[rsi+r8*2] ; advance next matrix B by two rows
lea rsi,[rsi+r8*2] ; advance next matrix B by 2 rows
add rcx,32 ; advance matrix D by 32 bytes
sub r10,2 ; subtract columns remaining
jae ProcessNextRowLoopNUnaligned
@ -739,13 +694,12 @@ ReduceColumnSumVectorNUnaligned:
MultiplyAccumulateRow MACRO ColumnCount, Vec1Reg, Vec2Reg
IF ColumnCount EQ 16
vpmaddwd ymm3,ymm2,ymm0
IF ColumnCount EQ 16
vpaddd Vec1Reg,Vec1Reg,ymm3
vpmaddwd ymm2,ymm2,ymm1
vpaddd Vec2Reg,Vec2Reg,ymm2
ELSE
vpmaddwd ymm3,ymm2,ymm0
vpaddd Vec2Reg,Vec2Reg,ymm3
ENDIF
@ -775,7 +729,7 @@ ENDIF
;
; rdx - Supplies the address into the matrix B data.
;
; r10 - Supplies the length in bytes of a row from matrix A.
; r9 - Supplies the length in bytes of a row from matrix A.
;
; ymm4-ymm15 - Supplies the block accumulators.
;
@ -786,15 +740,15 @@ ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
EmitIfCountGE ColumnCount, 16, <vpmovzxbw ymm1,XMMWORD PTR [rdx+VectorOffset+16]>
EmitIfCountGE RowCount, 1, <vpbroadcastd ymm2,DWORD PTR [rcx+BroadcastOffset]>
EmitIfCountGE RowCount, 1, <MultiplyAccumulateRow ColumnCount, ymm4, ymm5>
EmitIfCountGE RowCount, 2, <vpbroadcastd ymm2,DWORD PTR [rcx+r10+BroadcastOffset]>
EmitIfCountGE RowCount, 2, <vpbroadcastd ymm2,DWORD PTR [rcx+r9+BroadcastOffset]>
EmitIfCountGE RowCount, 2, <MultiplyAccumulateRow ColumnCount, ymm6, ymm7>
EmitIfCountGE RowCount, 3, <vpbroadcastd ymm2,DWORD PTR [rcx+r10*2+BroadcastOffset]>
EmitIfCountGE RowCount, 3, <vpbroadcastd ymm2,DWORD PTR [rcx+r9*2+BroadcastOffset]>
EmitIfCountGE RowCount, 3, <MultiplyAccumulateRow ColumnCount, ymm8, ymm9>
EmitIfCountGE RowCount, 4, <vpbroadcastd ymm2,DWORD PTR [rbx+BroadcastOffset]>
EmitIfCountGE RowCount, 4, <MultiplyAccumulateRow ColumnCount, ymm10, ymm11>
EmitIfCountGE RowCount, 5, <vpbroadcastd ymm2,DWORD PTR [rbx+r10+BroadcastOffset]>
EmitIfCountGE RowCount, 5, <vpbroadcastd ymm2,DWORD PTR [rbx+r9+BroadcastOffset]>
EmitIfCountGE RowCount, 5, <MultiplyAccumulateRow ColumnCount, ymm12, ymm13>
EmitIfCountGE RowCount, 6, <vpbroadcastd ymm2,DWORD PTR [rbx+r10*2+BroadcastOffset]>
EmitIfCountGE RowCount, 6, <vpbroadcastd ymm2,DWORD PTR [rbx+r9*2+BroadcastOffset]>
EmitIfCountGE RowCount, 6, <MultiplyAccumulateRow ColumnCount, ymm14, ymm15>
ENDM
@ -802,8 +756,8 @@ ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
;
; Macro Description:
;
; This macro generates code to produce an output block for a set of columns
; and rows.
; This macro generates code to execute the block compute macro multiple
; times and advancing the matrix A and matrix B data pointers.
;
; Arguments:
;
@ -813,281 +767,59 @@ ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
;
; Implicit Arguments:
;
; rax - Supplies the length in bytes of a row from matrix C.
; rbx - Supplies the address into the matrix A data plus 3 rows.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the number of paired columns from matrix A and the number of
; paired rows from matrix B to iterate over.
; r9 - Supplies the length in bytes of a row from matrix A.
;
; r10 - Supplies the length in bytes of a row from matrix A.
;
; r12 - Supplies the address of the row sum vector.
;
; r13 - Supplies the address of the column sum vector.
; ymm4-ymm15 - Supplies the block accumulators.
;
ProduceOutputBlock MACRO ColumnCount, RowCount
ComputeBlockLoop MACRO ColumnCount, RowCount
LOCAL ComputeBlockLoop
LOCAL ComputeBlockBy2Loop
LOCAL ProcessRemainingBlocks
LOCAL ComputeBlockBy1Loop
LOCAL ComputeBlockLoopExit
;
; Initialize the accumulators with the sum of the global depth value constant,
; the column sums, and the row sums.
;
vpbroadcastd ymm1,DWORD PTR GemmU8U8KernelFrame.DepthValue[rsp]
IF ColumnCount EQ 16
vpaddd ymm0,ymm1,YMMWORD PTR [r13]
vpaddd ymm1,ymm1,YMMWORD PTR [r13+32]
add r13,16*4 ; advance ColumnSumVector by 16 columns
ELSE
vpaddd ymm1,ymm1,YMMWORD PTR [r13]
ENDIF
EmitIfCountGE RowCount, 1, <vpbroadcastd ymm5,DWORD PTR [r12]>
EmitIfCountGE RowCount, 2, <vpbroadcastd ymm7,DWORD PTR [r12+4]>
EmitIfCountGE RowCount, 3, <vpbroadcastd ymm9,DWORD PTR [r12+8]>
EmitIfCountGE RowCount, 4, <vpbroadcastd ymm11,DWORD PTR [r12+12]>
EmitIfCountGE RowCount, 5, <vpbroadcastd ymm13,DWORD PTR [r12+16]>
EmitIfCountGE RowCount, 6, <vpbroadcastd ymm15,DWORD PTR [r12+20]>
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <vpaddd ymm4,ymm5,ymm0>
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,ymm1>
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <vpaddd ymm6,ymm7,ymm0>
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,ymm1>
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <vpaddd ymm8,ymm9,ymm0>
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,ymm1>
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <vpaddd ymm10,ymm11,ymm0>
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,ymm1>
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <vpaddd ymm12,ymm13,ymm0>
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,ymm1>
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <vpaddd ymm14,ymm15,ymm0>
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,ymm1>
;
; Iterate over PairedCountK elements from matrix A and matrix B.
;
; Unrolling the loop to do two iterations improves performance slightly at the
; cost of larger code size. Balance this by only unrolling for the common case
; of computing 16 columns for an even number of rows.
;
mov rsi,r9 ; reload PairedCountK
IF RowCount GT 3
lea rbx,[r10*2+r10]
add rbx,rcx ; compute matrix A plus 3 rows
ENDIF
mov rsi,r9 ; reload row length remaining
IF (ColumnCount EQ 16) AND ((RowCount AND 1) EQ 0)
sub rsi,2
sub rsi,2*4
jb ProcessRemainingBlocks
ComputeBlockLoop:
ComputeBlockBy2Loop:
ComputeBlock ColumnCount, RowCount, 0, 0
ComputeBlock ColumnCount, RowCount, 32, 4
add rcx,2*4 ; advance matrix A by 2 pairs
IF RowCount GT 3
add rbx,2*4 ; advance matrix A plus 3 rows by 2 pairs
ENDIF
add rdx,2*32 ; advance matrix B by 64 columns
sub rsi,2 ; subtract pairs remaining
jae ComputeBlockLoop
add rdx,2*32 ; advance matrix B
sub rsi,2*4
jae ComputeBlockBy2Loop
ProcessRemainingBlocks:
add rsi,2 ; correct for over-subtract above
add rsi,2*4 ; correct for over-subtract above
jz ComputeBlockLoopExit
ComputeBlock ColumnCount, RowCount, 0, 0
add rdx,32 ; advance matrix B by 32 columns
add rdx,32 ; advance matrix B
ELSE
ComputeBlockLoop:
ComputeBlockBy1Loop:
ComputeBlock ColumnCount, RowCount, 0, 0
add rcx,4 ; advance matrix A by 1 pair
IF RowCount GT 3
add rbx,4 ; advance matrix A plus 3 rows by 1 pair
ENDIF
add rdx,32
dec rsi ; decrement pairs remaining
jnz ComputeBlockLoop
add rdx,32 ; advance matrix B
sub rsi,4
jnz ComputeBlockBy1Loop
ENDIF
ComputeBlockLoopExit:
IF RowCount GT 3
lea rbx,[r8+rax*2] ; compute matrix C plus 3 rows
add rbx,rax
ENDIF
ENDM
;
; Macro Description:
;
; This macro generates code to compute matrix multiplication for a fixed set
; of rows.
;
; Arguments:
;
; RowCount - Supplies the number of rows to process.
;
; Fallthrough - Supplies a non-blank value if the macro may fall through to
; the ExitKernel label.
;
; Implicit Arguments:
;
; rax - Supplies the length in bytes of a row from matrix C.
;
; rcx - Supplies the address of matrix A.
;
; rdx - Supplies the address of matrix B.
;
; r8 - Supplies the address of matrix C.
;
; rdi - Supplies the address of matrix A.
;
; rbp - Supplies the number of columns from matrix B and matrix C to iterate
; over.
;
; r9 - Supplies the number of paired columns from matrix A and the number of
; paired rows from matrix B to iterate over.
;
; r10 - Supplies the length in bytes of a row from matrix A.
;
; r12 - Supplies the address of the row sum vector.
;
; r13 - Supplies the address of the column sum vector.
;
; r14b - Supplies the zero mode flag.
;
ProcessCountM MACRO RowCount, Fallthrough
LOCAL ProcessNextColumnLoop16xN
LOCAL SkipAccumulateOutput16xNBlock
LOCAL OutputMasked16xNBlock
LOCAL ProcessRemainingCountN
LOCAL SkipAccumulateOutput8xNBlock
LOCAL SkipAccumulateOutputMasked16xNBlock
LOCAL OutputMasked8xNBlock
LOCAL SkipAccumulateOutputMasked8xNBlock
cmp rbp,8
jbe ProcessRemainingCountN
ProcessNextColumnLoop16xN:
ProduceOutputBlock 16, RowCount
sub rbp,16
jb OutputMasked16xNBlock
test r14b,r14b ; ZeroMode?
jnz SkipAccumulateOutput16xNBlock
EmitIfCountGE RowCount, 1, <vpaddd ymm4,ymm4,YMMWORD PTR [r8]>
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,YMMWORD PTR [r8+32]>
EmitIfCountGE RowCount, 2, <vpaddd ymm6,ymm6,YMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,YMMWORD PTR [r8+rax+32]>
EmitIfCountGE RowCount, 3, <vpaddd ymm8,ymm8,YMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,YMMWORD PTR [r8+rax*2+32]>
EmitIfCountGE RowCount, 4, <vpaddd ymm10,ymm10,YMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,YMMWORD PTR [rbx+32]>
EmitIfCountGE RowCount, 5, <vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax+32]>
EmitIfCountGE RowCount, 6, <vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]>
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2+32]>
SkipAccumulateOutput16xNBlock:
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8],ymm4>
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8+32],ymm5>
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax],ymm6>
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax+32],ymm7>
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2],ymm8>
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2+32],ymm9>
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx],ymm10>
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx+32],ymm11>
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax],ymm12>
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax+32],ymm13>
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2],ymm14>
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2+32],ymm15>
add r8,16*4 ; advance matrix C by 16 columns
mov rcx,rdi ; reload matrix A
cmp rbp,8
ja ProcessNextColumnLoop16xN
test rbp,rbp
jz ExitKernel
ProcessRemainingCountN:
ProduceOutputBlock 8, RowCount
cmp rbp,8
jb OutputMasked8xNBlock
test r14b,r14b ; ZeroMode?
jnz SkipAccumulateOutput8xNBlock
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,YMMWORD PTR [r8]>
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,YMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,YMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,YMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2]>
SkipAccumulateOutput8xNBlock:
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8],ymm5>
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax],ymm7>
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2],ymm9>
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx],ymm11>
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax],ymm13>
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2],ymm15>
jmp ExitKernel
OutputMasked16xNBlock:
test r14b,r14b ; ZeroMode?
jnz SkipAccumulateOutputMasked16xNBlock
EmitIfCountGE RowCount, 1, <vpaddd ymm4,ymm4,YMMWORD PTR [r8]>
EmitIfCountGE RowCount, 2, <vpaddd ymm6,ymm6,YMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 3, <vpaddd ymm8,ymm8,YMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 4, <vpaddd ymm10,ymm10,YMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 5, <vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 6, <vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]>
SkipAccumulateOutputMasked16xNBlock:
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8],ymm4>
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax],ymm6>
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2],ymm8>
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx],ymm10>
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax],ymm12>
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2],ymm14>
add r8,8*4 ; advance matrix C by 8 columns
IF RowCount GT 3
add rbx,8*4 ; advance matrix C plus 3 rows by 8 columns
ENDIF
add rbp,8 ; correct for over-subtract above
OutputMasked8xNBlock:
mov DWORD PTR GemmU8U8KernelFrame.CountN[rsp],ebp
vpbroadcastd ymm0,DWORD PTR GemmU8U8KernelFrame.CountN[rsp]
vpcmpgtd ymm0,ymm0,YMMWORD PTR [MlasMaskMoveAvx]
test r14b,r14b ; ZeroMode?
jnz SkipAccumulateOutputMasked8xNBlock
EmitIfCountGE RowCount, 1, <vpmaskmovd ymm4,ymm0,YMMWORD PTR [r8]>
EmitIfCountGE RowCount, 2, <vpmaskmovd ymm6,ymm0,YMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 3, <vpmaskmovd ymm8,ymm0,YMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 4, <vpmaskmovd ymm10,ymm0,YMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 5, <vpmaskmovd ymm12,ymm0,YMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 6, <vpmaskmovd ymm14,ymm0,YMMWORD PTR [rbx+rax*2]>
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,ymm4>
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,ymm6>
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,ymm8>
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,ymm10>
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,ymm12>
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,ymm14>
SkipAccumulateOutputMasked8xNBlock:
EmitIfCountGE RowCount, 1, <vpmaskmovd YMMWORD PTR [r8],ymm0,ymm5>
EmitIfCountGE RowCount, 2, <vpmaskmovd YMMWORD PTR [r8+rax],ymm0,ymm7>
EmitIfCountGE RowCount, 3, <vpmaskmovd YMMWORD PTR [r8+rax*2],ymm0,ymm9>
EmitIfCountGE RowCount, 4, <vpmaskmovd YMMWORD PTR [rbx],ymm0,ymm11>
EmitIfCountGE RowCount, 5, <vpmaskmovd YMMWORD PTR [rbx+rax],ymm0,ymm13>
EmitIfCountGE RowCount, 6, <vpmaskmovd YMMWORD PTR [rbx+rax*2],ymm0,ymm15>
IFB <Fallthrough>
jmp ExitKernel
ENDIF
ENDM
@ -1108,8 +840,8 @@ ENDIF
;
; C (r8) - Supplies the address of matrix C.
;
; PairedCountK (r9) - Supplies the number of paired columns from matrix A and
; the number of paired rows from matrix B to iterate over.
; PairCountK (r9) - Supplies the number of pair columns from matrix A and the
; number of pair rows from matrix B to iterate over.
;
; CountM - Supplies the maximum number of rows that can be processed for
; matrix A and matrix C. The actual number of rows handled for this
@ -1129,7 +861,7 @@ ENDIF
; every column of matrix C.
;
; DepthValue - Supplies the value CountK multiplied by the zero point offset
; of matrixA multplied by the zero point offset of matrix B. This value is
; of matrix A multplied by the zero point offset of matrix B. This value is
; accumulated into every element of matrix C.
;
; ZeroMode - Supplies true if the output matrix must be zero initialized,
@ -1149,30 +881,29 @@ ENDIF
push_reg rdi
push_reg r12
push_reg r13
push_reg r14
alloc_stack (GemmU8U8KernelFrame.SavedR14)
save_xmm128_avx xmm6,GemmU8U8KernelFrame.SavedXmm6
save_xmm128_avx xmm7,GemmU8U8KernelFrame.SavedXmm7
save_xmm128_avx xmm8,GemmU8U8KernelFrame.SavedXmm8
save_xmm128_avx xmm9,GemmU8U8KernelFrame.SavedXmm9
save_xmm128_avx xmm10,GemmU8U8KernelFrame.SavedXmm10
save_xmm128_avx xmm11,GemmU8U8KernelFrame.SavedXmm11
save_xmm128_avx xmm12,GemmU8U8KernelFrame.SavedXmm12
save_xmm128_avx xmm13,GemmU8U8KernelFrame.SavedXmm13
save_xmm128_avx xmm14,GemmU8U8KernelFrame.SavedXmm14
save_xmm128_avx xmm15,GemmU8U8KernelFrame.SavedXmm15
alloc_stack (GemmU8X8KernelFrame.SavedR13)
save_xmm128 xmm6,GemmU8X8KernelFrame.SavedXmm6
save_xmm128 xmm7,GemmU8X8KernelFrame.SavedXmm7
save_xmm128 xmm8,GemmU8X8KernelFrame.SavedXmm8
save_xmm128 xmm9,GemmU8X8KernelFrame.SavedXmm9
save_xmm128 xmm10,GemmU8X8KernelFrame.SavedXmm10
save_xmm128 xmm11,GemmU8X8KernelFrame.SavedXmm11
save_xmm128 xmm12,GemmU8X8KernelFrame.SavedXmm12
save_xmm128 xmm13,GemmU8X8KernelFrame.SavedXmm13
save_xmm128 xmm14,GemmU8X8KernelFrame.SavedXmm14
save_xmm128 xmm15,GemmU8X8KernelFrame.SavedXmm15
END_PROLOGUE
mov rdi,rcx
mov rbp,GemmU8U8KernelFrame.CountN[rsp]
mov rax,GemmU8U8KernelFrame.ldc[rsp]
mov rbp,GemmU8X8KernelFrame.CountN[rsp]
mov rax,GemmU8X8KernelFrame.ldc[rsp]
shl rax,2 ; convert ldc to bytes
lea r10,[r9*4]
mov r11,GemmU8U8KernelFrame.CountM[rsp]
mov r12,GemmU8U8KernelFrame.RowSumVector[rsp]
mov r13,GemmU8U8KernelFrame.ColumnSumVector[rsp]
movzx r14,BYTE PTR GemmU8U8KernelFrame.ZeroMode[rsp]
shl r9,2 ; convert to row length
movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp]
mov r11,GemmU8X8KernelFrame.CountM[rsp]
mov r12,GemmU8X8KernelFrame.RowSumVector[rsp]
mov r13,GemmU8X8KernelFrame.ColumnSumVector[rsp]
;
; Process CountM rows of the matrices.
@ -1204,21 +935,20 @@ ProcessCountM6:
ExitKernel:
mov eax,r11d
vzeroupper
vmovaps xmm6,GemmU8U8KernelFrame.SavedXmm6[rsp]
vmovaps xmm7,GemmU8U8KernelFrame.SavedXmm7[rsp]
vmovaps xmm8,GemmU8U8KernelFrame.SavedXmm8[rsp]
vmovaps xmm9,GemmU8U8KernelFrame.SavedXmm9[rsp]
vmovaps xmm10,GemmU8U8KernelFrame.SavedXmm10[rsp]
vmovaps xmm11,GemmU8U8KernelFrame.SavedXmm11[rsp]
vmovaps xmm12,GemmU8U8KernelFrame.SavedXmm12[rsp]
vmovaps xmm13,GemmU8U8KernelFrame.SavedXmm13[rsp]
vmovaps xmm14,GemmU8U8KernelFrame.SavedXmm14[rsp]
vmovaps xmm15,GemmU8U8KernelFrame.SavedXmm15[rsp]
add rsp,(GemmU8U8KernelFrame.SavedR14)
movaps xmm6,GemmU8X8KernelFrame.SavedXmm6[rsp]
movaps xmm7,GemmU8X8KernelFrame.SavedXmm7[rsp]
movaps xmm8,GemmU8X8KernelFrame.SavedXmm8[rsp]
movaps xmm9,GemmU8X8KernelFrame.SavedXmm9[rsp]
movaps xmm10,GemmU8X8KernelFrame.SavedXmm10[rsp]
movaps xmm11,GemmU8X8KernelFrame.SavedXmm11[rsp]
movaps xmm12,GemmU8X8KernelFrame.SavedXmm12[rsp]
movaps xmm13,GemmU8X8KernelFrame.SavedXmm13[rsp]
movaps xmm14,GemmU8X8KernelFrame.SavedXmm14[rsp]
movaps xmm15,GemmU8X8KernelFrame.SavedXmm15[rsp]
add rsp,(GemmU8X8KernelFrame.SavedR13)
BEGIN_EPILOGUE
pop r14
pop r13
pop r12
pop rdi

View file

@ -25,39 +25,26 @@ INCLUDE QgemmU8U8KernelAvx512Common.inc
;
; Macro Description:
;
; This macro generates code to multiply and accumulator a single row of the
; This macro generates code to multiply and accumulator a single cell of the
; output block.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
; AccumReg - Supplies the register to accumulate into.
;
; Vec1Reg - Supplies the high block accumulator register (when ColumnCount
; is 32).
; Mult1Reg - Supplies the first multiplication operand register.
;
; Vec2Reg - Supplies the low block accumulator register.
; Mult2Reg - Supplies the second multiplication operand register.
;
; Implicit Arguments:
;
; zmm28 - Supplies the first vector loaded from matrix B.
;
; zmm29 - Supplies the second vector loaded from matrix B (when ColumnCount
; is 32).
;
; zmm30 - Supplies the broadcast value loaded from matrix A.
; zmm4 - Supplies a scratch register for intermediate results.
;
MultiplyAccumulateRow MACRO ColumnCount, Vec1Reg, Vec2Reg
MultiplyAccumulateCell MACRO AccumReg, Mult1Reg, Mult2Reg
IF ColumnCount EQ 32
vpmaddwd zmm31,zmm30,zmm28
vpaddd Vec1Reg,Vec1Reg,zmm31
vpmaddwd zmm30,zmm30,zmm29
vpaddd Vec2Reg,Vec2Reg,zmm30
ELSE
vpmaddwd zmm31,zmm30,zmm28
vpaddd Vec2Reg,Vec2Reg,zmm31
ENDIF
vpmaddwd zmm4,Mult1Reg,Mult2Reg
vpaddd AccumReg,AccumReg,zmm4
ENDM
@ -73,6 +60,10 @@ ENDIF
;
; RowCount - Supplies the number of rows to produce.
;
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
;
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
;
; Implicit Arguments:
;
; rbx - Supplies the address into the matrix A data plus 3 rows.
@ -81,27 +72,49 @@ ENDIF
;
; rdx - Supplies the address into the matrix B data.
;
; r10 - Supplies the length in bytes of a row from matrix A.
; r9 - Supplies the length in bytes of a row from matrix A.
;
; zmm16-zmm27 - Supplies the block accumulators.
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
;
; zmm14-zmm31 - Supplies the block accumulators.
;
ComputeBlock MACRO ColumnCount, RowCount
ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
vpmovzxbw zmm28,YMMWORD PTR [rdx]
EmitIfCountGE ColumnCount, 32, <vpmovzxbw zmm29,YMMWORD PTR [rdx+r10*8]>
EmitIfCountGE RowCount, 1, <vpbroadcastd zmm30,DWORD PTR [rcx]>
EmitIfCountGE RowCount, 1, <MultiplyAccumulateRow ColumnCount, zmm16, zmm17>
EmitIfCountGE RowCount, 2, <vpbroadcastd zmm30,DWORD PTR [rcx+r10]>
EmitIfCountGE RowCount, 2, <MultiplyAccumulateRow ColumnCount, zmm18, zmm19>
EmitIfCountGE RowCount, 3, <vpbroadcastd zmm30,DWORD PTR [rcx+r10*2]>
EmitIfCountGE RowCount, 3, <MultiplyAccumulateRow ColumnCount, zmm20, zmm21>
EmitIfCountGE RowCount, 4, <vpbroadcastd zmm30,DWORD PTR [rbx]>
EmitIfCountGE RowCount, 4, <MultiplyAccumulateRow ColumnCount, zmm22, zmm23>
EmitIfCountGE RowCount, 5, <vpbroadcastd zmm30,DWORD PTR [rbx+r10]>
EmitIfCountGE RowCount, 5, <MultiplyAccumulateRow ColumnCount, zmm24, zmm25>
EmitIfCountGE RowCount, 6, <vpbroadcastd zmm30,DWORD PTR [rbx+r10*2]>
EmitIfCountGE RowCount, 6, <MultiplyAccumulateRow ColumnCount, zmm26, zmm27>
IF ColumnCount GE 48
vpmovzxbw zmm0,YMMWORD PTR [rdx+VectorOffset]
vpmovzxbw zmm1,YMMWORD PTR [rdx+r14+VectorOffset]
vpmovzxbw zmm2,YMMWORD PTR [rdx+r14*2+VectorOffset]
ELSEIF ColumnCount GE 32
vpmovzxbw zmm1,YMMWORD PTR [rdx+VectorOffset]
vpmovzxbw zmm2,YMMWORD PTR [rdx+r14+VectorOffset]
ELSE
vpmovzxbw zmm2,YMMWORD PTR [rdx+VectorOffset]
ENDIF
EmitIfCountGE RowCount, 1, <vpbroadcastd zmm3,DWORD PTR [rcx+BroadcastOffset]>
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <MultiplyAccumulateCell zmm26,zmm3,zmm0>
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <MultiplyAccumulateCell zmm20,zmm3,zmm1>
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <MultiplyAccumulateCell zmm14,zmm3,zmm2>
EmitIfCountGE RowCount, 2, <vpbroadcastd zmm3,DWORD PTR [rcx+r9+BroadcastOffset]>
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <MultiplyAccumulateCell zmm27,zmm3,zmm0>
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <MultiplyAccumulateCell zmm21,zmm3,zmm1>
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <MultiplyAccumulateCell zmm15,zmm3,zmm2>
EmitIfCountGE RowCount, 3, <vpbroadcastd zmm3,DWORD PTR [rcx+r9*2+BroadcastOffset]>
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <MultiplyAccumulateCell zmm28,zmm3,zmm0>
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <MultiplyAccumulateCell zmm22,zmm3,zmm1>
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <MultiplyAccumulateCell zmm16,zmm3,zmm2>
EmitIfCountGE RowCount, 4, <vpbroadcastd zmm3,DWORD PTR [rbx+BroadcastOffset]>
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <MultiplyAccumulateCell zmm29,zmm3,zmm0>
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <MultiplyAccumulateCell zmm23,zmm3,zmm1>
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <MultiplyAccumulateCell zmm17,zmm3,zmm2>
EmitIfCountGE RowCount, 5, <vpbroadcastd zmm3,DWORD PTR [rbx+r9+BroadcastOffset]>
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <MultiplyAccumulateCell zmm30,zmm3,zmm0>
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <MultiplyAccumulateCell zmm24,zmm3,zmm1>
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <MultiplyAccumulateCell zmm18,zmm3,zmm2>
EmitIfCountGE RowCount, 6, <vpbroadcastd zmm3,DWORD PTR [rbx+r9*2+BroadcastOffset]>
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <MultiplyAccumulateCell zmm31,zmm3,zmm0>
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <MultiplyAccumulateCell zmm25,zmm3,zmm1>
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <MultiplyAccumulateCell zmm19,zmm3,zmm2>
ENDM
@ -109,6 +122,6 @@ ComputeBlock MACRO ColumnCount, RowCount
; Generate the GEMM kernel.
;
GemmU8U8KernelAvx512Function Avx512BW
GemmU8X8KernelAvx512Function U8U8, Avx512BW
END

View file

@ -16,39 +16,13 @@
;
;--
;
; Stack frame layout for the U8U8 kernel.
;
GemmU8U8KernelFrame STRUCT
SavedR14 QWORD ?
SavedR13 QWORD ?
SavedR12 QWORD ?
SavedRdi QWORD ?
SavedRsi QWORD ?
SavedRbx QWORD ?
SavedRbp QWORD ?
ReturnAddress QWORD ?
PreviousP1Home QWORD ?
PreviousP2Home QWORD ?
PreviousP3Home QWORD ?
PreviousP4Home QWORD ?
CountM QWORD ?
CountN QWORD ?
ldc QWORD ?
RowSumVector QWORD ?
ColumnSumVector QWORD ?
DepthValue QWORD ?
ZeroMode QWORD ?
GemmU8U8KernelFrame ENDS
INCLUDE QgemmU8X8KernelAvx512Common.inc
;
; Macro Description:
;
; This macro generates code to produce an output block for a set of columns
; and rows.
; This macro generates code to execute the block compute macro multiple
; times and advancing the matrix A and matrix B data pointers.
;
; Arguments:
;
@ -58,328 +32,33 @@ GemmU8U8KernelFrame ENDS
;
; Implicit Arguments:
;
; rbx - Supplies the address into the matrix A data plus 3 rows.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the number of paired columns from matrix A and the number of
; paired rows from matrix B to iterate over.
; r9 - Supplies the length in bytes of a row from matrix A.
;
; r10 - Supplies the length in bytes of a row from matrix A.
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
;
; r12 - Supplies the address of the row sum vector.
;
; r13 - Supplies the address of the column sum vector.
; zmm14-zmm31 - Supplies the block accumulators.
;
ProduceOutputBlock MACRO ColumnCount, RowCount
ComputeBlockLoop MACRO ColumnCount, RowCount
LOCAL ComputeBlockLoop
LOCAL ComputeBlockBy1Loop
;
; Initialize the accumulators with the sum of the global depth value constant,
; the column sums, and the row sums.
;
mov rsi,r9 ; reload row length remaining
vpbroadcastd zmm31,DWORD PTR GemmU8U8KernelFrame.DepthValue[rsp]
IF ColumnCount EQ 32
vpaddd zmm30,zmm31,ZMMWORD PTR [r13]
vpaddd zmm31,zmm31,ZMMWORD PTR [r13+64]
add r13,32*4 ; advance ColumnSumVector by 32 columns
ELSE
vpaddd zmm31,zmm31,ZMMWORD PTR [r13]
ENDIF
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <vpaddd zmm16,zmm30,DWORD BCST [r12]>
EmitIfCountGE RowCount, 1, <vpaddd zmm17,zmm31,DWORD BCST [r12]>
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <vpaddd zmm18,zmm30,DWORD BCST [r12+4]>
EmitIfCountGE RowCount, 2, <vpaddd zmm19,zmm31,DWORD BCST [r12+4]>
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <vpaddd zmm20,zmm30,DWORD BCST [r12+8]>
EmitIfCountGE RowCount, 3, <vpaddd zmm21,zmm31,DWORD BCST [r12+8]>
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <vpaddd zmm22,zmm30,DWORD BCST [r12+12]>
EmitIfCountGE RowCount, 4, <vpaddd zmm23,zmm31,DWORD BCST [r12+12]>
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <vpaddd zmm24,zmm30,DWORD BCST [r12+16]>
EmitIfCountGE RowCount, 5, <vpaddd zmm25,zmm31,DWORD BCST [r12+16]>
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <vpaddd zmm26,zmm30,DWORD BCST [r12+20]>
EmitIfCountGE RowCount, 6, <vpaddd zmm27,zmm31,DWORD BCST [r12+20]>
;
; Iterate over PairedCountK elements from matrix A and matrix B.
;
mov rsi,r9 ; reload PairedCountK
IF RowCount GT 3
lea rbx,[r10*2+r10]
add rbx,rcx ; compute matrix A plus 3 rows
ENDIF
ComputeBlockLoop:
ComputeBlock ColumnCount, RowCount
ComputeBlockBy1Loop:
ComputeBlock ColumnCount, RowCount, 0, 0
add rcx,4 ; advance matrix A by 1 pair
IF RowCount GT 3
add rbx,4 ; advance matrix A plus 3 rows by 1 pair
ENDIF
add rdx,32
dec rsi ; decrement pairs remaining
jnz ComputeBlockLoop
IF RowCount GT 3
lea rbx,[r8+rax*2] ; compute matrix C plus 3 rows
add rbx,rax
ENDIF
ENDM
;
; Macro Description:
;
; This macro generates code to compute matrix multiplication for a fixed set
; of rows.
;
; Arguments:
;
; RowCount - Supplies the number of rows to process.
;
; Implicit Arguments:
;
; rax - Supplies the length in bytes of a row from matrix C.
;
; rcx - Supplies the address of matrix A.
;
; rdx - Supplies the address of matrix B.
;
; r8 - Supplies the address of matrix C.
;
; rdi - Supplies the address of matrix A.
;
; rbp - Supplies the number of columns from matrix B and matrix C to iterate
; over.
;
; r9 - Supplies the number of paired columns from matrix A and the number of
; paired rows from matrix B to iterate over.
;
; r10 - Supplies the length in bytes of a row from matrix A.
;
; r12 - Supplies the address of the row sum vector.
;
; r13 - Supplies the address of the column sum vector.
;
; r14b - Supplies the zero mode flag.
;
ProcessCountM MACRO RowCount
LOCAL ProcessNextColumnLoop32xN
LOCAL SkipAccumulateOutput32xNBlock
LOCAL Output16xNBlock
LOCAL Output16xNBlockWithMask
LOCAL SkipAccumulateOutput16xNBlockWithMask
LOCAL ProcessRemainingCountN
cmp rbp,16
jbe ProcessRemainingCountN
ProcessNextColumnLoop32xN:
ProduceOutputBlock 32, RowCount
lea rdx,[rdx+r10*8] ; advance matrix B by 8*PairedCountK
test r14b,r14b ; ZeroMode?
jnz SkipAccumulateOutput32xNBlock
EmitIfCountGE RowCount, 1, <vpaddd zmm16,zmm16,ZMMWORD PTR [r8]>
EmitIfCountGE RowCount, 2, <vpaddd zmm18,zmm18,ZMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 3, <vpaddd zmm20,zmm20,ZMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 4, <vpaddd zmm22,zmm22,ZMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 5, <vpaddd zmm24,zmm24,ZMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 6, <vpaddd zmm26,zmm26,ZMMWORD PTR [rbx+rax*2]>
SkipAccumulateOutput32xNBlock:
EmitIfCountGE RowCount, 1, <vmovdqu32 ZMMWORD PTR [r8],zmm16>
EmitIfCountGE RowCount, 2, <vmovdqu32 ZMMWORD PTR [r8+rax],zmm18>
EmitIfCountGE RowCount, 3, <vmovdqu32 ZMMWORD PTR [r8+rax*2],zmm20>
EmitIfCountGE RowCount, 4, <vmovdqu32 ZMMWORD PTR [rbx],zmm22>
EmitIfCountGE RowCount, 5, <vmovdqu32 ZMMWORD PTR [rbx+rax],zmm24>
EmitIfCountGE RowCount, 6, <vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm26>
add r8,16*4 ; advance matrix C by 16 columns
IF RowCount GT 3
add rbx,16*4 ; advance matrix C plus 3 rows by 16 columns
ENDIF
sub rbp,16
Output16xNBlock:
sub rbp,16
jae Output16xNBlockWithMask
lea ecx,[ebp+16] ; correct for over-subtract above
mov esi,1
shl esi,cl
dec esi
kmovw k1,esi ; update mask for remaining columns
xor ebp,ebp ; no more columns remaining
Output16xNBlockWithMask:
test r14b,r14b ; ZeroMode?
jnz SkipAccumulateOutput16xNBlockWithMask
EmitIfCountGE RowCount, 1, <vpaddd zmm17{k1},zmm17,ZMMWORD PTR [r8]>
EmitIfCountGE RowCount, 2, <vpaddd zmm19{k1},zmm19,ZMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 3, <vpaddd zmm21{k1},zmm21,ZMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 4, <vpaddd zmm23{k1},zmm23,ZMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 5, <vpaddd zmm25{k1},zmm25,ZMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 6, <vpaddd zmm27{k1},zmm27,ZMMWORD PTR [rbx+rax*2]>
SkipAccumulateOutput16xNBlockWithMask:
EmitIfCountGE RowCount, 1, <vmovdqu32 ZMMWORD PTR [r8]{k1},zmm17>
EmitIfCountGE RowCount, 2, <vmovdqu32 ZMMWORD PTR [r8+rax]{k1},zmm19>
EmitIfCountGE RowCount, 3, <vmovdqu32 ZMMWORD PTR [r8+rax*2]{k1},zmm21>
EmitIfCountGE RowCount, 4, <vmovdqu32 ZMMWORD PTR [rbx]{k1},zmm23>
EmitIfCountGE RowCount, 5, <vmovdqu32 ZMMWORD PTR [rbx+rax]{k1},zmm25>
EmitIfCountGE RowCount, 6, <vmovdqu32 ZMMWORD PTR [rbx+rax*2]{k1},zmm27>
add r8,16*4 ; advance matrix C by 16 columns
mov rcx,rdi ; reload matrix A
cmp rbp,16
ja ProcessNextColumnLoop32xN
test rbp,rbp
jz ExitKernel
ProcessRemainingCountN:
ProduceOutputBlock 16, RowCount
jmp Output16xNBlock
ENDM
;
; Macro Description:
;
; This macro generates the common AVX512 code for the inner kernel to compute
; matrix multiplication.
;
; Arguments:
;
; Isa - Supplies the instruction set architecture string for function tags.
;
GemmU8U8KernelAvx512Function MACRO Isa
;++
;
; Routine Description:
;
; This routine is an inner kernel to compute matrix multiplication for a
; set of rows.
;
; Arguments:
;
; A (rcx) - Supplies the address of matrix A. The matrix data has been packed
; using MlasGemmU8U8CopyPackAAvx2.
;
; B (rdx) - Supplies the address of matrix B. The matrix data has been packed
; using MlasGemmU8U8CopyPackBAvx2.
;
; C (r8) - Supplies the address of matrix C.
;
; PairedCountK (r9) - Supplies the number of paired columns from matrix A and
; the number of paired rows from matrix B to iterate over.
;
; CountM - Supplies the maximum number of rows that can be processed for
; matrix A and matrix C. The actual number of rows handled for this
; invocation depends on the kernel implementation.
;
; CountN - Supplies the number of columns from matrix B and matrix C to iterate
; over.
;
; ldc - Supplies the first dimension of matrix C.
;
; RowSumVector - Supplies the sum of each row from matrix A multiplied by the
; zero point offset of matrix B. These values are accumulated into every
; row of matrix C.
;
; ColumnSumVector - Supplies the sum of each column from matrix B multiplied
; by the zero point offset of matrix A. These values are accumulated into
; every column of matrix C.
;
; DepthValue - Supplies the value CountK multiplied by the zero point offset
; of matrixA multplied by the zero point offset of matrix B. This value is
; accumulated into every element of matrix C.
;
; ZeroMode - Supplies true if the output matrix must be zero initialized,
; else false if the output matrix is accumulated into.
;
; Return Value:
;
; Returns the number of rows handled.
;
;--
NESTED_ENTRY MlasGemmU8U8Kernel&Isa&, _TEXT
rex_push_reg rbp
push_reg rbx
push_reg rsi
push_reg rdi
push_reg r12
push_reg r13
push_reg r14
END_PROLOGUE
mov rdi,rcx
mov rbp,GemmU8U8KernelFrame.CountN[rsp]
mov rax,GemmU8U8KernelFrame.ldc[rsp]
shl rax,2 ; convert ldc to bytes
lea r10,[r9*4]
mov r11,GemmU8U8KernelFrame.CountM[rsp]
mov r12,GemmU8U8KernelFrame.RowSumVector[rsp]
mov r13,GemmU8U8KernelFrame.ColumnSumVector[rsp]
movzx r14,BYTE PTR GemmU8U8KernelFrame.ZeroMode[rsp]
mov esi,-1
kmovw k1,esi ; update mask to write all columns
;
; Process CountM rows of the matrices.
;
cmp r11,5
ja ProcessCountM6
je ProcessCountM5
cmp r11,3
ja ProcessCountM4
je ProcessCountM3
cmp r11,1
je ProcessCountM1
ProcessCountM2:
ProcessCountM 2
ProcessCountM4:
ProcessCountM 4
ProcessCountM6:
mov r11d,6 ; return 6 rows handled
ProcessCountM 6
;
; Restore non-volatile registers and return.
;
ExitKernel:
mov eax,r11d
BEGIN_EPILOGUE
pop r14
pop r13
pop r12
pop rdi
pop rsi
pop rbx
pop rbp
ret
ProcessCountM1:
ProcessCountM 1
ProcessCountM3:
ProcessCountM 3
ProcessCountM5:
ProcessCountM 5
NESTED_END MlasGemmU8U8Kernel&Isa&, _TEXT
add rdx,32 ; advance matrix B
sub rsi,4
jnz ComputeBlockBy1Loop
ENDM

View file

@ -35,6 +35,10 @@ INCLUDE AssembleAvx512Vnni.inc
;
; RowCount - Supplies the number of rows to produce.
;
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
;
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
;
; Implicit Arguments:
;
; rbx - Supplies the address into the matrix A data plus 3 rows.
@ -43,41 +47,56 @@ INCLUDE AssembleAvx512Vnni.inc
;
; rdx - Supplies the address into the matrix B data.
;
; r10 - Supplies the length in bytes of a row from matrix A.
; r9 - Supplies the length in bytes of a row from matrix A.
;
; zmm16-zmm27 - Supplies the block accumulators.
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
;
; zmm14-zmm31 - Supplies the block accumulators.
;
ComputeBlock MACRO ColumnCount, RowCount
ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
vpmovzxbw zmm28,YMMWORD PTR [rdx]
IF ColumnCount EQ 32
vpmovzxbw zmm29,YMMWORD PTR [rdx+r10*8]
EmitIfCountGE RowCount, 1, <vpbroadcastd zmm30,DWORD PTR [rcx]>
EmitIfCountGE RowCount, 1, <VpdpwssdZmmZmmZmm zmm16,zmm28,zmm30>
EmitIfCountGE RowCount, 1, <VpdpwssdZmmZmmZmm zmm17,zmm29,zmm30>
EmitIfCountGE RowCount, 2, <vpbroadcastd zmm30,DWORD PTR [rcx+r10]>
EmitIfCountGE RowCount, 2, <VpdpwssdZmmZmmZmm zmm18,zmm28,zmm30>
EmitIfCountGE RowCount, 2, <VpdpwssdZmmZmmZmm zmm19,zmm29,zmm30>
EmitIfCountGE RowCount, 3, <vpbroadcastd zmm30,DWORD PTR [rcx+r10*2]>
EmitIfCountGE RowCount, 3, <VpdpwssdZmmZmmZmm zmm20,zmm28,zmm30>
EmitIfCountGE RowCount, 3, <VpdpwssdZmmZmmZmm zmm21,zmm29,zmm30>
EmitIfCountGE RowCount, 4, <vpbroadcastd zmm30,DWORD PTR [rbx]>
EmitIfCountGE RowCount, 4, <VpdpwssdZmmZmmZmm zmm22,zmm28,zmm30>
EmitIfCountGE RowCount, 4, <VpdpwssdZmmZmmZmm zmm23,zmm29,zmm30>
EmitIfCountGE RowCount, 5, <vpbroadcastd zmm30,DWORD PTR [rbx+r10]>
EmitIfCountGE RowCount, 5, <VpdpwssdZmmZmmZmm zmm24,zmm28,zmm30>
EmitIfCountGE RowCount, 5, <VpdpwssdZmmZmmZmm zmm25,zmm29,zmm30>
EmitIfCountGE RowCount, 6, <vpbroadcastd zmm30,DWORD PTR [rbx+r10*2]>
EmitIfCountGE RowCount, 6, <VpdpwssdZmmZmmZmm zmm26,zmm28,zmm30>
EmitIfCountGE RowCount, 6, <VpdpwssdZmmZmmZmm zmm27,zmm29,zmm30>
IF ColumnCount GE 32
IF ColumnCount GE 48
vpmovzxbw zmm0,YMMWORD PTR [rdx+VectorOffset]
vpmovzxbw zmm1,YMMWORD PTR [rdx+r14+VectorOffset]
vpmovzxbw zmm2,YMMWORD PTR [rdx+r14*2+VectorOffset]
ELSE
EmitIfCountGE RowCount, 1, <VpdpwssdZmmZmmBroadcast zmm17,zmm28,rcx>
EmitIfCountGE RowCount, 2, <VpdpwssdZmmZmmBroadcast zmm19,zmm28,rcx,r10,1>
EmitIfCountGE RowCount, 3, <VpdpwssdZmmZmmBroadcast zmm21,zmm28,rcx,r10,2>
EmitIfCountGE RowCount, 4, <VpdpwssdZmmZmmBroadcast zmm23,zmm28,rbx>
EmitIfCountGE RowCount, 5, <VpdpwssdZmmZmmBroadcast zmm25,zmm28,rbx,r10,1>
EmitIfCountGE RowCount, 6, <VpdpwssdZmmZmmBroadcast zmm27,zmm28,rbx,r10,2>
vpmovzxbw zmm1,YMMWORD PTR [rdx+VectorOffset]
vpmovzxbw zmm2,YMMWORD PTR [rdx+r14+VectorOffset]
ENDIF
EmitIfCountGE RowCount, 1, <vpbroadcastd zmm3,DWORD PTR [rcx+BroadcastOffset]>
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <VpdpwssdZmmZmmZmm zmm26,zmm3,zmm0>
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <VpdpwssdZmmZmmZmm zmm20,zmm3,zmm1>
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <VpdpwssdZmmZmmZmm zmm14,zmm3,zmm2>
EmitIfCountGE RowCount, 2, <vpbroadcastd zmm3,DWORD PTR [rcx+r9+BroadcastOffset]>
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <VpdpwssdZmmZmmZmm zmm27,zmm3,zmm0>
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <VpdpwssdZmmZmmZmm zmm21,zmm3,zmm1>
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <VpdpwssdZmmZmmZmm zmm15,zmm3,zmm2>
EmitIfCountGE RowCount, 3, <vpbroadcastd zmm3,DWORD PTR [rcx+r9*2+BroadcastOffset]>
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <VpdpwssdZmmZmmZmm zmm28,zmm3,zmm0>
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <VpdpwssdZmmZmmZmm zmm22,zmm3,zmm1>
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <VpdpwssdZmmZmmZmm zmm16,zmm3,zmm2>
EmitIfCountGE RowCount, 4, <vpbroadcastd zmm3,DWORD PTR [rbx+BroadcastOffset]>
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <VpdpwssdZmmZmmZmm zmm29,zmm3,zmm0>
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <VpdpwssdZmmZmmZmm zmm23,zmm3,zmm1>
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <VpdpwssdZmmZmmZmm zmm17,zmm3,zmm2>
EmitIfCountGE RowCount, 5, <vpbroadcastd zmm3,DWORD PTR [rbx+r9+BroadcastOffset]>
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <VpdpwssdZmmZmmZmm zmm30,zmm3,zmm0>
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <VpdpwssdZmmZmmZmm zmm24,zmm3,zmm1>
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <VpdpwssdZmmZmmZmm zmm18,zmm3,zmm2>
EmitIfCountGE RowCount, 6, <vpbroadcastd zmm3,DWORD PTR [rbx+r9*2+BroadcastOffset]>
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <VpdpwssdZmmZmmZmm zmm31,zmm3,zmm0>
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <VpdpwssdZmmZmmZmm zmm25,zmm3,zmm1>
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <VpdpwssdZmmZmmZmm zmm19,zmm3,zmm2>
ELSE
vpmovzxbw zmm2,YMMWORD PTR [rdx+VectorOffset]
EmitIfCountGE RowCount, 1, <VpdpwssdZmmZmmBroadcast zmm14,zmm2,rcx,BroadcastOffset>
EmitIfCountGE RowCount, 2, <VpdpwssdZmmZmmBroadcast zmm15,zmm2,rcx,BroadcastOffset,r9,1>
EmitIfCountGE RowCount, 3, <VpdpwssdZmmZmmBroadcast zmm16,zmm2,rcx,BroadcastOffset,r9,2>
EmitIfCountGE RowCount, 4, <VpdpwssdZmmZmmBroadcast zmm17,zmm2,rbx,BroadcastOffset>
EmitIfCountGE RowCount, 5, <VpdpwssdZmmZmmBroadcast zmm18,zmm2,rbx,BroadcastOffset,r9,1>
EmitIfCountGE RowCount, 6, <VpdpwssdZmmZmmBroadcast zmm19,zmm2,rbx,BroadcastOffset,r9,2>
ENDIF
ENDM
@ -86,6 +105,6 @@ ENDIF
; Generate the GEMM kernel.
;
GemmU8U8KernelAvx512Function Avx512Vnni
GemmU8X8KernelAvx512Function U8U8, Avx512Vnni
END

View file

@ -0,0 +1,302 @@
;++
;
; Copyright (c) Microsoft Corporation. All rights reserved.
;
; Licensed under the MIT License.
;
; Module Name:
;
; QgemmU8X8KernelAvx2Common.inc
;
; Abstract:
;
; This module contains common kernel macros and structures for the quantized
; integer matrix/matrix multiply operation (QGEMM) for the AVX2 kernels.
;
;--
EXTERN MlasMaskMoveAvx:NEAR
;
; Stack frame layout for the U8S8 and U8U8 kernels.
;
GemmU8X8KernelFrame STRUCT
SavedXmm6 OWORD ?
SavedXmm7 OWORD ?
SavedXmm8 OWORD ?
SavedXmm9 OWORD ?
SavedXmm10 OWORD ?
SavedXmm11 OWORD ?
SavedXmm12 OWORD ?
SavedXmm13 OWORD ?
SavedXmm14 OWORD ?
SavedXmm15 OWORD ?
Padding QWORD ?
SavedR13 QWORD ?
SavedR12 QWORD ?
SavedRdi QWORD ?
SavedRsi QWORD ?
SavedRbx QWORD ?
SavedRbp QWORD ?
ReturnAddress QWORD ?
PreviousP1Home QWORD ?
PreviousP2Home QWORD ?
PreviousP3Home QWORD ?
PreviousP4Home QWORD ?
CountM QWORD ?
CountN QWORD ?
ldc QWORD ?
RowSumVector QWORD ?
ColumnSumVector QWORD ?
DepthValue QWORD ?
ZeroMode QWORD ?
GemmU8X8KernelFrame ENDS
;
; Macro Description:
;
; This macro generates code to produce an output block for a set of columns
; and rows.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; Implicit Arguments:
;
; rax - Supplies the length in bytes of a row from matrix C.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; r12 - Supplies the address of the row sum vector.
;
; r13 - Supplies the address of the column sum vector.
;
; ymm4-ymm15 - Supplies the block accumulators.
;
ProduceOutputBlock MACRO ColumnCount, RowCount
;
; Initialize the accumulators with the sum of the global depth value constant,
; the column sums, and the row sums.
;
vpbroadcastd ymm1,DWORD PTR GemmU8X8KernelFrame.DepthValue[rsp]
IF ColumnCount EQ 16
vpaddd ymm0,ymm1,YMMWORD PTR [r13]
vpaddd ymm1,ymm1,YMMWORD PTR [r13+32]
add r13,16*4 ; advance ColumnSumVector by 16 columns
ELSE
vpaddd ymm1,ymm1,YMMWORD PTR [r13]
ENDIF
EmitIfCountGE RowCount, 1, <vpbroadcastd ymm5,DWORD PTR [r12]>
EmitIfCountGE RowCount, 2, <vpbroadcastd ymm7,DWORD PTR [r12+4]>
EmitIfCountGE RowCount, 3, <vpbroadcastd ymm9,DWORD PTR [r12+8]>
EmitIfCountGE RowCount, 4, <vpbroadcastd ymm11,DWORD PTR [r12+12]>
EmitIfCountGE RowCount, 5, <vpbroadcastd ymm13,DWORD PTR [r12+16]>
EmitIfCountGE RowCount, 6, <vpbroadcastd ymm15,DWORD PTR [r12+20]>
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <vpaddd ymm4,ymm5,ymm0>
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,ymm1>
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <vpaddd ymm6,ymm7,ymm0>
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,ymm1>
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <vpaddd ymm8,ymm9,ymm0>
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,ymm1>
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <vpaddd ymm10,ymm11,ymm0>
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,ymm1>
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <vpaddd ymm12,ymm13,ymm0>
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,ymm1>
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <vpaddd ymm14,ymm15,ymm0>
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,ymm1>
;
; Iterate over the length of a matrix A row to produce the output accumulators.
;
IF RowCount GT 3
lea rbx,[r9*2+r9]
add rbx,rcx ; compute matrix A plus 3 rows
ENDIF
ComputeBlockLoop ColumnCount, RowCount
IF RowCount GT 3
lea rbx,[r8+rax*2] ; compute matrix C plus 3 rows
add rbx,rax
ENDIF
ENDM
;
; Macro Description:
;
; This macro generates code to compute matrix multiplication for a fixed set
; of rows.
;
; Arguments:
;
; RowCount - Supplies the number of rows to process.
;
; Fallthrough - Supplies a non-blank value if the macro may fall through to
; the ExitKernel label.
;
; Implicit Arguments:
;
; rax - Supplies the length in bytes of a row from matrix C.
;
; rcx - Supplies the address of matrix A.
;
; rdx - Supplies the address of matrix B.
;
; r8 - Supplies the address of matrix C.
;
; rdi - Supplies the address of matrix A.
;
; rbp - Supplies the number of columns from matrix B and matrix C to iterate
; over.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; r10b - Supplies the zero mode flag.
;
; r12 - Supplies the address of the row sum vector.
;
; r13 - Supplies the address of the column sum vector.
;
ProcessCountM MACRO RowCount, Fallthrough
LOCAL ProcessNextColumnLoop16xN
LOCAL SkipAccumulateOutput16xNBlock
LOCAL OutputMasked16xNBlock
LOCAL ProcessRemainingCountN
LOCAL SkipAccumulateOutput8xNBlock
LOCAL SkipAccumulateOutputMasked16xNBlock
LOCAL OutputMasked8xNBlock
LOCAL SkipAccumulateOutputMasked8xNBlock
cmp rbp,8
jbe ProcessRemainingCountN
ProcessNextColumnLoop16xN:
ProduceOutputBlock 16, RowCount
sub rbp,16
jb OutputMasked16xNBlock
test r10b,r10b ; ZeroMode?
jnz SkipAccumulateOutput16xNBlock
EmitIfCountGE RowCount, 1, <vpaddd ymm4,ymm4,YMMWORD PTR [r8]>
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,YMMWORD PTR [r8+32]>
EmitIfCountGE RowCount, 2, <vpaddd ymm6,ymm6,YMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,YMMWORD PTR [r8+rax+32]>
EmitIfCountGE RowCount, 3, <vpaddd ymm8,ymm8,YMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,YMMWORD PTR [r8+rax*2+32]>
EmitIfCountGE RowCount, 4, <vpaddd ymm10,ymm10,YMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,YMMWORD PTR [rbx+32]>
EmitIfCountGE RowCount, 5, <vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax+32]>
EmitIfCountGE RowCount, 6, <vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]>
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2+32]>
SkipAccumulateOutput16xNBlock:
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8],ymm4>
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8+32],ymm5>
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax],ymm6>
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax+32],ymm7>
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2],ymm8>
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2+32],ymm9>
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx],ymm10>
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx+32],ymm11>
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax],ymm12>
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax+32],ymm13>
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2],ymm14>
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2+32],ymm15>
add r8,16*4 ; advance matrix C by 16 columns
mov rcx,rdi ; reload matrix A
cmp rbp,8
ja ProcessNextColumnLoop16xN
test rbp,rbp
jz ExitKernel
ProcessRemainingCountN:
ProduceOutputBlock 8, RowCount
cmp rbp,8
jb OutputMasked8xNBlock
test r10b,r10b ; ZeroMode?
jnz SkipAccumulateOutput8xNBlock
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,YMMWORD PTR [r8]>
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,YMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,YMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,YMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2]>
SkipAccumulateOutput8xNBlock:
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8],ymm5>
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax],ymm7>
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2],ymm9>
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx],ymm11>
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax],ymm13>
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2],ymm15>
jmp ExitKernel
OutputMasked16xNBlock:
test r10b,r10b ; ZeroMode?
jnz SkipAccumulateOutputMasked16xNBlock
EmitIfCountGE RowCount, 1, <vpaddd ymm4,ymm4,YMMWORD PTR [r8]>
EmitIfCountGE RowCount, 2, <vpaddd ymm6,ymm6,YMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 3, <vpaddd ymm8,ymm8,YMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 4, <vpaddd ymm10,ymm10,YMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 5, <vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 6, <vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]>
SkipAccumulateOutputMasked16xNBlock:
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8],ymm4>
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax],ymm6>
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2],ymm8>
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx],ymm10>
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax],ymm12>
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2],ymm14>
add r8,8*4 ; advance matrix C by 8 columns
IF RowCount GT 3
add rbx,8*4 ; advance matrix C plus 3 rows by 8 columns
ENDIF
add rbp,8 ; correct for over-subtract above
OutputMasked8xNBlock:
mov DWORD PTR GemmU8X8KernelFrame.CountN[rsp],ebp
vpbroadcastd ymm0,DWORD PTR GemmU8X8KernelFrame.CountN[rsp]
vpcmpgtd ymm0,ymm0,YMMWORD PTR [MlasMaskMoveAvx]
test r10b,r10b ; ZeroMode?
jnz SkipAccumulateOutputMasked8xNBlock
EmitIfCountGE RowCount, 1, <vpmaskmovd ymm4,ymm0,YMMWORD PTR [r8]>
EmitIfCountGE RowCount, 2, <vpmaskmovd ymm6,ymm0,YMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 3, <vpmaskmovd ymm8,ymm0,YMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 4, <vpmaskmovd ymm10,ymm0,YMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 5, <vpmaskmovd ymm12,ymm0,YMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 6, <vpmaskmovd ymm14,ymm0,YMMWORD PTR [rbx+rax*2]>
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,ymm4>
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,ymm6>
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,ymm8>
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,ymm10>
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,ymm12>
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,ymm14>
SkipAccumulateOutputMasked8xNBlock:
EmitIfCountGE RowCount, 1, <vpmaskmovd YMMWORD PTR [r8],ymm0,ymm5>
EmitIfCountGE RowCount, 2, <vpmaskmovd YMMWORD PTR [r8+rax],ymm0,ymm7>
EmitIfCountGE RowCount, 3, <vpmaskmovd YMMWORD PTR [r8+rax*2],ymm0,ymm9>
EmitIfCountGE RowCount, 4, <vpmaskmovd YMMWORD PTR [rbx],ymm0,ymm11>
EmitIfCountGE RowCount, 5, <vpmaskmovd YMMWORD PTR [rbx+rax],ymm0,ymm13>
EmitIfCountGE RowCount, 6, <vpmaskmovd YMMWORD PTR [rbx+rax*2],ymm0,ymm15>
IFB <Fallthrough>
jmp ExitKernel
ENDIF
ENDM

View file

@ -0,0 +1,438 @@
;++
;
; Copyright (c) Microsoft Corporation. All rights reserved.
;
; Licensed under the MIT License.
;
; Module Name:
;
; QgemmU8X8KernelAvx512Common.inc
;
; Abstract:
;
; This module contains common kernel macros and structures for the quantized
; integer matrix/matrix multiply operation (QGEMM) for the AVX512BW and
; AVX512VNNI kernels.
;
;--
;
; Stack frame layout for the U8S8 and U8U8 kernels.
;
GemmU8X8KernelFrame STRUCT
SavedXmm14 OWORD ?
SavedXmm15 OWORD ?
SavedR14 QWORD ?
SavedR13 QWORD ?
SavedR12 QWORD ?
SavedRdi QWORD ?
SavedRsi QWORD ?
SavedRbx QWORD ?
SavedRbp QWORD ?
ReturnAddress QWORD ?
PreviousP1Home QWORD ?
PreviousP2Home QWORD ?
PreviousP3Home QWORD ?
PreviousP4Home QWORD ?
CountM QWORD ?
CountN QWORD ?
ldc QWORD ?
RowSumVector QWORD ?
ColumnSumVector QWORD ?
DepthValue QWORD ?
ZeroMode QWORD ?
GemmU8X8KernelFrame ENDS
;
; Macro Description:
;
; This macro generates code to produce an output block for a set of columns
; and rows.
;
; Arguments:
;
; ColumnCount - Supplies the number of columns to produce.
;
; RowCount - Supplies the number of rows to produce.
;
; Implicit Arguments:
;
; rax - Supplies the length in bytes of a row from matrix C.
;
; rcx - Supplies the address into the matrix A data.
;
; rdx - Supplies the address into the matrix B data.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; r12 - Supplies the address of the row sum vector.
;
; r13 - Supplies the address of the column sum vector.
;
ProduceOutputBlock MACRO ColumnCount, RowCount
;
; Initialize the accumulators with the sum of the global depth value constant,
; the column sums, and the row sums.
;
vpbroadcastd zmm3,DWORD PTR GemmU8X8KernelFrame.DepthValue[rsp]
IF ColumnCount GE 32
IF ColumnCount GE 48
vpaddd zmm2,zmm3,ZMMWORD PTR [r13]
vpaddd zmm1,zmm3,ZMMWORD PTR [r13+64]
vpaddd zmm0,zmm3,ZMMWORD PTR [r13+128]
ELSE
vpaddd zmm1,zmm3,ZMMWORD PTR [r13]
vpaddd zmm0,zmm3,ZMMWORD PTR [r13+64]
ENDIF
add_immed r13,ColumnCount*4 ; advance ColumnSumVector by N columns
ELSE
vpaddd zmm0,zmm3,ZMMWORD PTR [r13]
ENDIF
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <vpaddd zmm14,zmm0,DWORD BCST [r12]>
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <vpaddd zmm20,zmm1,DWORD BCST [r12]>
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <vpaddd zmm26,zmm2,DWORD BCST [r12]>
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <vpaddd zmm15,zmm0,DWORD BCST [r12+4]>
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <vpaddd zmm21,zmm1,DWORD BCST [r12+4]>
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <vpaddd zmm27,zmm2,DWORD BCST [r12+4]>
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <vpaddd zmm16,zmm0,DWORD BCST [r12+8]>
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <vpaddd zmm22,zmm1,DWORD BCST [r12+8]>
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <vpaddd zmm28,zmm2,DWORD BCST [r12+8]>
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <vpaddd zmm17,zmm0,DWORD BCST [r12+12]>
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <vpaddd zmm23,zmm1,DWORD BCST [r12+12]>
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <vpaddd zmm29,zmm2,DWORD BCST [r12+12]>
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <vpaddd zmm18,zmm0,DWORD BCST [r12+16]>
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <vpaddd zmm24,zmm1,DWORD BCST [r12+16]>
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <vpaddd zmm30,zmm2,DWORD BCST [r12+16]>
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <vpaddd zmm19,zmm0,DWORD BCST [r12+20]>
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <vpaddd zmm25,zmm1,DWORD BCST [r12+20]>
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <vpaddd zmm31,zmm2,DWORD BCST [r12+20]>
;
; Iterate over the length of a matrix A row to produce the output accumulators.
;
IF RowCount GT 3
lea rbx,[r9*2+r9]
add rbx,rcx ; compute matrix A plus 3 rows
ENDIF
ComputeBlockLoop ColumnCount, RowCount
IF RowCount GT 3
lea rbx,[r8+rax*2] ; compute matrix C plus 3 rows
add rbx,rax
ENDIF
ENDM
;
; Macro Description:
;
; This macro generates code to compute matrix multiplication for a fixed set
; of rows.
;
; Arguments:
;
; RowCount - Supplies the number of rows to process.
;
; Implicit Arguments:
;
; rax - Supplies the length in bytes of a row from matrix C.
;
; rcx - Supplies the address of matrix A.
;
; rdx - Supplies the address of matrix B.
;
; r8 - Supplies the address of matrix C.
;
; rdi - Supplies the address of matrix A.
;
; rbp - Supplies the number of columns from matrix B and matrix C to iterate
; over.
;
; r9 - Supplies the length in bytes of a row from matrix A.
;
; r10b - Supplies the zero mode flag.
;
; r12 - Supplies the address of the row sum vector.
;
; r13 - Supplies the address of the column sum vector.
;
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
;
ProcessCountM MACRO RowCount
LOCAL ProcessNextColumnLoop32xN
LOCAL Output32xNBlock
LOCAL SkipAccumulateOutput32xNBlock
LOCAL Output16xNBlock
LOCAL Output16xNBlockWithMask
LOCAL SkipAccumulateOutput16xNBlockWithMask
LOCAL ProcessRemainingCountN
LOCAL ProcessNextColumnLoop48xN
LOCAL SkipAccumulateOutput48xNBlock
cmp rbp,32
ja ProcessNextColumnLoop48xN
cmp rbp,16
jbe ProcessRemainingCountN
ProcessNextColumnLoop32xN:
ProduceOutputBlock 32, RowCount
add rdx,r14 ; advance matrix B by packed block stride
Output32xNBlock:
test r10b,r10b ; ZeroMode?
jnz SkipAccumulateOutput32xNBlock
EmitIfCountGE RowCount, 1, <vpaddd zmm20,zmm20,ZMMWORD PTR [r8]>
EmitIfCountGE RowCount, 2, <vpaddd zmm21,zmm21,ZMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 3, <vpaddd zmm22,zmm22,ZMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 4, <vpaddd zmm23,zmm23,ZMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 5, <vpaddd zmm24,zmm24,ZMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 6, <vpaddd zmm25,zmm25,ZMMWORD PTR [rbx+rax*2]>
SkipAccumulateOutput32xNBlock:
EmitIfCountGE RowCount, 1, <vmovdqu32 ZMMWORD PTR [r8],zmm20>
EmitIfCountGE RowCount, 2, <vmovdqu32 ZMMWORD PTR [r8+rax],zmm21>
EmitIfCountGE RowCount, 3, <vmovdqu32 ZMMWORD PTR [r8+rax*2],zmm22>
EmitIfCountGE RowCount, 4, <vmovdqu32 ZMMWORD PTR [rbx],zmm23>
EmitIfCountGE RowCount, 5, <vmovdqu32 ZMMWORD PTR [rbx+rax],zmm24>
EmitIfCountGE RowCount, 6, <vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm25>
add r8,16*4 ; advance matrix C by 16 columns
IF RowCount GT 3
add rbx,16*4 ; advance matrix C plus 3 rows by 16 columns
ENDIF
sub rbp,16
Output16xNBlock:
sub rbp,16
jae Output16xNBlockWithMask
lea ecx,[ebp+16] ; correct for over-subtract above
mov esi,1
shl esi,cl
dec esi
kmovw k1,esi ; update mask for remaining columns
xor ebp,ebp ; no more columns remaining
Output16xNBlockWithMask:
test r10b,r10b ; ZeroMode?
jnz SkipAccumulateOutput16xNBlockWithMask
EmitIfCountGE RowCount, 1, <vpaddd zmm14{k1},zmm14,ZMMWORD PTR [r8]>
EmitIfCountGE RowCount, 2, <vpaddd zmm15{k1},zmm15,ZMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 3, <vpaddd zmm16{k1},zmm16,ZMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 4, <vpaddd zmm17{k1},zmm17,ZMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 5, <vpaddd zmm18{k1},zmm18,ZMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 6, <vpaddd zmm19{k1},zmm19,ZMMWORD PTR [rbx+rax*2]>
SkipAccumulateOutput16xNBlockWithMask:
EmitIfCountGE RowCount, 1, <vmovdqu32 ZMMWORD PTR [r8]{k1},zmm14>
EmitIfCountGE RowCount, 2, <vmovdqu32 ZMMWORD PTR [r8+rax]{k1},zmm15>
EmitIfCountGE RowCount, 3, <vmovdqu32 ZMMWORD PTR [r8+rax*2]{k1},zmm16>
EmitIfCountGE RowCount, 4, <vmovdqu32 ZMMWORD PTR [rbx]{k1},zmm17>
EmitIfCountGE RowCount, 5, <vmovdqu32 ZMMWORD PTR [rbx+rax]{k1},zmm18>
EmitIfCountGE RowCount, 6, <vmovdqu32 ZMMWORD PTR [rbx+rax*2]{k1},zmm19>
add r8,16*4 ; advance matrix C by 16 columns
mov rcx,rdi ; reload matrix A
cmp rbp,32
ja ProcessNextColumnLoop48xN
cmp rbp,16
ja ProcessNextColumnLoop32xN
test rbp,rbp
jz ExitKernel
ProcessRemainingCountN:
ProduceOutputBlock 16, RowCount
jmp Output16xNBlock
ProcessNextColumnLoop48xN:
ProduceOutputBlock 48, RowCount
lea rdx,[rdx+r14*2] ; advance matrix B by packed block stride
test r10b,r10b ; ZeroMode?
jnz SkipAccumulateOutput48xNBlock
EmitIfCountGE RowCount, 1, <vpaddd zmm26,zmm26,ZMMWORD PTR [r8]>
EmitIfCountGE RowCount, 2, <vpaddd zmm27,zmm27,ZMMWORD PTR [r8+rax]>
EmitIfCountGE RowCount, 3, <vpaddd zmm28,zmm28,ZMMWORD PTR [r8+rax*2]>
EmitIfCountGE RowCount, 4, <vpaddd zmm29,zmm29,ZMMWORD PTR [rbx]>
EmitIfCountGE RowCount, 5, <vpaddd zmm30,zmm30,ZMMWORD PTR [rbx+rax]>
EmitIfCountGE RowCount, 6, <vpaddd zmm31,zmm31,ZMMWORD PTR [rbx+rax*2]>
SkipAccumulateOutput48xNBlock:
EmitIfCountGE RowCount, 1, <vmovdqu32 ZMMWORD PTR [r8],zmm26>
EmitIfCountGE RowCount, 2, <vmovdqu32 ZMMWORD PTR [r8+rax],zmm27>
EmitIfCountGE RowCount, 3, <vmovdqu32 ZMMWORD PTR [r8+rax*2],zmm28>
EmitIfCountGE RowCount, 4, <vmovdqu32 ZMMWORD PTR [rbx],zmm29>
EmitIfCountGE RowCount, 5, <vmovdqu32 ZMMWORD PTR [rbx+rax],zmm30>
EmitIfCountGE RowCount, 6, <vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm31>
add r8,16*4 ; advance matrix C by 16 columns
IF RowCount GT 3
add rbx,16*4 ; advance matrix C plus 3 rows by 16 columns
ENDIF
sub rbp,16
jmp Output32xNBlock
ENDM
;
; Macro Description:
;
; This macro generates the common AVX512 code for the inner kernel to compute
; matrix multiplication.
;
; Arguments:
;
; Type - Supplies the kernel type string for function tags.
;
; Isa - Supplies the instruction set architecture string for function tags.
;
GemmU8X8KernelAvx512Function MACRO Type, Isa
;++
;
; Routine Description:
;
; This routine is an inner kernel to compute matrix multiplication for a
; set of rows.
;
; Arguments:
;
; A (rcx) - Supplies the address of matrix A. The matrix data has been packed
; using MlasGemmU8X8CopyPackAAvx2.
;
; B (rdx) - Supplies the address of matrix B. The matrix data has been packed
; using MlasGemmU8X8CopyPackBAvx2.
;
; C (r8) - Supplies the address of matrix C.
;
; QuadCountK (r9) - Supplies the number of quad columns from matrix A and the
; number of quad rows from matrix B to iterate over.
;
; CountM - Supplies the maximum number of rows that can be processed for
; matrix A and matrix C. The actual number of rows handled for this
; invocation depends on the kernel implementation.
;
; CountN - Supplies the number of columns from matrix B and matrix C to iterate
; over.
;
; ldc - Supplies the first dimension of matrix C.
;
; RowSumVector - Supplies the sum of each row from matrix A multiplied by the
; zero point offset of matrix B. These values are accumulated into every
; row of matrix C.
;
; ColumnSumVector - Supplies the sum of each column from matrix B multiplied
; by the zero point offset of matrix A. These values are accumulated into
; every column of matrix C.
;
; DepthValue - Supplies the value CountK multiplied by the zero point offset
; of matrixA multplied by the zero point offset of matrix B. This value is
; accumulated into every element of matrix C.
;
; ZeroMode - Supplies true if the output matrix must be zero initialized,
; else false if the output matrix is accumulated into.
;
; Return Value:
;
; Returns the number of rows handled.
;
;--
NESTED_ENTRY MlasGemm&Type&Kernel&Isa&, _TEXT
rex_push_reg rbp
push_reg rbx
push_reg rsi
push_reg rdi
push_reg r12
push_reg r13
push_reg r14
alloc_stack (GemmU8X8KernelFrame.SavedR14)
save_xmm128 xmm14,GemmU8X8KernelFrame.SavedXmm14
save_xmm128 xmm15,GemmU8X8KernelFrame.SavedXmm15
END_PROLOGUE
mov rdi,rcx
mov rbp,GemmU8X8KernelFrame.CountN[rsp]
mov rax,GemmU8X8KernelFrame.ldc[rsp]
shl rax,2 ; convert ldc to bytes
shl r9,2 ; convert to row length
movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp]
mov r11,GemmU8X8KernelFrame.CountM[rsp]
mov r12,GemmU8X8KernelFrame.RowSumVector[rsp]
mov r13,GemmU8X8KernelFrame.ColumnSumVector[rsp]
mov esi,-1
kmovw k1,esi ; update mask to write all columns
IFIDNI <Type>, <U8S8>
IFIDNI <Isa>, <Avx512BW>
neg esi
vpbroadcastw zmm5,esi ; generate 512-bit word vector [0x0001]
ENDIF
mov r14,r9
shl r14,4 ; compute matrix B packed stride
ELSE
lea r14,[r9*8] ; compute matrix B packed stride
ENDIF
;
; Process CountM rows of the matrices.
;
cmp r11,5
ja ProcessCountM6
je ProcessCountM5
cmp r11,3
ja ProcessCountM4
je ProcessCountM3
cmp r11,1
je ProcessCountM1
ProcessCountM2:
ProcessCountM 2
ProcessCountM4:
ProcessCountM 4
ProcessCountM6:
mov r11d,6 ; return 6 rows handled
ProcessCountM 6
;
; Restore non-volatile registers and return.
;
ExitKernel:
mov eax,r11d
vzeroupper
movaps xmm14,GemmU8X8KernelFrame.SavedXmm14[rsp]
movaps xmm15,GemmU8X8KernelFrame.SavedXmm15[rsp]
add rsp,(GemmU8X8KernelFrame.SavedR14)
BEGIN_EPILOGUE
pop r14
pop r13
pop r12
pop rdi
pop rsi
pop rbx
pop rbp
ret
ProcessCountM1:
ProcessCountM 1
ProcessCountM3:
ProcessCountM 3
ProcessCountM5:
ProcessCountM 5
NESTED_END MlasGemm&Type&Kernel&Isa&, _TEXT
ENDM

View file

@ -194,6 +194,52 @@ void
typedef MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE* PMLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE;
typedef
void
(MLASCALL MLAS_GEMM_U8S8_COPY_PACKA_ROUTINE)(
uint8_t* D,
const uint8_t* A,
size_t lda,
size_t CountM,
size_t CountK,
int32_t* RowSumVector,
int16_t offb
);
typedef MLAS_GEMM_U8S8_COPY_PACKA_ROUTINE* PMLAS_GEMM_U8S8_COPY_PACKA_ROUTINE;
typedef
void
(MLASCALL MLAS_GEMM_U8S8_COPY_PACKB_ROUTINE)(
int8_t* D,
const int8_t* B,
size_t ldb,
size_t CountN,
size_t CountK,
int32_t* ColumnSumVector,
int16_t offa
);
typedef MLAS_GEMM_U8S8_COPY_PACKB_ROUTINE* PMLAS_GEMM_U8S8_COPY_PACKB_ROUTINE;
typedef
size_t
(MLASCALL MLAS_GEMM_U8S8_KERNEL)(
const uint8_t* A,
const int8_t* B,
int32_t* C,
size_t QuadCountK,
size_t CountM,
size_t CountN,
size_t ldc,
const int32_t* RowSumVector,
const int32_t* ColumnSumVector,
int32_t DepthValue,
bool ZeroMode
);
typedef MLAS_GEMM_U8S8_KERNEL* PMLAS_GEMM_U8S8_KERNEL;
typedef
void
(MLASCALL MLAS_GEMM_U8U8_COPY_PACKA_ROUTINE)(
@ -228,7 +274,7 @@ size_t
const int16_t* A,
const uint8_t* B,
int32_t* C,
size_t PairedCountK,
size_t PairCountK,
size_t CountM,
size_t CountN,
size_t ldc,
@ -364,10 +410,18 @@ extern "C" {
#endif
#if defined(MLAS_TARGET_AMD64_IX86)
MLAS_GEMM_U8S8_COPY_PACKA_ROUTINE MlasGemmU8S8CopyPackASse;
MLAS_GEMM_U8S8_COPY_PACKB_ROUTINE MlasGemmU8S8CopyPackBSse;
MLAS_GEMM_U8S8_KERNEL MlasGemmU8S8KernelSse;
MLAS_GEMM_U8U8_COPY_PACKA_ROUTINE MlasGemmU8U8CopyPackASse;
MLAS_GEMM_U8U8_COPY_PACKB_ROUTINE MlasGemmU8U8CopyPackBSse;
MLAS_GEMM_U8U8_KERNEL MlasGemmU8U8KernelSse;
#if defined(MLAS_TARGET_AMD64)
MLAS_GEMM_U8S8_COPY_PACKA_ROUTINE MlasGemmU8S8CopyPackAAvx2;
MLAS_GEMM_U8S8_COPY_PACKB_ROUTINE MlasGemmU8S8CopyPackBAvx2;
MLAS_GEMM_U8S8_KERNEL MlasGemmU8S8KernelAvx2;
MLAS_GEMM_U8S8_KERNEL MlasGemmU8S8KernelAvx512BW;
MLAS_GEMM_U8S8_KERNEL MlasGemmU8S8KernelAvx512Vnni;
MLAS_GEMM_U8U8_COPY_PACKA_ROUTINE MlasGemmU8U8CopyPackAAvx2;
MLAS_GEMM_U8U8_COPY_PACKB_ROUTINE MlasGemmU8U8CopyPackBAvx2;
MLAS_GEMM_U8U8_KERNEL MlasGemmU8U8KernelAvx2;
@ -490,6 +544,9 @@ struct MLAS_PLATFORM {
#if defined(MLAS_TARGET_AMD64_IX86)
PMLAS_GEMM_FLOAT_KERNEL GemmFloatKernel;
PMLAS_GEMM_U8S8_COPY_PACKA_ROUTINE GemmU8S8CopyPackARoutine;
PMLAS_GEMM_U8S8_COPY_PACKB_ROUTINE GemmU8S8CopyPackBRoutine;
PMLAS_GEMM_U8S8_KERNEL GemmU8S8Kernel;
PMLAS_GEMM_U8U8_COPY_PACKA_ROUTINE GemmU8U8CopyPackARoutine;
PMLAS_GEMM_U8U8_COPY_PACKB_ROUTINE GemmU8U8CopyPackBRoutine;
PMLAS_GEMM_U8U8_KERNEL GemmU8U8Kernel;

View file

@ -85,6 +85,9 @@ Return Value:
//
this->GemmFloatKernel = MlasGemmFloatKernelSse;
this->GemmU8S8CopyPackARoutine = MlasGemmU8S8CopyPackASse;
this->GemmU8S8CopyPackBRoutine = MlasGemmU8S8CopyPackBSse;
this->GemmU8S8Kernel = MlasGemmU8S8KernelSse;
this->GemmU8U8CopyPackARoutine = MlasGemmU8U8CopyPackASse;
this->GemmU8U8CopyPackBRoutine = MlasGemmU8U8CopyPackBSse;
this->GemmU8U8Kernel = MlasGemmU8U8KernelSse;
@ -157,6 +160,9 @@ Return Value:
if (((Cpuid1[2] & 0x1000) != 0) && ((Cpuid7[1] & 0x20) != 0)) {
this->GemmU8S8CopyPackARoutine = MlasGemmU8S8CopyPackAAvx2;
this->GemmU8S8CopyPackBRoutine = MlasGemmU8S8CopyPackBAvx2;
this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx2;
this->GemmU8U8CopyPackARoutine = MlasGemmU8U8CopyPackAAvx2;
this->GemmU8U8CopyPackBRoutine = MlasGemmU8U8CopyPackBAvx2;
this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx2;
@ -180,6 +186,7 @@ Return Value:
if ((Cpuid7[1] & 0x40000000) != 0) {
this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx512BW;
this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx512BW;
//
@ -187,6 +194,8 @@ Return Value:
//
if ((Cpuid7[2] & 0x800) != 0) {
this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx512Vnni;
this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx512Vnni;
}
}

View file

@ -21,12 +21,577 @@ Abstract:
// Define the default strides to step through slices of the input matrices.
//
#define MLAS_GEMM_U8U8_STRIDEM 12
#define MLAS_GEMM_U8U8_STRIDEN 128
#define MLAS_GEMM_U8S8_STRIDEM 24
#define MLAS_GEMM_U8S8_STRIDEN 256
#define MLAS_GEMM_U8S8_STRIDEK 128
#define MLAS_GEMM_U8U8_STRIDEM 24
#define MLAS_GEMM_U8U8_STRIDEN 256
#define MLAS_GEMM_U8U8_STRIDEK 128
#ifdef MLAS_TARGET_AMD64_IX86
//
// U8S8 implementation using SSE2 intrinsics.
//
void
MLASCALL
MlasGemmU8S8CopyPackASse(
uint8_t* D,
const uint8_t* A,
size_t lda,
size_t CountM,
size_t CountK,
int32_t* RowSumVector,
int16_t offb
)
/*++
Routine Description:
This routine copies elements from the source matrix to the destination
packed buffer.
Arguments:
D - Supplies the address of the destination packed buffer.
A - Supplies the address of the source matrix.
lda - Supplies the number of elements per row of the source matrix.
CountM - Supplies the number of rows of the source matrix to copy.
CountK - Supplies the number of columns of the source matrix to copy.
RowSumVector - Supplies the address of the buffer to receive the sums of
the elements from each of the rows. Each sum has also been multiplied
by the zero point offset.
offb - Supplies the zero point offset for the other source matrix of the
matrix multiplication.
Return Value:
None.
--*/
{
const __m128i ZeroVector = _mm_setzero_si128();
const __m128i OffsetBroadcast = _mm_set1_epi16(offb);
uint8_t PaddedMatrixAData[8] = { 0 };
//
// Process a single row of matrix A in a loop.
//
while (CountM > 0) {
const uint8_t* a = A;
size_t k = CountK;
__m128i RowSum = ZeroVector;
//
// Copy the source bytes to the packed buffer.
//
// The packed buffer has the same data ordering as the source bytes,
// but CountK is aligned up to a multiple of 4 to maintain 32-bit
// alignment. All extra bytes are zero-padded.
//
// These values are also zero-extended and accumulated into an
// intermediate per-row accumulator. CountK cannot be greater than 128
// to avoid overflowing these signed 16-bit accumulators.
//
while (k >= 8) {
__m128i Bytes = _mm_loadl_epi64((__m128i*)&a[0]);
_mm_storel_epi64((__m128i*)&D[0], Bytes);
RowSum = _mm_add_epi16(RowSum, _mm_unpacklo_epi8(Bytes, ZeroVector));
D += 8;
a += 8;
k -= 8;
}
if (k > 0) {
//
// Copy the remaining bytes to the zero padded stack buffer.
//
uint8_t* padded = PaddedMatrixAData;
uint8_t* padded_end = padded + k;
do {
padded[0] = a[0];
padded++;
a++;
} while (padded < padded_end);
__m128i Bytes = _mm_loadl_epi64((__m128i*)PaddedMatrixAData);
_mm_storel_epi64((__m128i*)&D[0], Bytes);
RowSum = _mm_add_epi16(RowSum, _mm_unpacklo_epi8(Bytes, ZeroVector));
//
// Copy quads of 8-bit values from the vector to the packed
// buffer and rotate the vector for the next iteration.
//
for (size_t quads = (k + 3) / 4; quads > 0; quads--) {
*((int32_t*)D) = _mm_cvtsi128_si32(Bytes);
D += 4;
Bytes = _mm_shuffle_epi32(Bytes, _MM_SHUFFLE(0, 3, 2, 1));
}
}
//
// Reduce the sum for the single row of output and multiply by the
// zero point offset of the other source matrix.
//
RowSum = _mm_madd_epi16(RowSum, OffsetBroadcast);
RowSum = _mm_add_epi32(RowSum, _mm_shuffle_epi32(RowSum, _MM_SHUFFLE(3, 2, 3, 2)));
RowSum = _mm_add_epi32(RowSum, _mm_shuffle_epi32(RowSum, _MM_SHUFFLE(0, 1, 0, 1)));
*RowSumVector++ = _mm_cvtsi128_si32(RowSum);
A += lda;
CountM -= 1;
}
}
void
MLASCALL
MlasGemmU8S8CopyPackBSse(
int8_t* D,
const int8_t* B,
size_t ldb,
size_t CountN,
size_t CountK,
int32_t* ColumnSumVector,
int16_t offa
)
/*++
Routine Description:
This routine copies elements from the source matrix to the destination
packed buffer.
Arguments:
D - Supplies the address of the destination packed buffer.
B - Supplies the address of the source matrix.
ldb - Supplies the number of elements per row of the source matrix.
CountN - Supplies the number of columns of the source matrix to copy.
CountK - Supplies the number of rows of the source matrix to copy.
ColumnSumVector - Supplies the address of the buffer to receive the sums of
the elements from each of the columns. Each sum has also been multiplied
by the zero point offset.
offa - Supplies the zero point offset for the other source matrix of the
matrix multiplication.
Return Value:
None.
--*/
{
const __m128i ZeroVector = _mm_setzero_si128();
const __m128i OffsetBroadcast = _mm_set1_epi16(offa);
int8_t PaddedMatrixBData[16] = { 0 };
//
// Process 8 columns of matrix B in a loop.
//
while (CountN >= 8) {
const int8_t* b = B;
size_t k = CountK;
__m128i ColumnSum0 = ZeroVector;
__m128i ColumnSum1 = ZeroVector;
//
// Interleave 2 rows of matrix B and write to the packed buffer.
//
// These values are also sign-extended and accumulated into an
// intermediate per-column accumulator. CountK cannot be greater than
// 128 to avoid overflowing these signed 16-bit accumulators.
//
while (k >= 2) {
__m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&b[0]);
__m128i BytesRow1 = _mm_loadl_epi64((__m128i*)&b[ldb]);
__m128i BytesInterleaved = _mm_unpacklo_epi8(BytesRow0, BytesRow1);
_mm_storeu_si128((__m128i*)&D[0], BytesInterleaved);
__m128i WordsLow = _mm_srai_epi16(_mm_unpacklo_epi8(ZeroVector, BytesInterleaved), 8);
ColumnSum0 = _mm_add_epi16(ColumnSum0, WordsLow);
__m128i WordsHigh = _mm_srai_epi16(_mm_unpackhi_epi8(ZeroVector, BytesInterleaved), 8);
ColumnSum1 = _mm_add_epi16(ColumnSum1, WordsHigh);
b += ldb * 2;
D += 16;
k -= 2;
}
if (k > 0) {
//
// Process the remaining row of matrix B.
//
__m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&b[0]);
__m128i BytesInterleaved = _mm_unpacklo_epi8(BytesRow0, ZeroVector);
_mm_storeu_si128((__m128i*)&D[0], BytesInterleaved);
__m128i WordsLow = _mm_srai_epi16(_mm_unpacklo_epi8(ZeroVector, BytesInterleaved), 8);
ColumnSum0 = _mm_add_epi16(ColumnSum0, WordsLow);
__m128i WordsHigh = _mm_srai_epi16(_mm_unpackhi_epi8(ZeroVector, BytesInterleaved), 8);
ColumnSum1 = _mm_add_epi16(ColumnSum1, WordsHigh);
D += 16;
}
//
// The number of rows written to the packed buffer should be a multiple
// of 4. Zero pad the packed buffer if the block is not complete.
//
if (((CountK - 1) & 2) == 0) {
_mm_storeu_si128((__m128i*)&D[0], ZeroVector);
D += 16;
}
ColumnSum0 = _mm_madd_epi16(ColumnSum0, OffsetBroadcast);
ColumnSum1 = _mm_madd_epi16(ColumnSum1, OffsetBroadcast);
_mm_storeu_si128((__m128i*)&ColumnSumVector[0], ColumnSum0);
_mm_storeu_si128((__m128i*)&ColumnSumVector[4], ColumnSum1);
ColumnSumVector += 8;
B += 8;
CountN -= 8;
}
//
// Process the remaining columns of matrix B.
//
if (CountN > 0) {
const int8_t* b = B;
size_t k = CountK;
__m128i ColumnSum0 = ZeroVector;
__m128i ColumnSum1 = ZeroVector;
//
// Interleave 2 rows of matrix B and write to the packed buffer.
//
// These values are also sign-extended and accumulated into an
// intermediate per-column accumulator. CountK cannot be greater than
// 128 to avoid overflowing these signed 16-bit accumulators.
//
while (k >= 2) {
//
// Copy the remaining columns to the zero padded stack buffer.
//
const int8_t* bcopy = b;
int8_t* padded = PaddedMatrixBData;
int8_t* padded_end = padded + CountN;
do {
padded[0] = bcopy[0];
padded[8] = bcopy[ldb];
padded++;
bcopy++;
} while (padded < padded_end);
__m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&PaddedMatrixBData[0]);
__m128i BytesRow1 = _mm_loadl_epi64((__m128i*)&PaddedMatrixBData[8]);
__m128i BytesInterleaved = _mm_unpacklo_epi8(BytesRow0, BytesRow1);
_mm_storeu_si128((__m128i*)&D[0], BytesInterleaved);
__m128i WordsLow = _mm_srai_epi16(_mm_unpacklo_epi8(ZeroVector, BytesInterleaved), 8);
ColumnSum0 = _mm_add_epi16(ColumnSum0, WordsLow);
__m128i WordsHigh = _mm_srai_epi16(_mm_unpackhi_epi8(ZeroVector, BytesInterleaved), 8);
ColumnSum1 = _mm_add_epi16(ColumnSum1, WordsHigh);
b += ldb * 2;
D += 16;
k -= 2;
}
if (k > 0) {
//
// Copy the remaining columns to the zero padded stack buffer.
//
const int8_t* bcopy = b;
int8_t* padded = PaddedMatrixBData;
int8_t* padded_end = padded + CountN;
do {
padded[0] = bcopy[0];
padded++;
bcopy++;
} while (padded < padded_end);
__m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&PaddedMatrixBData[0]);
__m128i BytesInterleaved = _mm_unpacklo_epi8(BytesRow0, ZeroVector);
_mm_storeu_si128((__m128i*)&D[0], BytesInterleaved);
__m128i WordsLow = _mm_srai_epi16(_mm_unpacklo_epi8(ZeroVector, BytesInterleaved), 8);
ColumnSum0 = _mm_add_epi16(ColumnSum0, WordsLow);
__m128i WordsHigh = _mm_srai_epi16(_mm_unpackhi_epi8(ZeroVector, BytesInterleaved), 8);
ColumnSum1 = _mm_add_epi16(ColumnSum1, WordsHigh);
D += 16;
}
//
// The number of rows written to the packed buffer should be a multiple
// of 4. Zero pad the packed buffer if the block is not complete.
//
if (((CountK - 1) & 2) == 0) {
_mm_storeu_si128((__m128i*)&D[0], ZeroVector);
D += 16;
}
ColumnSum0 = _mm_madd_epi16(ColumnSum0, OffsetBroadcast);
ColumnSum1 = _mm_madd_epi16(ColumnSum1, OffsetBroadcast);
_mm_storeu_si128((__m128i*)&ColumnSumVector[0], ColumnSum0);
_mm_storeu_si128((__m128i*)&ColumnSumVector[4], ColumnSum1);
}
}
size_t
MLASCALL
MlasGemmU8S8KernelSse(
const uint8_t* A,
const int8_t* B,
int32_t* C,
size_t PairCountK,
size_t CountM,
size_t CountN,
size_t ldc,
const int32_t* RowSumVector,
const int32_t* ColumnSumVector,
int32_t DepthValue,
bool ZeroMode
)
/*++
Routine Description:
This routine is an inner kernel to compute matrix multiplication for a
set of rows.
Arguments:
A - Supplies the address of matrix A. The matrix data has been packed
using MlasGemmU8S8CopyPackASse.
B - Supplies the address of matrix B. The matrix data has been packed
using MlasGemmU8S8CopyPackBSse.
C - Supplies the address of matrix C.
PairCountK - Supplies the number of paired columns from matrix A and the
number of paired rows from matrix B to iterate over.
CountM - Supplies the maximum number of rows that can be processed for
matrix A and matrix C. The actual number of rows handled for this
invocation depends on the kernel implementation.
CountN - Supplies the number of columns from matrix B and matrix C to iterate
over.
ldc - Supplies the first dimension of matrix C.
RowSumVector - Supplies the sum of each row from matrix A multiplied by the
zero point offset of matrix B. These values are accumulated into every
row of matrix C.
ColumnSumVector - Supplies the sum of each column from matrix B multiplied
by the zero point offset of matrix A. These values are accumulated into
every column of matrix C.
DepthValue - Supplies the value CountK multiplied by the zero point offset
of matrixA multplied by the zero point offset of matrix B. This value is
accumulated into every element of matrix C.
ZeroMode - Supplies true if the output matrix must be zero initialized,
else false if the output matrix is accumulated into.
Return Value:
Returns the number of rows handled.
--*/
{
const __m128i ZeroVector = _mm_setzero_si128();
MLAS_UNREFERENCED_PARAMETER(CountM);
MLAS_UNREFERENCED_PARAMETER(ldc);
while (CountN > 0) {
//
// Initialize the accumulators with the sum of the global depth value
// constant, the column sums, and the row sums.
//
__m128i Accumulator0 = _mm_set1_epi32(DepthValue);
Accumulator0 = _mm_add_epi32(Accumulator0, _mm_set1_epi32(RowSumVector[0]));
__m128i Accumulator1 = Accumulator0;
Accumulator0 = _mm_add_epi32(Accumulator0, _mm_loadu_si128((__m128i*)&ColumnSumVector[0]));
Accumulator1 = _mm_add_epi32(Accumulator1, _mm_loadu_si128((__m128i*)&ColumnSumVector[4]));
ColumnSumVector += 8;
//
// Broadcast each pair of 16-bit values from the matrix A and multiply
// with the zero-extended pair of 16-bit values from matrix B, and add
// the 32-bit intermediate into the accumulator registers.
//
const uint8_t* a = A;
size_t k = PairCountK;
while (k > 0) {
__m128i AElements = _mm_unpacklo_epi8(_mm_cvtsi32_si128(*((int32_t*)a)), ZeroVector);
__m128i BElements;
__m128i Intermediate0;
__m128i Intermediate1;
BElements = _mm_loadu_si128((__m128i*)&B[0]);
Intermediate0 = _mm_srai_epi16(_mm_unpacklo_epi8(ZeroVector, BElements), 8);
Intermediate1 = _mm_srai_epi16(_mm_unpackhi_epi8(ZeroVector, BElements), 8);
__m128i AElements0 = _mm_shuffle_epi32(AElements, _MM_SHUFFLE(0, 0, 0, 0));
Intermediate0 = _mm_madd_epi16(Intermediate0, AElements0);
Intermediate1 = _mm_madd_epi16(Intermediate1, AElements0);
Accumulator0 = _mm_add_epi32(Accumulator0, Intermediate0);
Accumulator1 = _mm_add_epi32(Accumulator1, Intermediate1);
BElements = _mm_loadu_si128((__m128i*)&B[16]);
Intermediate0 = _mm_srai_epi16(_mm_unpacklo_epi8(ZeroVector, BElements), 8);
Intermediate1 = _mm_srai_epi16(_mm_unpackhi_epi8(ZeroVector, BElements), 8);
__m128i AElements1 = _mm_shuffle_epi32(AElements, _MM_SHUFFLE(1, 1, 1, 1));
Intermediate0 = _mm_madd_epi16(Intermediate0, AElements1);
Intermediate1 = _mm_madd_epi16(Intermediate1, AElements1);
Accumulator0 = _mm_add_epi32(Accumulator0, Intermediate0);
Accumulator1 = _mm_add_epi32(Accumulator1, Intermediate1);
a += 4;
B += 32;
k -= 1;
}
//
// Output the accumulator block after optionally accumulating the values
// from matrix C.
//
if (CountN >= 8) {
if (!ZeroMode) {
Accumulator0 = _mm_add_epi32(Accumulator0, _mm_loadu_si128((__m128i*)&C[0]));
Accumulator1 = _mm_add_epi32(Accumulator1, _mm_loadu_si128((__m128i*)&C[4]));
}
_mm_storeu_si128((__m128i*)&C[0], Accumulator0);
_mm_storeu_si128((__m128i*)&C[4], Accumulator1);
C += 8;
CountN -= 8;
} else {
//
// Output the remaining partial output block.
//
if ((CountN & 4) != 0) {
if (!ZeroMode) {
Accumulator0 = _mm_add_epi32(Accumulator0, _mm_loadu_si128((__m128i*)&C[0]));
}
_mm_storeu_si128((__m128i*)&C[0], Accumulator0);
C += 4;
Accumulator0 = Accumulator1;
}
if ((CountN & 2) != 0) {
if (!ZeroMode) {
Accumulator0 = _mm_add_epi32(Accumulator0, _mm_loadl_epi64((__m128i*)&C[0]));
}
_mm_storel_epi64((__m128i*)&C[0], Accumulator0);
C += 2;
Accumulator0 = _mm_shuffle_epi32(Accumulator0, _MM_SHUFFLE(1, 0, 3, 2));
}
if ((CountN & 1) != 0) {
int32_t AccumulatorValue = _mm_cvtsi128_si32(Accumulator0);
if (!ZeroMode) {
AccumulatorValue += C[0];
}
C[0] = AccumulatorValue;
}
CountN = 0;
}
}
return 1;
}
//
// U8U8 implementation using SSE2 intrinsics.
//
void
MLASCALL
MlasGemmU8U8CopyPackASse(
@ -86,13 +651,15 @@ Return Value:
//
// Zero extend the source bytes to 16-bits and write to the packed
// buffer. The packed buffer has the same data ordering as the source
// bytes, but the stride is CountK aligned up to a multiple of 8
// values.
// buffer.
//
// The packed buffer has the same data ordering as the source bytes,
// but CountK is aligned up to a multiple of 2 to maintain 32-bit
// alignment. All extra bytes are zero-padded.
//
// These 16-bit values are also accumulated into an intermediate per-row
// accumulator. CountK cannot be greater than 256 to avoid overflowing
// these 16-bit accumulators.
// accumulator. CountK cannot be greater than 128 to avoid overflowing
// these signed 16-bit accumulators.
//
while (k >= 8) {
@ -130,8 +697,8 @@ Return Value:
RowSum = _mm_add_epi16(RowSum, Words);
//
// Copy the 16-bit pairs from the vector to the destination packed
// buffer. Rotate the vector at each iteration.
// Copy pairs of 16-bit values from the vector to the packed
// buffer and rotate the vector for the next iteration.
//
for (size_t pairs = (k + 1) / 2; pairs > 0; pairs--) {
@ -142,7 +709,8 @@ Return Value:
}
//
// Reduce the sum for the single row of output.
// Reduce the sum for the single row of output and multiply by the
// zero point offset of the other source matrix.
//
RowSum = _mm_madd_epi16(RowSum, OffsetBroadcast);
@ -176,13 +744,13 @@ Routine Description:
Arguments:
D (rcx) - Supplies the address of the destination packed buffer.
D - Supplies the address of the destination packed buffer.
B (rdx) - Supplies the address of the source matrix.
B - Supplies the address of the source matrix.
ldb (r8) - Supplies the number of elements per row of the source matrix.
ldb - Supplies the number of elements per row of the source matrix.
CountN (r9) - Supplies the number of columns of the source matrix to copy.
CountN - Supplies the number of columns of the source matrix to copy.
CountK - Supplies the number of rows of the source matrix to copy.
@ -214,6 +782,14 @@ Return Value:
__m128i ColumnSum0 = ZeroVector;
__m128i ColumnSum1 = ZeroVector;
//
// Interleave 2 rows of matrix B and write to the packed buffer.
//
// These values are also zero-extended and accumulated into an
// intermediate per-column accumulator. CountK cannot be greater than
// 128 to avoid overflowing these signed 16-bit accumulators.
//
while (k >= 2) {
__m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&b[0]);
@ -232,6 +808,10 @@ Return Value:
if (k > 0) {
//
// Process the remaining row of matrix B.
//
__m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&b[0]);
__m128i BytesInterleaved = _mm_unpacklo_epi8(BytesRow0, ZeroVector);
@ -240,7 +820,6 @@ Return Value:
ColumnSum0 = _mm_add_epi16(ColumnSum0, _mm_unpacklo_epi8(BytesInterleaved, ZeroVector));
ColumnSum1 = _mm_add_epi16(ColumnSum1, _mm_unpackhi_epi8(BytesInterleaved, ZeroVector));
b += ldb * 2;
D += 16;
}
@ -323,6 +902,11 @@ Return Value:
ColumnSum1 = _mm_add_epi16(ColumnSum1, _mm_unpackhi_epi8(BytesInterleaved, ZeroVector));
}
//
// Reduce the sum for the packed columns and multiply by the zero point
// offset of the other source matrix.
//
ColumnSum0 = _mm_madd_epi16(ColumnSum0, OffsetBroadcast);
ColumnSum1 = _mm_madd_epi16(ColumnSum1, OffsetBroadcast);
@ -337,7 +921,7 @@ MlasGemmU8U8KernelSse(
const int16_t* A,
const uint8_t* B,
int32_t* C,
size_t PairedCountK,
size_t PairCountK,
size_t CountM,
size_t CountN,
size_t ldc,
@ -363,8 +947,8 @@ Arguments:
C - Supplies the address of matrix C.
PairedCountK - Supplies the number of paired columns from matrix A and
the number of paired rows from matrix B to iterate over.
PairCountK - Supplies the number of paired columns from matrix A and the
number of paired rows from matrix B to iterate over.
CountM - Supplies the maximum number of rows that can be processed for
matrix A and matrix C. The actual number of rows handled for this
@ -422,18 +1006,18 @@ Return Value:
//
const int16_t* a = A;
size_t k = PairedCountK;
size_t k = PairCountK;
while (k > 0) {
__m128i AElements0 = _mm_set1_epi32(*((int32_t*)a));
__m128i AElements = _mm_set1_epi32(*((int32_t*)a));
__m128i BElements0 = _mm_loadu_si128((__m128i*)&B[0]);
__m128i Intermediate0 = _mm_unpacklo_epi8(BElements0, ZeroVector);
__m128i Intermediate1 = _mm_unpackhi_epi8(BElements0, ZeroVector);
Intermediate0 = _mm_madd_epi16(Intermediate0, AElements0);
Intermediate1 = _mm_madd_epi16(Intermediate1, AElements0);
Intermediate0 = _mm_madd_epi16(Intermediate0, AElements);
Intermediate1 = _mm_madd_epi16(Intermediate1, AElements);
Accumulator0 = _mm_add_epi32(Accumulator0, Intermediate0);
Accumulator1 = _mm_add_epi32(Accumulator1, Intermediate1);
@ -502,7 +1086,7 @@ Return Value:
C[0] = AccumulatorValue;
}
break;
CountN = 0;
}
}
@ -511,7 +1095,94 @@ Return Value:
void
MLASCALL
MlasQgemm(
MlasGemm(
size_t M,
size_t N,
size_t K,
const uint8_t* A,
size_t lda,
uint8_t offa,
const int8_t* B,
size_t ldb,
int8_t offb,
int32_t* C,
size_t ldc,
MLAS_THREADPOOL* ThreadPool
)
{
MLAS_DECLSPEC_ALIGN(uint8_t PanelA[MLAS_GEMM_U8S8_STRIDEM * MLAS_GEMM_U8S8_STRIDEK], 64);
MLAS_DECLSPEC_ALIGN(int8_t PanelB[MLAS_GEMM_U8S8_STRIDEN * MLAS_GEMM_U8S8_STRIDEK], 64);
MLAS_DECLSPEC_ALIGN(int32_t RowSumVector[MLAS_GEMM_U8S8_STRIDEM], 16);
MLAS_DECLSPEC_ALIGN(int32_t ColumnSumVector[MLAS_GEMM_U8S8_STRIDEN], 16);
size_t StrideM = MLAS_GEMM_U8S8_STRIDEM;
size_t StrideN = MLAS_GEMM_U8S8_STRIDEN;
size_t StrideK = MLAS_GEMM_U8S8_STRIDEK;
MLAS_UNREFERENCED_PARAMETER(ThreadPool);
size_t CountK;
for (size_t k = 0; k < K; k += CountK) {
CountK = StrideK;
if (CountK > (K - k)) {
CountK = K - k;
}
size_t CountN;
for (size_t n = 0; n < N; n += CountN) {
CountN = StrideN;
if (CountN > (N - n)) {
CountN = N - n;
}
MlasPlatform.GemmU8S8CopyPackBRoutine(PanelB, B + n + k * ldb, ldb, CountN, CountK, ColumnSumVector, -int16_t(offa));
size_t CountM;
for (size_t m = 0; m < M; m += CountM) {
CountM = StrideM;
if (CountM > (M - m)) {
CountM = M - m;
}
MlasPlatform.GemmU8S8CopyPackARoutine(PanelA, A + k + m * lda, lda, CountM, CountK, RowSumVector, -int16_t(offb));
uint8_t* pa = PanelA;
int32_t* c = C + n + m * ldc;
int32_t* RowSums = RowSumVector;
size_t RowsRemaining = CountM;
size_t RowsHandled;
size_t QuadCountK = (CountK + 3) / 4;
while (RowsRemaining > 0) {
RowsHandled = MlasPlatform.GemmU8S8Kernel(pa, PanelB, c, QuadCountK, RowsRemaining, CountN, ldc, RowSums, ColumnSumVector, int32_t(CountK) * offa * offb, k == 0);
RowsRemaining -= RowsHandled;
c += ldc * RowsHandled;
pa += 4 * QuadCountK * RowsHandled;
RowSums += RowsHandled;
}
}
}
}
}
void
MLASCALL
MlasGemm(
size_t M,
size_t N,
size_t K,
@ -580,15 +1251,15 @@ MlasQgemm(
size_t RowsRemaining = CountM;
size_t RowsHandled;
size_t PairedCountK = (CountK + 1) / 2;
size_t PairCountK = (CountK + 1) / 2;
while (RowsRemaining > 0) {
RowsHandled = MlasPlatform.GemmU8U8Kernel(pa, PanelB, c, PairedCountK, RowsRemaining, CountN, ldc, RowSums, ColumnSumVector, int32_t(CountK) * offa * offb, k == 0);
RowsHandled = MlasPlatform.GemmU8U8Kernel(pa, PanelB, c, PairCountK, RowsRemaining, CountN, ldc, RowSums, ColumnSumVector, int32_t(CountK) * offa * offb, k == 0);
RowsRemaining -= RowsHandled;
c += ldc * RowsHandled;
pa += 2 * PairedCountK * RowsHandled;
pa += 2 * PairCountK * RowsHandled;
RowSums += RowsHandled;
}
}

View file

@ -140,7 +140,7 @@ Macro Description:
This macro builds a VNNI instruction of the form:
instr zmm1,zmm2,DWORD PTR [BaseReg+IndexReg*Scale]{1to16}
instr zmm1,zmm2,DWORD PTR [BaseReg+IndexReg*Scale+ByteOffset]{1to16}
Arguments:
@ -152,13 +152,16 @@ Arguments:
BaseReg - Specifies the base register of the broadcast operand.
ByteOffset - Specifies the DWORD aligned byte offset for the broadcast
operand.
IndexReg - Specifies the optional index register of the broadcast operand.
Scale - Specifies the scaling factor of the optional index register.
--*/
.macro VnniZmmZmmBroadcast Opcode, DestReg, Src1Reg, BaseReg, IndexReg, Scale
.macro VnniZmmZmmBroadcast Opcode, DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
.set Payload0, 0x02 # "0F 38" prefix
.set Payload0, Payload0 + ((((.LZmmIndex_\DestReg\() >> 3) & 1) ^ 1) << 7)
@ -183,6 +186,9 @@ Arguments:
.else
.set ModRMByte, ModRMByte + (.LGprIndex_\BaseReg\() & 7)
.endif
.if \ByteOffset\() != 0
.set ModRMByte, ModRMByte + 0x40 # indicate disp8 byte offset
.endif
.ifnes "\IndexReg\()", ""
.set SibByte, 0
@ -205,34 +211,36 @@ Arguments:
.set SibByte, SibByte + (.LGprIndex_\BaseReg\() & 7)
.endif
.ifnes "\IndexReg\()", ""
.byte 0x62, Payload0, Payload1, Payload2, \Opcode\(), ModRMByte, SibByte
.else
.byte 0x62, Payload0, Payload1, Payload2, \Opcode\(), ModRMByte
.ifnes "\IndexReg\()", ""
.byte SibByte
.endif
.if \ByteOffset\() != 0
.byte (\ByteOffset\() >> 2)
.endif
.endm
.macro VpdpbusdZmmZmmBroadcast DestReg, Src1Reg, BaseReg, IndexReg, Scale
.macro VpdpbusdZmmZmmBroadcast DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
VnniZmmZmmBroadcast 0x50, \DestReg\(), \Src1Reg\(), \BaseReg\(), \IndexReg\(), \Scale\()
VnniZmmZmmBroadcast 0x50, \DestReg\(), \Src1Reg\(), \BaseReg\(), \ByteOffset\(), \IndexReg\(), \Scale\()
.endm
.macro VpdpbusdsZmmZmmBroadcast DestReg, Src1Reg, BaseReg, IndexReg, Scale
.macro VpdpbusdsZmmZmmBroadcast DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
VnniZmmZmmBroadcast 0x51, \DestReg\(), \Src1Reg\(), \BaseReg\(), \IndexReg\(), \Scale\()
VnniZmmZmmBroadcast 0x51, \DestReg\(), \Src1Reg\(), \BaseReg\(), \ByteOffset\(), \IndexReg\(), \Scale\()
.endm
.macro VpdpwssdZmmZmmBroadcast DestReg, Src1Reg, BaseReg, IndexReg, Scale
.macro VpdpwssdZmmZmmBroadcast DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
VnniZmmZmmBroadcast 0x52, \DestReg\(), \Src1Reg\(), \BaseReg\(), \IndexReg\(), \Scale\()
VnniZmmZmmBroadcast 0x52, \DestReg\(), \Src1Reg\(), \BaseReg\(), \ByteOffset\(), \IndexReg\(), \Scale\()
.endm
.macro VpdpwssdsZmmZmmBroadcast DestReg, Src1Reg, BaseReg, IndexReg, Scale
.macro VpdpwssdsZmmZmmBroadcast DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
VnniZmmZmmBroadcast 0x53, \DestReg\(), \Src1Reg\(), \BaseReg\(), \IndexReg\(), \Scale\()
VnniZmmZmmBroadcast 0x53, \DestReg\(), \Src1Reg\(), \BaseReg\(), \ByteOffset\(), \IndexReg\(), \Scale\()
.endm

View file

@ -0,0 +1,955 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
QgemmU8S8KernelAvx2.s
Abstract:
This module implements the kernels for the quantized integer matrix/matrix
multiply operation (QGEMM).
This implementation uses AVX2 instructions.
--*/
#include "asmmacro.h"
#include "QgemmU8X8KernelAvx2Common.h"
.intel_syntax noprefix
//
// Stack frame layout for the U8S8 CopyPackA routine.
//
.equ .LGemmU8S8CopyPackAFrame_PaddedMatrixAData, -72
.equ .LGemmU8S8CopyPackAFrame_mask, -8
.equ .LGemmU8S8CopyPackAFrame_SavedR13, 0
.equ .LGemmU8S8CopyPackAFrame_SavedR12, 8
.equ .LGemmU8S8CopyPackAFrame_SavedRbx, 16
.equ .LGemmU8S8CopyPackAFrame_SavedRbp, 24
.equ .LGemmU8S8CopyPackAFrame_ReturnAddress, 32
.equ .LGemmU8S8CopyPackAFrame_offb, 40
//
// Stack frame layout for the U8S8 CopyPackB routine.
//
.equ .LGemmU8S8CopyPackBFrame_PaddedMatrixBData, -72
.equ .LGemmU8S8CopyPackBFrame_Padding, -8
.equ .LGemmU8S8CopyPackBFrame_SavedRbx, 0
.equ .LGemmU8S8CopyPackBFrame_SavedRbp, 8
.equ .LGemmU8S8CopyPackBFrame_ReturnAddress, 16
.equ .LGemmU8S8CopyPackBFrame_offa, 24
.text
/*++
Routine Description:
This routine copies elements from the source matrix to the destination
packed buffer.
Arguments:
D (rdi) - Supplies the address of the destination packed buffer.
A (rsi) - Supplies the address of the source matrix.
lda (rdx) - Supplies the number of elements per row of the source matrix.
CountM (rcx) - Supplies the number of rows of the source matrix to copy.
CountK (r8) - Supplies the number of columns of the source matrix to copy.
RowSumVector (r9) - Supplies the address of the buffer to receive the sums
of the elements from each of the rows. Each sum has also been multiplied
by the zero point offset.
offb - Supplies the zero point offset for the other source matrix of the
matrix multiplication.
Return Value:
None.
--*/
.globl C_UNDERSCORE(MlasGemmU8S8CopyPackAAvx2)
C_UNDERSCORE(MlasGemmU8S8CopyPackAAvx2):
push rbp
push rbx
push r12
push r13
mov r10,rdx
mov r11,rcx
lea r12,[r8+3]
and r12,NOT 3 # align CountK up to quad count
vpbroadcastw xmm8,WORD PTR .LGemmU8S8CopyPackAFrame_offb[rsp]
vpcmpeqw ymm9,ymm9,ymm9 # generate word vector [0xFFFF]
vpsrlw ymm9,ymm9,15 # generate word vector [0x0001]
vpsllw ymm0,ymm9,8 # generate word vector [0x0100]
vpor ymm9,ymm9,ymm0 # generate word vector [0x0101]
//
// Compute the conditional load/store mask for an unaligned CountK.
//
mov eax,r8d
and eax,15 # isolate unaligned count
add eax,3
shr eax,2 # align unaligned count to quad count
mov DWORD PTR .LGemmU8S8CopyPackAFrame_mask[rsp],eax
vpbroadcastd xmm10,DWORD PTR .LGemmU8S8CopyPackAFrame_mask[rsp]
vpcmpgtd xmm10,xmm10,XMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip]
//
// Zero initialize the padded stack buffers.
//
vpxor xmm0,xmm0,xmm0
vmovdqu YMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp],ymm0
vmovdqu YMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp+32],ymm0
//
// Process 4 rows of matrix A in a loop.
//
sub r11,4
jb .LCopyPackA.ProcessRemainingRows
.LCopyPackA.ProcessNextRowM4:
vpxor xmm0,xmm0,xmm0 # clear row accumulators
vpxor xmm1,xmm1,xmm1
vpxor xmm2,xmm2,xmm2
vpxor xmm3,xmm3,xmm3
mov rdx,rsi
mov rcx,rdi
lea rsi,[rsi+r10*4] # advance next matrix A by 4 rows
lea rdi,[rdi+r12*4] # advance next matrix D by 4 rows
mov rbx,r8 # reload columns remaining
sub rbx,32
jb .LCopyPackA.ProcessRemainingColumnsM4
.LCopyPackA.ProcessNextColumnLoopM4:
lea rax,[rdx+r10*2] # compute matrix A plus 2 rows
vmovdqu ymm4,YMMWORD PTR [rdx]
vmovdqu ymm5,YMMWORD PTR [rdx+r10]
vmovdqu ymm6,YMMWORD PTR [rax]
vmovdqu ymm7,YMMWORD PTR [rax+r10]
lea rax,[rcx+r12*2] # compute matrix D plus 2 rows
vmovdqu YMMWORD PTR [rcx],ymm4
vmovdqu YMMWORD PTR [rcx+r12],ymm5
vmovdqu YMMWORD PTR [rax],ymm6
vmovdqu YMMWORD PTR [rax+r12],ymm7
vpmaddubsw ymm4,ymm4,ymm9 # horizontal byte+byte=word per row
vpaddw ymm0,ymm0,ymm4 # add words to row accumulators
vpmaddubsw ymm5,ymm5,ymm9
vpaddw ymm1,ymm1,ymm5
vpmaddubsw ymm6,ymm6,ymm9
vpaddw ymm2,ymm2,ymm6
vpmaddubsw ymm7,ymm7,ymm9
vpaddw ymm3,ymm3,ymm7
add rdx,32 # advance matrix A by 32 bytes
add rcx,32 # advance matrix D by 32 bytes
sub rbx,32 # subtract columns remaining
jae .LCopyPackA.ProcessNextColumnLoopM4
.LCopyPackA.ProcessRemainingColumnsM4:
add rbx,32 # correct for over-subtract above
jz .LCopyPackA.ReduceRowSumVectorM4
test bl,16 # (CountK & 16) != 0?
jz .LCopyPackA.CopyRemainingCountKLessThan16M4
lea rax,[rdx+r10*2] # compute matrix A plus 2 rows
vmovdqu xmm4,XMMWORD PTR [rdx]
vmovdqu xmm5,XMMWORD PTR [rdx+r10]
vmovdqu xmm6,XMMWORD PTR [rax]
vmovdqu xmm7,XMMWORD PTR [rax+r10]
lea rax,[rcx+r12*2] # compute matrix D plus 2 rows
vmovdqu XMMWORD PTR [rcx],xmm4
vmovdqu XMMWORD PTR [rcx+r12],xmm5
vmovdqu XMMWORD PTR [rax],xmm6
vmovdqu XMMWORD PTR [rax+r12],xmm7
vpmaddubsw xmm4,xmm4,xmm9 # horizontal byte+byte=word per row
vpaddw ymm0,ymm0,ymm4 # add words to row accumulators
vpmaddubsw xmm5,xmm5,xmm9
vpaddw ymm1,ymm1,ymm5
vpmaddubsw xmm6,xmm6,xmm9
vpaddw ymm2,ymm2,ymm6
vpmaddubsw xmm7,xmm7,xmm9
vpaddw ymm3,ymm3,ymm7
add rdx,16 # advance matrix A by 16 bytes
add rcx,16 # advance matrix D by 16 bytes
test bl,15 # test for unaligned columns
jz .LCopyPackA.ReduceRowSumVectorM4
//
// Copy the unaligned CountK columns to a zero padded stack buffer.
//
.LCopyPackA.CopyRemainingCountKLessThan16M4:
lea rbp,.LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp]
test bl,8 # (CountK & 8) != 0?
jz .LCopyPackA.CopyRemainingCountKLessThan8M4
lea r13,[rdx+r10*2] # compute matrix A plus 2 rows
mov rax,QWORD PTR [rdx]
mov QWORD PTR [rbp],rax
mov rax,QWORD PTR [rdx+r10]
mov QWORD PTR [rbp+16],rax
mov rax,QWORD PTR [r13]
mov QWORD PTR [rbp+32],rax
mov rax,QWORD PTR [r13+r10]
mov QWORD PTR [rbp+48],rax
add rdx,8
add rbp,8 # advance padded buffer destination
.LCopyPackA.CopyRemainingCountKLessThan8M4:
test bl,4 # (CountK & 4) != 0?
jz .LCopyPackA.CopyRemainingCountKLessThan4M4
lea r13,[rdx+r10*2] # compute matrix A plus 2 rows
mov eax,DWORD PTR [rdx]
mov DWORD PTR [rbp],eax
mov eax,DWORD PTR [rdx+r10]
mov DWORD PTR [rbp+16],eax
mov eax,DWORD PTR [r13]
mov DWORD PTR [rbp+32],eax
mov eax,DWORD PTR [r13+r10]
mov DWORD PTR [rbp+48],eax
add rdx,4
add rbp,4 # advance padded buffer destination
.LCopyPackA.CopyRemainingCountKLessThan4M4:
test bl,2 # (CountK & 2) != 0?
jz .LCopyPackA.CopyRemainingCountKLessThan2M4
lea r13,[rdx+r10*2] # compute matrix A plus 2 rows
movzx eax,WORD PTR [rdx]
mov WORD PTR [rbp],ax
movzx eax,WORD PTR [rdx+r10]
mov WORD PTR [rbp+16],ax
movzx eax,WORD PTR [r13]
mov WORD PTR [rbp+32],ax
movzx eax,WORD PTR [r13+r10]
mov WORD PTR [rbp+48],ax
add rdx,2
add rbp,2 # advance padded buffer destination
.LCopyPackA.CopyRemainingCountKLessThan2M4:
test bl,1 # (CountK & 1) != 0?
jz .LCopyPackA.ProcessPaddedMatrixADataM4
lea r13,[rdx+r10*2] # compute matrix A plus 2 rows
movzx eax,BYTE PTR [rdx]
mov BYTE PTR [rbp],al
movzx eax,BYTE PTR [rdx+r10]
mov BYTE PTR [rbp+16],al
movzx eax,BYTE PTR [r13]
mov BYTE PTR [rbp+32],al
movzx eax,BYTE PTR [r13+r10]
mov BYTE PTR [rbp+48],al
//
// Process the remaining CountK columns using the zero padded stack buffer.
//
.LCopyPackA.ProcessPaddedMatrixADataM4:
vmovdqu xmm4,XMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp]
vmovdqu xmm5,XMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp+16]
vmovdqu xmm6,XMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp+32]
vmovdqu xmm7,XMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp+48]
lea rax,[rcx+r12*2] # compute matrix D plus 2 rows
vpmaskmovd XMMWORD PTR [rcx],xmm10,xmm4
vpmaskmovd XMMWORD PTR [rcx+r12],xmm10,xmm5
vpmaskmovd XMMWORD PTR [rax],xmm10,xmm6
vpmaskmovd XMMWORD PTR [rax+r12],xmm10,xmm7
vpmaddubsw xmm4,xmm4,xmm9 # horizontal byte+byte=word per row
vpaddw ymm0,ymm0,ymm4 # add words to row accumulators
vpmaddubsw xmm5,xmm5,xmm9
vpaddw ymm1,ymm1,ymm5
vpmaddubsw xmm6,xmm6,xmm9
vpaddw ymm2,ymm2,ymm6
vpmaddubsw xmm7,xmm7,xmm9
vpaddw ymm3,ymm3,ymm7
//
// Reduce the sums for the four rows of output.
//
.LCopyPackA.ReduceRowSumVectorM4:
vphaddw ymm0,ymm0,ymm1 # reduce and interleave Sum1/Sum0
vphaddw ymm1,ymm2,ymm3 # reduce and interleave Sum3/Sum2
vphaddw ymm0,ymm0,ymm1 # reduce and interleave Sum3/Sum2/Sum1/Sum0
vextracti128 xmm1,ymm0,1 # extract high pairs
vpaddw xmm0,xmm0,xmm1 # reduce low/high pairs
vpmaddwd xmm0,xmm0,xmm8 # multiply by offset and reduce 32-bit sum
vmovdqu XMMWORD PTR [r9],xmm0
add r9,4*4 # advance row sum vector by 4 DWORDs
sub r11,4 # subtract rows remaining
jae .LCopyPackA.ProcessNextRowM4
.LCopyPackA.ProcessRemainingRows:
add r11,4 # correct for over-subtract above
jz .LCopyPackA.ExitRoutine
//
// Process a single row of matrix A in a loop.
//
.LCopyPackA.ProcessNextRowM1:
vpxor xmm0,xmm0,xmm0 # clear row accumulator
mov rdx,rsi
mov rcx,rdi
add rsi,r10
add rdi,r12
mov rbx,r8 # reload columns remaining
sub rbx,32
jb .LCopyPackA.ProcessRemainingColumnsM1
.LCopyPackA.ProcessNextColumnLoopM1:
vmovdqu ymm4,YMMWORD PTR [rdx]
vmovdqu YMMWORD PTR [rcx],ymm4
vpmaddubsw ymm4,ymm4,ymm9 # horizontal byte+byte=word per row
vpaddw ymm0,ymm0,ymm4 # add words to row accumulators
add rdx,32 # advance matrix A by 32 bytes
add rcx,32 # advance matrix D by 32 bytes
sub rbx,32 # subtract columns remaining
jae .LCopyPackA.ProcessNextColumnLoopM1
.LCopyPackA.ProcessRemainingColumnsM1:
add rbx,32 # correct for over-subtract above
jz .LCopyPackA.ReduceRowSumVectorM1
test bl,16 # (CountK & 16) != 0?
jz .LCopyPackA.CopyRemainingCountKLessThan16M1
vmovdqu xmm4,XMMWORD PTR [rdx]
vmovdqu XMMWORD PTR [rcx],xmm4
vpmaddubsw xmm4,xmm4,xmm9 # horizontal byte+byte=word per row
vpaddw ymm0,ymm0,ymm4 # add words to row accumulators
add rdx,16 # advance matrix A by 16 bytes
add rcx,16 # advance matrix D by 16 bytes
test bl,15 # test for unaligned columns
jz .LCopyPackA.ReduceRowSumVectorM1
//
// Copy the unaligned CountK columns to a zero padded stack buffer.
//
.LCopyPackA.CopyRemainingCountKLessThan16M1:
lea rbp,.LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp]
test bl,8 # (CountK & 8) != 0?
jz .LCopyPackA.CopyRemainingCountKLessThan8M1
mov rax,QWORD PTR [rdx]
mov QWORD PTR [rbp],rax
add rdx,8
add rbp,8 # advance padded buffer destination
.LCopyPackA.CopyRemainingCountKLessThan8M1:
test bl,4 # (CountK & 4) != 0?
jz .LCopyPackA.CopyRemainingCountKLessThan4M1
mov eax,DWORD PTR [rdx]
mov DWORD PTR [rbp],eax
add rdx,4
add rbp,4 # advance padded buffer destination
.LCopyPackA.CopyRemainingCountKLessThan4M1:
test bl,2 # (CountK & 2) != 0?
jz .LCopyPackA.CopyRemainingCountKLessThan2M1
movzx eax,WORD PTR [rdx]
mov WORD PTR [rbp],ax
add rdx,2
add rbp,2 # advance padded buffer destination
.LCopyPackA.CopyRemainingCountKLessThan2M1:
test bl,1 # (CountK & 1) != 0?
jz .LCopyPackA.ProcessPaddedMatrixADataM1
movzx eax,BYTE PTR [rdx]
mov BYTE PTR [rbp],al
//
// Process the remaining CountK columns using the zero padded stack buffer.
//
.LCopyPackA.ProcessPaddedMatrixADataM1:
vmovdqu xmm4,XMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp]
vpmaskmovd XMMWORD PTR [rcx],xmm10,xmm4
vpmaddubsw ymm4,ymm4,ymm9 # horizontal byte+byte=word per row
vpaddw ymm0,ymm0,ymm4 # accumulate per row along columns
//
// Reduce the sum for the single row of output.
//
.LCopyPackA.ReduceRowSumVectorM1:
vextracti128 xmm1,ymm0,1 # extract high pairs
vpaddw xmm0,xmm0,xmm1 # reduction
vphaddw xmm0,xmm0,xmm0
vphaddw xmm0,xmm0,xmm0
vpmaddwd xmm0,xmm0,xmm8 # multiply by offset and reduce
vmovd DWORD PTR [r9],xmm0
add r9,4 # advance row sum vector by 1 DWORD
dec r11 # decrement rows remaining
jnz .LCopyPackA.ProcessNextRowM1
//
// Restore non-volatile registers and return.
//
.LCopyPackA.ExitRoutine:
vzeroupper
pop r13
pop r12
pop rbx
pop rbp
ret
/*++
Routine Description:
This routine copies elements from the source matrix to the destination
packed buffer.
Arguments:
D (rdi) - Supplies the address of the destination packed buffer.
B (rsi) - Supplies the address of the source matrix.
ldb (rdx) - Supplies the number of elements per row of the source matrix.
CountN (rcx) - Supplies the number of columns of the source matrix to copy.
CountK (r8) - Supplies the number of rows of the source matrix to copy.
ColumnSumVector (r9) - Supplies the address of the buffer to receive the sums
of the elements from each of the columns. Each sum has also been
multiplied by the zero point offset.
offa - Supplies the zero point offset for the other source matrix of the
matrix multiplication.
Return Value:
None.
--*/
.globl C_UNDERSCORE(MlasGemmU8S8CopyPackBAvx2)
C_UNDERSCORE(MlasGemmU8S8CopyPackBAvx2):
push rbp
push rbx
mov r10,rdx
vpbroadcastw ymm7,WORD PTR .LGemmU8S8CopyPackBFrame_offa[rsp]
vpcmpeqw ymm8,ymm8,ymm8 # generate word vector [0xFFFF]
vpsrlw ymm8,ymm8,15 # generate word vector [0x0001]
vpsllw ymm0,ymm8,8 # generate word vector [0x0100]
vpor ymm8,ymm8,ymm0 # generate word vector [0x0101]
//
// Process 16 columns of matrix B in a loop.
//
sub rcx,16
jb .LCopyPackB.ProcessRemainingColumns
.LCopyPackB.ProcessNextColumnN16:
vpxor xmm0,xmm0,xmm0 # clear column accumulators
vpxor xmm1,xmm1,xmm1
mov rdx,rsi
add rsi,16 # advance next matrix B by 16 columns
mov rbx,r8 # reload rows remaining
sub rbx,4
jb .LCopyPackB.ProcessRemainingRowsN16
.LCopyPackB.ProcessNextRowLoopN16:
lea rax,[rdx+r10*2] # compute matrix B plus 2 rows
vmovdqu xmm2,XMMWORD PTR [rdx] # load 4 rows
vmovdqu xmm3,XMMWORD PTR [rdx+r10]
vmovdqu xmm4,XMMWORD PTR [rax]
vmovdqu xmm5,XMMWORD PTR [rax+r10]
lea rdx,[rdx+r10*4] # advance matrix B by 4 rows
.LCopyPackB.InterleaveRowDataN16:
vpunpcklbw xmm6,xmm2,xmm3 # interleave row data
vpunpckhbw xmm3,xmm2,xmm3
vpunpcklbw xmm2,xmm4,xmm5
vpunpckhbw xmm5,xmm4,xmm5
vpunpcklwd xmm4,xmm6,xmm2
vpunpckhwd xmm6,xmm6,xmm2
vpunpcklwd xmm2,xmm3,xmm5
vpunpckhwd xmm3,xmm3,xmm5
vinsertf128 ymm4,ymm4,xmm6,1
vinsertf128 ymm2,ymm2,xmm3,1
vmovdqu YMMWORD PTR [rdi],ymm4 # store interleaved rows
vmovdqu YMMWORD PTR [rdi+32],ymm2
vpmaddubsw ymm4,ymm8,ymm4 # horizontal byte+byte=word per row
vpaddw ymm0,ymm0,ymm4 # add words to row accumulators
vpmaddubsw ymm2,ymm8,ymm2
vpaddw ymm1,ymm1,ymm2
add rdi,64 # advance matrix D by 64 bytes
sub rbx,4 # subtract rows remaining
jae .LCopyPackB.ProcessNextRowLoopN16
//
// Process the less than 4 remaining rows where the row has 16 columns.
//
.LCopyPackB.ProcessRemainingRowsN16:
add rbx,4 # correct for over-subtract above
jz .LCopyPackB.ReduceColumnSumVectorN16
vmovdqu xmm2,XMMWORD PTR [rdx]
vpxor xmm3,xmm3,xmm3
vpxor xmm4,xmm4,xmm4
vpxor xmm5,xmm5,xmm5
xor ebx,ebx # no more rows remaining
test r8b,2 # (CountK & 2) != 0?
jz .LCopyPackB.InterleaveRowDataN16
vmovdqu xmm3,XMMWORD PTR [rdx+r10]
test r8b,1 # (CountK & 1) != 0?
jz .LCopyPackB.InterleaveRowDataN16
vmovdqu xmm4,XMMWORD PTR [rdx+r10*2]
jmp .LCopyPackB.InterleaveRowDataN16
.LCopyPackB.ReduceColumnSumVectorN16:
vpmaddwd ymm0,ymm0,ymm7 # multiply by offset and reduce
vpmaddwd ymm1,ymm1,ymm7 # multiply by offset and reduce
vmovdqu YMMWORD PTR [r9],ymm0
vmovdqu YMMWORD PTR [r9+32],ymm1
add r9,16*4 # advance column sum vector by 16 DWORDs
sub rcx,16 # subtract columns remaining
jae .LCopyPackB.ProcessNextColumnN16
.LCopyPackB.ProcessRemainingColumns:
add rcx,16 # correct for over-subtract above
jnz .LCopyPackB.ProcessColumnNUnaligned
//
// Restore non-volatile registers and return.
//
.LCopyPackB.ExitRoutine:
vzeroupper
pop rbx
pop rbp
ret
//
// Process the remaining columns of matrix B.
//
.LCopyPackB.ProcessColumnNUnaligned:
vpxor xmm0,xmm0,xmm0 # clear column accumulators
vpxor xmm1,xmm1,xmm1
vmovdqu YMMWORD PTR .LGemmU8S8CopyPackBFrame_PaddedMatrixBData[rsp],ymm0
vmovdqu YMMWORD PTR .LGemmU8S8CopyPackBFrame_PaddedMatrixBData[rsp+32],ymm0
sub r8,4
jb .LCopyPackB.ProcessRemainingRowsNUnaligned
.LCopyPackB.ProcessNextRowLoopNUnaligned:
mov rdx,rsi
lea rbp,.LGemmU8S8CopyPackBFrame_PaddedMatrixBData[rsp]
test cl,8 # (CountN & 8) != 0?
jz .LCopyPackB.CopyRemainingCountNLessThan8K4
lea r11,[rdx+r10*2] # compute matrix B plus 2 rows
mov rax,QWORD PTR [rdx]
mov QWORD PTR [rbp],rax
mov rax,QWORD PTR [rdx+r10]
mov QWORD PTR [rbp+16],rax
mov rax,QWORD PTR [r11]
mov QWORD PTR [rbp+32],rax
mov rax,QWORD PTR [r11+r10]
mov QWORD PTR [rbp+48],rax
add rdx,8 # advance matrix B
add rbp,8 # advance padded buffer destination
.LCopyPackB.CopyRemainingCountNLessThan8K4:
test cl,4 # (CountN & 4) != 0?
jz .LCopyPackB.CopyRemainingCountNLessThan4K4
lea r11,[rdx+r10*2] # compute matrix B plus 2 rows
mov eax,DWORD PTR [rdx]
mov DWORD PTR [rbp],eax
mov eax,DWORD PTR [rdx+r10]
mov DWORD PTR [rbp+16],eax
mov eax,DWORD PTR [r11]
mov DWORD PTR [rbp+32],eax
mov eax,DWORD PTR [r11+r10]
mov DWORD PTR [rbp+48],eax
add rdx,4 # advance matrix B
add rbp,4 # advance padded buffer destination
.LCopyPackB.CopyRemainingCountNLessThan4K4:
test cl,2 # (CountN & 2) != 0?
jz .LCopyPackB.CopyRemainingCountNLessThan2K4
lea r11,[rdx+r10*2] # compute matrix B plus 2 rows
movzx eax,WORD PTR [rdx]
mov WORD PTR [rbp],ax
movzx eax,WORD PTR [rdx+r10]
mov WORD PTR [rbp+16],ax
movzx eax,WORD PTR [r11]
mov WORD PTR [rbp+32],ax
movzx eax,WORD PTR [r11+r10]
mov WORD PTR [rbp+48],ax
add rdx,2 # advance matrix B
add rbp,2 # advance padded buffer destination
.LCopyPackB.CopyRemainingCountNLessThan2K4:
test cl,1 # (CountN & 1) != 0?
jz .LCopyPackB.ProcessPaddedMatrixBData
lea r11,[rdx+r10*2] # compute matrix B plus 2 rows
movzx eax,BYTE PTR [rdx]
mov BYTE PTR [rbp],al
movzx eax,BYTE PTR [rdx+r10]
mov BYTE PTR [rbp+16],al
movzx eax,BYTE PTR [r11]
mov BYTE PTR [rbp+32],al
movzx eax,BYTE PTR [r11+r10]
mov BYTE PTR [rbp+48],al
.LCopyPackB.ProcessPaddedMatrixBData:
vmovdqu xmm2,XMMWORD PTR .LGemmU8S8CopyPackBFrame_PaddedMatrixBData[rsp]
vmovdqu xmm3,XMMWORD PTR .LGemmU8S8CopyPackBFrame_PaddedMatrixBData[rsp+16]
vmovdqu xmm4,XMMWORD PTR .LGemmU8S8CopyPackBFrame_PaddedMatrixBData[rsp+32]
vmovdqu xmm5,XMMWORD PTR .LGemmU8S8CopyPackBFrame_PaddedMatrixBData[rsp+48]
vpunpcklbw xmm6,xmm2,xmm3 # interleave row data
vpunpckhbw xmm3,xmm2,xmm3
vpunpcklbw xmm2,xmm4,xmm5
vpunpckhbw xmm5,xmm4,xmm5
vpunpcklwd xmm4,xmm6,xmm2
vpunpckhwd xmm6,xmm6,xmm2
vpunpcklwd xmm2,xmm3,xmm5
vpunpckhwd xmm3,xmm3,xmm5
vinsertf128 ymm4,ymm4,xmm6,1
vinsertf128 ymm2,ymm2,xmm3,1
vmovdqu YMMWORD PTR [rdi],ymm4 # store interleaved rows
vmovdqu YMMWORD PTR [rdi+32],ymm2
vpmaddubsw ymm4,ymm8,ymm4 # horizontal byte+byte=word per row
vpaddw ymm0,ymm0,ymm4 # add words to row accumulators
vpmaddubsw ymm2,ymm8,ymm2
vpaddw ymm1,ymm1,ymm2
lea rsi,[rsi+r10*4] # advance next matrix B by 4 rows
add rdi,64 # advance matrix D by 64 bytes
sub r8,4 # subtract rows remaining
jae .LCopyPackB.ProcessNextRowLoopNUnaligned
.LCopyPackB.ProcessRemainingRowsNUnaligned:
add r8,4
jz .LCopyPackB.ReduceColumnSumVectorNUnaligned
//
// Process the less than 4 remaining rows where the row has less than 16 columns.
//
lea rbp,.LGemmU8S8CopyPackBFrame_PaddedMatrixBData[rsp]
vpxor xmm6,xmm6,xmm6
vmovdqu YMMWORD PTR [rbp],ymm6
vmovdqu YMMWORD PTR [rbp+32],ymm6
.LCopyPackB.CopyUnalignedRowLoop:
lea r11,[rbp+16] # advance next padded buffer by 16 bytes
mov rdx,rsi
test cl,8 # (CountN & 8) != 0?
jz .LCopyPackB.CopyRemainingCountNLessThan8KSmall
mov rax,QWORD PTR [rdx]
mov QWORD PTR [rbp],rax
add rdx,8 # advance matrix B
add rbp,8 # advance padded buffer destination
.LCopyPackB.CopyRemainingCountNLessThan8KSmall:
test cl,4 # (CountN & 4) != 0?
jz .LCopyPackB.CopyRemainingCountNLessThan4KSmall
mov eax,DWORD PTR [rdx]
mov DWORD PTR [rbp],eax
add rdx,4 # advance matrix B
add rbp,4 # advance padded buffer destination
.LCopyPackB.CopyRemainingCountNLessThan4KSmall:
test cl,2 # (CountN & 2) != 0?
jz .LCopyPackB.CopyRemainingCountNLessThan2KSmall
movzx eax,WORD PTR [rdx]
mov WORD PTR [rbp],ax
add rdx,2 # advance matrix B
add rbp,2 # advance padded buffer destination
.LCopyPackB.CopyRemainingCountNLessThan2KSmall:
test cl,1 # (CountN & 1) != 0?
jz .LCopyPackB.DoneCopyRemainingCountNKSmall
movzx eax,BYTE PTR [rdx]
mov BYTE PTR [rbp],al
.LCopyPackB.DoneCopyRemainingCountNKSmall:
dec r8
jz .LCopyPackB.ProcessPaddedMatrixBData
add rsi,r10 # advance next matrix B by 1 row
mov rbp,r11
jmp .LCopyPackB.CopyUnalignedRowLoop
.LCopyPackB.ReduceColumnSumVectorNUnaligned:
vpmaddwd ymm0,ymm0,ymm7 # multiply by offset and reduce
vpmaddwd ymm1,ymm1,ymm7 # multiply by offset and reduce
vmovdqu YMMWORD PTR [r9],ymm0
vmovdqu YMMWORD PTR [r9+32],ymm1
jmp .LCopyPackB.ExitRoutine
/*++
Macro Description:
This macro generates code to multiply and accumulator a single row of the
output block.
Arguments:
ColumnCount - Supplies the number of columns to produce.
Vec1Reg - Supplies the high block accumulator register (when ColumnCount
is 16).
Vec2Reg - Supplies the low block accumulator register.
Implicit Arguments:
ymm0 - Supplies the first vector loaded from matrix B.
ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
is 16).
ymm2 - Supplies the broadcast value loaded from matrix A.
ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001.
--*/
.macro MultiplyAccumulateRow ColumnCount, Vec1Reg, Vec2Reg
vpmaddubsw ymm3,ymm2,ymm0
vpmaddwd ymm3,ymm3,ymm12
.if \ColumnCount\() == 16
vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3
vpmaddubsw ymm2,ymm2,ymm1
vpmaddwd ymm2,ymm2,ymm12
vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2
.else
vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm3
.endif
.endm
/*++
Macro Description:
This macro generates code to multiply and accumulate each row of the output
block.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
Implicit Arguments:
rbx - Supplies the address into the matrix A data plus 3 rows.
rcx - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
ymm4-ymm11 - Supplies the block accumulators.
ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001.
--*/
.macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset
.if \RowCount\() == 1
vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]
vpmaddubsw ymm3,ymm2,YMMWORD PTR [rsi+\VectorOffset\()]
vpmaddwd ymm3,ymm3,ymm12
.if \ColumnCount\() == 16
vpaddd ymm4,ymm4,ymm3
vpmaddubsw ymm2,ymm2,YMMWORD PTR [rsi+\VectorOffset\()+32]
vpmaddwd ymm2,ymm2,ymm12
vpaddd ymm5,ymm5,ymm2
.else
vpaddd ymm5,ymm5,ymm3
.endif
.else
vmovdqu ymm0,YMMWORD PTR [rsi+\VectorOffset\()]
EmitIfCountGE \ColumnCount\(), 16, "vmovdqu ymm1,YMMWORD PTR [rsi+\VectorOffset\()+32]"
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRow \ColumnCount\(), ymm4, ymm5"
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRow \ColumnCount\(), ymm6, ymm7"
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRow \ColumnCount\(), ymm8, ymm9"
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [rbx+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRow \ColumnCount\(), ymm10, ymm11"
.endif
.endm
/*++
Macro Description:
This macro generates code to execute the block compute macro multiple
times and advancing the matrix A and matrix B data pointers.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
Implicit Arguments:
rbx - Supplies the address into the matrix A data plus 3 rows.
rdi - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
ymm4-ymm11 - Supplies the block accumulators.
--*/
.macro ComputeBlockLoop ColumnCount, RowCount
mov rbp,rcx # reload row length remaining
.LComputeBlockBy1Loop\@:
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
add rdi,4 # advance matrix A by 1 quad
.if \RowCount\() > 3
add rbx,4 # advance matrix A plus 3 rows by 1 quad
.endif
add rsi,64 # advance matrix B
sub rbp,4
jnz .LComputeBlockBy1Loop\@
.endm
/*++
Routine Description:
This routine is an inner kernel to compute matrix multiplication for a
set of rows.
Arguments:
A (rdi) - Supplies the address of matrix A. The matrix data has been packed
using MlasGemmU8S8CopyPackAAvx2.
B (rsi) - Supplies the address of matrix B. The matrix data has been packed
using MlasGemmU8S8CopyPackBAvx2.
C (rdx) - Supplies the address of matrix C.
QuadCountK (rcx) - Supplies the number of quad columns from matrix A and
the number of quad rows from matrix B to iterate over.
CountM (r8) - Supplies the maximum number of rows that can be processed for
matrix A and matrix C. The actual number of rows handled for this
invocation depends on the kernel implementation.
CountN (r9) - Supplies the number of columns from matrix B and matrix C to
iterate over.
ldc - Supplies the first dimension of matrix C.
RowSumVector - Supplies the sum of each row from matrix A multiplied by the
zero point offset of matrix B. These values are accumulated into every
row of matrix C.
ColumnSumVector - Supplies the sum of each column from matrix B multiplied
by the zero point offset of matrix A. These values are accumulated into
every column of matrix C.
DepthValue - Supplies the value CountK multiplied by the zero point offset
of matrix A multplied by the zero point offset of matrix B. This value
is accumulated into every element of matrix C.
ZeroMode - Supplies true if the output matrix must be zero initialized,
else false if the output matrix is accumulated into.
Return Value:
Returns the number of rows handled.
--*/
.globl C_UNDERSCORE(MlasGemmU8S8KernelAvx2)
C_UNDERSCORE(MlasGemmU8S8KernelAvx2):
push rbp
push rbx
push r12
push r13
mov rax,.LGemmU8X8KernelFrame_ldc[rsp]
shl rax,2 # convert ldc to bytes
shl rcx,2 # convert to row length
movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp]
mov r11,rdi
mov r12,.LGemmU8X8KernelFrame_RowSumVector[rsp]
mov r13,.LGemmU8X8KernelFrame_ColumnSumVector[rsp]
vpcmpeqw ymm12,ymm12,ymm12 # generate 256-bit word vector [0xFFFF]
vpsrlw ymm12,ymm12,15 # generate 256-bit word vector [0x0001]
//
// Process CountM rows of the matrices.
//
cmp r8,3
ja .LProcessCountM4
je .LProcessCountM3
cmp r8,1
je .LProcessCountM1
.LProcessCountM2:
ProcessCountM 2
.LProcessCountM4:
mov r8d,4 # return 4 rows handled
ProcessCountM 4, Fallthrough
//
// Restore non-volatile registers and return.
//
.LExitKernel:
mov eax,r8d
vzeroupper
pop r13
pop r12
pop rbx
pop rbp
ret
.LProcessCountM1:
ProcessCountM 1
.LProcessCountM3:
ProcessCountM 3
.end

View file

@ -0,0 +1,136 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
QgemmU8S8KernelAvx512BW.s
Abstract:
This module implements the kernels for the quantized integer matrix/matrix
multiply operation (QGEMM).
This implementation uses AVX512BW instructions.
--*/
#include "asmmacro.h"
#include "QgemmU8S8KernelAvx512Common.h"
.intel_syntax noprefix
.text
/*++
Macro Description:
This macro generates code to multiply and accumulator a single cell of the
output block.
Arguments:
AccumReg - Supplies the register to accumulate into.
Mult1Reg - Supplies the first multiplication operand register.
Mult2Reg - Supplies the second multiplication operand register.
Implicit Arguments:
zmm4 - Supplies a scratch register for intermediate results.
zmm5 - Supplies a 512-bit with the broadcasted word value 0x0001.
--*/
.macro MultiplyAccumulateCell AccumReg, Mult1Reg, Mult2Reg
vpmaddubsw zmm4,\Mult1Reg\(),\Mult2Reg\()
vpmaddwd zmm4,zmm4,zmm5
vpaddd \AccumReg\(),\AccumReg\(),zmm4
.endm
/*++
Macro Description:
This macro generates code to multiply and accumulate each row of the output
block.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
Implicit Arguments:
rbx - Supplies the address into the matrix A data plus 3 rows.
rdi - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
zmm14-zmm31 - Supplies the block accumulators.
--*/
.macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset
.if \ColumnCount\() >= 48
vmovdqu32 zmm0,ZMMWORD PTR [rsi+\VectorOffset\()]
vmovdqu32 zmm1,ZMMWORD PTR [rsi+r14+\VectorOffset\()]
vmovdqu32 zmm2,ZMMWORD PTR [rsi+r14*2+\VectorOffset\()]
.elseif \ColumnCount\() >= 32
vmovdqu32 zmm1,ZMMWORD PTR [rsi+\VectorOffset\()]
vmovdqu32 zmm2,ZMMWORD PTR [rsi+r14+\VectorOffset\()]
.else
vmovdqu32 zmm2,ZMMWORD PTR [rsi+\VectorOffset\()]
.endif
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm3,DWORD PTR [rdi+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm26,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm20,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm14,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm27,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm21,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm15,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm28,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm22,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm16,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm3,DWORD PTR [rbx+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm29,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm23,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm17,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm30,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm24,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm18,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm31,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm25,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm19,zmm3,zmm2"
.endm
//
// Generate the GEMM kernel.
//
GemmU8X8KernelAvx512Function U8S8, Avx512BW
.end

View file

@ -0,0 +1,88 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
QgemmU8S8KernelAvx512Common.h
Abstract:
This module contains common kernel macros and structures for the quantized
integer matrix/matrix multiply operation (QGEMM) for the AVX512BW and
AVX512VNNI kernels.
--*/
#include "QgemmU8X8KernelAvx512Common.h"
/*++
Macro Description:
This macro generates code to execute the block compute macro multiple
times and advancing the matrix A and matrix B data pointers.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
Implicit Arguments:
rbx - Supplies the address into the matrix A data plus 3 rows.
rdi - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
zmm14-zmm31 - Supplies the block accumulators.
--*/
.macro ComputeBlockLoop ColumnCount, RowCount
mov rbp,rcx # reload row length remaining
.if ((\RowCount\() & 1) == 0)
sub rbp,4*4
jb .LProcessRemainingBlocks\@
.LComputeBlockBy4Loop\@:
ComputeBlock \ColumnCount\(), \RowCount\(), 0*64, 0
ComputeBlock \ColumnCount\(), \RowCount\(), 1*64, 4
ComputeBlock \ColumnCount\(), \RowCount\(), 2*64, 8
ComputeBlock \ColumnCount\(), \RowCount\(), 3*64, 12
add rdi,4*4 # advance matrix A by 1 quad
.if \RowCount\() > 3
add rbx,4*4 # advance matrix A plus 3 rows by 1 quad
.endif
add rsi,4*64 # advance matrix B
sub rbp,4*4 # decrement quads remaining
jae .LComputeBlockBy4Loop\@
.LProcessRemainingBlocks\@:
add rbp,4*4 # correct for over-subtract above
jz .LComputeBlockLoopExit\@
.endif
.LComputeBlockBy1Loop\@:
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
add rdi,4 # advance matrix A by 1 quad
.if \RowCount\() > 3
add rbx,4 # advance matrix A plus 3 rows by 1 quad
.endif
add rsi,64 # advance matrix B
sub rbp,4 # decrement quads remaining
jnz .LComputeBlockBy1Loop\@
.LComputeBlockLoopExit\@:
.endm

View file

@ -0,0 +1,106 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
QgemmU8S8KernelAvx512Vnni.s
Abstract:
This module implements the kernels for the quantized integer matrix/matrix
multiply operation (QGEMM).
This implementation uses AVX512VNNI instructions.
--*/
#include "asmmacro.h"
#include "QgemmU8S8KernelAvx512Common.h"
#include "AssembleAvx512Vnni.h"
.intel_syntax noprefix
.text
/*++
Macro Description:
This macro generates code to multiply and accumulate each row of the output
block.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
Implicit Arguments:
rbx - Supplies the address into the matrix A data plus 3 rows.
rdi - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
zmm14-zmm31 - Supplies the block accumulators.
--*/
.macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset
.if \ColumnCount\() >= 48
vmovdqu32 zmm0,ZMMWORD PTR [rsi+\VectorOffset\()]
vmovdqu32 zmm1,ZMMWORD PTR [rsi+r14+\VectorOffset\()]
vmovdqu32 zmm2,ZMMWORD PTR [rsi+r14*2+\VectorOffset\()]
.elseif \ColumnCount\() >= 32
vmovdqu32 zmm1,ZMMWORD PTR [rsi+\VectorOffset\()]
vmovdqu32 zmm2,ZMMWORD PTR [rsi+r14+\VectorOffset\()]
.else
vmovdqu32 zmm2,ZMMWORD PTR [rsi+\VectorOffset\()]
.endif
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm3,DWORD PTR [rdi+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm26,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm20,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm14,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm27,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm21,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm15,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm28,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm22,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm16,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm3,DWORD PTR [rbx+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm29,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm23,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm17,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm30,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm24,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm18,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm31,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm25,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm19,zmm3,zmm2"
.endm
//
// Generate the GEMM kernel.
//
GemmU8X8KernelAvx512Function U8S8, Avx512Vnni
.end

View file

@ -18,11 +18,10 @@ Abstract:
--*/
#include "asmmacro.h"
#include "QgemmU8X8KernelAvx2Common.h"
.intel_syntax noprefix
.text
//
// Stack frame layout for the U8U8 CopyPackA routine.
//
@ -47,22 +46,7 @@ Abstract:
.equ .LGemmU8U8CopyPackBFrame_ReturnAddress, 16
.equ .LGemmU8U8CopyPackBFrame_offa, 24
//
// Stack frame layout for the U8U8 kernel.
//
.equ .LGemmU8U8KernelFrame_mask, -8
.equ .LGemmU8U8KernelFrame_SavedR14, 0
.equ .LGemmU8U8KernelFrame_SavedR13, 8
.equ .LGemmU8U8KernelFrame_SavedR12, 16
.equ .LGemmU8U8KernelFrame_SavedRbx, 24
.equ .LGemmU8U8KernelFrame_SavedRbp, 32
.equ .LGemmU8U8KernelFrame_ReturnAddress, 40
.equ .LGemmU8U8KernelFrame_ldc, 48
.equ .LGemmU8U8KernelFrame_RowSumVector, 56
.equ .LGemmU8U8KernelFrame_ColumnSumVector, 64
.equ .LGemmU8U8KernelFrame_DepthValue, 72
.equ .LGemmU8U8KernelFrame_ZeroMode, 80
.text
/*++
@ -138,13 +122,15 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2):
//
// Process 4 rows of matrix A in a loop.
//
// For each row, zero extend the source bytes to 16-bits and write to the packed
// buffer. The packed buffer has the same data ordering as the source bytes, but
// the stride is CountK aligned up to an even number of 16-bit values.
// Zero extend the source bytes to 16-bits and write to the packed buffer.
//
// The packed buffer has the same data ordering as the source bytes, but CountK
// is aligned up to a multiple of 2 to maintain 32-bit alignment. All padding
// bytes are zero filled.
//
// These 16-bit values are also accumulated into an intermediate per-row
// accumulator. CountK cannot be greater than 256 to avoid overflowing these
// 16-bit accumulators.
// accumulator. CountK cannot be greater than 128 to avoid overflowing these
// signed 16-bit accumulators.
//
sub r11,4
@ -164,12 +150,12 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2):
jb .LCopyPackA.ProcessRemainingColumnsM4
.LCopyPackA.ProcessNextColumnLoopM4:
lea rax,[rdx+r10*2] # compute matrix A plus two rows
lea rax,[rdx+r10*2] # compute matrix A plus 2 rows
vpmovzxbw ymm4,XMMWORD PTR [rdx]
vpmovzxbw ymm5,XMMWORD PTR [rdx+r10]
vpmovzxbw ymm6,XMMWORD PTR [rax]
vpmovzxbw ymm7,XMMWORD PTR [rax+r10]
lea rax,[rcx+r12*4] # compute matrix D plus two rows
lea rax,[rcx+r12*4] # compute matrix D plus 2 rows
vmovdqu YMMWORD PTR [rcx],ymm4
vmovdqu YMMWORD PTR [rcx+r12*2],ymm5
vmovdqu YMMWORD PTR [rax],ymm6
@ -194,7 +180,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2):
lea rbp,.LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp]
test bl,8 # (CountK & 8) != 0?
jz .LCopyPackA.CopyRemainingCountKLessThan8M4
lea r13,[rdx+r10*2] # compute matrix A plus two rows
lea r13,[rdx+r10*2] # compute matrix A plus 2 rows
mov rax,QWORD PTR [rdx]
mov QWORD PTR [rbp],rax
mov rax,QWORD PTR [rdx+r10]
@ -209,7 +195,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2):
.LCopyPackA.CopyRemainingCountKLessThan8M4:
test bl,4 # (CountK & 4) != 0?
jz .LCopyPackA.CopyRemainingCountKLessThan4M4
lea r13,[rdx+r10*2] # compute matrix A plus two rows
lea r13,[rdx+r10*2] # compute matrix A plus 2 rows
mov eax,DWORD PTR [rdx]
mov DWORD PTR [rbp],eax
mov eax,DWORD PTR [rdx+r10]
@ -224,7 +210,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2):
.LCopyPackA.CopyRemainingCountKLessThan4M4:
test bl,2 # (CountK & 2) != 0?
jz .LCopyPackA.CopyRemainingCountKLessThan2M4
lea r13,[rdx+r10*2] # compute matrix A plus two rows
lea r13,[rdx+r10*2] # compute matrix A plus 2 rows
movzx eax,WORD PTR [rdx]
mov WORD PTR [rbp],ax
movzx eax,WORD PTR [rdx+r10]
@ -239,7 +225,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2):
.LCopyPackA.CopyRemainingCountKLessThan2M4:
test bl,1 # (CountK & 1) != 0?
jz .LCopyPackA.ProcessPaddedMatrixADataM4
lea r13,[rdx+r10*2] # compute matrix A plus two rows
lea r13,[rdx+r10*2] # compute matrix A plus 2 rows
movzx eax,BYTE PTR [rdx]
mov BYTE PTR [rbp],al
movzx eax,BYTE PTR [rdx+r10]
@ -258,7 +244,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2):
vpmovzxbw ymm5,XMMWORD PTR .LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp+16]
vpmovzxbw ymm6,XMMWORD PTR .LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp+32]
vpmovzxbw ymm7,XMMWORD PTR .LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp+48]
lea rax,[rcx+r12*4] # compute matrix D plus two rows
lea rax,[rcx+r12*4] # compute matrix D plus 2 rows
vpmaskmovd YMMWORD PTR [rcx],ymm9,ymm4
vpmaskmovd YMMWORD PTR [rcx+r12*2],ymm9,ymm5
vpmaskmovd YMMWORD PTR [rax],ymm9,ymm6
@ -276,20 +262,12 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2):
//
.LCopyPackA.ReduceRowSumVectorM4:
vpunpckldq ymm4,ymm0,ymm1 # [A5 B5 A4 B4 A1 B1 A0 B0]
vpunpckhdq ymm5,ymm0,ymm1 # [A7 B7 A6 B6 A3 B3 A2 B2]
vpunpckldq ymm6,ymm2,ymm3 # [C5 D5 C4 D4 C1 D1 C0 D0]
vpunpckhdq ymm7,ymm2,ymm3 # [C7 D7 C6 D6 C3 D3 C2 D2]
vpunpcklqdq ymm0,ymm4,ymm6 # [A4 B4 C4 D4 A0 B0 C0 D0]
vpunpckhqdq ymm1,ymm4,ymm6 # [A5 B5 C5 D5 A1 B1 C1 D1]
vpunpcklqdq ymm2,ymm5,ymm7 # [A6 B6 C6 D6 A2 B2 C2 D2]
vpunpckhqdq ymm3,ymm5,ymm7 # [A7 B7 C7 D7 A3 B3 C3 D3]
vpaddw ymm0,ymm0,ymm1 # reduction
vpaddw ymm0,ymm0,ymm2
vpaddw ymm0,ymm0,ymm3
vphaddw ymm0,ymm0,ymm1 # reduce and interleave Sum1/Sum0
vphaddw ymm1,ymm2,ymm3 # reduce and interleave Sum3/Sum2
vphaddw ymm0,ymm0,ymm1 # reduce and interleave Sum3/Sum2/Sum1/Sum0
vextracti128 xmm1,ymm0,1 # extract high pairs
vpaddw xmm0,xmm0,xmm1 # reduction
vpmaddwd xmm0,xmm0,xmm8 # multiply by offset and reduce
vpaddw xmm0,xmm0,xmm1 # reduce low/high pairs
vpmaddwd xmm0,xmm0,xmm8 # multiply by offset and reduce 32-bit sum
vmovdqu XMMWORD PTR [r9],xmm0
add r9,4*4 # advance row sum vector by 4 dwords
sub r11,4 # subtract rows remaining
@ -436,7 +414,6 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
push rbx
mov r10,rdx
mov r11,rcx
vpbroadcastw ymm5,WORD PTR .LGemmU8U8CopyPackBFrame_offa[rsp]
//
@ -450,7 +427,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
// Process 16 columns of matrix B in a loop.
//
sub r11,16
sub rcx,16
jb .LCopyPackB.ProcessRemainingColumns
.LCopyPackB.ProcessNextColumnN16:
@ -463,9 +440,9 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
jb .LCopyPackB.ProcessRemainingRowsN16
.LCopyPackB.ProcessNextRowLoopN16:
vmovdqu xmm2,XMMWORD PTR [rdx] # load two rows
vmovdqu xmm2,XMMWORD PTR [rdx] # load 2 rows
vmovdqu xmm3,XMMWORD PTR [rdx+r10]
lea rdx,[rdx+r10*2] # advance matrix B by two rows
lea rdx,[rdx+r10*2] # advance matrix B by 2 rows
vpunpcklbw xmm4,xmm2,xmm3 # interleave row data
vpunpckhbw xmm3,xmm2,xmm3
vmovdqu XMMWORD PTR [rdi],xmm4 # store interleaved rows
@ -475,7 +452,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
add rdi,32 # advance matrix D by 32 bytes
vpaddw ymm0,ymm0,ymm4 # accumulate per column
vpaddw ymm1,ymm1,ymm3
sub rbx,2 # subtract columns remaining
sub rbx,2 # subtract rows remaining
jae .LCopyPackB.ProcessNextRowLoopN16
.LCopyPackB.ProcessRemainingRowsN16:
@ -495,12 +472,12 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
vpmaddwd ymm1,ymm1,ymm5 # multiply by offset and reduce
vmovdqu YMMWORD PTR [r9],ymm0
vmovdqu YMMWORD PTR [r9+32],ymm1
add r9,64 # advance column sum vector by 16 dwords
sub r11,16 # subtract columns remaining
add r9,64 # advance column sum vector by 16 DWORDs
sub rcx,16 # subtract columns remaining
jae .LCopyPackB.ProcessNextColumnN16
.LCopyPackB.ProcessRemainingColumns:
add r11,16 # correct for over-subtract above
add rcx,16 # correct for over-subtract above
jnz .LCopyPackB.ProcessColumnNUnaligned
//
@ -527,7 +504,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
.LCopyPackB.ProcessNextRowLoopNUnaligned:
mov rdx,rsi
lea rbp,.LGemmU8U8CopyPackBFrame_PaddedMatrixBData[rsp]
test r11b,8 # (CountN & 8) != 0?
test cl,8 # (CountN & 8) != 0?
jz .LCopyPackB.CopyRemainingCountNLessThan8K2
mov rax,QWORD PTR [rdx]
mov QWORD PTR [rbp],rax
@ -537,7 +514,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
add rbp,8 # advance padded buffer destination
.LCopyPackB.CopyRemainingCountNLessThan8K2:
test r11b,4 # (CountN & 4) != 0?
test cl,4 # (CountN & 4) != 0?
jz .LCopyPackB.CopyRemainingCountNLessThan4K2
mov eax,DWORD PTR [rdx]
mov DWORD PTR [rbp],eax
@ -547,7 +524,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
add rbp,4 # advance padded buffer destination
.LCopyPackB.CopyRemainingCountNLessThan4K2:
test r11b,2 # (CountN & 2) != 0?
test cl,2 # (CountN & 2) != 0?
jz .LCopyPackB.CopyRemainingCountNLessThan2K2
movzx eax,WORD PTR [rdx]
mov WORD PTR [rbp],ax
@ -557,7 +534,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
add rbp,2 # advance padded buffer destination
.LCopyPackB.CopyRemainingCountNLessThan2K2:
test r11b,1 # (CountN & 1) != 0?
test cl,1 # (CountN & 1) != 0?
jz .LCopyPackB.ProcessPaddedMatrixBDataK2
movzx eax,BYTE PTR [rdx]
mov BYTE PTR [rbp],al
@ -575,9 +552,9 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
vpmovzxbw ymm3,xmm3
vpaddw ymm0,ymm0,ymm4 # accumulate per column
vpaddw ymm1,ymm1,ymm3
lea rsi,[rsi+r10*2] # advance next matrix B by two rows
lea rsi,[rsi+r10*2] # advance next matrix B by 2 rows
add rdi,32 # advance matrix D by 32 bytes
sub r8,2 # subtract columns remaining
sub r8,2 # subtract rows remaining
jae .LCopyPackB.ProcessNextRowLoopNUnaligned
.LCopyPackB.ProcessRemainingRowsNUnaligned:
@ -585,7 +562,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
jz .LCopyPackB.ReduceColumnSumVectorNUnaligned
mov rdx,rsi
lea rbp,.LGemmU8U8CopyPackBFrame_PaddedMatrixBData[rsp]
test r11b,8 # (CountN & 8) != 0?
test cl,8 # (CountN & 8) != 0?
jz .LCopyPackB.CopyRemainingCountNLessThan8K1
mov rax,QWORD PTR [rdx]
mov QWORD PTR [rbp],rax
@ -593,7 +570,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
add rbp,8 # advance padded buffer destination
.LCopyPackB.CopyRemainingCountNLessThan8K1:
test r11b,4 # (CountN & 4) != 0?
test cl,4 # (CountN & 4) != 0?
jz .LCopyPackB.CopyRemainingCountNLessThan4K1
mov eax,DWORD PTR [rdx]
mov DWORD PTR [rbp],eax
@ -601,7 +578,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
add rbp,4 # advance padded buffer destination
.LCopyPackB.CopyRemainingCountNLessThan4K1:
test r11b,2 # (CountN & 2) != 0?
test cl,2 # (CountN & 2) != 0?
jz .LCopyPackB.CopyRemainingCountNLessThan2K1
movzx eax,WORD PTR [rdx]
mov WORD PTR [rbp],ax
@ -609,7 +586,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
add rbp,2 # advance padded buffer destination
.LCopyPackB.CopyRemainingCountNLessThan2K1:
test r11b,1 # (CountN & 1) != 0?
test cl,1 # (CountN & 1) != 0?
jz .LCopyPackB.ProcessPaddedMatrixBDataK1
movzx eax,BYTE PTR [rdx]
mov BYTE PTR [rbp],al
@ -659,13 +636,12 @@ Implicit Arguments:
.macro MultiplyAccumulateRow ColumnCount, Vec1Reg, Vec2Reg
.if \ColumnCount\() == 16
vpmaddwd ymm3,ymm2,ymm0
.if \ColumnCount\() == 16
vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3
vpmaddwd ymm2,ymm2,ymm1
vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2
.else
vpmaddwd ymm3,ymm2,ymm0
vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm3
.endif
@ -696,7 +672,7 @@ Implicit Arguments:
rsi - Supplies the address into the matrix B data.
r10 - Supplies the length in bytes of a row from matrix A.
rcx - Supplies the length in bytes of a row from matrix A.
ymm4-ymm15 - Supplies the block accumulators.
@ -708,15 +684,15 @@ Implicit Arguments:
EmitIfCountGE \ColumnCount\(), 16, "vpmovzxbw ymm1,XMMWORD PTR [rsi+\VectorOffset\()+16]"
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRow \ColumnCount\(), ymm4, ymm5"
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+r10+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRow \ColumnCount\(), ymm6, ymm7"
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+r10*2+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRow \ColumnCount\(), ymm8, ymm9"
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [rbx+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRow \ColumnCount\(), ymm10, ymm11"
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm2,DWORD PTR [rbx+r10+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm2,DWORD PTR [rbx+rcx+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRow \ColumnCount\(), ymm12, ymm13"
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm2,DWORD PTR [rbx+r10*2+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm2,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]"
EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRow \ColumnCount\(), ymm14, ymm15"
.endm
@ -725,8 +701,8 @@ Implicit Arguments:
Macro Description:
This macro generates code to produce an output block for a set of columns
and rows.
This macro generates code to execute the block compute macro multiple
times and advancing the matrix A and matrix B data pointers.
Arguments:
@ -736,271 +712,55 @@ Arguments:
Implicit Arguments:
rax - Supplies the length in bytes of a row from matrix C.
rbx - Supplies the address into the matrix A data plus 3 rows.
rdi - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the number of paired columns from matrix A and the number of
paired rows from matrix B to iterate over.
rcx - Supplies the length in bytes of a row from matrix A.
r10 - Supplies the length in bytes of a row from matrix A.
r12 - Supplies the address of the row sum vector.
r13 - Supplies the address of the column sum vector.
ymm4-ymm15 - Supplies the block accumulators.
--*/
.macro ProduceOutputBlock ColumnCount, RowCount
.macro ComputeBlockLoop ColumnCount, RowCount
//
// Initialize the accumulators with the sum of the global depth value constant,
// the column sums, and the row sums.
//
vpbroadcastd ymm1,DWORD PTR .LGemmU8U8KernelFrame_DepthValue[rsp]
.if \ColumnCount\() == 16
vpaddd ymm0,ymm1,YMMWORD PTR [r13]
vpaddd ymm1,ymm1,YMMWORD PTR [r13+32]
add r13,16*4 # advance ColumnSumVector by 16 columns
.else
vpaddd ymm1,ymm1,YMMWORD PTR [r13]
.endif
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm5,DWORD PTR [r12]"
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm7,DWORD PTR [r12+4]"
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm9,DWORD PTR [r12+8]"
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm11,DWORD PTR [r12+12]"
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm13,DWORD PTR [r12+16]"
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm15,DWORD PTR [r12+20]"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd ymm4,ymm5,ymm0"
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm1"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd ymm6,ymm7,ymm0"
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm1"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd ymm8,ymm9,ymm0"
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm1"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd ymm10,ymm11,ymm0"
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm1"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd ymm12,ymm13,ymm0"
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm1"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd ymm14,ymm15,ymm0"
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm1"
//
// Iterate over PairedCountK elements from matrix A and matrix B.
//
// Unrolling the loop to do two iterations improves performance slightly at the
// cost of larger code size. Balance this by only unrolling for the common case
// of computing 16 columns for an even number of rows.
//
mov rbp,rcx # reload PairedCountK
.if \RowCount\() > 3
lea rbx,[r10*2+r10]
add rbx,rdi # compute matrix A plus 3 rows
.endif
mov rbp,rcx # reload row length remaining
.if (\ColumnCount\() == 16) && ((\RowCount\() & 1) == 0)
sub rbp,2
jb .LProcessRemainingBlocks.\ColumnCount\().\RowCount\()
sub rbp,2*4
jb .LProcessRemainingBlocks\@
.LComputeBlockLoop.\ColumnCount\().\RowCount\():
.LComputeBlockBy2Loop\@:
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
ComputeBlock \ColumnCount\(), \RowCount\(), 32, 4
add rdi,2*4 # advance matrix A by 2 pairs
.if \RowCount\() > 3
add rbx,2*4 # advance matrix A plus 3 rows by 2 pairs
.endif
add rsi,2*32 # advance matrix B by 64 columns
sub rbp,2 # subtract pairs remaining
jae .LComputeBlockLoop.\ColumnCount\().\RowCount\()
add rsi,2*32 # advance matrix B
sub rbp,2*4
jae .LComputeBlockBy2Loop\@
.LProcessRemainingBlocks.\ColumnCount\().\RowCount\():
add rbp,2 # correct for over-subtract above
jz .LComputeBlockLoopExit.\ColumnCount\().\RowCount\()
.LProcessRemainingBlocks\@:
add rbp,2*4 # correct for over-subtract above
jz .LComputeBlockLoopExit\@
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
add rsi,32 # advance matrix B by 32 columns
add rsi,32 # advance matrix B
.else
.LComputeBlockLoop.\ColumnCount\().\RowCount\():
.LComputeBlockBy1Loop\@:
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
add rdi,4 # advance matrix A by 1 pair
.if \RowCount\() > 3
add rbx,4 # advance matrix A plus 3 rows by 1 pair
.endif
add rsi,32
dec rbp # decrement pairs remaining
jnz .LComputeBlockLoop.\ColumnCount\().\RowCount\()
add rsi,32 # advance matrix B
sub rbp,4
jnz .LComputeBlockBy1Loop\@
.endif
.LComputeBlockLoopExit.\ColumnCount\().\RowCount\():
.if \RowCount\() > 3
lea rbx,[rdx+rax*2] # compute matrix C plus 3 rows
add rbx,rax
.endif
.endm
/*++
Macro Description:
This macro generates code to compute matrix multiplication for a fixed set
of rows.
Arguments:
RowCount - Supplies the number of rows to process.
Fallthrough - Supplies a non-blank value if the macro may fall through to
the ExitKernel label.
Implicit Arguments:
rax - Supplies the length in bytes of a row from matrix C.
rdi - Supplies the address of matrix A.
rsi - Supplies the address of matrix B.
rdx - Supplies the address of matrix C.
r11 - Supplies the address of matrix A.
r9 - Supplies the number of columns from matrix B and matrix C to iterate
over.
rcx - Supplies the number of paired columns from matrix A and the number of
paired rows from matrix B to iterate over.
r10 - Supplies the length in bytes of a row from matrix A.
r12 - Supplies the address of the row sum vector.
r13 - Supplies the address of the column sum vector.
r14b - Supplies the zero mode flag.
--*/
.macro ProcessCountM RowCount, Fallthrough
cmp r9,8
jbe .LProcessRemainingCountN.\RowCount\()
.LProcessNextColumnLoop16xN.\RowCount\():
ProduceOutputBlock 16, \RowCount\()
sub r9,16
jb .LOutputMasked16xNBlock.\RowCount\()
test r14b,r14b # ZeroMode?
jnz .LSkipAccumulateOutput16xNBlock.\RowCount\()
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx+32]"
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax+32]"
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2+32]"
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [rbx]"
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [rbx+32]"
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]"
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax+32]"
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]"
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2+32]"
.LSkipAccumulateOutput16xNBlock.\RowCount\():
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4"
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx+32],ymm5"
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6"
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax+32],ymm7"
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8"
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2+32],ymm9"
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm10"
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx+32],ymm11"
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm12"
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax+32],ymm13"
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm14"
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2+32],ymm15"
add rdx,16*4 # advance matrix C by 16 columns
mov rdi,r11 # reload matrix A
cmp r9,8
ja .LProcessNextColumnLoop16xN.\RowCount\()
test r9,r9
jz .LExitKernel
.LProcessRemainingCountN.\RowCount\():
ProduceOutputBlock 8, \RowCount\()
cmp r9,8
jb .LOutputMasked8xNBlock.\RowCount\()
test r14b,r14b # ZeroMode?
jnz .LSkipAccumulateOutput8xNBlock.\RowCount\()
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [rbx]"
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax]"
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2]"
.LSkipAccumulateOutput8xNBlock.\RowCount\():
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm5"
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm7"
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm9"
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm11"
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm13"
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm15"
jmp .LExitKernel
.LOutputMasked16xNBlock.\RowCount\():
test r14b,r14b # ZeroMode?
jnz .LSkipAccumulateOutputMasked16xNBlock.\RowCount\()
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [rbx]"
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]"
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]"
.LSkipAccumulateOutputMasked16xNBlock.\RowCount\():
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4"
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6"
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8"
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm10"
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm12"
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm14"
add rdx,8*4 # advance matrix C by 8 columns
.if \RowCount\() > 3
add rbx,8*4 # advance matrix C plus 3 rows by 8 columns
.endif
add r9,8 # correct for over-subtract above
.LOutputMasked8xNBlock.\RowCount\():
mov DWORD PTR .LGemmU8U8KernelFrame_mask[rsp],r9d
vpbroadcastd ymm0,DWORD PTR .LGemmU8U8KernelFrame_mask[rsp]
vpcmpgtd ymm0,ymm0,YMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip]
test r14b,r14b # ZeroMode?
jnz .LSkipAccumulateOutputMasked8xNBlock.\RowCount\()
EmitIfCountGE \RowCount\(), 1, "vpmaskmovd ymm4,ymm0,YMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 2, "vpmaskmovd ymm6,ymm0,YMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 3, "vpmaskmovd ymm8,ymm0,YMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 4, "vpmaskmovd ymm10,ymm0,YMMWORD PTR [rbx]"
EmitIfCountGE \RowCount\(), 5, "vpmaskmovd ymm12,ymm0,YMMWORD PTR [rbx+rax]"
EmitIfCountGE \RowCount\(), 6, "vpmaskmovd ymm14,ymm0,YMMWORD PTR [rbx+rax*2]"
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm4"
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm6"
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm8"
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm10"
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm12"
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm14"
.LSkipAccumulateOutputMasked8xNBlock.\RowCount\():
EmitIfCountGE \RowCount\(), 1, "vpmaskmovd YMMWORD PTR [rdx],ymm0,ymm5"
EmitIfCountGE \RowCount\(), 2, "vpmaskmovd YMMWORD PTR [rdx+rax],ymm0,ymm7"
EmitIfCountGE \RowCount\(), 3, "vpmaskmovd YMMWORD PTR [rdx+rax*2],ymm0,ymm9"
EmitIfCountGE \RowCount\(), 4, "vpmaskmovd YMMWORD PTR [rbx],ymm0,ymm11"
EmitIfCountGE \RowCount\(), 5, "vpmaskmovd YMMWORD PTR [rbx+rax],ymm0,ymm13"
EmitIfCountGE \RowCount\(), 6, "vpmaskmovd YMMWORD PTR [rbx+rax*2],ymm0,ymm15"
.ifb \Fallthrough\()
jmp .LExitKernel
.endif
.LComputeBlockLoopExit\@:
.endm
@ -1021,8 +781,8 @@ Arguments:
C (rdx) - Supplies the address of matrix C.
PairedCountK (rcx) - Supplies the number of paired columns from matrix A and
the number of paired rows from matrix B to iterate over.
PairCountK (rcx) - Supplies the number of pair columns from matrix A and
the number of pair rows from matrix B to iterate over.
CountM (r8) - Supplies the maximum number of rows that can be processed for
matrix A and matrix C. The actual number of rows handled for this
@ -1042,8 +802,8 @@ Arguments:
every column of matrix C.
DepthValue - Supplies the value CountK multiplied by the zero point offset
of matrixA multplied by the zero point offset of matrix B. This value is
accumulated into every element of matrix C.
of matrix A multplied by the zero point offset of matrix B. This value
is accumulated into every element of matrix C.
ZeroMode - Supplies true if the output matrix must be zero initialized,
else false if the output matrix is accumulated into.
@ -1061,15 +821,14 @@ C_UNDERSCORE(MlasGemmU8U8KernelAvx2):
push rbx
push r12
push r13
push r14
mov rax,.LGemmU8U8KernelFrame_ldc[rsp]
mov rax,.LGemmU8X8KernelFrame_ldc[rsp]
shl rax,2 # convert ldc to bytes
lea r10,[rcx*4]
shl rcx,2 # convert to row length
movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp]
mov r11,rdi
mov r12,.LGemmU8U8KernelFrame_RowSumVector[rsp]
mov r13,.LGemmU8U8KernelFrame_ColumnSumVector[rsp]
movzx r14,BYTE PTR .LGemmU8U8KernelFrame_ZeroMode[rsp]
mov r12,.LGemmU8X8KernelFrame_RowSumVector[rsp]
mov r13,.LGemmU8X8KernelFrame_ColumnSumVector[rsp]
//
// Process CountM rows of the matrices.
@ -1102,7 +861,6 @@ C_UNDERSCORE(MlasGemmU8U8KernelAvx2):
mov eax,r8d
vzeroupper
pop r14
pop r13
pop r12
pop rbx

View file

@ -28,40 +28,27 @@ Abstract:
Macro Description:
This macro generates code to multiply and accumulator a single row of the
This macro generates code to multiply and accumulator a single cell of the
output block.
Arguments:
ColumnCount - Supplies the number of columns to produce.
AccumReg - Supplies the register to accumulate into.
Vec1Reg - Supplies the high block accumulator register (when ColumnCount
is 32).
Mult1Reg - Supplies the first multiplication operand register.
Vec2Reg - Supplies the low block accumulator register.
Mult2Reg - Supplies the second multiplication operand register.
Implicit Arguments:
zmm28 - Supplies the first vector loaded from matrix B.
zmm29 - Supplies the second vector loaded from matrix B (when ColumnCount
is 32).
zmm30 - Supplies the broadcast value loaded from matrix A.
zmm4 - Supplies a scratch register for intermediate results.
--*/
.macro MultiplyAccumulateRow ColumnCount, Vec1Reg, Vec2Reg
.macro MultiplyAccumulateCell AccumReg, Mult1Reg, Mult2Reg
.if \ColumnCount\() == 32
vpmaddwd zmm31,zmm30,zmm28
vpaddd \Vec1Reg\(),\Vec1Reg\(),zmm31
vpmaddwd zmm30,zmm30,zmm29
vpaddd \Vec2Reg\(),\Vec2Reg\(),zmm30
.else
vpmaddwd zmm31,zmm30,zmm28
vpaddd \Vec2Reg\(),\Vec2Reg\(),zmm31
.endif
vpmaddwd zmm4,\Mult1Reg\(),\Mult2Reg\()
vpaddd \AccumReg\(),\AccumReg\(),zmm4
.endm
@ -78,36 +65,62 @@ Arguments:
RowCount - Supplies the number of rows to produce.
Implicit Arguments:
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
rdi - Supplies the address into the matrix A data.
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
Implicit Arguments:
rbx - Supplies the address into the matrix A data plus 3 rows.
rdi - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
r10 - Supplies the length in bytes of a row from matrix A.
rcx - Supplies the length in bytes of a row from matrix A.
zmm16-zmm27 - Supplies the block accumulators.
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
zmm14-zmm31 - Supplies the block accumulators.
--*/
.macro ComputeBlock ColumnCount, RowCount
.macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset
vpmovzxbw zmm28,YMMWORD PTR [rsi]
EmitIfCountGE \ColumnCount\(), 32, "vpmovzxbw zmm29,YMMWORD PTR [rsi+r10*8]"
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm30,DWORD PTR [rdi]"
EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRow \ColumnCount\(), zmm16, zmm17"
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm30,DWORD PTR [rdi+r10]"
EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRow \ColumnCount\(), zmm18, zmm19"
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm30,DWORD PTR [rdi+r10*2]"
EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRow \ColumnCount\(), zmm20, zmm21"
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm30,DWORD PTR [rbx]"
EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRow \ColumnCount\(), zmm22, zmm23"
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm30,DWORD PTR [rbx+r10]"
EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRow \ColumnCount\(), zmm24, zmm25"
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm30,DWORD PTR [rbx+r10*2]"
EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRow \ColumnCount\(), zmm26, zmm27"
.if \ColumnCount\() >= 48
vpmovzxbw zmm0,YMMWORD PTR [rsi+\VectorOffset\()]
vpmovzxbw zmm1,YMMWORD PTR [rsi+r14+\VectorOffset\()]
vpmovzxbw zmm2,YMMWORD PTR [rsi+r14*2+\VectorOffset\()]
.elseif \ColumnCount\() >= 32
vpmovzxbw zmm1,YMMWORD PTR [rsi+\VectorOffset\()]
vpmovzxbw zmm2,YMMWORD PTR [rsi+r14+\VectorOffset\()]
.else
vpmovzxbw zmm2,YMMWORD PTR [rsi+\VectorOffset\()]
.endif
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm3,DWORD PTR [rdi+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm26,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm20,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm14,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm27,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm21,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm15,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm28,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm22,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm16,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm3,DWORD PTR [rbx+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm29,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm23,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm17,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm30,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm24,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm18,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm31,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm25,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm19,zmm3,zmm2"
.endm
@ -115,6 +128,6 @@ Implicit Arguments:
// Generate the GEMM kernel.
//
GemmU8U8KernelAvx512Function Avx512BW
GemmU8X8KernelAvx512Function U8U8, Avx512BW
.end

View file

@ -16,28 +16,14 @@ Abstract:
--*/
//
// Stack frame layout for the U8U8 kernel.
//
.equ .LGemmU8U8KernelFrame_SavedR14, 0
.equ .LGemmU8U8KernelFrame_SavedR13, 8
.equ .LGemmU8U8KernelFrame_SavedR12, 16
.equ .LGemmU8U8KernelFrame_SavedRbx, 24
.equ .LGemmU8U8KernelFrame_SavedRbp, 32
.equ .LGemmU8U8KernelFrame_ReturnAddress, 40
.equ .LGemmU8U8KernelFrame_ldc, 48
.equ .LGemmU8U8KernelFrame_RowSumVector, 56
.equ .LGemmU8U8KernelFrame_ColumnSumVector, 64
.equ .LGemmU8U8KernelFrame_DepthValue, 72
.equ .LGemmU8U8KernelFrame_ZeroMode, 80
#include "QgemmU8X8KernelAvx512Common.h"
/*++
Macro Description:
Macro Description:
This macro generates code to produce an output block for a set of columns
and rows.
This macro generates code to execute the block compute macro multiple
times and advancing the matrix A and matrix B data pointers.
Arguments:
@ -47,315 +33,32 @@ Arguments:
Implicit Arguments:
rax - Supplies the length in bytes of a row from matrix C.
rbx - Supplies the address into the matrix A data plus 3 rows.
rdi - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the number of paired columns from matrix A and the number of
paired rows from matrix B to iterate over.
rcx - Supplies the length in bytes of a row from matrix A.
r10 - Supplies the length in bytes of a row from matrix A.
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
r12 - Supplies the address of the row sum vector.
r13 - Supplies the address of the column sum vector.
zmm14-zmm31 - Supplies the block accumulators.
--*/
.macro ProduceOutputBlock ColumnCount, RowCount
.macro ComputeBlockLoop ColumnCount, RowCount
//
// Initialize the accumulators with the sum of the global depth value constant,
// the column sums, and the row sums.
//
mov rbp,rcx # reload row length remaining
vpbroadcastd zmm31,DWORD PTR .LGemmU8U8KernelFrame_DepthValue[rsp]
.if \ColumnCount\() == 32
vpaddd zmm30,zmm31,ZMMWORD PTR [r13]
vpaddd zmm31,zmm31,ZMMWORD PTR [r13+64]
add r13,32*4 # advance ColumnSumVector by 32 columns
.else
vpaddd zmm31,zmm31,ZMMWORD PTR [r13]
.endif
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd zmm16,zmm30,DWORD PTR [r12]{1to16}"
EmitIfCountGE \RowCount\(), 1, "vpaddd zmm17,zmm31,DWORD PTR [r12]{1to16}"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "vpaddd zmm18,zmm30,DWORD PTR [r12+4]{1to16}"
EmitIfCountGE \RowCount\(), 2, "vpaddd zmm19,zmm31,DWORD PTR [r12+4]{1to16}"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "vpaddd zmm20,zmm30,DWORD PTR [r12+8]{1to16}"
EmitIfCountGE \RowCount\(), 3, "vpaddd zmm21,zmm31,DWORD PTR [r12+8]{1to16}"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "vpaddd zmm22,zmm30,DWORD PTR [r12+12]{1to16}"
EmitIfCountGE \RowCount\(), 4, "vpaddd zmm23,zmm31,DWORD PTR [r12+12]{1to16}"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "vpaddd zmm24,zmm30,DWORD PTR [r12+16]{1to16}"
EmitIfCountGE \RowCount\(), 5, "vpaddd zmm25,zmm31,DWORD PTR [r12+16]{1to16}"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "vpaddd zmm26,zmm30,DWORD PTR [r12+20]{1to16}"
EmitIfCountGE \RowCount\(), 6, "vpaddd zmm27,zmm31,DWORD PTR [r12+20]{1to16}"
//
// Iterate over PairedCountK elements from matrix A and matrix B.
//
mov rbp,rcx # reload PairedCountK
.if \RowCount\() > 3
lea rbx,[r10*2+r10]
add rbx,rdi # compute matrix A plus 3 rows
.endif
.LComputeBlockLoop.\ColumnCount\().\RowCount\():
ComputeBlock \ColumnCount\(), \RowCount\()
.LComputeBlockBy1Loop\@:
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
add rdi,4 # advance matrix A by 1 pair
.if \RowCount\() > 3
add rbx,4 # advance matrix A plus 3 rows by 1 pair
.endif
add rsi,32
dec rbp # decrement pairs remaining
jnz .LComputeBlockLoop.\ColumnCount\().\RowCount\()
.if \RowCount\() > 3
lea rbx,[rdx+rax*2] # compute matrix C plus 3 rows
add rbx,rax
.endif
.endm
/*++
Macro Description:
This macro generates code to compute matrix multiplication for a fixed set
of rows.
Arguments:
RowCount - Supplies the number of rows to process.
Implicit Arguments:
rax - Supplies the length in bytes of a row from matrix C.
rdi - Supplies the address of matrix A.
rsi - Supplies the address of matrix B.
rdx - Supplies the address of matrix C.
r11 - Supplies the address of matrix A.
r9 - Supplies the number of columns from matrix B and matrix C to iterate
over.
rcx - Supplies the number of paired columns from matrix A and the number of
paired rows from matrix B to iterate over.
r10 - Supplies the length in bytes of a row from matrix A.
r12 - Supplies the address of the row sum vector.
r13 - Supplies the address of the column sum vector.
r14b - Supplies the zero mode flag.
--*/
.macro ProcessCountM RowCount
cmp r9,16
jbe .LProcessRemainingCountN.\RowCount\()
.LProcessNextColumnLoop32xN.\RowCount\():
ProduceOutputBlock 32, \RowCount\()
lea rsi,[rsi+r10*8] # advance matrix B by 8*PairedCountK
test r14b,r14b # ZeroMode?
jnz .LSkipAccumulateOutput32xNBlock.\RowCount\()
EmitIfCountGE \RowCount\(), 1, "vpaddd zmm16,zmm16,ZMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 2, "vpaddd zmm18,zmm18,ZMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 3, "vpaddd zmm20,zmm20,ZMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 4, "vpaddd zmm22,zmm22,ZMMWORD PTR [rbx]"
EmitIfCountGE \RowCount\(), 5, "vpaddd zmm24,zmm24,ZMMWORD PTR [rbx+rax]"
EmitIfCountGE \RowCount\(), 6, "vpaddd zmm26,zmm26,ZMMWORD PTR [rbx+rax*2]"
.LSkipAccumulateOutput32xNBlock.\RowCount\():
EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx],zmm16"
EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax],zmm18"
EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2],zmm20"
EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx],zmm22"
EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax],zmm24"
EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm26"
add rdx,16*4 # advance matrix C by 16 columns
.if \RowCount\() > 3
add rbx,16*4 # advance matrix C plus 3 rows by 16 columns
.endif
sub r9,16
.LOutput16xNBlock.\RowCount\():
sub r9,16
jae .LOutput16xNBlockWithMask.\RowCount\()
lea rcx,[r9+16] # correct for over-subtract above
mov ebp,1
shl ebp,cl
dec ebp
kmovw k1,ebp # update mask for remaining columns
xor r9,r9 # no more columns remaining
.LOutput16xNBlockWithMask.\RowCount\():
test r14b,r14b # ZeroMode?
jnz .LSkipAccumulateOutput16xNBlockWithMask.\RowCount\()
EmitIfCountGE \RowCount\(), 1, "vpaddd zmm17{k1},zmm17,ZMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 2, "vpaddd zmm19{k1},zmm19,ZMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 3, "vpaddd zmm21{k1},zmm21,ZMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 4, "vpaddd zmm23{k1},zmm23,ZMMWORD PTR [rbx]"
EmitIfCountGE \RowCount\(), 5, "vpaddd zmm25{k1},zmm25,ZMMWORD PTR [rbx+rax]"
EmitIfCountGE \RowCount\(), 6, "vpaddd zmm27{k1},zmm27,ZMMWORD PTR [rbx+rax*2]"
.LSkipAccumulateOutput16xNBlockWithMask.\RowCount\():
EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx]{k1},zmm17"
EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax]{k1},zmm19"
EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2]{k1},zmm21"
EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx]{k1},zmm23"
EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax]{k1},zmm25"
EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2]{k1},zmm27"
add rdx,16*4 # advance matrix C by 16 columns
mov rdi,r11 # reload matrix A
cmp r9,16
ja .LProcessNextColumnLoop32xN.\RowCount\()
test r9,r9
jz .LExitKernel
.LProcessRemainingCountN.\RowCount\():
ProduceOutputBlock 16, \RowCount\()
jmp .LOutput16xNBlock.\RowCount\()
.endm
/*++
Macro Description:
This macro generates the common AVX512 code for the inner kernel to compute
matrix multiplication.
Arguments:
Isa - Supplies the instruction set architecture string for function tags.
--*/
.macro GemmU8U8KernelAvx512Function Isa
/*++
Routine Description:
This routine is an inner kernel to compute matrix multiplication for a
set of rows.
Arguments:
A (rdi) - Supplies the address of matrix A. The matrix data has been packed
using MlasGemmU8U8CopyPackAAvx2.
B (rsi) - Supplies the address of matrix B. The matrix data has been packed
using MlasGemmU8U8CopyPackBAvx2.
C (rdx) - Supplies the address of matrix C.
PairedCountK (rcx) - Supplies the number of paired columns from matrix A and
the number of paired rows from matrix B to iterate over.
CountM (r8) - Supplies the maximum number of rows that can be processed for
matrix A and matrix C. The actual number of rows handled for this
invocation depends on the kernel implementation.
CountN (r9) - Supplies the number of columns from matrix B and matrix C to
iterate over.
ldc - Supplies the first dimension of matrix C.
RowSumVector - Supplies the sum of each row from matrix A multiplied by the
zero point offset of matrix B. These values are accumulated into every
row of matrix C.
ColumnSumVector - Supplies the sum of each column from matrix B multiplied
by the zero point offset of matrix A. These values are accumulated into
every column of matrix C.
DepthValue - Supplies the value CountK multiplied by the zero point offset
of matrixA multplied by the zero point offset of matrix B. This value is
accumulated into every element of matrix C.
ZeroMode - Supplies true if the output matrix must be zero initialized,
else false if the output matrix is accumulated into.
Return Value:
Returns the number of rows handled.
--*/
.globl C_UNDERSCORE(MlasGemmU8U8Kernel\Isa\())
C_UNDERSCORE(MlasGemmU8U8Kernel\Isa\()):
push rbp
push rbx
push r12
push r13
push r14
mov rax,.LGemmU8U8KernelFrame_ldc[rsp]
shl rax,2 # convert ldc to bytes
lea r10,[rcx*4]
mov r11,rdi
mov r12,.LGemmU8U8KernelFrame_RowSumVector[rsp]
mov r13,.LGemmU8U8KernelFrame_ColumnSumVector[rsp]
movzx r14,BYTE PTR .LGemmU8U8KernelFrame_ZeroMode[rsp]
mov ebp,-1
kmovw k1,ebp # update mask to write all columns
//
// Process CountM rows of the matrices.
//
cmp r8,5
ja .LProcessCountM6
je .LProcessCountM5
cmp r8,3
ja .LProcessCountM4
je .LProcessCountM3
cmp r8,1
je .LProcessCountM1
.LProcessCountM2:
ProcessCountM 2
.LProcessCountM4:
ProcessCountM 4
.LProcessCountM6:
mov r8d,6 # return 6 rows handled
ProcessCountM 6
//
// Restore non-volatile registers and return.
//
.LExitKernel:
mov eax,r8d
pop r14
pop r13
pop r12
pop rbx
pop rbp
ret
.LProcessCountM1:
ProcessCountM 1
.LProcessCountM3:
ProcessCountM 3
.LProcessCountM5:
ProcessCountM 5
add rsi,32 # advance matrix B
sub rbp,4
jnz .LComputeBlockBy1Loop\@
.endm

View file

@ -38,50 +38,69 @@ Arguments:
RowCount - Supplies the number of rows to produce.
Implicit Arguments:
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
rdi - Supplies the address into the matrix A data.
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
Implicit Arguments:
rbx - Supplies the address into the matrix A data plus 3 rows.
rdi - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
r10 - Supplies the length in bytes of a row from matrix A.
rcx - Supplies the length in bytes of a row from matrix A.
zmm16-zmm27 - Supplies the block accumulators.
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
zmm14-zmm31 - Supplies the block accumulators.
--*/
.macro ComputeBlock ColumnCount, RowCount
.macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset
vpmovzxbw zmm28,YMMWORD PTR [rsi]
.if \ColumnCount\() == 32
vpmovzxbw zmm29,YMMWORD PTR [rsi+r10*8]
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm30,DWORD PTR [rdi]"
EmitIfCountGE \RowCount\(), 1, "VpdpwssdZmmZmmZmm zmm16,zmm28,zmm30"
EmitIfCountGE \RowCount\(), 1, "VpdpwssdZmmZmmZmm zmm17,zmm29,zmm30"
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm30,DWORD PTR [rdi+r10]"
EmitIfCountGE \RowCount\(), 2, "VpdpwssdZmmZmmZmm zmm18,zmm28,zmm30"
EmitIfCountGE \RowCount\(), 2, "VpdpwssdZmmZmmZmm zmm19,zmm29,zmm30"
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm30,DWORD PTR [rdi+r10*2]"
EmitIfCountGE \RowCount\(), 3, "VpdpwssdZmmZmmZmm zmm20,zmm28,zmm30"
EmitIfCountGE \RowCount\(), 3, "VpdpwssdZmmZmmZmm zmm21,zmm29,zmm30"
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm30,DWORD PTR [rbx]"
EmitIfCountGE \RowCount\(), 4, "VpdpwssdZmmZmmZmm zmm22,zmm28,zmm30"
EmitIfCountGE \RowCount\(), 4, "VpdpwssdZmmZmmZmm zmm23,zmm29,zmm30"
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm30,DWORD PTR [rbx+r10]"
EmitIfCountGE \RowCount\(), 5, "VpdpwssdZmmZmmZmm zmm24,zmm28,zmm30"
EmitIfCountGE \RowCount\(), 5, "VpdpwssdZmmZmmZmm zmm25,zmm29,zmm30"
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm30,DWORD PTR [rbx+r10*2]"
EmitIfCountGE \RowCount\(), 6, "VpdpwssdZmmZmmZmm zmm26,zmm28,zmm30"
EmitIfCountGE \RowCount\(), 6, "VpdpwssdZmmZmmZmm zmm27,zmm29,zmm30"
.if \ColumnCount\() >= 32
.if \ColumnCount\() >= 48
vpmovzxbw zmm0,YMMWORD PTR [rsi+\VectorOffset\()]
vpmovzxbw zmm1,YMMWORD PTR [rsi+r14+\VectorOffset\()]
vpmovzxbw zmm2,YMMWORD PTR [rsi+r14*2+\VectorOffset\()]
.else
EmitIfCountGE \RowCount\(), 1, "VpdpwssdZmmZmmBroadcast zmm17,zmm28,rdi"
EmitIfCountGE \RowCount\(), 2, "VpdpwssdZmmZmmBroadcast zmm19,zmm28,rdi,r10,1"
EmitIfCountGE \RowCount\(), 3, "VpdpwssdZmmZmmBroadcast zmm21,zmm28,rdi,r10,2"
EmitIfCountGE \RowCount\(), 4, "VpdpwssdZmmZmmBroadcast zmm23,zmm28,rbx"
EmitIfCountGE \RowCount\(), 5, "VpdpwssdZmmZmmBroadcast zmm25,zmm28,rbx,r10,1"
EmitIfCountGE \RowCount\(), 6, "VpdpwssdZmmZmmBroadcast zmm27,zmm28,rbx,r10,2"
vpmovzxbw zmm1,YMMWORD PTR [rsi+\VectorOffset\()]
vpmovzxbw zmm2,YMMWORD PTR [rsi+r14+\VectorOffset\()]
.endif
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm3,DWORD PTR [rdi+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "VpdpwssdZmmZmmZmm zmm26,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "VpdpwssdZmmZmmZmm zmm20,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "VpdpwssdZmmZmmZmm zmm14,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "VpdpwssdZmmZmmZmm zmm27,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "VpdpwssdZmmZmmZmm zmm21,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "VpdpwssdZmmZmmZmm zmm15,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "VpdpwssdZmmZmmZmm zmm28,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "VpdpwssdZmmZmmZmm zmm22,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "VpdpwssdZmmZmmZmm zmm16,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm3,DWORD PTR [rbx+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "VpdpwssdZmmZmmZmm zmm29,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "VpdpwssdZmmZmmZmm zmm23,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "VpdpwssdZmmZmmZmm zmm17,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "VpdpwssdZmmZmmZmm zmm30,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "VpdpwssdZmmZmmZmm zmm24,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "VpdpwssdZmmZmmZmm zmm18,zmm3,zmm2"
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "VpdpwssdZmmZmmZmm zmm31,zmm3,zmm0"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "VpdpwssdZmmZmmZmm zmm25,zmm3,zmm1"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "VpdpwssdZmmZmmZmm zmm19,zmm3,zmm2"
.else
vpmovzxbw zmm2,YMMWORD PTR [rsi+\VectorOffset\()]
EmitIfCountGE \RowCount\(), 1, "VpdpwssdZmmZmmBroadcast zmm14,zmm2,rdi,\BroadcastOffset\()"
EmitIfCountGE \RowCount\(), 2, "VpdpwssdZmmZmmBroadcast zmm15,zmm2,rdi,\BroadcastOffset\(),rcx,1"
EmitIfCountGE \RowCount\(), 3, "VpdpwssdZmmZmmBroadcast zmm16,zmm2,rdi,\BroadcastOffset\(),rcx,2"
EmitIfCountGE \RowCount\(), 4, "VpdpwssdZmmZmmBroadcast zmm17,zmm2,rbx,\BroadcastOffset\()"
EmitIfCountGE \RowCount\(), 5, "VpdpwssdZmmZmmBroadcast zmm18,zmm2,rbx,\BroadcastOffset\(),rcx,1"
EmitIfCountGE \RowCount\(), 6, "VpdpwssdZmmZmmBroadcast zmm19,zmm2,rbx,\BroadcastOffset\(),rcx,2"
.endif
.endm
@ -90,6 +109,6 @@ Implicit Arguments:
// Generate the GEMM kernel.
//
GemmU8U8KernelAvx512Function Avx512Vnni
GemmU8X8KernelAvx512Function U8U8, Avx512Vnni
.end

View file

@ -0,0 +1,273 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
QgemmU8X8KernelAvx2Common.h
Abstract:
This module contains common kernel macros and structures for the quantized
integer matrix/matrix multiply operation (QGEMM) for the AVX2 kernels.
--*/
//
// Stack frame layout for the U8S8 and U8U8 kernels.
//
.equ .LGemmU8X8KernelFrame_mask, -8
.equ .LGemmU8X8KernelFrame_SavedR13, 0
.equ .LGemmU8X8KernelFrame_SavedR12, 8
.equ .LGemmU8X8KernelFrame_SavedRbx, 16
.equ .LGemmU8X8KernelFrame_SavedRbp, 24
.equ .LGemmU8X8KernelFrame_ReturnAddress, 32
.equ .LGemmU8X8KernelFrame_ldc, 40
.equ .LGemmU8X8KernelFrame_RowSumVector, 48
.equ .LGemmU8X8KernelFrame_ColumnSumVector, 56
.equ .LGemmU8X8KernelFrame_DepthValue, 64
.equ .LGemmU8X8KernelFrame_ZeroMode, 72
/*++
Macro Description:
This macro generates code to produce an output block for a set of columns
and rows.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
Implicit Arguments:
rax - Supplies the length in bytes of a row from matrix C.
rdi - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
r12 - Supplies the address of the row sum vector.
r13 - Supplies the address of the column sum vector.
ymm4-ymm15 - Supplies the block accumulators.
--*/
.macro ProduceOutputBlock ColumnCount, RowCount
//
// Initialize the accumulators with the sum of the global depth value constant,
// the column sums, and the row sums.
//
vpbroadcastd ymm1,DWORD PTR .LGemmU8X8KernelFrame_DepthValue[rsp]
.if \ColumnCount\() == 16
vpaddd ymm0,ymm1,YMMWORD PTR [r13]
vpaddd ymm1,ymm1,YMMWORD PTR [r13+32]
add r13,16*4 # advance ColumnSumVector by 16 columns
.else
vpaddd ymm1,ymm1,YMMWORD PTR [r13]
.endif
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm5,DWORD PTR [r12]"
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm7,DWORD PTR [r12+4]"
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm9,DWORD PTR [r12+8]"
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm11,DWORD PTR [r12+12]"
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm13,DWORD PTR [r12+16]"
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm15,DWORD PTR [r12+20]"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd ymm4,ymm5,ymm0"
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm1"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd ymm6,ymm7,ymm0"
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm1"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd ymm8,ymm9,ymm0"
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm1"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd ymm10,ymm11,ymm0"
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm1"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd ymm12,ymm13,ymm0"
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm1"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd ymm14,ymm15,ymm0"
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm1"
//
// Iterate over the length of a matrix A row to produce the output accumulators.
//
.if \RowCount\() > 3
lea rbx,[rcx*2+rcx]
add rbx,rdi # compute matrix A plus 3 rows
.endif
ComputeBlockLoop \ColumnCount\(), \RowCount\()
.if \RowCount\() > 3
lea rbx,[rdx+rax*2] # compute matrix C plus 3 rows
add rbx,rax
.endif
.endm
/*++
Macro Description:
This macro generates code to compute matrix multiplication for a fixed set
of rows.
Arguments:
RowCount - Supplies the number of rows to process.
Fallthrough - Supplies a non-blank value if the macro may fall through to
the ExitKernel label.
Implicit Arguments:
rax - Supplies the length in bytes of a row from matrix C.
rdi - Supplies the address of matrix A.
rsi - Supplies the address of matrix B.
rdx - Supplies the address of matrix C.
r11 - Supplies the address of matrix A.
r9 - Supplies the number of columns from matrix B and matrix C to iterate
over.
rcx - Supplies the length in bytes of a row from matrix A.
r10b - Supplies the zero mode flag.
r12 - Supplies the address of the row sum vector.
r13 - Supplies the address of the column sum vector.
--*/
.macro ProcessCountM RowCount, Fallthrough
cmp r9,8
jbe .LProcessRemainingCountN\@
.LProcessNextColumnLoop16xN\@:
ProduceOutputBlock 16, \RowCount\()
sub r9,16
jb .LOutputMasked16xNBlock\@
test r10b,r10b # ZeroMode?
jnz .LSkipAccumulateOutput16xNBlock\@
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx+32]"
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax+32]"
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2+32]"
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [rbx]"
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [rbx+32]"
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]"
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax+32]"
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]"
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2+32]"
.LSkipAccumulateOutput16xNBlock\@:
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4"
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx+32],ymm5"
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6"
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax+32],ymm7"
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8"
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2+32],ymm9"
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm10"
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx+32],ymm11"
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm12"
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax+32],ymm13"
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm14"
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2+32],ymm15"
add rdx,16*4 # advance matrix C by 16 columns
mov rdi,r11 # reload matrix A
cmp r9,8
ja .LProcessNextColumnLoop16xN\@
test r9,r9
jz .LExitKernel
.LProcessRemainingCountN\@:
ProduceOutputBlock 8, \RowCount\()
cmp r9,8
jb .LOutputMasked8xNBlock\@
test r10b,r10b # ZeroMode?
jnz .LSkipAccumulateOutput8xNBlock\@
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [rbx]"
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax]"
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2]"
.LSkipAccumulateOutput8xNBlock\@:
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm5"
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm7"
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm9"
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm11"
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm13"
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm15"
jmp .LExitKernel
.LOutputMasked16xNBlock\@:
test r10b,r10b # ZeroMode?
jnz .LSkipAccumulateOutputMasked16xNBlock\@
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [rbx]"
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]"
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]"
.LSkipAccumulateOutputMasked16xNBlock\@:
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4"
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6"
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8"
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm10"
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm12"
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm14"
add rdx,8*4 # advance matrix C by 8 columns
.if \RowCount\() > 3
add rbx,8*4 # advance matrix C plus 3 rows by 8 columns
.endif
add r9,8 # correct for over-subtract above
.LOutputMasked8xNBlock\@:
mov DWORD PTR .LGemmU8X8KernelFrame_mask[rsp],r9d
vpbroadcastd ymm0,DWORD PTR .LGemmU8X8KernelFrame_mask[rsp]
vpcmpgtd ymm0,ymm0,YMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip]
test r10b,r10b # ZeroMode?
jnz .LSkipAccumulateOutputMasked8xNBlock\@
EmitIfCountGE \RowCount\(), 1, "vpmaskmovd ymm4,ymm0,YMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 2, "vpmaskmovd ymm6,ymm0,YMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 3, "vpmaskmovd ymm8,ymm0,YMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 4, "vpmaskmovd ymm10,ymm0,YMMWORD PTR [rbx]"
EmitIfCountGE \RowCount\(), 5, "vpmaskmovd ymm12,ymm0,YMMWORD PTR [rbx+rax]"
EmitIfCountGE \RowCount\(), 6, "vpmaskmovd ymm14,ymm0,YMMWORD PTR [rbx+rax*2]"
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm4"
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm6"
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm8"
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm10"
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm12"
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm14"
.LSkipAccumulateOutputMasked8xNBlock\@:
EmitIfCountGE \RowCount\(), 1, "vpmaskmovd YMMWORD PTR [rdx],ymm0,ymm5"
EmitIfCountGE \RowCount\(), 2, "vpmaskmovd YMMWORD PTR [rdx+rax],ymm0,ymm7"
EmitIfCountGE \RowCount\(), 3, "vpmaskmovd YMMWORD PTR [rdx+rax*2],ymm0,ymm9"
EmitIfCountGE \RowCount\(), 4, "vpmaskmovd YMMWORD PTR [rbx],ymm0,ymm11"
EmitIfCountGE \RowCount\(), 5, "vpmaskmovd YMMWORD PTR [rbx+rax],ymm0,ymm13"
EmitIfCountGE \RowCount\(), 6, "vpmaskmovd YMMWORD PTR [rbx+rax*2],ymm0,ymm15"
.ifb \Fallthrough\()
jmp .LExitKernel
.endif
.endm

View file

@ -0,0 +1,403 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
QgemmU8X8KernelAvx512Common.h
Abstract:
This module contains common kernel macros and structures for the quantized
integer matrix/matrix multiply operation (QGEMM) for the AVX512BW and
AVX512VNNI kernels.
--*/
//
// Stack frame layout for the U8S8 and U8U8 kernels.
//
.equ .LGemmU8X8KernelFrame_SavedR14, 0
.equ .LGemmU8X8KernelFrame_SavedR13, 8
.equ .LGemmU8X8KernelFrame_SavedR12, 16
.equ .LGemmU8X8KernelFrame_SavedRbx, 24
.equ .LGemmU8X8KernelFrame_SavedRbp, 32
.equ .LGemmU8X8KernelFrame_ReturnAddress, 40
.equ .LGemmU8X8KernelFrame_ldc, 48
.equ .LGemmU8X8KernelFrame_RowSumVector, 56
.equ .LGemmU8X8KernelFrame_ColumnSumVector, 64
.equ .LGemmU8X8KernelFrame_DepthValue, 72
.equ .LGemmU8X8KernelFrame_ZeroMode, 80
/*++
Macro Description:
This macro generates code to produce an output block for a set of columns
and rows.
Arguments:
ColumnCount - Supplies the number of columns to produce.
RowCount - Supplies the number of rows to produce.
Implicit Arguments:
rax - Supplies the length in bytes of a row from matrix C.
rdi - Supplies the address into the matrix A data.
rsi - Supplies the address into the matrix B data.
rcx - Supplies the length in bytes of a row from matrix A.
r12 - Supplies the address of the row sum vector.
r13 - Supplies the address of the column sum vector.
--*/
.macro ProduceOutputBlock ColumnCount, RowCount
//
// Initialize the accumulators with the sum of the global depth value constant,
// the column sums, and the row sums.
//
vpbroadcastd zmm3,DWORD PTR .LGemmU8X8KernelFrame_DepthValue[rsp]
.if \ColumnCount\() >= 32
.if \ColumnCount\() >= 48
vpaddd zmm2,zmm3,ZMMWORD PTR [r13]
vpaddd zmm1,zmm3,ZMMWORD PTR [r13+64]
vpaddd zmm0,zmm3,ZMMWORD PTR [r13+128]
.else
vpaddd zmm1,zmm3,ZMMWORD PTR [r13]
vpaddd zmm0,zmm3,ZMMWORD PTR [r13+64]
.endif
add_immed r13,\ColumnCount\()*4 # advance ColumnSumVector by N columns
.else
vpaddd zmm0,zmm3,ZMMWORD PTR [r13]
.endif
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd zmm14,zmm0,DWORD PTR [r12]{1to16}"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd zmm20,zmm1,DWORD PTR [r12]{1to16}"
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "vpaddd zmm26,zmm2,DWORD PTR [r12]{1to16}"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd zmm15,zmm0,DWORD PTR [r12+4]{1to16}"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "vpaddd zmm21,zmm1,DWORD PTR [r12+4]{1to16}"
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "vpaddd zmm27,zmm2,DWORD PTR [r12+4]{1to16}"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd zmm16,zmm0,DWORD PTR [r12+8]{1to16}"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "vpaddd zmm22,zmm1,DWORD PTR [r12+8]{1to16}"
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "vpaddd zmm28,zmm2,DWORD PTR [r12+8]{1to16}"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd zmm17,zmm0,DWORD PTR [r12+12]{1to16}"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "vpaddd zmm23,zmm1,DWORD PTR [r12+12]{1to16}"
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "vpaddd zmm29,zmm2,DWORD PTR [r12+12]{1to16}"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd zmm18,zmm0,DWORD PTR [r12+16]{1to16}"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "vpaddd zmm24,zmm1,DWORD PTR [r12+16]{1to16}"
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "vpaddd zmm30,zmm2,DWORD PTR [r12+16]{1to16}"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd zmm19,zmm0,DWORD PTR [r12+20]{1to16}"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "vpaddd zmm25,zmm1,DWORD PTR [r12+20]{1to16}"
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "vpaddd zmm31,zmm2,DWORD PTR [r12+20]{1to16}"
//
// Iterate over the length of a matrix A row to produce the output accumulators.
//
.if \RowCount\() > 3
lea rbx,[rcx*2+rcx]
add rbx,rdi # compute matrix A plus 3 rows
.endif
ComputeBlockLoop \ColumnCount\(), \RowCount\()
.if \RowCount\() > 3
lea rbx,[rdx+rax*2] # compute matrix C plus 3 rows
add rbx,rax
.endif
.endm
/*++
Macro Description:
This macro generates code to compute matrix multiplication for a fixed set
of rows.
Arguments:
RowCount - Supplies the number of rows to process.
Implicit Arguments:
rax - Supplies the length in bytes of a row from matrix C.
rdi - Supplies the address of matrix A.
rsi - Supplies the address of matrix B.
rdx - Supplies the address of matrix C.
r11 - Supplies the address of matrix A.
r9 - Supplies the number of columns from matrix B and matrix C to iterate
over.
rcx - Supplies the length in bytes of a row from matrix A.
r10b - Supplies the zero mode flag.
r12 - Supplies the address of the row sum vector.
r13 - Supplies the address of the column sum vector.
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
--*/
.macro ProcessCountM RowCount
cmp r9,32
ja .LProcessNextColumnLoop48xN\@
cmp r9,16
jbe .LProcessRemainingCountN\@
.LProcessNextColumnLoop32xN\@:
ProduceOutputBlock 32, \RowCount\()
add rsi,r14 # advance matrix B by packed block stride
.LOutput32xNBlock\@:
test r10b,r10b # ZeroMode?
jnz .LSkipAccumulateOutput32xNBlock\@
EmitIfCountGE \RowCount\(), 1, "vpaddd zmm20,zmm20,ZMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 2, "vpaddd zmm21,zmm21,ZMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 3, "vpaddd zmm22,zmm22,ZMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 4, "vpaddd zmm23,zmm23,ZMMWORD PTR [rbx]"
EmitIfCountGE \RowCount\(), 5, "vpaddd zmm24,zmm24,ZMMWORD PTR [rbx+rax]"
EmitIfCountGE \RowCount\(), 6, "vpaddd zmm25,zmm25,ZMMWORD PTR [rbx+rax*2]"
.LSkipAccumulateOutput32xNBlock\@:
EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx],zmm20"
EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax],zmm21"
EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2],zmm22"
EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx],zmm23"
EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax],zmm24"
EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm25"
add rdx,16*4 # advance matrix C by 16 columns
.if \RowCount\() > 3
add rbx,16*4 # advance matrix C plus 3 rows by 16 columns
.endif
sub r9,16
.LOutput16xNBlock\@:
sub r9,16
jae .LOutput16xNBlockWithMask\@
lea rcx,[r9+16] # correct for over-subtract above
mov ebp,1
shl ebp,cl
dec ebp
kmovw k1,ebp # update mask for remaining columns
xor r9,r9 # no more columns remaining
.LOutput16xNBlockWithMask\@:
test r10b,r10b # ZeroMode?
jnz .LSkipAccumulateOutput16xNBlockWithMask\@
EmitIfCountGE \RowCount\(), 1, "vpaddd zmm14{k1},zmm14,ZMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 2, "vpaddd zmm15{k1},zmm15,ZMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 3, "vpaddd zmm16{k1},zmm16,ZMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 4, "vpaddd zmm17{k1},zmm17,ZMMWORD PTR [rbx]"
EmitIfCountGE \RowCount\(), 5, "vpaddd zmm18{k1},zmm18,ZMMWORD PTR [rbx+rax]"
EmitIfCountGE \RowCount\(), 6, "vpaddd zmm19{k1},zmm19,ZMMWORD PTR [rbx+rax*2]"
.LSkipAccumulateOutput16xNBlockWithMask\@:
EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx]{k1},zmm14"
EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax]{k1},zmm15"
EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2]{k1},zmm16"
EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx]{k1},zmm17"
EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax]{k1},zmm18"
EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2]{k1},zmm19"
add rdx,16*4 # advance matrix C by 16 columns
mov rdi,r11 # reload matrix A
cmp r9,32
ja .LProcessNextColumnLoop48xN\@
cmp r9,16
ja .LProcessNextColumnLoop32xN\@
test r9,r9
jz .LExitKernel
.LProcessRemainingCountN\@:
ProduceOutputBlock 16, \RowCount\()
jmp .LOutput16xNBlock\@
.LProcessNextColumnLoop48xN\@:
ProduceOutputBlock 48, \RowCount\()
lea rsi,[rsi+r14*2] # advance matrix B by packed block stride
test r10b,r10b # ZeroMode?
jnz .LSkipAccumulateOutput48xNBlock\@
EmitIfCountGE \RowCount\(), 1, "vpaddd zmm26,zmm26,ZMMWORD PTR [rdx]"
EmitIfCountGE \RowCount\(), 2, "vpaddd zmm27,zmm27,ZMMWORD PTR [rdx+rax]"
EmitIfCountGE \RowCount\(), 3, "vpaddd zmm28,zmm28,ZMMWORD PTR [rdx+rax*2]"
EmitIfCountGE \RowCount\(), 4, "vpaddd zmm29,zmm29,ZMMWORD PTR [rbx]"
EmitIfCountGE \RowCount\(), 5, "vpaddd zmm30,zmm30,ZMMWORD PTR [rbx+rax]"
EmitIfCountGE \RowCount\(), 6, "vpaddd zmm31,zmm31,ZMMWORD PTR [rbx+rax*2]"
.LSkipAccumulateOutput48xNBlock\@:
EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx],zmm26"
EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax],zmm27"
EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2],zmm28"
EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx],zmm29"
EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax],zmm30"
EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm31"
add rdx,16*4 # advance matrix C by 16 columns
.if \RowCount\() > 3
add rbx,16*4 # advance matrix C plus 3 rows by 16 columns
.endif
sub r9,16
jmp .LOutput32xNBlock\@
.endm
/*++
Macro Description:
This macro generates the common AVX512 code for the inner kernel to compute
matrix multiplication.
Arguments:
Type - Supplies the kernel type string for function tags.
Isa - Supplies the instruction set architecture string for function tags.
--*/
.macro GemmU8X8KernelAvx512Function Type, Isa
/*++
Routine Description:
This routine is an inner kernel to compute matrix multiplication for a
set of rows.
Arguments:
A (rdi) - Supplies the address of matrix A. The matrix data has been packed
using MlasGemmU8X8CopyPackAAvx2.
B (rsi) - Supplies the address of matrix B. The matrix data has been packed
using MlasGemmU8X8CopyPackBAvx2.
C (rdx) - Supplies the address of matrix C.
PairedCountK (rcx) - Supplies the number of paired columns from matrix A and
the number of paired rows from matrix B to iterate over.
CountM (r8) - Supplies the maximum number of rows that can be processed for
matrix A and matrix C. The actual number of rows handled for this
invocation depends on the kernel implementation.
CountN (r9) - Supplies the number of columns from matrix B and matrix C to
iterate over.
ldc - Supplies the first dimension of matrix C.
RowSumVector - Supplies the sum of each row from matrix A multiplied by the
zero point offset of matrix B. These values are accumulated into every
row of matrix C.
ColumnSumVector - Supplies the sum of each column from matrix B multiplied
by the zero point offset of matrix A. These values are accumulated into
every column of matrix C.
DepthValue - Supplies the value CountK multiplied by the zero point offset
of matrixA multplied by the zero point offset of matrix B. This value is
accumulated into every element of matrix C.
ZeroMode - Supplies true if the output matrix must be zero initialized,
else false if the output matrix is accumulated into.
Return Value:
Returns the number of rows handled.
--*/
.globl C_UNDERSCORE(MlasGemm\Type\()Kernel\Isa\())
C_UNDERSCORE(MlasGemm\Type\()Kernel\Isa\()):
push rbp
push rbx
push r12
push r13
push r14
mov rax,.LGemmU8X8KernelFrame_ldc[rsp]
shl rax,2 # convert ldc to bytes
shl rcx,2 # convert to row length
movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp]
mov r11,rdi
mov r12,.LGemmU8X8KernelFrame_RowSumVector[rsp]
mov r13,.LGemmU8X8KernelFrame_ColumnSumVector[rsp]
mov ebp,-1
kmovw k1,ebp # update mask to write all columns
.ifeqs "\Type\()", "U8S8"
.ifeqs "\Isa\()", "Avx512BW"
neg ebp
vpbroadcastw zmm5,ebp # generate 512-bit word vector [0x0001]
.endif
mov r14,rcx
shl r14,4 # compute matrix B packed stride
.else
lea r14,[rcx*8] # compute matrix B packed stride
.endif
//
// Process CountM rows of the matrices.
//
cmp r8,5
ja .LProcessCountM6
je .LProcessCountM5
cmp r8,3
ja .LProcessCountM4
je .LProcessCountM3
cmp r8,1
je .LProcessCountM1
.LProcessCountM2:
ProcessCountM 2
.LProcessCountM4:
ProcessCountM 4
.LProcessCountM6:
mov r8d,6 # return 6 rows handled
ProcessCountM 6
//
// Restore non-volatile registers and return.
//
.LExitKernel:
mov eax,r8d
vzeroupper
pop r14
pop r13
pop r12
pop rbx
pop rbp
ret
.LProcessCountM1:
ProcessCountM 1
.LProcessCountM3:
ProcessCountM 3
.LProcessCountM5:
ProcessCountM 5
.endm

View file

@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/op_kernel_context_internal.h"
#include "core/providers/cpu/math/matmul_integer.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/util/qmath.h"
@ -35,6 +36,9 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
template <>
Status MatMulInteger<uint8_t, uint8_t>::Compute(OpKernelContext* ctx) const {
auto ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
concurrency::ThreadPool* thread_pool = ctx_internal->GetOperatorThreadPool();
auto a = ctx->Input<Tensor>(0);
auto b = ctx->Input<Tensor>(1);
ORT_ENFORCE(a != nullptr && b != nullptr);
@ -71,13 +75,16 @@ Status MatMulInteger<uint8_t, uint8_t>::Compute(OpKernelContext* ctx) const {
b_offset,
y->template MutableData<int32_t>() + helper.OutputOffsets()[i],
static_cast<int>(helper.N()),
nullptr);
thread_pool);
}
return Status::OK();
}
template <>
Status MatMulInteger<uint8_t, int8_t>::Compute(OpKernelContext* ctx) const {
auto ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
concurrency::ThreadPool* thread_pool = ctx_internal->GetOperatorThreadPool();
auto a = ctx->Input<Tensor>(0);
auto b = ctx->Input<Tensor>(1);
ORT_ENFORCE(a != nullptr && b != nullptr);
@ -107,15 +114,19 @@ Status MatMulInteger<uint8_t, int8_t>::Compute(OpKernelContext* ctx) const {
}
}
// NOTE: Eigen based implementation is a reference implementation for accuracy only
for (int i = 0; i < static_cast<int>(helper.OutputOffsets().size()); i++) {
EigenCastGEMM<uint8_t, int8_t, int32_t>(
a->template Data<uint8_t>() + helper.LeftOffsets()[i],
b->template Data<int8_t>() + helper.RightOffsets()[i],
y->template MutableData<int32_t>() + helper.OutputOffsets()[i],
static_cast<int>(helper.M()),
static_cast<int>(helper.N()),
static_cast<int>(helper.K()));
QGemmu8s8_s32(static_cast<int>(helper.M()),
static_cast<int>(helper.N()),
static_cast<int>(helper.K()),
a->template Data<uint8_t>() + helper.LeftOffsets()[i],
static_cast<int>(helper.K()),
0,
b->template Data<int8_t>() + helper.RightOffsets()[i],
static_cast<int>(helper.N()),
0,
y->template MutableData<int32_t>() + helper.OutputOffsets()[i],
static_cast<int>(helper.N()),
thread_pool);
}
return Status::OK();
}

View file

@ -3,9 +3,50 @@
#include "core/util/qmath.h"
#include "core/common/common.h"
#include "core/util/math_cpuonly.h"
#include "core/mlas/inc/mlas.h"
#if defined(_M_AMD64) || defined(__x86_64__) || defined(_M_IX86) || defined(__i386__)
#define MLAS_SUPPORTS_GEMM_U8X8
#else
// default to gemmlowp when building for arm devices
#ifndef USE_GEMMLOWP
#define USE_GEMMLOWP
#endif
#endif
#ifdef USE_GEMMLOWP
#include "core/util/gemmlowp_common.h"
#endif
namespace onnxruntime {
void QGemmu8s8_s32(
int M,
int N,
int K,
const uint8_t* lhs_data,
int lda,
const uint8_t lhs_offset,
const int8_t* rhs_data,
int ldb,
const int8_t rhs_offset,
int32_t* result_data,
int ldc,
concurrency::ThreadPool* thread_pool) {
#ifdef MLAS_SUPPORTS_GEMM_U8X8
MlasGemm(M, N, K, lhs_data, lda, lhs_offset, rhs_data, ldb, rhs_offset, result_data, ldc, thread_pool);
#else
ORT_ENFORCE(lhs_offset == 0 && rhs_offset == 0, "For Eigen, zero point must be zero");
ORT_ENFORCE(lda == K && ldb == N && ldc == N, "For Eigen only RowMajor*RowMajor=RowMajor format is supported");
EigenCastGEMM<uint8_t, int8_t, int32_t>(lhs_data, rhs_data, result_data, M, N, K);
#endif
}
void QGemmu8u8_s32(
int M,
int N,
@ -21,13 +62,13 @@ void QGemmu8u8_s32(
concurrency::ThreadPool* thread_pool) {
#ifdef USE_GEMMLOWP
ORT_ENFORCE(lda == K && ldb == N && ldc == N, "For gemmlowp only RowMajor*RowMajor=RowMajor format is supported");
ORT_ENFORCE(lda == K && ldb == N && ldc == N, "For gemmlowp only RowMajor*RowMajor=RowMajor format is supported");
GemmlowpMultiplyu8u8_s32(lhs_data, rhs_data, result_data, lhs_offset, rhs_offset, M, N, K, thread_pool);
GemmlowpMultiplyu8u8_s32(lhs_data, rhs_data, result_data, lhs_offset, rhs_offset, M, N, K, thread_pool);
#else
MlasQgemm(M, N, K, lhs_data, lda, lhs_offset, rhs_data, ldb, rhs_offset, result_data, ldc, thread_pool);
MlasGemm(M, N, K, lhs_data, lda, lhs_offset, rhs_data, ldb, rhs_offset, result_data, ldc, thread_pool);
#endif
}
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -1,26 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
// default to gemmlowp when building for arm devices
#ifndef USE_GEMMLOWP
#if defined(_M_ARM64) || defined(__aarch64__)
#define USE_GEMMLOWP
#endif
#if defined(_M_ARM) || defined(__arm__)
#define USE_GEMMLOWP
#endif
#endif
#ifdef USE_GEMMLOWP
#include "core/util/gemmlowp_common.h"
#else
#include "core/mlas/inc/mlas.h"
#endif
#include "core/platform/threadpool.h"
#include <mutex>
#include <thread>
namespace onnxruntime {
void QGemmu8s8_s32(
int M,
int N,
int K,
const uint8_t* lhs_data,
int lda,
const uint8_t lhs_offset,
const int8_t* rhs_data,
int ldb,
const int8_t rhs_offset,
int32_t* result_data,
int ldc,
concurrency::ThreadPool* thread_pool);
void QGemmu8u8_s32(
int M,
int N,

View file

@ -35,7 +35,7 @@ Abstract:
#endif
#if defined(_M_IX86) || defined(__i386__) || defined(_M_AMD64) || defined(__x86_64__)
#define MLAS_HAS_QGEMM_U8U8
#define MLAS_HAS_QGEMM_U8X8
#endif
MLAS_THREADPOOL* threadpool = nullptr;
@ -452,9 +452,10 @@ public:
}
};
#ifdef MLAS_HAS_QGEMM_U8U8
#ifdef MLAS_HAS_QGEMM_U8X8
class MlasQgemmU8U8Test : public MlasTestBase
template <typename xint8_t>
class MlasQgemmU8X8Test : public MlasTestBase
{
private:
void
@ -467,11 +468,11 @@ private:
)
{
const uint8_t* A = BufferA.GetBuffer(K * M);
const uint8_t* B = BufferB.GetBuffer(N * K);
const xint8_t* B = BufferB.GetBuffer(N * K);
int32_t* C = BufferC.GetBuffer(N * M);
int32_t* CReference = BufferCReference.GetBuffer(N * M);
Test(M, N, K, A, K, offa, B, N, offb, C, CReference, N);
Test(M, N, K, A, K, offa, B, N, xint8_t(offb), C, CReference, N);
}
void
@ -482,9 +483,9 @@ private:
const uint8_t* A,
size_t lda,
uint8_t offa,
const uint8_t* B,
const xint8_t* B,
size_t ldb,
uint8_t offb,
xint8_t offb,
int32_t* C,
int32_t* CReference,
size_t ldc
@ -493,7 +494,7 @@ private:
std::fill_n(C, M * N, -1);
std::fill_n(CReference, M * N, -1);
MlasQgemm(M, N, K, A, lda, offa, B, ldb, offb, C, ldc, threadpool);
MlasGemm(M, N, K, A, lda, offa, B, ldb, offb, C, ldc, threadpool);
ReferenceQgemm(M, N, K, A, lda, offa, B, ldb, offb, CReference, ldc);
for (size_t f = 0; f < M * N; f++) {
@ -511,9 +512,9 @@ private:
const uint8_t* A,
size_t lda,
uint8_t offa,
const uint8_t* B,
const xint8_t* B,
size_t ldb,
uint8_t offb,
xint8_t offb,
int32_t* C,
size_t ldc
)
@ -523,7 +524,7 @@ private:
for (size_t n = 0; n < N; n++) {
const uint8_t* a = A + (m * lda);
const uint8_t* b = B + n;
const xint8_t* b = B + n;
int32_t* c = C + (m * ldc) + n;
int32_t sum = 0;
@ -539,7 +540,7 @@ private:
}
MatrixGuardBuffer<uint8_t> BufferA;
MatrixGuardBuffer<uint8_t> BufferB;
MatrixGuardBuffer<xint8_t> BufferB;
MatrixGuardBuffer<int32_t> BufferC;
MatrixGuardBuffer<int32_t> BufferCReference;
@ -565,7 +566,7 @@ public:
void
) override
{
static const uint8_t zero_points[] = { 0, 18, 128, 157, 231, 255 };
static const uint8_t zero_points[] = { 0, 18, 75, 128, 157, 231, 255 };
for (size_t a = 0; a < _countof(zero_points); a++) {
uint8_t offa = zero_points[a];
@ -1970,9 +1971,10 @@ main(
printf("SGEMM tests.\n");
std::make_unique<MlasSgemmTest>()->ExecuteShort();
#ifdef MLAS_HAS_QGEMM_U8U8
#ifdef MLAS_HAS_QGEMM_U8X8
printf("QGEMM tests.\n");
std::make_unique<MlasQgemmU8U8Test>()->ExecuteShort();
std::make_unique<MlasQgemmU8X8Test<int8_t>>()->ExecuteShort();
std::make_unique<MlasQgemmU8X8Test<uint8_t>>()->ExecuteShort();
#endif
printf("Conv2D tests.\n");