mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-30 03:37:44 +00:00
MLAS: add U8S8 MatMul operation (#1895)
Implement the second round of changes for quantization inside MLAS. This adds a MatMul operation for U8xS8=S32 for x86/x64 processors.
This commit is contained in:
parent
d3cb2a5572
commit
28a62f7728
31 changed files with 5415 additions and 1572 deletions
|
|
@ -44,6 +44,9 @@ if(MSVC)
|
|||
enable_language(ASM_MASM)
|
||||
|
||||
set(mlas_platform_srcs
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8S8KernelAvx512BW.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Vnni.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8U8KernelAvx2.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8U8KernelAvx512BW.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Vnni.asm
|
||||
|
|
@ -158,6 +161,7 @@ else()
|
|||
set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "-mavx")
|
||||
|
||||
set(mlas_platform_srcs_avx2
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8U8KernelAvx2.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/SgemmKernelFma3.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/SconvKernelFma3.S
|
||||
|
|
@ -175,6 +179,8 @@ else()
|
|||
set_source_files_properties(${mlas_platform_srcs_avx512f} PROPERTIES COMPILE_FLAGS "-mavx512f")
|
||||
|
||||
set(mlas_platform_srcs_avx512bw
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512BW.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Vnni.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512BW.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Vnni.S
|
||||
)
|
||||
|
|
|
|||
|
|
@ -135,7 +135,24 @@ MlasSgemm(
|
|||
|
||||
void
|
||||
MLASCALL
|
||||
MlasQgemm(
|
||||
MlasGemm(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const uint8_t* A,
|
||||
size_t lda,
|
||||
uint8_t offa,
|
||||
const int8_t* B,
|
||||
size_t ldb,
|
||||
int8_t offb,
|
||||
int32_t* C,
|
||||
size_t ldc,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
);
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasGemm(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
|
|
|
|||
|
|
@ -139,7 +139,7 @@ VpdpwssdsZmmZmmZmm MACRO DestReg, Src1Reg, Src2Reg
|
|||
;
|
||||
; This macro builds a VNNI instruction of the form:
|
||||
;
|
||||
; instr zmm1,zmm2,DWORD BCST [BaseReg+IndexReg*Scale]
|
||||
; instr zmm1,zmm2,DWORD BCST [BaseReg+IndexReg*Scale+ByteOffset]
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
|
|
@ -151,15 +151,20 @@ VpdpwssdsZmmZmmZmm MACRO DestReg, Src1Reg, Src2Reg
|
|||
;
|
||||
; BaseReg - Specifies the base register of the broadcast operand.
|
||||
;
|
||||
; ByteOffset - Specifies the DWORD aligned byte offset for the broadcast
|
||||
; operand.
|
||||
;
|
||||
; IndexReg - Specifies the optional index register of the broadcast operand.
|
||||
;
|
||||
; Scale - Specifies the scaling factor of the optional index register.
|
||||
;
|
||||
|
||||
VnniZmmZmmBroadcast MACRO Opcode, DestReg, Src1Reg, BaseReg, IndexReg, Scale
|
||||
VnniZmmZmmBroadcast MACRO Opcode, DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
|
||||
|
||||
LOCAL Payload0, Payload1, Payload2, ModRMByte, SibByte
|
||||
|
||||
.errnz (ByteOffset AND 3)
|
||||
|
||||
Payload0 = 002h ; "0F 38" prefix
|
||||
Payload0 = Payload0 + ((((ZmmIndex_&DestReg& SHR 3) AND 1) XOR 1) SHL 7)
|
||||
IFNB <IndexReg>
|
||||
|
|
@ -183,6 +188,9 @@ IFNB <IndexReg>
|
|||
ELSE
|
||||
ModRMByte = ModRMByte + (GprIndex_&BaseReg& AND 7)
|
||||
ENDIF
|
||||
IF ByteOffset NE 0
|
||||
ModRMByte = ModRMByte + 040h ; indicate disp8 byte offset
|
||||
ENDIF
|
||||
|
||||
IFNB <IndexReg>
|
||||
SibByte = 0
|
||||
|
|
@ -199,34 +207,36 @@ ENDIF
|
|||
SibByte = SibByte + (GprIndex_&BaseReg& AND 7)
|
||||
ENDIF
|
||||
|
||||
IFNB <IndexReg>
|
||||
db 062h, Payload0, Payload1, Payload2, Opcode, ModRMByte, SibByte
|
||||
ELSE
|
||||
db 062h, Payload0, Payload1, Payload2, Opcode, ModRMByte
|
||||
IFNB <IndexReg>
|
||||
db SibByte
|
||||
ENDIF
|
||||
IF ByteOffset NE 0
|
||||
db ByteOffset SHR 2
|
||||
ENDIF
|
||||
|
||||
ENDM
|
||||
|
||||
VpdpbusdZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, IndexReg, Scale
|
||||
VpdpbusdZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
|
||||
|
||||
VnniZmmZmmBroadcast 050h, DestReg, Src1Reg, BaseReg, IndexReg, Scale
|
||||
VnniZmmZmmBroadcast 050h, DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
|
||||
|
||||
ENDM
|
||||
|
||||
VpdpbusdsZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, IndexReg, Scale
|
||||
VpdpbusdsZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
|
||||
|
||||
VnniZmmZmmBroadcast 051h, DestReg, Src1Reg, BaseReg, IndexReg, Scale
|
||||
VnniZmmZmmBroadcast 051h, DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
|
||||
|
||||
ENDM
|
||||
|
||||
VpdpwssdZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, IndexReg, Scale
|
||||
VpdpwssdZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
|
||||
|
||||
VnniZmmZmmBroadcast 052h, DestReg, Src1Reg, BaseReg, IndexReg, Scale
|
||||
VnniZmmZmmBroadcast 052h, DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
|
||||
|
||||
ENDM
|
||||
|
||||
VpdpwssdsZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, IndexReg, Scale
|
||||
VpdpwssdsZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
|
||||
|
||||
VnniZmmZmmBroadcast 053h, DestReg, Src1Reg, BaseReg, IndexReg, Scale
|
||||
VnniZmmZmmBroadcast 053h, DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
|
||||
|
||||
ENDM
|
||||
|
|
|
|||
1053
onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm
Normal file
1053
onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm
Normal file
File diff suppressed because it is too large
Load diff
130
onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx512BW.asm
Normal file
130
onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx512BW.asm
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
;++
|
||||
;
|
||||
; Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
;
|
||||
; Licensed under the MIT License.
|
||||
;
|
||||
; Module Name:
|
||||
;
|
||||
; QgemmU8S8KernelAvx512BW.asm
|
||||
;
|
||||
; Abstract:
|
||||
;
|
||||
; This module implements the kernels for the quantized integer matrix/matrix
|
||||
; multiply operation (QGEMM).
|
||||
;
|
||||
; This implementation uses AVX512BW instructions.
|
||||
;
|
||||
;--
|
||||
|
||||
.xlist
|
||||
INCLUDE mlasi.inc
|
||||
INCLUDE QgemmU8S8KernelAvx512Common.inc
|
||||
.list
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to multiply and accumulator a single cell of the
|
||||
; output block.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; AccumReg - Supplies the register to accumulate into.
|
||||
;
|
||||
; Mult1Reg - Supplies the first multiplication operand register.
|
||||
;
|
||||
; Mult2Reg - Supplies the second multiplication operand register.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; zmm4 - Supplies a scratch register for intermediate results.
|
||||
;
|
||||
; zmm5 - Supplies a 512-bit with the broadcasted word value 0x0001.
|
||||
;
|
||||
|
||||
MultiplyAccumulateCell MACRO AccumReg, Mult1Reg, Mult2Reg
|
||||
|
||||
vpmaddubsw zmm4,Mult1Reg,Mult2Reg
|
||||
vpmaddwd zmm4,zmm4,zmm5
|
||||
vpaddd AccumReg,AccumReg,zmm4
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to multiply and accumulate each row of the output
|
||||
; block.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
;
|
||||
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
;
|
||||
; zmm14-zmm31 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
IF ColumnCount GE 48
|
||||
vmovdqu32 zmm0,ZMMWORD PTR [rdx+VectorOffset]
|
||||
vmovdqu32 zmm1,ZMMWORD PTR [rdx+r14+VectorOffset]
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [rdx+r14*2+VectorOffset]
|
||||
ELSEIF ColumnCount GE 32
|
||||
vmovdqu32 zmm1,ZMMWORD PTR [rdx+VectorOffset]
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [rdx+r14+VectorOffset]
|
||||
ELSE
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [rdx+VectorOffset]
|
||||
ENDIF
|
||||
EmitIfCountGE RowCount, 1, <vpbroadcastd zmm3,DWORD PTR [rcx+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <MultiplyAccumulateCell zmm26,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <MultiplyAccumulateCell zmm20,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <MultiplyAccumulateCell zmm14,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 2, <vpbroadcastd zmm3,DWORD PTR [rcx+r9+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <MultiplyAccumulateCell zmm27,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <MultiplyAccumulateCell zmm21,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <MultiplyAccumulateCell zmm15,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 3, <vpbroadcastd zmm3,DWORD PTR [rcx+r9*2+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <MultiplyAccumulateCell zmm28,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <MultiplyAccumulateCell zmm22,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <MultiplyAccumulateCell zmm16,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 4, <vpbroadcastd zmm3,DWORD PTR [rbx+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <MultiplyAccumulateCell zmm29,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <MultiplyAccumulateCell zmm23,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <MultiplyAccumulateCell zmm17,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 5, <vpbroadcastd zmm3,DWORD PTR [rbx+r9+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <MultiplyAccumulateCell zmm30,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <MultiplyAccumulateCell zmm24,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <MultiplyAccumulateCell zmm18,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 6, <vpbroadcastd zmm3,DWORD PTR [rbx+r9*2+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <MultiplyAccumulateCell zmm31,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <MultiplyAccumulateCell zmm25,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <MultiplyAccumulateCell zmm19,zmm3,zmm2>
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Generate the GEMM kernel.
|
||||
;
|
||||
|
||||
GemmU8X8KernelAvx512Function U8S8, Avx512BW
|
||||
|
||||
END
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
;++
|
||||
;
|
||||
; Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
;
|
||||
; Licensed under the MIT License.
|
||||
;
|
||||
; Module Name:
|
||||
;
|
||||
; QgemmU8S8KernelAvx512Common.inc
|
||||
;
|
||||
; Abstract:
|
||||
;
|
||||
; This module contains common kernel macros and structures for the quantized
|
||||
; integer matrix/matrix multiply operation (QGEMM) for the AVX512BW and
|
||||
; AVX512VNNI kernels.
|
||||
;
|
||||
;--
|
||||
|
||||
INCLUDE QgemmU8X8KernelAvx512Common.inc
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to execute the block compute macro multiple
|
||||
; times and advancing the matrix A and matrix B data pointers.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
;
|
||||
; zmm14-zmm31 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ComputeBlockLoop MACRO ColumnCount, RowCount
|
||||
|
||||
LOCAL ComputeBlockBy4Loop
|
||||
LOCAL ProcessRemainingBlocks
|
||||
LOCAL ComputeBlockBy1Loop
|
||||
LOCAL ComputeBlockLoopExit
|
||||
|
||||
mov rsi,r9 ; reload row length remaining
|
||||
|
||||
IF ((RowCount AND 1) EQ 0)
|
||||
sub rsi,4*4
|
||||
jb ProcessRemainingBlocks
|
||||
|
||||
ComputeBlockBy4Loop:
|
||||
ComputeBlock ColumnCount, RowCount, 0*64, 0
|
||||
ComputeBlock ColumnCount, RowCount, 1*64, 4
|
||||
ComputeBlock ColumnCount, RowCount, 2*64, 8
|
||||
ComputeBlock ColumnCount, RowCount, 3*64, 12
|
||||
add rcx,4*4 ; advance matrix A by 1 quad
|
||||
IF RowCount GT 3
|
||||
add rbx,4*4 ; advance matrix A plus 3 rows by 1 quad
|
||||
ENDIF
|
||||
add rdx,4*64 ; advance matrix B
|
||||
sub rsi,4*4 ; decrement quads remaining
|
||||
jae ComputeBlockBy4Loop
|
||||
|
||||
ProcessRemainingBlocks:
|
||||
add rsi,4*4 ; correct for over-subtract above
|
||||
jz ComputeBlockLoopExit
|
||||
ENDIF
|
||||
|
||||
ComputeBlockBy1Loop:
|
||||
ComputeBlock ColumnCount, RowCount, 0, 0
|
||||
add rcx,4 ; advance matrix A by 1 quad
|
||||
IF RowCount GT 3
|
||||
add rbx,4 ; advance matrix A plus 3 rows by 1 quad
|
||||
ENDIF
|
||||
add rdx,64 ; advance matrix B
|
||||
sub rsi,4 ; decrement quads remaining
|
||||
jnz ComputeBlockBy1Loop
|
||||
|
||||
ComputeBlockLoopExit:
|
||||
|
||||
ENDM
|
||||
102
onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Vnni.asm
Normal file
102
onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx512Vnni.asm
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
;++
|
||||
;
|
||||
; Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
;
|
||||
; Licensed under the MIT License.
|
||||
;
|
||||
; Module Name:
|
||||
;
|
||||
; QgemmU8S8KernelAvx512Vnni.asm
|
||||
;
|
||||
; Abstract:
|
||||
;
|
||||
; This module implements the kernels for the quantized integer matrix/matrix
|
||||
; multiply operation (QGEMM).
|
||||
;
|
||||
; This implementation uses AVX512VNNI instructions.
|
||||
;
|
||||
;--
|
||||
|
||||
.xlist
|
||||
INCLUDE mlasi.inc
|
||||
INCLUDE QgemmU8S8KernelAvx512Common.inc
|
||||
INCLUDE AssembleAvx512Vnni.inc
|
||||
.list
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to multiply and accumulate each row of the output
|
||||
; block.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
;
|
||||
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
;
|
||||
; zmm14-zmm31 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
IF ColumnCount GE 48
|
||||
vmovdqu32 zmm0,ZMMWORD PTR [rdx+VectorOffset]
|
||||
vmovdqu32 zmm1,ZMMWORD PTR [rdx+r14+VectorOffset]
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [rdx+r14*2+VectorOffset]
|
||||
ELSEIF ColumnCount GE 32
|
||||
vmovdqu32 zmm1,ZMMWORD PTR [rdx+VectorOffset]
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [rdx+r14+VectorOffset]
|
||||
ELSE
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [rdx+VectorOffset]
|
||||
ENDIF
|
||||
EmitIfCountGE RowCount, 1, <vpbroadcastd zmm3,DWORD PTR [rcx+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <VpdpbusdsZmmZmmZmm zmm26,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <VpdpbusdsZmmZmmZmm zmm20,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <VpdpbusdsZmmZmmZmm zmm14,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 2, <vpbroadcastd zmm3,DWORD PTR [rcx+r9+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <VpdpbusdsZmmZmmZmm zmm27,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <VpdpbusdsZmmZmmZmm zmm21,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <VpdpbusdsZmmZmmZmm zmm15,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 3, <vpbroadcastd zmm3,DWORD PTR [rcx+r9*2+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <VpdpbusdsZmmZmmZmm zmm28,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <VpdpbusdsZmmZmmZmm zmm22,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <VpdpbusdsZmmZmmZmm zmm16,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 4, <vpbroadcastd zmm3,DWORD PTR [rbx+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <VpdpbusdsZmmZmmZmm zmm29,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <VpdpbusdsZmmZmmZmm zmm23,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <VpdpbusdsZmmZmmZmm zmm17,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 5, <vpbroadcastd zmm3,DWORD PTR [rbx+r9+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <VpdpbusdsZmmZmmZmm zmm30,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <VpdpbusdsZmmZmmZmm zmm24,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <VpdpbusdsZmmZmmZmm zmm18,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 6, <vpbroadcastd zmm3,DWORD PTR [rbx+r9*2+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <VpdpbusdsZmmZmmZmm zmm31,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <VpdpbusdsZmmZmmZmm zmm25,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <VpdpbusdsZmmZmmZmm zmm19,zmm3,zmm2>
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Generate the GEMM kernel.
|
||||
;
|
||||
|
||||
GemmU8X8KernelAvx512Function U8S8, Avx512Vnni
|
||||
|
||||
END
|
||||
|
|
@ -19,10 +19,9 @@
|
|||
|
||||
.xlist
|
||||
INCLUDE mlasi.inc
|
||||
INCLUDE QgemmU8X8KernelAvx2Common.inc
|
||||
.list
|
||||
|
||||
EXTERN MlasMaskMoveAvx:NEAR
|
||||
|
||||
;
|
||||
; Stack frame layout for the U8U8 CopyPackA routine.
|
||||
;
|
||||
|
|
@ -73,44 +72,6 @@ GemmU8U8CopyPackBFrame STRUCT
|
|||
|
||||
GemmU8U8CopyPackBFrame ENDS
|
||||
|
||||
;
|
||||
; Stack frame layout for the U8U8 kernel.
|
||||
;
|
||||
|
||||
GemmU8U8KernelFrame STRUCT
|
||||
|
||||
SavedXmm6 OWORD ?
|
||||
SavedXmm7 OWORD ?
|
||||
SavedXmm8 OWORD ?
|
||||
SavedXmm9 OWORD ?
|
||||
SavedXmm10 OWORD ?
|
||||
SavedXmm11 OWORD ?
|
||||
SavedXmm12 OWORD ?
|
||||
SavedXmm13 OWORD ?
|
||||
SavedXmm14 OWORD ?
|
||||
SavedXmm15 OWORD ?
|
||||
SavedR14 QWORD ?
|
||||
SavedR13 QWORD ?
|
||||
SavedR12 QWORD ?
|
||||
SavedRdi QWORD ?
|
||||
SavedRsi QWORD ?
|
||||
SavedRbx QWORD ?
|
||||
SavedRbp QWORD ?
|
||||
ReturnAddress QWORD ?
|
||||
PreviousP1Home QWORD ?
|
||||
PreviousP2Home QWORD ?
|
||||
PreviousP3Home QWORD ?
|
||||
PreviousP4Home QWORD ?
|
||||
CountM QWORD ?
|
||||
CountN QWORD ?
|
||||
ldc QWORD ?
|
||||
RowSumVector QWORD ?
|
||||
ColumnSumVector QWORD ?
|
||||
DepthValue QWORD ?
|
||||
ZeroMode QWORD ?
|
||||
|
||||
GemmU8U8KernelFrame ENDS
|
||||
|
||||
;++
|
||||
;
|
||||
; Routine Description:
|
||||
|
|
@ -157,10 +118,10 @@ GemmU8U8KernelFrame ENDS
|
|||
push_reg r12
|
||||
push_reg r13
|
||||
alloc_stack (GemmU8U8CopyPackAFrame.SavedR13)
|
||||
save_xmm128_avx xmm6,GemmU8U8CopyPackAFrame.SavedXmm6
|
||||
save_xmm128_avx xmm7,GemmU8U8CopyPackAFrame.SavedXmm7
|
||||
save_xmm128_avx xmm8,GemmU8U8CopyPackAFrame.SavedXmm8
|
||||
save_xmm128_avx xmm9,GemmU8U8CopyPackAFrame.SavedXmm9
|
||||
save_xmm128 xmm6,GemmU8U8CopyPackAFrame.SavedXmm6
|
||||
save_xmm128 xmm7,GemmU8U8CopyPackAFrame.SavedXmm7
|
||||
save_xmm128 xmm8,GemmU8U8CopyPackAFrame.SavedXmm8
|
||||
save_xmm128 xmm9,GemmU8U8CopyPackAFrame.SavedXmm9
|
||||
|
||||
END_PROLOGUE
|
||||
|
||||
|
|
@ -195,13 +156,15 @@ GemmU8U8KernelFrame ENDS
|
|||
;
|
||||
; Process 4 rows of matrix A in a loop.
|
||||
;
|
||||
; For each row, zero extend the source bytes to 16-bits and write to the packed
|
||||
; buffer. The packed buffer has the same data ordering as the source bytes, but
|
||||
; the stride is CountK aligned up to an even number of 16-bit values.
|
||||
; Zero extend the source bytes to 16-bits and write to the packed buffer.
|
||||
;
|
||||
; The packed buffer has the same data ordering as the source bytes, but CountK
|
||||
; is aligned up to a multiple of 2 to maintain 32-bit alignment. All padding
|
||||
; bytes are zero filled.
|
||||
;
|
||||
; These 16-bit values are also accumulated into an intermediate per-row
|
||||
; accumulator. CountK cannot be greater than 256 to avoid overflowing these
|
||||
; 16-bit accumulators.
|
||||
; accumulator. CountK cannot be greater than 128 to avoid overflowing these
|
||||
; signed 16-bit accumulators.
|
||||
;
|
||||
|
||||
sub r9,4
|
||||
|
|
@ -221,12 +184,12 @@ ProcessNextRowM4:
|
|||
jb ProcessRemainingColumnsM4
|
||||
|
||||
ProcessNextColumnLoopM4:
|
||||
lea rax,[rdx+r8*2] ; compute matrix A plus two rows
|
||||
lea rax,[rdx+r8*2] ; compute matrix A plus 2 rows
|
||||
vpmovzxbw ymm4,XMMWORD PTR [rdx]
|
||||
vpmovzxbw ymm5,XMMWORD PTR [rdx+r8]
|
||||
vpmovzxbw ymm6,XMMWORD PTR [rax]
|
||||
vpmovzxbw ymm7,XMMWORD PTR [rax+r8]
|
||||
lea rax,[rcx+r11*4] ; compute matrix D plus two rows
|
||||
lea rax,[rcx+r11*4] ; compute matrix D plus 2 rows
|
||||
vmovdqu YMMWORD PTR [rcx],ymm4
|
||||
vmovdqu YMMWORD PTR [rcx+r11*2],ymm5
|
||||
vmovdqu YMMWORD PTR [rax],ymm6
|
||||
|
|
@ -252,7 +215,7 @@ ProcessRemainingColumnsM4:
|
|||
mov rbp,rsp ; GemmU8U8CopyPackAFrame.PaddedMatrixAData
|
||||
test bl,8 ; (CountK & 8) != 0?
|
||||
jz CopyRemainingCountKLessThan8M4
|
||||
lea r13,[rdx+r8*2] ; compute matrix A plus two rows
|
||||
lea r13,[rdx+r8*2] ; compute matrix A plus 2 rows
|
||||
mov rax,QWORD PTR [rdx]
|
||||
mov QWORD PTR [rbp],rax
|
||||
mov rax,QWORD PTR [rdx+r8]
|
||||
|
|
@ -267,7 +230,7 @@ ProcessRemainingColumnsM4:
|
|||
CopyRemainingCountKLessThan8M4:
|
||||
test bl,4 ; (CountK & 4) != 0?
|
||||
jz CopyRemainingCountKLessThan4M4
|
||||
lea r13,[rdx+r8*2] ; compute matrix A plus two rows
|
||||
lea r13,[rdx+r8*2] ; compute matrix A plus 2 rows
|
||||
mov eax,DWORD PTR [rdx]
|
||||
mov DWORD PTR [rbp],eax
|
||||
mov eax,DWORD PTR [rdx+r8]
|
||||
|
|
@ -282,7 +245,7 @@ CopyRemainingCountKLessThan8M4:
|
|||
CopyRemainingCountKLessThan4M4:
|
||||
test bl,2 ; (CountK & 2) != 0?
|
||||
jz CopyRemainingCountKLessThan2M4
|
||||
lea r13,[rdx+r8*2] ; compute matrix A plus two rows
|
||||
lea r13,[rdx+r8*2] ; compute matrix A plus 2 rows
|
||||
movzx eax,WORD PTR [rdx]
|
||||
mov WORD PTR [rbp],ax
|
||||
movzx eax,WORD PTR [rdx+r8]
|
||||
|
|
@ -297,7 +260,7 @@ CopyRemainingCountKLessThan4M4:
|
|||
CopyRemainingCountKLessThan2M4:
|
||||
test bl,1 ; (CountK & 1) != 0?
|
||||
jz ProcessPaddedMatrixADataM4
|
||||
lea r13,[rdx+r8*2] ; compute matrix A plus two rows
|
||||
lea r13,[rdx+r8*2] ; compute matrix A plus 2 rows
|
||||
movzx eax,BYTE PTR [rdx]
|
||||
mov BYTE PTR [rbp],al
|
||||
movzx eax,BYTE PTR [rdx+r8]
|
||||
|
|
@ -316,7 +279,7 @@ ProcessPaddedMatrixADataM4:
|
|||
vpmovzxbw ymm5,XMMWORD PTR GemmU8U8CopyPackAFrame.PaddedMatrixAData[rsp+16]
|
||||
vpmovzxbw ymm6,XMMWORD PTR GemmU8U8CopyPackAFrame.PaddedMatrixAData[rsp+32]
|
||||
vpmovzxbw ymm7,XMMWORD PTR GemmU8U8CopyPackAFrame.PaddedMatrixAData[rsp+48]
|
||||
lea rax,[rcx+r11*4] ; compute matrix D plus two rows
|
||||
lea rax,[rcx+r11*4] ; compute matrix D plus 2 rows
|
||||
vpmaskmovd YMMWORD PTR [rcx],ymm9,ymm4
|
||||
vpmaskmovd YMMWORD PTR [rcx+r11*2],ymm9,ymm5
|
||||
vpmaskmovd YMMWORD PTR [rax],ymm9,ymm6
|
||||
|
|
@ -334,22 +297,14 @@ ProcessPaddedMatrixADataM4:
|
|||
;
|
||||
|
||||
ReduceRowSumVectorM4:
|
||||
vpunpckldq ymm4,ymm0,ymm1 ; [A5 B5 A4 B4 A1 B1 A0 B0]
|
||||
vpunpckhdq ymm5,ymm0,ymm1 ; [A7 B7 A6 B6 A3 B3 A2 B2]
|
||||
vpunpckldq ymm6,ymm2,ymm3 ; [C5 D5 C4 D4 C1 D1 C0 D0]
|
||||
vpunpckhdq ymm7,ymm2,ymm3 ; [C7 D7 C6 D6 C3 D3 C2 D2]
|
||||
vpunpcklqdq ymm0,ymm4,ymm6 ; [A4 B4 C4 D4 A0 B0 C0 D0]
|
||||
vpunpckhqdq ymm1,ymm4,ymm6 ; [A5 B5 C5 D5 A1 B1 C1 D1]
|
||||
vpunpcklqdq ymm2,ymm5,ymm7 ; [A6 B6 C6 D6 A2 B2 C2 D2]
|
||||
vpunpckhqdq ymm3,ymm5,ymm7 ; [A7 B7 C7 D7 A3 B3 C3 D3]
|
||||
vpaddw ymm0,ymm0,ymm1 ; reduction
|
||||
vpaddw ymm0,ymm0,ymm2
|
||||
vpaddw ymm0,ymm0,ymm3
|
||||
vphaddw ymm0,ymm0,ymm1 ; reduce and interleave Sum1/Sum0
|
||||
vphaddw ymm1,ymm2,ymm3 ; reduce and interleave Sum3/Sum2
|
||||
vphaddw ymm0,ymm0,ymm1 ; reduce and interleave Sum3/Sum2/Sum1/Sum0
|
||||
vextracti128 xmm1,ymm0,1 ; extract high pairs
|
||||
vpaddw xmm0,xmm0,xmm1 ; reduction
|
||||
vpmaddwd xmm0,xmm0,xmm8 ; multiply by offset and reduce
|
||||
vpaddw xmm0,xmm0,xmm1 ; reduce low/high pairs
|
||||
vpmaddwd xmm0,xmm0,xmm8 ; multiply by offset and reduce 32-bit sum
|
||||
vmovdqu XMMWORD PTR [r12],xmm0
|
||||
add r12,4*4 ; advance row sum vector by 4 dwords
|
||||
add r12,4*4 ; advance row sum vector by 4 DWORDs
|
||||
sub r9,4 ; subtract rows remaining
|
||||
jae ProcessNextRowM4
|
||||
|
||||
|
|
@ -449,10 +404,10 @@ ReduceRowSumVectorM1:
|
|||
|
||||
ExitRoutine:
|
||||
vzeroupper
|
||||
vmovaps xmm6,GemmU8U8CopyPackAFrame.SavedXmm6[rsp]
|
||||
vmovaps xmm7,GemmU8U8CopyPackAFrame.SavedXmm7[rsp]
|
||||
vmovaps xmm8,GemmU8U8CopyPackAFrame.SavedXmm8[rsp]
|
||||
vmovaps xmm9,GemmU8U8CopyPackAFrame.SavedXmm9[rsp]
|
||||
movaps xmm6,GemmU8U8CopyPackAFrame.SavedXmm6[rsp]
|
||||
movaps xmm7,GemmU8U8CopyPackAFrame.SavedXmm7[rsp]
|
||||
movaps xmm8,GemmU8U8CopyPackAFrame.SavedXmm8[rsp]
|
||||
movaps xmm9,GemmU8U8CopyPackAFrame.SavedXmm9[rsp]
|
||||
add rsp,(GemmU8U8CopyPackAFrame.SavedR13)
|
||||
|
||||
BEGIN_EPILOGUE
|
||||
|
|
@ -537,9 +492,9 @@ ProcessNextColumnN16:
|
|||
jb ProcessRemainingRowsN16
|
||||
|
||||
ProcessNextRowLoopN16:
|
||||
vmovdqu xmm2,XMMWORD PTR [rdx] ; load two rows
|
||||
vmovdqu xmm2,XMMWORD PTR [rdx] ; load 2 rows
|
||||
vmovdqu xmm3,XMMWORD PTR [rdx+r8]
|
||||
lea rdx,[rdx+r8*2] ; advance matrix B by two rows
|
||||
lea rdx,[rdx+r8*2] ; advance matrix B by 2 rows
|
||||
vpunpcklbw xmm4,xmm2,xmm3 ; interleave row data
|
||||
vpunpckhbw xmm3,xmm2,xmm3
|
||||
vmovdqu XMMWORD PTR [rcx],xmm4 ; store interleaved rows
|
||||
|
|
@ -549,7 +504,7 @@ ProcessNextRowLoopN16:
|
|||
add rcx,32 ; advance matrix D by 32 bytes
|
||||
vpaddw ymm0,ymm0,ymm4 ; accumulate per column
|
||||
vpaddw ymm1,ymm1,ymm3
|
||||
sub rbx,2 ; subtract columns remaining
|
||||
sub rbx,2 ; subtract rows remaining
|
||||
jae ProcessNextRowLoopN16
|
||||
|
||||
ProcessRemainingRowsN16:
|
||||
|
|
@ -569,7 +524,7 @@ ReduceColumnSumVectorN16:
|
|||
vpmaddwd ymm1,ymm1,ymm5 ; multiply by offset and reduce
|
||||
vmovdqu YMMWORD PTR [r11],ymm0
|
||||
vmovdqu YMMWORD PTR [r11+32],ymm1
|
||||
add r11,64 ; advance column sum vector by 16 dwords
|
||||
add r11,16*4 ; advance column sum vector by 16 DWORDs
|
||||
sub r9,16 ; subtract columns remaining
|
||||
jae ProcessNextColumnN16
|
||||
|
||||
|
|
@ -654,7 +609,7 @@ ProcessPaddedMatrixBDataK2:
|
|||
vpmovzxbw ymm3,xmm3
|
||||
vpaddw ymm0,ymm0,ymm4 ; accumulate per column
|
||||
vpaddw ymm1,ymm1,ymm3
|
||||
lea rsi,[rsi+r8*2] ; advance next matrix B by two rows
|
||||
lea rsi,[rsi+r8*2] ; advance next matrix B by 2 rows
|
||||
add rcx,32 ; advance matrix D by 32 bytes
|
||||
sub r10,2 ; subtract columns remaining
|
||||
jae ProcessNextRowLoopNUnaligned
|
||||
|
|
@ -739,13 +694,12 @@ ReduceColumnSumVectorNUnaligned:
|
|||
|
||||
MultiplyAccumulateRow MACRO ColumnCount, Vec1Reg, Vec2Reg
|
||||
|
||||
IF ColumnCount EQ 16
|
||||
vpmaddwd ymm3,ymm2,ymm0
|
||||
IF ColumnCount EQ 16
|
||||
vpaddd Vec1Reg,Vec1Reg,ymm3
|
||||
vpmaddwd ymm2,ymm2,ymm1
|
||||
vpaddd Vec2Reg,Vec2Reg,ymm2
|
||||
ELSE
|
||||
vpmaddwd ymm3,ymm2,ymm0
|
||||
vpaddd Vec2Reg,Vec2Reg,ymm3
|
||||
ENDIF
|
||||
|
||||
|
|
@ -775,7 +729,7 @@ ENDIF
|
|||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r10 - Supplies the length in bytes of a row from matrix A.
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; ymm4-ymm15 - Supplies the block accumulators.
|
||||
;
|
||||
|
|
@ -786,15 +740,15 @@ ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
|||
EmitIfCountGE ColumnCount, 16, <vpmovzxbw ymm1,XMMWORD PTR [rdx+VectorOffset+16]>
|
||||
EmitIfCountGE RowCount, 1, <vpbroadcastd ymm2,DWORD PTR [rcx+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 1, <MultiplyAccumulateRow ColumnCount, ymm4, ymm5>
|
||||
EmitIfCountGE RowCount, 2, <vpbroadcastd ymm2,DWORD PTR [rcx+r10+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 2, <vpbroadcastd ymm2,DWORD PTR [rcx+r9+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 2, <MultiplyAccumulateRow ColumnCount, ymm6, ymm7>
|
||||
EmitIfCountGE RowCount, 3, <vpbroadcastd ymm2,DWORD PTR [rcx+r10*2+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 3, <vpbroadcastd ymm2,DWORD PTR [rcx+r9*2+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 3, <MultiplyAccumulateRow ColumnCount, ymm8, ymm9>
|
||||
EmitIfCountGE RowCount, 4, <vpbroadcastd ymm2,DWORD PTR [rbx+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 4, <MultiplyAccumulateRow ColumnCount, ymm10, ymm11>
|
||||
EmitIfCountGE RowCount, 5, <vpbroadcastd ymm2,DWORD PTR [rbx+r10+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 5, <vpbroadcastd ymm2,DWORD PTR [rbx+r9+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 5, <MultiplyAccumulateRow ColumnCount, ymm12, ymm13>
|
||||
EmitIfCountGE RowCount, 6, <vpbroadcastd ymm2,DWORD PTR [rbx+r10*2+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 6, <vpbroadcastd ymm2,DWORD PTR [rbx+r9*2+BroadcastOffset]>
|
||||
EmitIfCountGE RowCount, 6, <MultiplyAccumulateRow ColumnCount, ymm14, ymm15>
|
||||
|
||||
ENDM
|
||||
|
|
@ -802,8 +756,8 @@ ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
|||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to produce an output block for a set of columns
|
||||
; and rows.
|
||||
; This macro generates code to execute the block compute macro multiple
|
||||
; times and advancing the matrix A and matrix B data pointers.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
|
|
@ -813,281 +767,59 @@ ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
|||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rax - Supplies the length in bytes of a row from matrix C.
|
||||
; rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the number of paired columns from matrix A and the number of
|
||||
; paired rows from matrix B to iterate over.
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r10 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r12 - Supplies the address of the row sum vector.
|
||||
;
|
||||
; r13 - Supplies the address of the column sum vector.
|
||||
; ymm4-ymm15 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ProduceOutputBlock MACRO ColumnCount, RowCount
|
||||
ComputeBlockLoop MACRO ColumnCount, RowCount
|
||||
|
||||
LOCAL ComputeBlockLoop
|
||||
LOCAL ComputeBlockBy2Loop
|
||||
LOCAL ProcessRemainingBlocks
|
||||
LOCAL ComputeBlockBy1Loop
|
||||
LOCAL ComputeBlockLoopExit
|
||||
|
||||
;
|
||||
; Initialize the accumulators with the sum of the global depth value constant,
|
||||
; the column sums, and the row sums.
|
||||
;
|
||||
|
||||
vpbroadcastd ymm1,DWORD PTR GemmU8U8KernelFrame.DepthValue[rsp]
|
||||
IF ColumnCount EQ 16
|
||||
vpaddd ymm0,ymm1,YMMWORD PTR [r13]
|
||||
vpaddd ymm1,ymm1,YMMWORD PTR [r13+32]
|
||||
add r13,16*4 ; advance ColumnSumVector by 16 columns
|
||||
ELSE
|
||||
vpaddd ymm1,ymm1,YMMWORD PTR [r13]
|
||||
ENDIF
|
||||
EmitIfCountGE RowCount, 1, <vpbroadcastd ymm5,DWORD PTR [r12]>
|
||||
EmitIfCountGE RowCount, 2, <vpbroadcastd ymm7,DWORD PTR [r12+4]>
|
||||
EmitIfCountGE RowCount, 3, <vpbroadcastd ymm9,DWORD PTR [r12+8]>
|
||||
EmitIfCountGE RowCount, 4, <vpbroadcastd ymm11,DWORD PTR [r12+12]>
|
||||
EmitIfCountGE RowCount, 5, <vpbroadcastd ymm13,DWORD PTR [r12+16]>
|
||||
EmitIfCountGE RowCount, 6, <vpbroadcastd ymm15,DWORD PTR [r12+20]>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <vpaddd ymm4,ymm5,ymm0>
|
||||
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,ymm1>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <vpaddd ymm6,ymm7,ymm0>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,ymm1>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <vpaddd ymm8,ymm9,ymm0>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,ymm1>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <vpaddd ymm10,ymm11,ymm0>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,ymm1>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <vpaddd ymm12,ymm13,ymm0>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,ymm1>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <vpaddd ymm14,ymm15,ymm0>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,ymm1>
|
||||
|
||||
;
|
||||
; Iterate over PairedCountK elements from matrix A and matrix B.
|
||||
;
|
||||
; Unrolling the loop to do two iterations improves performance slightly at the
|
||||
; cost of larger code size. Balance this by only unrolling for the common case
|
||||
; of computing 16 columns for an even number of rows.
|
||||
;
|
||||
|
||||
mov rsi,r9 ; reload PairedCountK
|
||||
IF RowCount GT 3
|
||||
lea rbx,[r10*2+r10]
|
||||
add rbx,rcx ; compute matrix A plus 3 rows
|
||||
ENDIF
|
||||
mov rsi,r9 ; reload row length remaining
|
||||
|
||||
IF (ColumnCount EQ 16) AND ((RowCount AND 1) EQ 0)
|
||||
sub rsi,2
|
||||
sub rsi,2*4
|
||||
jb ProcessRemainingBlocks
|
||||
|
||||
ComputeBlockLoop:
|
||||
ComputeBlockBy2Loop:
|
||||
ComputeBlock ColumnCount, RowCount, 0, 0
|
||||
ComputeBlock ColumnCount, RowCount, 32, 4
|
||||
add rcx,2*4 ; advance matrix A by 2 pairs
|
||||
IF RowCount GT 3
|
||||
add rbx,2*4 ; advance matrix A plus 3 rows by 2 pairs
|
||||
ENDIF
|
||||
add rdx,2*32 ; advance matrix B by 64 columns
|
||||
sub rsi,2 ; subtract pairs remaining
|
||||
jae ComputeBlockLoop
|
||||
add rdx,2*32 ; advance matrix B
|
||||
sub rsi,2*4
|
||||
jae ComputeBlockBy2Loop
|
||||
|
||||
ProcessRemainingBlocks:
|
||||
add rsi,2 ; correct for over-subtract above
|
||||
add rsi,2*4 ; correct for over-subtract above
|
||||
jz ComputeBlockLoopExit
|
||||
ComputeBlock ColumnCount, RowCount, 0, 0
|
||||
add rdx,32 ; advance matrix B by 32 columns
|
||||
add rdx,32 ; advance matrix B
|
||||
ELSE
|
||||
ComputeBlockLoop:
|
||||
ComputeBlockBy1Loop:
|
||||
ComputeBlock ColumnCount, RowCount, 0, 0
|
||||
add rcx,4 ; advance matrix A by 1 pair
|
||||
IF RowCount GT 3
|
||||
add rbx,4 ; advance matrix A plus 3 rows by 1 pair
|
||||
ENDIF
|
||||
add rdx,32
|
||||
dec rsi ; decrement pairs remaining
|
||||
jnz ComputeBlockLoop
|
||||
add rdx,32 ; advance matrix B
|
||||
sub rsi,4
|
||||
jnz ComputeBlockBy1Loop
|
||||
ENDIF
|
||||
|
||||
ComputeBlockLoopExit:
|
||||
IF RowCount GT 3
|
||||
lea rbx,[r8+rax*2] ; compute matrix C plus 3 rows
|
||||
add rbx,rax
|
||||
ENDIF
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to compute matrix multiplication for a fixed set
|
||||
; of rows.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; RowCount - Supplies the number of rows to process.
|
||||
;
|
||||
; Fallthrough - Supplies a non-blank value if the macro may fall through to
|
||||
; the ExitKernel label.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rax - Supplies the length in bytes of a row from matrix C.
|
||||
;
|
||||
; rcx - Supplies the address of matrix A.
|
||||
;
|
||||
; rdx - Supplies the address of matrix B.
|
||||
;
|
||||
; r8 - Supplies the address of matrix C.
|
||||
;
|
||||
; rdi - Supplies the address of matrix A.
|
||||
;
|
||||
; rbp - Supplies the number of columns from matrix B and matrix C to iterate
|
||||
; over.
|
||||
;
|
||||
; r9 - Supplies the number of paired columns from matrix A and the number of
|
||||
; paired rows from matrix B to iterate over.
|
||||
;
|
||||
; r10 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r12 - Supplies the address of the row sum vector.
|
||||
;
|
||||
; r13 - Supplies the address of the column sum vector.
|
||||
;
|
||||
; r14b - Supplies the zero mode flag.
|
||||
;
|
||||
|
||||
ProcessCountM MACRO RowCount, Fallthrough
|
||||
|
||||
LOCAL ProcessNextColumnLoop16xN
|
||||
LOCAL SkipAccumulateOutput16xNBlock
|
||||
LOCAL OutputMasked16xNBlock
|
||||
LOCAL ProcessRemainingCountN
|
||||
LOCAL SkipAccumulateOutput8xNBlock
|
||||
LOCAL SkipAccumulateOutputMasked16xNBlock
|
||||
LOCAL OutputMasked8xNBlock
|
||||
LOCAL SkipAccumulateOutputMasked8xNBlock
|
||||
|
||||
cmp rbp,8
|
||||
jbe ProcessRemainingCountN
|
||||
|
||||
ProcessNextColumnLoop16xN:
|
||||
ProduceOutputBlock 16, RowCount
|
||||
sub rbp,16
|
||||
jb OutputMasked16xNBlock
|
||||
test r14b,r14b ; ZeroMode?
|
||||
jnz SkipAccumulateOutput16xNBlock
|
||||
EmitIfCountGE RowCount, 1, <vpaddd ymm4,ymm4,YMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,YMMWORD PTR [r8+32]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd ymm6,ymm6,YMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,YMMWORD PTR [r8+rax+32]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd ymm8,ymm8,YMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,YMMWORD PTR [r8+rax*2+32]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd ymm10,ymm10,YMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,YMMWORD PTR [rbx+32]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax+32]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2+32]>
|
||||
|
||||
SkipAccumulateOutput16xNBlock:
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8],ymm4>
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8+32],ymm5>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax],ymm6>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax+32],ymm7>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2],ymm8>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2+32],ymm9>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx],ymm10>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx+32],ymm11>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax],ymm12>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax+32],ymm13>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2],ymm14>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2+32],ymm15>
|
||||
add r8,16*4 ; advance matrix C by 16 columns
|
||||
mov rcx,rdi ; reload matrix A
|
||||
cmp rbp,8
|
||||
ja ProcessNextColumnLoop16xN
|
||||
test rbp,rbp
|
||||
jz ExitKernel
|
||||
|
||||
ProcessRemainingCountN:
|
||||
ProduceOutputBlock 8, RowCount
|
||||
cmp rbp,8
|
||||
jb OutputMasked8xNBlock
|
||||
test r14b,r14b ; ZeroMode?
|
||||
jnz SkipAccumulateOutput8xNBlock
|
||||
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,YMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,YMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,YMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,YMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2]>
|
||||
|
||||
SkipAccumulateOutput8xNBlock:
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8],ymm5>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax],ymm7>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2],ymm9>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx],ymm11>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax],ymm13>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2],ymm15>
|
||||
jmp ExitKernel
|
||||
|
||||
OutputMasked16xNBlock:
|
||||
test r14b,r14b ; ZeroMode?
|
||||
jnz SkipAccumulateOutputMasked16xNBlock
|
||||
EmitIfCountGE RowCount, 1, <vpaddd ymm4,ymm4,YMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd ymm6,ymm6,YMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd ymm8,ymm8,YMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd ymm10,ymm10,YMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]>
|
||||
|
||||
SkipAccumulateOutputMasked16xNBlock:
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8],ymm4>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax],ymm6>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2],ymm8>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx],ymm10>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax],ymm12>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2],ymm14>
|
||||
add r8,8*4 ; advance matrix C by 8 columns
|
||||
IF RowCount GT 3
|
||||
add rbx,8*4 ; advance matrix C plus 3 rows by 8 columns
|
||||
ENDIF
|
||||
add rbp,8 ; correct for over-subtract above
|
||||
|
||||
OutputMasked8xNBlock:
|
||||
mov DWORD PTR GemmU8U8KernelFrame.CountN[rsp],ebp
|
||||
vpbroadcastd ymm0,DWORD PTR GemmU8U8KernelFrame.CountN[rsp]
|
||||
vpcmpgtd ymm0,ymm0,YMMWORD PTR [MlasMaskMoveAvx]
|
||||
test r14b,r14b ; ZeroMode?
|
||||
jnz SkipAccumulateOutputMasked8xNBlock
|
||||
EmitIfCountGE RowCount, 1, <vpmaskmovd ymm4,ymm0,YMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 2, <vpmaskmovd ymm6,ymm0,YMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 3, <vpmaskmovd ymm8,ymm0,YMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 4, <vpmaskmovd ymm10,ymm0,YMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 5, <vpmaskmovd ymm12,ymm0,YMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 6, <vpmaskmovd ymm14,ymm0,YMMWORD PTR [rbx+rax*2]>
|
||||
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,ymm4>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,ymm6>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,ymm8>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,ymm10>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,ymm12>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,ymm14>
|
||||
|
||||
SkipAccumulateOutputMasked8xNBlock:
|
||||
EmitIfCountGE RowCount, 1, <vpmaskmovd YMMWORD PTR [r8],ymm0,ymm5>
|
||||
EmitIfCountGE RowCount, 2, <vpmaskmovd YMMWORD PTR [r8+rax],ymm0,ymm7>
|
||||
EmitIfCountGE RowCount, 3, <vpmaskmovd YMMWORD PTR [r8+rax*2],ymm0,ymm9>
|
||||
EmitIfCountGE RowCount, 4, <vpmaskmovd YMMWORD PTR [rbx],ymm0,ymm11>
|
||||
EmitIfCountGE RowCount, 5, <vpmaskmovd YMMWORD PTR [rbx+rax],ymm0,ymm13>
|
||||
EmitIfCountGE RowCount, 6, <vpmaskmovd YMMWORD PTR [rbx+rax*2],ymm0,ymm15>
|
||||
IFB <Fallthrough>
|
||||
jmp ExitKernel
|
||||
ENDIF
|
||||
|
||||
ENDM
|
||||
|
||||
|
|
@ -1108,8 +840,8 @@ ENDIF
|
|||
;
|
||||
; C (r8) - Supplies the address of matrix C.
|
||||
;
|
||||
; PairedCountK (r9) - Supplies the number of paired columns from matrix A and
|
||||
; the number of paired rows from matrix B to iterate over.
|
||||
; PairCountK (r9) - Supplies the number of pair columns from matrix A and the
|
||||
; number of pair rows from matrix B to iterate over.
|
||||
;
|
||||
; CountM - Supplies the maximum number of rows that can be processed for
|
||||
; matrix A and matrix C. The actual number of rows handled for this
|
||||
|
|
@ -1129,7 +861,7 @@ ENDIF
|
|||
; every column of matrix C.
|
||||
;
|
||||
; DepthValue - Supplies the value CountK multiplied by the zero point offset
|
||||
; of matrixA multplied by the zero point offset of matrix B. This value is
|
||||
; of matrix A multplied by the zero point offset of matrix B. This value is
|
||||
; accumulated into every element of matrix C.
|
||||
;
|
||||
; ZeroMode - Supplies true if the output matrix must be zero initialized,
|
||||
|
|
@ -1149,30 +881,29 @@ ENDIF
|
|||
push_reg rdi
|
||||
push_reg r12
|
||||
push_reg r13
|
||||
push_reg r14
|
||||
alloc_stack (GemmU8U8KernelFrame.SavedR14)
|
||||
save_xmm128_avx xmm6,GemmU8U8KernelFrame.SavedXmm6
|
||||
save_xmm128_avx xmm7,GemmU8U8KernelFrame.SavedXmm7
|
||||
save_xmm128_avx xmm8,GemmU8U8KernelFrame.SavedXmm8
|
||||
save_xmm128_avx xmm9,GemmU8U8KernelFrame.SavedXmm9
|
||||
save_xmm128_avx xmm10,GemmU8U8KernelFrame.SavedXmm10
|
||||
save_xmm128_avx xmm11,GemmU8U8KernelFrame.SavedXmm11
|
||||
save_xmm128_avx xmm12,GemmU8U8KernelFrame.SavedXmm12
|
||||
save_xmm128_avx xmm13,GemmU8U8KernelFrame.SavedXmm13
|
||||
save_xmm128_avx xmm14,GemmU8U8KernelFrame.SavedXmm14
|
||||
save_xmm128_avx xmm15,GemmU8U8KernelFrame.SavedXmm15
|
||||
alloc_stack (GemmU8X8KernelFrame.SavedR13)
|
||||
save_xmm128 xmm6,GemmU8X8KernelFrame.SavedXmm6
|
||||
save_xmm128 xmm7,GemmU8X8KernelFrame.SavedXmm7
|
||||
save_xmm128 xmm8,GemmU8X8KernelFrame.SavedXmm8
|
||||
save_xmm128 xmm9,GemmU8X8KernelFrame.SavedXmm9
|
||||
save_xmm128 xmm10,GemmU8X8KernelFrame.SavedXmm10
|
||||
save_xmm128 xmm11,GemmU8X8KernelFrame.SavedXmm11
|
||||
save_xmm128 xmm12,GemmU8X8KernelFrame.SavedXmm12
|
||||
save_xmm128 xmm13,GemmU8X8KernelFrame.SavedXmm13
|
||||
save_xmm128 xmm14,GemmU8X8KernelFrame.SavedXmm14
|
||||
save_xmm128 xmm15,GemmU8X8KernelFrame.SavedXmm15
|
||||
|
||||
END_PROLOGUE
|
||||
|
||||
mov rdi,rcx
|
||||
mov rbp,GemmU8U8KernelFrame.CountN[rsp]
|
||||
mov rax,GemmU8U8KernelFrame.ldc[rsp]
|
||||
mov rbp,GemmU8X8KernelFrame.CountN[rsp]
|
||||
mov rax,GemmU8X8KernelFrame.ldc[rsp]
|
||||
shl rax,2 ; convert ldc to bytes
|
||||
lea r10,[r9*4]
|
||||
mov r11,GemmU8U8KernelFrame.CountM[rsp]
|
||||
mov r12,GemmU8U8KernelFrame.RowSumVector[rsp]
|
||||
mov r13,GemmU8U8KernelFrame.ColumnSumVector[rsp]
|
||||
movzx r14,BYTE PTR GemmU8U8KernelFrame.ZeroMode[rsp]
|
||||
shl r9,2 ; convert to row length
|
||||
movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp]
|
||||
mov r11,GemmU8X8KernelFrame.CountM[rsp]
|
||||
mov r12,GemmU8X8KernelFrame.RowSumVector[rsp]
|
||||
mov r13,GemmU8X8KernelFrame.ColumnSumVector[rsp]
|
||||
|
||||
;
|
||||
; Process CountM rows of the matrices.
|
||||
|
|
@ -1204,21 +935,20 @@ ProcessCountM6:
|
|||
ExitKernel:
|
||||
mov eax,r11d
|
||||
vzeroupper
|
||||
vmovaps xmm6,GemmU8U8KernelFrame.SavedXmm6[rsp]
|
||||
vmovaps xmm7,GemmU8U8KernelFrame.SavedXmm7[rsp]
|
||||
vmovaps xmm8,GemmU8U8KernelFrame.SavedXmm8[rsp]
|
||||
vmovaps xmm9,GemmU8U8KernelFrame.SavedXmm9[rsp]
|
||||
vmovaps xmm10,GemmU8U8KernelFrame.SavedXmm10[rsp]
|
||||
vmovaps xmm11,GemmU8U8KernelFrame.SavedXmm11[rsp]
|
||||
vmovaps xmm12,GemmU8U8KernelFrame.SavedXmm12[rsp]
|
||||
vmovaps xmm13,GemmU8U8KernelFrame.SavedXmm13[rsp]
|
||||
vmovaps xmm14,GemmU8U8KernelFrame.SavedXmm14[rsp]
|
||||
vmovaps xmm15,GemmU8U8KernelFrame.SavedXmm15[rsp]
|
||||
add rsp,(GemmU8U8KernelFrame.SavedR14)
|
||||
movaps xmm6,GemmU8X8KernelFrame.SavedXmm6[rsp]
|
||||
movaps xmm7,GemmU8X8KernelFrame.SavedXmm7[rsp]
|
||||
movaps xmm8,GemmU8X8KernelFrame.SavedXmm8[rsp]
|
||||
movaps xmm9,GemmU8X8KernelFrame.SavedXmm9[rsp]
|
||||
movaps xmm10,GemmU8X8KernelFrame.SavedXmm10[rsp]
|
||||
movaps xmm11,GemmU8X8KernelFrame.SavedXmm11[rsp]
|
||||
movaps xmm12,GemmU8X8KernelFrame.SavedXmm12[rsp]
|
||||
movaps xmm13,GemmU8X8KernelFrame.SavedXmm13[rsp]
|
||||
movaps xmm14,GemmU8X8KernelFrame.SavedXmm14[rsp]
|
||||
movaps xmm15,GemmU8X8KernelFrame.SavedXmm15[rsp]
|
||||
add rsp,(GemmU8X8KernelFrame.SavedR13)
|
||||
|
||||
BEGIN_EPILOGUE
|
||||
|
||||
pop r14
|
||||
pop r13
|
||||
pop r12
|
||||
pop rdi
|
||||
|
|
|
|||
|
|
@ -25,39 +25,26 @@ INCLUDE QgemmU8U8KernelAvx512Common.inc
|
|||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to multiply and accumulator a single row of the
|
||||
; This macro generates code to multiply and accumulator a single cell of the
|
||||
; output block.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
; AccumReg - Supplies the register to accumulate into.
|
||||
;
|
||||
; Vec1Reg - Supplies the high block accumulator register (when ColumnCount
|
||||
; is 32).
|
||||
; Mult1Reg - Supplies the first multiplication operand register.
|
||||
;
|
||||
; Vec2Reg - Supplies the low block accumulator register.
|
||||
; Mult2Reg - Supplies the second multiplication operand register.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; zmm28 - Supplies the first vector loaded from matrix B.
|
||||
;
|
||||
; zmm29 - Supplies the second vector loaded from matrix B (when ColumnCount
|
||||
; is 32).
|
||||
;
|
||||
; zmm30 - Supplies the broadcast value loaded from matrix A.
|
||||
; zmm4 - Supplies a scratch register for intermediate results.
|
||||
;
|
||||
|
||||
MultiplyAccumulateRow MACRO ColumnCount, Vec1Reg, Vec2Reg
|
||||
MultiplyAccumulateCell MACRO AccumReg, Mult1Reg, Mult2Reg
|
||||
|
||||
IF ColumnCount EQ 32
|
||||
vpmaddwd zmm31,zmm30,zmm28
|
||||
vpaddd Vec1Reg,Vec1Reg,zmm31
|
||||
vpmaddwd zmm30,zmm30,zmm29
|
||||
vpaddd Vec2Reg,Vec2Reg,zmm30
|
||||
ELSE
|
||||
vpmaddwd zmm31,zmm30,zmm28
|
||||
vpaddd Vec2Reg,Vec2Reg,zmm31
|
||||
ENDIF
|
||||
vpmaddwd zmm4,Mult1Reg,Mult2Reg
|
||||
vpaddd AccumReg,AccumReg,zmm4
|
||||
|
||||
ENDM
|
||||
|
||||
|
|
@ -73,6 +60,10 @@ ENDIF
|
|||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
;
|
||||
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
|
@ -81,27 +72,49 @@ ENDIF
|
|||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r10 - Supplies the length in bytes of a row from matrix A.
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; zmm16-zmm27 - Supplies the block accumulators.
|
||||
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
;
|
||||
; zmm14-zmm31 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ComputeBlock MACRO ColumnCount, RowCount
|
||||
ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
vpmovzxbw zmm28,YMMWORD PTR [rdx]
|
||||
EmitIfCountGE ColumnCount, 32, <vpmovzxbw zmm29,YMMWORD PTR [rdx+r10*8]>
|
||||
EmitIfCountGE RowCount, 1, <vpbroadcastd zmm30,DWORD PTR [rcx]>
|
||||
EmitIfCountGE RowCount, 1, <MultiplyAccumulateRow ColumnCount, zmm16, zmm17>
|
||||
EmitIfCountGE RowCount, 2, <vpbroadcastd zmm30,DWORD PTR [rcx+r10]>
|
||||
EmitIfCountGE RowCount, 2, <MultiplyAccumulateRow ColumnCount, zmm18, zmm19>
|
||||
EmitIfCountGE RowCount, 3, <vpbroadcastd zmm30,DWORD PTR [rcx+r10*2]>
|
||||
EmitIfCountGE RowCount, 3, <MultiplyAccumulateRow ColumnCount, zmm20, zmm21>
|
||||
EmitIfCountGE RowCount, 4, <vpbroadcastd zmm30,DWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 4, <MultiplyAccumulateRow ColumnCount, zmm22, zmm23>
|
||||
EmitIfCountGE RowCount, 5, <vpbroadcastd zmm30,DWORD PTR [rbx+r10]>
|
||||
EmitIfCountGE RowCount, 5, <MultiplyAccumulateRow ColumnCount, zmm24, zmm25>
|
||||
EmitIfCountGE RowCount, 6, <vpbroadcastd zmm30,DWORD PTR [rbx+r10*2]>
|
||||
EmitIfCountGE RowCount, 6, <MultiplyAccumulateRow ColumnCount, zmm26, zmm27>
|
||||
IF ColumnCount GE 48
|
||||
vpmovzxbw zmm0,YMMWORD PTR [rdx+VectorOffset]
|
||||
vpmovzxbw zmm1,YMMWORD PTR [rdx+r14+VectorOffset]
|
||||
vpmovzxbw zmm2,YMMWORD PTR [rdx+r14*2+VectorOffset]
|
||||
ELSEIF ColumnCount GE 32
|
||||
vpmovzxbw zmm1,YMMWORD PTR [rdx+VectorOffset]
|
||||
vpmovzxbw zmm2,YMMWORD PTR [rdx+r14+VectorOffset]
|
||||
ELSE
|
||||
vpmovzxbw zmm2,YMMWORD PTR [rdx+VectorOffset]
|
||||
ENDIF
|
||||
EmitIfCountGE RowCount, 1, <vpbroadcastd zmm3,DWORD PTR [rcx+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <MultiplyAccumulateCell zmm26,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <MultiplyAccumulateCell zmm20,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <MultiplyAccumulateCell zmm14,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 2, <vpbroadcastd zmm3,DWORD PTR [rcx+r9+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <MultiplyAccumulateCell zmm27,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <MultiplyAccumulateCell zmm21,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <MultiplyAccumulateCell zmm15,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 3, <vpbroadcastd zmm3,DWORD PTR [rcx+r9*2+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <MultiplyAccumulateCell zmm28,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <MultiplyAccumulateCell zmm22,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <MultiplyAccumulateCell zmm16,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 4, <vpbroadcastd zmm3,DWORD PTR [rbx+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <MultiplyAccumulateCell zmm29,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <MultiplyAccumulateCell zmm23,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <MultiplyAccumulateCell zmm17,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 5, <vpbroadcastd zmm3,DWORD PTR [rbx+r9+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <MultiplyAccumulateCell zmm30,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <MultiplyAccumulateCell zmm24,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <MultiplyAccumulateCell zmm18,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 6, <vpbroadcastd zmm3,DWORD PTR [rbx+r9*2+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <MultiplyAccumulateCell zmm31,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <MultiplyAccumulateCell zmm25,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <MultiplyAccumulateCell zmm19,zmm3,zmm2>
|
||||
|
||||
ENDM
|
||||
|
||||
|
|
@ -109,6 +122,6 @@ ComputeBlock MACRO ColumnCount, RowCount
|
|||
; Generate the GEMM kernel.
|
||||
;
|
||||
|
||||
GemmU8U8KernelAvx512Function Avx512BW
|
||||
GemmU8X8KernelAvx512Function U8U8, Avx512BW
|
||||
|
||||
END
|
||||
|
|
|
|||
|
|
@ -16,39 +16,13 @@
|
|||
;
|
||||
;--
|
||||
|
||||
;
|
||||
; Stack frame layout for the U8U8 kernel.
|
||||
;
|
||||
|
||||
GemmU8U8KernelFrame STRUCT
|
||||
|
||||
SavedR14 QWORD ?
|
||||
SavedR13 QWORD ?
|
||||
SavedR12 QWORD ?
|
||||
SavedRdi QWORD ?
|
||||
SavedRsi QWORD ?
|
||||
SavedRbx QWORD ?
|
||||
SavedRbp QWORD ?
|
||||
ReturnAddress QWORD ?
|
||||
PreviousP1Home QWORD ?
|
||||
PreviousP2Home QWORD ?
|
||||
PreviousP3Home QWORD ?
|
||||
PreviousP4Home QWORD ?
|
||||
CountM QWORD ?
|
||||
CountN QWORD ?
|
||||
ldc QWORD ?
|
||||
RowSumVector QWORD ?
|
||||
ColumnSumVector QWORD ?
|
||||
DepthValue QWORD ?
|
||||
ZeroMode QWORD ?
|
||||
|
||||
GemmU8U8KernelFrame ENDS
|
||||
INCLUDE QgemmU8X8KernelAvx512Common.inc
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to produce an output block for a set of columns
|
||||
; and rows.
|
||||
; This macro generates code to execute the block compute macro multiple
|
||||
; times and advancing the matrix A and matrix B data pointers.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
|
|
@ -58,328 +32,33 @@ GemmU8U8KernelFrame ENDS
|
|||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the number of paired columns from matrix A and the number of
|
||||
; paired rows from matrix B to iterate over.
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r10 - Supplies the length in bytes of a row from matrix A.
|
||||
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
;
|
||||
; r12 - Supplies the address of the row sum vector.
|
||||
;
|
||||
; r13 - Supplies the address of the column sum vector.
|
||||
; zmm14-zmm31 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ProduceOutputBlock MACRO ColumnCount, RowCount
|
||||
ComputeBlockLoop MACRO ColumnCount, RowCount
|
||||
|
||||
LOCAL ComputeBlockLoop
|
||||
LOCAL ComputeBlockBy1Loop
|
||||
|
||||
;
|
||||
; Initialize the accumulators with the sum of the global depth value constant,
|
||||
; the column sums, and the row sums.
|
||||
;
|
||||
mov rsi,r9 ; reload row length remaining
|
||||
|
||||
vpbroadcastd zmm31,DWORD PTR GemmU8U8KernelFrame.DepthValue[rsp]
|
||||
IF ColumnCount EQ 32
|
||||
vpaddd zmm30,zmm31,ZMMWORD PTR [r13]
|
||||
vpaddd zmm31,zmm31,ZMMWORD PTR [r13+64]
|
||||
add r13,32*4 ; advance ColumnSumVector by 32 columns
|
||||
ELSE
|
||||
vpaddd zmm31,zmm31,ZMMWORD PTR [r13]
|
||||
ENDIF
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <vpaddd zmm16,zmm30,DWORD BCST [r12]>
|
||||
EmitIfCountGE RowCount, 1, <vpaddd zmm17,zmm31,DWORD BCST [r12]>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <vpaddd zmm18,zmm30,DWORD BCST [r12+4]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd zmm19,zmm31,DWORD BCST [r12+4]>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <vpaddd zmm20,zmm30,DWORD BCST [r12+8]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd zmm21,zmm31,DWORD BCST [r12+8]>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <vpaddd zmm22,zmm30,DWORD BCST [r12+12]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd zmm23,zmm31,DWORD BCST [r12+12]>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <vpaddd zmm24,zmm30,DWORD BCST [r12+16]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd zmm25,zmm31,DWORD BCST [r12+16]>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <vpaddd zmm26,zmm30,DWORD BCST [r12+20]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd zmm27,zmm31,DWORD BCST [r12+20]>
|
||||
|
||||
;
|
||||
; Iterate over PairedCountK elements from matrix A and matrix B.
|
||||
;
|
||||
|
||||
mov rsi,r9 ; reload PairedCountK
|
||||
IF RowCount GT 3
|
||||
lea rbx,[r10*2+r10]
|
||||
add rbx,rcx ; compute matrix A plus 3 rows
|
||||
ENDIF
|
||||
|
||||
ComputeBlockLoop:
|
||||
ComputeBlock ColumnCount, RowCount
|
||||
ComputeBlockBy1Loop:
|
||||
ComputeBlock ColumnCount, RowCount, 0, 0
|
||||
add rcx,4 ; advance matrix A by 1 pair
|
||||
IF RowCount GT 3
|
||||
add rbx,4 ; advance matrix A plus 3 rows by 1 pair
|
||||
ENDIF
|
||||
add rdx,32
|
||||
dec rsi ; decrement pairs remaining
|
||||
jnz ComputeBlockLoop
|
||||
|
||||
IF RowCount GT 3
|
||||
lea rbx,[r8+rax*2] ; compute matrix C plus 3 rows
|
||||
add rbx,rax
|
||||
ENDIF
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to compute matrix multiplication for a fixed set
|
||||
; of rows.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; RowCount - Supplies the number of rows to process.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rax - Supplies the length in bytes of a row from matrix C.
|
||||
;
|
||||
; rcx - Supplies the address of matrix A.
|
||||
;
|
||||
; rdx - Supplies the address of matrix B.
|
||||
;
|
||||
; r8 - Supplies the address of matrix C.
|
||||
;
|
||||
; rdi - Supplies the address of matrix A.
|
||||
;
|
||||
; rbp - Supplies the number of columns from matrix B and matrix C to iterate
|
||||
; over.
|
||||
;
|
||||
; r9 - Supplies the number of paired columns from matrix A and the number of
|
||||
; paired rows from matrix B to iterate over.
|
||||
;
|
||||
; r10 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r12 - Supplies the address of the row sum vector.
|
||||
;
|
||||
; r13 - Supplies the address of the column sum vector.
|
||||
;
|
||||
; r14b - Supplies the zero mode flag.
|
||||
;
|
||||
|
||||
ProcessCountM MACRO RowCount
|
||||
|
||||
LOCAL ProcessNextColumnLoop32xN
|
||||
LOCAL SkipAccumulateOutput32xNBlock
|
||||
LOCAL Output16xNBlock
|
||||
LOCAL Output16xNBlockWithMask
|
||||
LOCAL SkipAccumulateOutput16xNBlockWithMask
|
||||
LOCAL ProcessRemainingCountN
|
||||
|
||||
cmp rbp,16
|
||||
jbe ProcessRemainingCountN
|
||||
|
||||
ProcessNextColumnLoop32xN:
|
||||
ProduceOutputBlock 32, RowCount
|
||||
lea rdx,[rdx+r10*8] ; advance matrix B by 8*PairedCountK
|
||||
test r14b,r14b ; ZeroMode?
|
||||
jnz SkipAccumulateOutput32xNBlock
|
||||
EmitIfCountGE RowCount, 1, <vpaddd zmm16,zmm16,ZMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd zmm18,zmm18,ZMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd zmm20,zmm20,ZMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd zmm22,zmm22,ZMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd zmm24,zmm24,ZMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd zmm26,zmm26,ZMMWORD PTR [rbx+rax*2]>
|
||||
|
||||
SkipAccumulateOutput32xNBlock:
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu32 ZMMWORD PTR [r8],zmm16>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu32 ZMMWORD PTR [r8+rax],zmm18>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu32 ZMMWORD PTR [r8+rax*2],zmm20>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu32 ZMMWORD PTR [rbx],zmm22>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu32 ZMMWORD PTR [rbx+rax],zmm24>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm26>
|
||||
add r8,16*4 ; advance matrix C by 16 columns
|
||||
IF RowCount GT 3
|
||||
add rbx,16*4 ; advance matrix C plus 3 rows by 16 columns
|
||||
ENDIF
|
||||
sub rbp,16
|
||||
|
||||
Output16xNBlock:
|
||||
sub rbp,16
|
||||
jae Output16xNBlockWithMask
|
||||
lea ecx,[ebp+16] ; correct for over-subtract above
|
||||
mov esi,1
|
||||
shl esi,cl
|
||||
dec esi
|
||||
kmovw k1,esi ; update mask for remaining columns
|
||||
xor ebp,ebp ; no more columns remaining
|
||||
|
||||
Output16xNBlockWithMask:
|
||||
test r14b,r14b ; ZeroMode?
|
||||
jnz SkipAccumulateOutput16xNBlockWithMask
|
||||
EmitIfCountGE RowCount, 1, <vpaddd zmm17{k1},zmm17,ZMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd zmm19{k1},zmm19,ZMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd zmm21{k1},zmm21,ZMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd zmm23{k1},zmm23,ZMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd zmm25{k1},zmm25,ZMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd zmm27{k1},zmm27,ZMMWORD PTR [rbx+rax*2]>
|
||||
|
||||
SkipAccumulateOutput16xNBlockWithMask:
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu32 ZMMWORD PTR [r8]{k1},zmm17>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu32 ZMMWORD PTR [r8+rax]{k1},zmm19>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu32 ZMMWORD PTR [r8+rax*2]{k1},zmm21>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu32 ZMMWORD PTR [rbx]{k1},zmm23>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu32 ZMMWORD PTR [rbx+rax]{k1},zmm25>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu32 ZMMWORD PTR [rbx+rax*2]{k1},zmm27>
|
||||
add r8,16*4 ; advance matrix C by 16 columns
|
||||
mov rcx,rdi ; reload matrix A
|
||||
cmp rbp,16
|
||||
ja ProcessNextColumnLoop32xN
|
||||
test rbp,rbp
|
||||
jz ExitKernel
|
||||
|
||||
ProcessRemainingCountN:
|
||||
ProduceOutputBlock 16, RowCount
|
||||
jmp Output16xNBlock
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates the common AVX512 code for the inner kernel to compute
|
||||
; matrix multiplication.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; Isa - Supplies the instruction set architecture string for function tags.
|
||||
;
|
||||
|
||||
GemmU8U8KernelAvx512Function MACRO Isa
|
||||
|
||||
;++
|
||||
;
|
||||
; Routine Description:
|
||||
;
|
||||
; This routine is an inner kernel to compute matrix multiplication for a
|
||||
; set of rows.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; A (rcx) - Supplies the address of matrix A. The matrix data has been packed
|
||||
; using MlasGemmU8U8CopyPackAAvx2.
|
||||
;
|
||||
; B (rdx) - Supplies the address of matrix B. The matrix data has been packed
|
||||
; using MlasGemmU8U8CopyPackBAvx2.
|
||||
;
|
||||
; C (r8) - Supplies the address of matrix C.
|
||||
;
|
||||
; PairedCountK (r9) - Supplies the number of paired columns from matrix A and
|
||||
; the number of paired rows from matrix B to iterate over.
|
||||
;
|
||||
; CountM - Supplies the maximum number of rows that can be processed for
|
||||
; matrix A and matrix C. The actual number of rows handled for this
|
||||
; invocation depends on the kernel implementation.
|
||||
;
|
||||
; CountN - Supplies the number of columns from matrix B and matrix C to iterate
|
||||
; over.
|
||||
;
|
||||
; ldc - Supplies the first dimension of matrix C.
|
||||
;
|
||||
; RowSumVector - Supplies the sum of each row from matrix A multiplied by the
|
||||
; zero point offset of matrix B. These values are accumulated into every
|
||||
; row of matrix C.
|
||||
;
|
||||
; ColumnSumVector - Supplies the sum of each column from matrix B multiplied
|
||||
; by the zero point offset of matrix A. These values are accumulated into
|
||||
; every column of matrix C.
|
||||
;
|
||||
; DepthValue - Supplies the value CountK multiplied by the zero point offset
|
||||
; of matrixA multplied by the zero point offset of matrix B. This value is
|
||||
; accumulated into every element of matrix C.
|
||||
;
|
||||
; ZeroMode - Supplies true if the output matrix must be zero initialized,
|
||||
; else false if the output matrix is accumulated into.
|
||||
;
|
||||
; Return Value:
|
||||
;
|
||||
; Returns the number of rows handled.
|
||||
;
|
||||
;--
|
||||
|
||||
NESTED_ENTRY MlasGemmU8U8Kernel&Isa&, _TEXT
|
||||
|
||||
rex_push_reg rbp
|
||||
push_reg rbx
|
||||
push_reg rsi
|
||||
push_reg rdi
|
||||
push_reg r12
|
||||
push_reg r13
|
||||
push_reg r14
|
||||
|
||||
END_PROLOGUE
|
||||
|
||||
mov rdi,rcx
|
||||
mov rbp,GemmU8U8KernelFrame.CountN[rsp]
|
||||
mov rax,GemmU8U8KernelFrame.ldc[rsp]
|
||||
shl rax,2 ; convert ldc to bytes
|
||||
lea r10,[r9*4]
|
||||
mov r11,GemmU8U8KernelFrame.CountM[rsp]
|
||||
mov r12,GemmU8U8KernelFrame.RowSumVector[rsp]
|
||||
mov r13,GemmU8U8KernelFrame.ColumnSumVector[rsp]
|
||||
movzx r14,BYTE PTR GemmU8U8KernelFrame.ZeroMode[rsp]
|
||||
mov esi,-1
|
||||
kmovw k1,esi ; update mask to write all columns
|
||||
|
||||
;
|
||||
; Process CountM rows of the matrices.
|
||||
;
|
||||
|
||||
cmp r11,5
|
||||
ja ProcessCountM6
|
||||
je ProcessCountM5
|
||||
cmp r11,3
|
||||
ja ProcessCountM4
|
||||
je ProcessCountM3
|
||||
cmp r11,1
|
||||
je ProcessCountM1
|
||||
|
||||
ProcessCountM2:
|
||||
ProcessCountM 2
|
||||
|
||||
ProcessCountM4:
|
||||
ProcessCountM 4
|
||||
|
||||
ProcessCountM6:
|
||||
mov r11d,6 ; return 6 rows handled
|
||||
ProcessCountM 6
|
||||
|
||||
;
|
||||
; Restore non-volatile registers and return.
|
||||
;
|
||||
|
||||
ExitKernel:
|
||||
mov eax,r11d
|
||||
|
||||
BEGIN_EPILOGUE
|
||||
|
||||
pop r14
|
||||
pop r13
|
||||
pop r12
|
||||
pop rdi
|
||||
pop rsi
|
||||
pop rbx
|
||||
pop rbp
|
||||
ret
|
||||
|
||||
ProcessCountM1:
|
||||
ProcessCountM 1
|
||||
|
||||
ProcessCountM3:
|
||||
ProcessCountM 3
|
||||
|
||||
ProcessCountM5:
|
||||
ProcessCountM 5
|
||||
|
||||
NESTED_END MlasGemmU8U8Kernel&Isa&, _TEXT
|
||||
add rdx,32 ; advance matrix B
|
||||
sub rsi,4
|
||||
jnz ComputeBlockBy1Loop
|
||||
|
||||
ENDM
|
||||
|
|
|
|||
|
|
@ -35,6 +35,10 @@ INCLUDE AssembleAvx512Vnni.inc
|
|||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
;
|
||||
; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
|
@ -43,41 +47,56 @@ INCLUDE AssembleAvx512Vnni.inc
|
|||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r10 - Supplies the length in bytes of a row from matrix A.
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; zmm16-zmm27 - Supplies the block accumulators.
|
||||
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
;
|
||||
; zmm14-zmm31 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ComputeBlock MACRO ColumnCount, RowCount
|
||||
ComputeBlock MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
vpmovzxbw zmm28,YMMWORD PTR [rdx]
|
||||
IF ColumnCount EQ 32
|
||||
vpmovzxbw zmm29,YMMWORD PTR [rdx+r10*8]
|
||||
EmitIfCountGE RowCount, 1, <vpbroadcastd zmm30,DWORD PTR [rcx]>
|
||||
EmitIfCountGE RowCount, 1, <VpdpwssdZmmZmmZmm zmm16,zmm28,zmm30>
|
||||
EmitIfCountGE RowCount, 1, <VpdpwssdZmmZmmZmm zmm17,zmm29,zmm30>
|
||||
EmitIfCountGE RowCount, 2, <vpbroadcastd zmm30,DWORD PTR [rcx+r10]>
|
||||
EmitIfCountGE RowCount, 2, <VpdpwssdZmmZmmZmm zmm18,zmm28,zmm30>
|
||||
EmitIfCountGE RowCount, 2, <VpdpwssdZmmZmmZmm zmm19,zmm29,zmm30>
|
||||
EmitIfCountGE RowCount, 3, <vpbroadcastd zmm30,DWORD PTR [rcx+r10*2]>
|
||||
EmitIfCountGE RowCount, 3, <VpdpwssdZmmZmmZmm zmm20,zmm28,zmm30>
|
||||
EmitIfCountGE RowCount, 3, <VpdpwssdZmmZmmZmm zmm21,zmm29,zmm30>
|
||||
EmitIfCountGE RowCount, 4, <vpbroadcastd zmm30,DWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 4, <VpdpwssdZmmZmmZmm zmm22,zmm28,zmm30>
|
||||
EmitIfCountGE RowCount, 4, <VpdpwssdZmmZmmZmm zmm23,zmm29,zmm30>
|
||||
EmitIfCountGE RowCount, 5, <vpbroadcastd zmm30,DWORD PTR [rbx+r10]>
|
||||
EmitIfCountGE RowCount, 5, <VpdpwssdZmmZmmZmm zmm24,zmm28,zmm30>
|
||||
EmitIfCountGE RowCount, 5, <VpdpwssdZmmZmmZmm zmm25,zmm29,zmm30>
|
||||
EmitIfCountGE RowCount, 6, <vpbroadcastd zmm30,DWORD PTR [rbx+r10*2]>
|
||||
EmitIfCountGE RowCount, 6, <VpdpwssdZmmZmmZmm zmm26,zmm28,zmm30>
|
||||
EmitIfCountGE RowCount, 6, <VpdpwssdZmmZmmZmm zmm27,zmm29,zmm30>
|
||||
IF ColumnCount GE 32
|
||||
IF ColumnCount GE 48
|
||||
vpmovzxbw zmm0,YMMWORD PTR [rdx+VectorOffset]
|
||||
vpmovzxbw zmm1,YMMWORD PTR [rdx+r14+VectorOffset]
|
||||
vpmovzxbw zmm2,YMMWORD PTR [rdx+r14*2+VectorOffset]
|
||||
ELSE
|
||||
EmitIfCountGE RowCount, 1, <VpdpwssdZmmZmmBroadcast zmm17,zmm28,rcx>
|
||||
EmitIfCountGE RowCount, 2, <VpdpwssdZmmZmmBroadcast zmm19,zmm28,rcx,r10,1>
|
||||
EmitIfCountGE RowCount, 3, <VpdpwssdZmmZmmBroadcast zmm21,zmm28,rcx,r10,2>
|
||||
EmitIfCountGE RowCount, 4, <VpdpwssdZmmZmmBroadcast zmm23,zmm28,rbx>
|
||||
EmitIfCountGE RowCount, 5, <VpdpwssdZmmZmmBroadcast zmm25,zmm28,rbx,r10,1>
|
||||
EmitIfCountGE RowCount, 6, <VpdpwssdZmmZmmBroadcast zmm27,zmm28,rbx,r10,2>
|
||||
vpmovzxbw zmm1,YMMWORD PTR [rdx+VectorOffset]
|
||||
vpmovzxbw zmm2,YMMWORD PTR [rdx+r14+VectorOffset]
|
||||
ENDIF
|
||||
EmitIfCountGE RowCount, 1, <vpbroadcastd zmm3,DWORD PTR [rcx+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <VpdpwssdZmmZmmZmm zmm26,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <VpdpwssdZmmZmmZmm zmm20,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <VpdpwssdZmmZmmZmm zmm14,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 2, <vpbroadcastd zmm3,DWORD PTR [rcx+r9+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <VpdpwssdZmmZmmZmm zmm27,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <VpdpwssdZmmZmmZmm zmm21,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <VpdpwssdZmmZmmZmm zmm15,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 3, <vpbroadcastd zmm3,DWORD PTR [rcx+r9*2+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <VpdpwssdZmmZmmZmm zmm28,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <VpdpwssdZmmZmmZmm zmm22,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <VpdpwssdZmmZmmZmm zmm16,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 4, <vpbroadcastd zmm3,DWORD PTR [rbx+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <VpdpwssdZmmZmmZmm zmm29,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <VpdpwssdZmmZmmZmm zmm23,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <VpdpwssdZmmZmmZmm zmm17,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 5, <vpbroadcastd zmm3,DWORD PTR [rbx+r9+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <VpdpwssdZmmZmmZmm zmm30,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <VpdpwssdZmmZmmZmm zmm24,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <VpdpwssdZmmZmmZmm zmm18,zmm3,zmm2>
|
||||
EmitIfCountGE RowCount, 6, <vpbroadcastd zmm3,DWORD PTR [rbx+r9*2+BroadcastOffset]>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <VpdpwssdZmmZmmZmm zmm31,zmm3,zmm0>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <VpdpwssdZmmZmmZmm zmm25,zmm3,zmm1>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <VpdpwssdZmmZmmZmm zmm19,zmm3,zmm2>
|
||||
ELSE
|
||||
vpmovzxbw zmm2,YMMWORD PTR [rdx+VectorOffset]
|
||||
EmitIfCountGE RowCount, 1, <VpdpwssdZmmZmmBroadcast zmm14,zmm2,rcx,BroadcastOffset>
|
||||
EmitIfCountGE RowCount, 2, <VpdpwssdZmmZmmBroadcast zmm15,zmm2,rcx,BroadcastOffset,r9,1>
|
||||
EmitIfCountGE RowCount, 3, <VpdpwssdZmmZmmBroadcast zmm16,zmm2,rcx,BroadcastOffset,r9,2>
|
||||
EmitIfCountGE RowCount, 4, <VpdpwssdZmmZmmBroadcast zmm17,zmm2,rbx,BroadcastOffset>
|
||||
EmitIfCountGE RowCount, 5, <VpdpwssdZmmZmmBroadcast zmm18,zmm2,rbx,BroadcastOffset,r9,1>
|
||||
EmitIfCountGE RowCount, 6, <VpdpwssdZmmZmmBroadcast zmm19,zmm2,rbx,BroadcastOffset,r9,2>
|
||||
ENDIF
|
||||
|
||||
ENDM
|
||||
|
|
@ -86,6 +105,6 @@ ENDIF
|
|||
; Generate the GEMM kernel.
|
||||
;
|
||||
|
||||
GemmU8U8KernelAvx512Function Avx512Vnni
|
||||
GemmU8X8KernelAvx512Function U8U8, Avx512Vnni
|
||||
|
||||
END
|
||||
|
|
|
|||
302
onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2Common.inc
Normal file
302
onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2Common.inc
Normal file
|
|
@ -0,0 +1,302 @@
|
|||
;++
|
||||
;
|
||||
; Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
;
|
||||
; Licensed under the MIT License.
|
||||
;
|
||||
; Module Name:
|
||||
;
|
||||
; QgemmU8X8KernelAvx2Common.inc
|
||||
;
|
||||
; Abstract:
|
||||
;
|
||||
; This module contains common kernel macros and structures for the quantized
|
||||
; integer matrix/matrix multiply operation (QGEMM) for the AVX2 kernels.
|
||||
;
|
||||
;--
|
||||
|
||||
EXTERN MlasMaskMoveAvx:NEAR
|
||||
|
||||
;
|
||||
; Stack frame layout for the U8S8 and U8U8 kernels.
|
||||
;
|
||||
|
||||
GemmU8X8KernelFrame STRUCT
|
||||
|
||||
SavedXmm6 OWORD ?
|
||||
SavedXmm7 OWORD ?
|
||||
SavedXmm8 OWORD ?
|
||||
SavedXmm9 OWORD ?
|
||||
SavedXmm10 OWORD ?
|
||||
SavedXmm11 OWORD ?
|
||||
SavedXmm12 OWORD ?
|
||||
SavedXmm13 OWORD ?
|
||||
SavedXmm14 OWORD ?
|
||||
SavedXmm15 OWORD ?
|
||||
Padding QWORD ?
|
||||
SavedR13 QWORD ?
|
||||
SavedR12 QWORD ?
|
||||
SavedRdi QWORD ?
|
||||
SavedRsi QWORD ?
|
||||
SavedRbx QWORD ?
|
||||
SavedRbp QWORD ?
|
||||
ReturnAddress QWORD ?
|
||||
PreviousP1Home QWORD ?
|
||||
PreviousP2Home QWORD ?
|
||||
PreviousP3Home QWORD ?
|
||||
PreviousP4Home QWORD ?
|
||||
CountM QWORD ?
|
||||
CountN QWORD ?
|
||||
ldc QWORD ?
|
||||
RowSumVector QWORD ?
|
||||
ColumnSumVector QWORD ?
|
||||
DepthValue QWORD ?
|
||||
ZeroMode QWORD ?
|
||||
|
||||
GemmU8X8KernelFrame ENDS
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to produce an output block for a set of columns
|
||||
; and rows.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rax - Supplies the length in bytes of a row from matrix C.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r12 - Supplies the address of the row sum vector.
|
||||
;
|
||||
; r13 - Supplies the address of the column sum vector.
|
||||
;
|
||||
; ymm4-ymm15 - Supplies the block accumulators.
|
||||
;
|
||||
|
||||
ProduceOutputBlock MACRO ColumnCount, RowCount
|
||||
|
||||
;
|
||||
; Initialize the accumulators with the sum of the global depth value constant,
|
||||
; the column sums, and the row sums.
|
||||
;
|
||||
|
||||
vpbroadcastd ymm1,DWORD PTR GemmU8X8KernelFrame.DepthValue[rsp]
|
||||
IF ColumnCount EQ 16
|
||||
vpaddd ymm0,ymm1,YMMWORD PTR [r13]
|
||||
vpaddd ymm1,ymm1,YMMWORD PTR [r13+32]
|
||||
add r13,16*4 ; advance ColumnSumVector by 16 columns
|
||||
ELSE
|
||||
vpaddd ymm1,ymm1,YMMWORD PTR [r13]
|
||||
ENDIF
|
||||
EmitIfCountGE RowCount, 1, <vpbroadcastd ymm5,DWORD PTR [r12]>
|
||||
EmitIfCountGE RowCount, 2, <vpbroadcastd ymm7,DWORD PTR [r12+4]>
|
||||
EmitIfCountGE RowCount, 3, <vpbroadcastd ymm9,DWORD PTR [r12+8]>
|
||||
EmitIfCountGE RowCount, 4, <vpbroadcastd ymm11,DWORD PTR [r12+12]>
|
||||
EmitIfCountGE RowCount, 5, <vpbroadcastd ymm13,DWORD PTR [r12+16]>
|
||||
EmitIfCountGE RowCount, 6, <vpbroadcastd ymm15,DWORD PTR [r12+20]>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <vpaddd ymm4,ymm5,ymm0>
|
||||
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,ymm1>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <vpaddd ymm6,ymm7,ymm0>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,ymm1>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <vpaddd ymm8,ymm9,ymm0>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,ymm1>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <vpaddd ymm10,ymm11,ymm0>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,ymm1>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <vpaddd ymm12,ymm13,ymm0>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,ymm1>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <vpaddd ymm14,ymm15,ymm0>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,ymm1>
|
||||
|
||||
;
|
||||
; Iterate over the length of a matrix A row to produce the output accumulators.
|
||||
;
|
||||
|
||||
IF RowCount GT 3
|
||||
lea rbx,[r9*2+r9]
|
||||
add rbx,rcx ; compute matrix A plus 3 rows
|
||||
ENDIF
|
||||
ComputeBlockLoop ColumnCount, RowCount
|
||||
IF RowCount GT 3
|
||||
lea rbx,[r8+rax*2] ; compute matrix C plus 3 rows
|
||||
add rbx,rax
|
||||
ENDIF
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to compute matrix multiplication for a fixed set
|
||||
; of rows.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; RowCount - Supplies the number of rows to process.
|
||||
;
|
||||
; Fallthrough - Supplies a non-blank value if the macro may fall through to
|
||||
; the ExitKernel label.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rax - Supplies the length in bytes of a row from matrix C.
|
||||
;
|
||||
; rcx - Supplies the address of matrix A.
|
||||
;
|
||||
; rdx - Supplies the address of matrix B.
|
||||
;
|
||||
; r8 - Supplies the address of matrix C.
|
||||
;
|
||||
; rdi - Supplies the address of matrix A.
|
||||
;
|
||||
; rbp - Supplies the number of columns from matrix B and matrix C to iterate
|
||||
; over.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r10b - Supplies the zero mode flag.
|
||||
;
|
||||
; r12 - Supplies the address of the row sum vector.
|
||||
;
|
||||
; r13 - Supplies the address of the column sum vector.
|
||||
;
|
||||
|
||||
ProcessCountM MACRO RowCount, Fallthrough
|
||||
|
||||
LOCAL ProcessNextColumnLoop16xN
|
||||
LOCAL SkipAccumulateOutput16xNBlock
|
||||
LOCAL OutputMasked16xNBlock
|
||||
LOCAL ProcessRemainingCountN
|
||||
LOCAL SkipAccumulateOutput8xNBlock
|
||||
LOCAL SkipAccumulateOutputMasked16xNBlock
|
||||
LOCAL OutputMasked8xNBlock
|
||||
LOCAL SkipAccumulateOutputMasked8xNBlock
|
||||
|
||||
cmp rbp,8
|
||||
jbe ProcessRemainingCountN
|
||||
|
||||
ProcessNextColumnLoop16xN:
|
||||
ProduceOutputBlock 16, RowCount
|
||||
sub rbp,16
|
||||
jb OutputMasked16xNBlock
|
||||
test r10b,r10b ; ZeroMode?
|
||||
jnz SkipAccumulateOutput16xNBlock
|
||||
EmitIfCountGE RowCount, 1, <vpaddd ymm4,ymm4,YMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,YMMWORD PTR [r8+32]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd ymm6,ymm6,YMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,YMMWORD PTR [r8+rax+32]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd ymm8,ymm8,YMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,YMMWORD PTR [r8+rax*2+32]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd ymm10,ymm10,YMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,YMMWORD PTR [rbx+32]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax+32]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2+32]>
|
||||
|
||||
SkipAccumulateOutput16xNBlock:
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8],ymm4>
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8+32],ymm5>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax],ymm6>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax+32],ymm7>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2],ymm8>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2+32],ymm9>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx],ymm10>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx+32],ymm11>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax],ymm12>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax+32],ymm13>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2],ymm14>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2+32],ymm15>
|
||||
add r8,16*4 ; advance matrix C by 16 columns
|
||||
mov rcx,rdi ; reload matrix A
|
||||
cmp rbp,8
|
||||
ja ProcessNextColumnLoop16xN
|
||||
test rbp,rbp
|
||||
jz ExitKernel
|
||||
|
||||
ProcessRemainingCountN:
|
||||
ProduceOutputBlock 8, RowCount
|
||||
cmp rbp,8
|
||||
jb OutputMasked8xNBlock
|
||||
test r10b,r10b ; ZeroMode?
|
||||
jnz SkipAccumulateOutput8xNBlock
|
||||
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,YMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,YMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,YMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,YMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2]>
|
||||
|
||||
SkipAccumulateOutput8xNBlock:
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8],ymm5>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax],ymm7>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2],ymm9>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx],ymm11>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax],ymm13>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2],ymm15>
|
||||
jmp ExitKernel
|
||||
|
||||
OutputMasked16xNBlock:
|
||||
test r10b,r10b ; ZeroMode?
|
||||
jnz SkipAccumulateOutputMasked16xNBlock
|
||||
EmitIfCountGE RowCount, 1, <vpaddd ymm4,ymm4,YMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd ymm6,ymm6,YMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd ymm8,ymm8,YMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd ymm10,ymm10,YMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]>
|
||||
|
||||
SkipAccumulateOutputMasked16xNBlock:
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu YMMWORD PTR [r8],ymm4>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu YMMWORD PTR [r8+rax],ymm6>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu YMMWORD PTR [r8+rax*2],ymm8>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu YMMWORD PTR [rbx],ymm10>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu YMMWORD PTR [rbx+rax],ymm12>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu YMMWORD PTR [rbx+rax*2],ymm14>
|
||||
add r8,8*4 ; advance matrix C by 8 columns
|
||||
IF RowCount GT 3
|
||||
add rbx,8*4 ; advance matrix C plus 3 rows by 8 columns
|
||||
ENDIF
|
||||
add rbp,8 ; correct for over-subtract above
|
||||
|
||||
OutputMasked8xNBlock:
|
||||
mov DWORD PTR GemmU8X8KernelFrame.CountN[rsp],ebp
|
||||
vpbroadcastd ymm0,DWORD PTR GemmU8X8KernelFrame.CountN[rsp]
|
||||
vpcmpgtd ymm0,ymm0,YMMWORD PTR [MlasMaskMoveAvx]
|
||||
test r10b,r10b ; ZeroMode?
|
||||
jnz SkipAccumulateOutputMasked8xNBlock
|
||||
EmitIfCountGE RowCount, 1, <vpmaskmovd ymm4,ymm0,YMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 2, <vpmaskmovd ymm6,ymm0,YMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 3, <vpmaskmovd ymm8,ymm0,YMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 4, <vpmaskmovd ymm10,ymm0,YMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 5, <vpmaskmovd ymm12,ymm0,YMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 6, <vpmaskmovd ymm14,ymm0,YMMWORD PTR [rbx+rax*2]>
|
||||
EmitIfCountGE RowCount, 1, <vpaddd ymm5,ymm5,ymm4>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd ymm7,ymm7,ymm6>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd ymm9,ymm9,ymm8>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd ymm11,ymm11,ymm10>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd ymm13,ymm13,ymm12>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd ymm15,ymm15,ymm14>
|
||||
|
||||
SkipAccumulateOutputMasked8xNBlock:
|
||||
EmitIfCountGE RowCount, 1, <vpmaskmovd YMMWORD PTR [r8],ymm0,ymm5>
|
||||
EmitIfCountGE RowCount, 2, <vpmaskmovd YMMWORD PTR [r8+rax],ymm0,ymm7>
|
||||
EmitIfCountGE RowCount, 3, <vpmaskmovd YMMWORD PTR [r8+rax*2],ymm0,ymm9>
|
||||
EmitIfCountGE RowCount, 4, <vpmaskmovd YMMWORD PTR [rbx],ymm0,ymm11>
|
||||
EmitIfCountGE RowCount, 5, <vpmaskmovd YMMWORD PTR [rbx+rax],ymm0,ymm13>
|
||||
EmitIfCountGE RowCount, 6, <vpmaskmovd YMMWORD PTR [rbx+rax*2],ymm0,ymm15>
|
||||
IFB <Fallthrough>
|
||||
jmp ExitKernel
|
||||
ENDIF
|
||||
|
||||
ENDM
|
||||
438
onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx512Common.inc
Normal file
438
onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx512Common.inc
Normal file
|
|
@ -0,0 +1,438 @@
|
|||
;++
|
||||
;
|
||||
; Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
;
|
||||
; Licensed under the MIT License.
|
||||
;
|
||||
; Module Name:
|
||||
;
|
||||
; QgemmU8X8KernelAvx512Common.inc
|
||||
;
|
||||
; Abstract:
|
||||
;
|
||||
; This module contains common kernel macros and structures for the quantized
|
||||
; integer matrix/matrix multiply operation (QGEMM) for the AVX512BW and
|
||||
; AVX512VNNI kernels.
|
||||
;
|
||||
;--
|
||||
|
||||
;
|
||||
; Stack frame layout for the U8S8 and U8U8 kernels.
|
||||
;
|
||||
|
||||
GemmU8X8KernelFrame STRUCT
|
||||
|
||||
SavedXmm14 OWORD ?
|
||||
SavedXmm15 OWORD ?
|
||||
SavedR14 QWORD ?
|
||||
SavedR13 QWORD ?
|
||||
SavedR12 QWORD ?
|
||||
SavedRdi QWORD ?
|
||||
SavedRsi QWORD ?
|
||||
SavedRbx QWORD ?
|
||||
SavedRbp QWORD ?
|
||||
ReturnAddress QWORD ?
|
||||
PreviousP1Home QWORD ?
|
||||
PreviousP2Home QWORD ?
|
||||
PreviousP3Home QWORD ?
|
||||
PreviousP4Home QWORD ?
|
||||
CountM QWORD ?
|
||||
CountN QWORD ?
|
||||
ldc QWORD ?
|
||||
RowSumVector QWORD ?
|
||||
ColumnSumVector QWORD ?
|
||||
DepthValue QWORD ?
|
||||
ZeroMode QWORD ?
|
||||
|
||||
GemmU8X8KernelFrame ENDS
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to produce an output block for a set of columns
|
||||
; and rows.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; ColumnCount - Supplies the number of columns to produce.
|
||||
;
|
||||
; RowCount - Supplies the number of rows to produce.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rax - Supplies the length in bytes of a row from matrix C.
|
||||
;
|
||||
; rcx - Supplies the address into the matrix A data.
|
||||
;
|
||||
; rdx - Supplies the address into the matrix B data.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r12 - Supplies the address of the row sum vector.
|
||||
;
|
||||
; r13 - Supplies the address of the column sum vector.
|
||||
;
|
||||
|
||||
ProduceOutputBlock MACRO ColumnCount, RowCount
|
||||
|
||||
;
|
||||
; Initialize the accumulators with the sum of the global depth value constant,
|
||||
; the column sums, and the row sums.
|
||||
;
|
||||
|
||||
vpbroadcastd zmm3,DWORD PTR GemmU8X8KernelFrame.DepthValue[rsp]
|
||||
IF ColumnCount GE 32
|
||||
IF ColumnCount GE 48
|
||||
vpaddd zmm2,zmm3,ZMMWORD PTR [r13]
|
||||
vpaddd zmm1,zmm3,ZMMWORD PTR [r13+64]
|
||||
vpaddd zmm0,zmm3,ZMMWORD PTR [r13+128]
|
||||
ELSE
|
||||
vpaddd zmm1,zmm3,ZMMWORD PTR [r13]
|
||||
vpaddd zmm0,zmm3,ZMMWORD PTR [r13+64]
|
||||
ENDIF
|
||||
add_immed r13,ColumnCount*4 ; advance ColumnSumVector by N columns
|
||||
ELSE
|
||||
vpaddd zmm0,zmm3,ZMMWORD PTR [r13]
|
||||
ENDIF
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 16, <vpaddd zmm14,zmm0,DWORD BCST [r12]>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 32, <vpaddd zmm20,zmm1,DWORD BCST [r12]>
|
||||
EmitIfCount2GE RowCount, 1, ColumnCount, 48, <vpaddd zmm26,zmm2,DWORD BCST [r12]>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 16, <vpaddd zmm15,zmm0,DWORD BCST [r12+4]>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 32, <vpaddd zmm21,zmm1,DWORD BCST [r12+4]>
|
||||
EmitIfCount2GE RowCount, 2, ColumnCount, 48, <vpaddd zmm27,zmm2,DWORD BCST [r12+4]>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 16, <vpaddd zmm16,zmm0,DWORD BCST [r12+8]>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 32, <vpaddd zmm22,zmm1,DWORD BCST [r12+8]>
|
||||
EmitIfCount2GE RowCount, 3, ColumnCount, 48, <vpaddd zmm28,zmm2,DWORD BCST [r12+8]>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 16, <vpaddd zmm17,zmm0,DWORD BCST [r12+12]>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 32, <vpaddd zmm23,zmm1,DWORD BCST [r12+12]>
|
||||
EmitIfCount2GE RowCount, 4, ColumnCount, 48, <vpaddd zmm29,zmm2,DWORD BCST [r12+12]>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 16, <vpaddd zmm18,zmm0,DWORD BCST [r12+16]>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 32, <vpaddd zmm24,zmm1,DWORD BCST [r12+16]>
|
||||
EmitIfCount2GE RowCount, 5, ColumnCount, 48, <vpaddd zmm30,zmm2,DWORD BCST [r12+16]>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 16, <vpaddd zmm19,zmm0,DWORD BCST [r12+20]>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 32, <vpaddd zmm25,zmm1,DWORD BCST [r12+20]>
|
||||
EmitIfCount2GE RowCount, 6, ColumnCount, 48, <vpaddd zmm31,zmm2,DWORD BCST [r12+20]>
|
||||
|
||||
;
|
||||
; Iterate over the length of a matrix A row to produce the output accumulators.
|
||||
;
|
||||
|
||||
IF RowCount GT 3
|
||||
lea rbx,[r9*2+r9]
|
||||
add rbx,rcx ; compute matrix A plus 3 rows
|
||||
ENDIF
|
||||
ComputeBlockLoop ColumnCount, RowCount
|
||||
IF RowCount GT 3
|
||||
lea rbx,[r8+rax*2] ; compute matrix C plus 3 rows
|
||||
add rbx,rax
|
||||
ENDIF
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates code to compute matrix multiplication for a fixed set
|
||||
; of rows.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; RowCount - Supplies the number of rows to process.
|
||||
;
|
||||
; Implicit Arguments:
|
||||
;
|
||||
; rax - Supplies the length in bytes of a row from matrix C.
|
||||
;
|
||||
; rcx - Supplies the address of matrix A.
|
||||
;
|
||||
; rdx - Supplies the address of matrix B.
|
||||
;
|
||||
; r8 - Supplies the address of matrix C.
|
||||
;
|
||||
; rdi - Supplies the address of matrix A.
|
||||
;
|
||||
; rbp - Supplies the number of columns from matrix B and matrix C to iterate
|
||||
; over.
|
||||
;
|
||||
; r9 - Supplies the length in bytes of a row from matrix A.
|
||||
;
|
||||
; r10b - Supplies the zero mode flag.
|
||||
;
|
||||
; r12 - Supplies the address of the row sum vector.
|
||||
;
|
||||
; r13 - Supplies the address of the column sum vector.
|
||||
;
|
||||
; r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
;
|
||||
|
||||
ProcessCountM MACRO RowCount
|
||||
|
||||
LOCAL ProcessNextColumnLoop32xN
|
||||
LOCAL Output32xNBlock
|
||||
LOCAL SkipAccumulateOutput32xNBlock
|
||||
LOCAL Output16xNBlock
|
||||
LOCAL Output16xNBlockWithMask
|
||||
LOCAL SkipAccumulateOutput16xNBlockWithMask
|
||||
LOCAL ProcessRemainingCountN
|
||||
LOCAL ProcessNextColumnLoop48xN
|
||||
LOCAL SkipAccumulateOutput48xNBlock
|
||||
|
||||
cmp rbp,32
|
||||
ja ProcessNextColumnLoop48xN
|
||||
cmp rbp,16
|
||||
jbe ProcessRemainingCountN
|
||||
|
||||
ProcessNextColumnLoop32xN:
|
||||
ProduceOutputBlock 32, RowCount
|
||||
add rdx,r14 ; advance matrix B by packed block stride
|
||||
|
||||
Output32xNBlock:
|
||||
test r10b,r10b ; ZeroMode?
|
||||
jnz SkipAccumulateOutput32xNBlock
|
||||
EmitIfCountGE RowCount, 1, <vpaddd zmm20,zmm20,ZMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd zmm21,zmm21,ZMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd zmm22,zmm22,ZMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd zmm23,zmm23,ZMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd zmm24,zmm24,ZMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd zmm25,zmm25,ZMMWORD PTR [rbx+rax*2]>
|
||||
|
||||
SkipAccumulateOutput32xNBlock:
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu32 ZMMWORD PTR [r8],zmm20>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu32 ZMMWORD PTR [r8+rax],zmm21>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu32 ZMMWORD PTR [r8+rax*2],zmm22>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu32 ZMMWORD PTR [rbx],zmm23>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu32 ZMMWORD PTR [rbx+rax],zmm24>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm25>
|
||||
add r8,16*4 ; advance matrix C by 16 columns
|
||||
IF RowCount GT 3
|
||||
add rbx,16*4 ; advance matrix C plus 3 rows by 16 columns
|
||||
ENDIF
|
||||
sub rbp,16
|
||||
|
||||
Output16xNBlock:
|
||||
sub rbp,16
|
||||
jae Output16xNBlockWithMask
|
||||
lea ecx,[ebp+16] ; correct for over-subtract above
|
||||
mov esi,1
|
||||
shl esi,cl
|
||||
dec esi
|
||||
kmovw k1,esi ; update mask for remaining columns
|
||||
xor ebp,ebp ; no more columns remaining
|
||||
|
||||
Output16xNBlockWithMask:
|
||||
test r10b,r10b ; ZeroMode?
|
||||
jnz SkipAccumulateOutput16xNBlockWithMask
|
||||
EmitIfCountGE RowCount, 1, <vpaddd zmm14{k1},zmm14,ZMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd zmm15{k1},zmm15,ZMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd zmm16{k1},zmm16,ZMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd zmm17{k1},zmm17,ZMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd zmm18{k1},zmm18,ZMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd zmm19{k1},zmm19,ZMMWORD PTR [rbx+rax*2]>
|
||||
|
||||
SkipAccumulateOutput16xNBlockWithMask:
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu32 ZMMWORD PTR [r8]{k1},zmm14>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu32 ZMMWORD PTR [r8+rax]{k1},zmm15>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu32 ZMMWORD PTR [r8+rax*2]{k1},zmm16>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu32 ZMMWORD PTR [rbx]{k1},zmm17>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu32 ZMMWORD PTR [rbx+rax]{k1},zmm18>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu32 ZMMWORD PTR [rbx+rax*2]{k1},zmm19>
|
||||
add r8,16*4 ; advance matrix C by 16 columns
|
||||
mov rcx,rdi ; reload matrix A
|
||||
cmp rbp,32
|
||||
ja ProcessNextColumnLoop48xN
|
||||
cmp rbp,16
|
||||
ja ProcessNextColumnLoop32xN
|
||||
test rbp,rbp
|
||||
jz ExitKernel
|
||||
|
||||
ProcessRemainingCountN:
|
||||
ProduceOutputBlock 16, RowCount
|
||||
jmp Output16xNBlock
|
||||
|
||||
ProcessNextColumnLoop48xN:
|
||||
ProduceOutputBlock 48, RowCount
|
||||
lea rdx,[rdx+r14*2] ; advance matrix B by packed block stride
|
||||
test r10b,r10b ; ZeroMode?
|
||||
jnz SkipAccumulateOutput48xNBlock
|
||||
EmitIfCountGE RowCount, 1, <vpaddd zmm26,zmm26,ZMMWORD PTR [r8]>
|
||||
EmitIfCountGE RowCount, 2, <vpaddd zmm27,zmm27,ZMMWORD PTR [r8+rax]>
|
||||
EmitIfCountGE RowCount, 3, <vpaddd zmm28,zmm28,ZMMWORD PTR [r8+rax*2]>
|
||||
EmitIfCountGE RowCount, 4, <vpaddd zmm29,zmm29,ZMMWORD PTR [rbx]>
|
||||
EmitIfCountGE RowCount, 5, <vpaddd zmm30,zmm30,ZMMWORD PTR [rbx+rax]>
|
||||
EmitIfCountGE RowCount, 6, <vpaddd zmm31,zmm31,ZMMWORD PTR [rbx+rax*2]>
|
||||
|
||||
SkipAccumulateOutput48xNBlock:
|
||||
EmitIfCountGE RowCount, 1, <vmovdqu32 ZMMWORD PTR [r8],zmm26>
|
||||
EmitIfCountGE RowCount, 2, <vmovdqu32 ZMMWORD PTR [r8+rax],zmm27>
|
||||
EmitIfCountGE RowCount, 3, <vmovdqu32 ZMMWORD PTR [r8+rax*2],zmm28>
|
||||
EmitIfCountGE RowCount, 4, <vmovdqu32 ZMMWORD PTR [rbx],zmm29>
|
||||
EmitIfCountGE RowCount, 5, <vmovdqu32 ZMMWORD PTR [rbx+rax],zmm30>
|
||||
EmitIfCountGE RowCount, 6, <vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm31>
|
||||
add r8,16*4 ; advance matrix C by 16 columns
|
||||
IF RowCount GT 3
|
||||
add rbx,16*4 ; advance matrix C plus 3 rows by 16 columns
|
||||
ENDIF
|
||||
sub rbp,16
|
||||
jmp Output32xNBlock
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
; Macro Description:
|
||||
;
|
||||
; This macro generates the common AVX512 code for the inner kernel to compute
|
||||
; matrix multiplication.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; Type - Supplies the kernel type string for function tags.
|
||||
;
|
||||
; Isa - Supplies the instruction set architecture string for function tags.
|
||||
;
|
||||
|
||||
GemmU8X8KernelAvx512Function MACRO Type, Isa
|
||||
|
||||
;++
|
||||
;
|
||||
; Routine Description:
|
||||
;
|
||||
; This routine is an inner kernel to compute matrix multiplication for a
|
||||
; set of rows.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; A (rcx) - Supplies the address of matrix A. The matrix data has been packed
|
||||
; using MlasGemmU8X8CopyPackAAvx2.
|
||||
;
|
||||
; B (rdx) - Supplies the address of matrix B. The matrix data has been packed
|
||||
; using MlasGemmU8X8CopyPackBAvx2.
|
||||
;
|
||||
; C (r8) - Supplies the address of matrix C.
|
||||
;
|
||||
; QuadCountK (r9) - Supplies the number of quad columns from matrix A and the
|
||||
; number of quad rows from matrix B to iterate over.
|
||||
;
|
||||
; CountM - Supplies the maximum number of rows that can be processed for
|
||||
; matrix A and matrix C. The actual number of rows handled for this
|
||||
; invocation depends on the kernel implementation.
|
||||
;
|
||||
; CountN - Supplies the number of columns from matrix B and matrix C to iterate
|
||||
; over.
|
||||
;
|
||||
; ldc - Supplies the first dimension of matrix C.
|
||||
;
|
||||
; RowSumVector - Supplies the sum of each row from matrix A multiplied by the
|
||||
; zero point offset of matrix B. These values are accumulated into every
|
||||
; row of matrix C.
|
||||
;
|
||||
; ColumnSumVector - Supplies the sum of each column from matrix B multiplied
|
||||
; by the zero point offset of matrix A. These values are accumulated into
|
||||
; every column of matrix C.
|
||||
;
|
||||
; DepthValue - Supplies the value CountK multiplied by the zero point offset
|
||||
; of matrixA multplied by the zero point offset of matrix B. This value is
|
||||
; accumulated into every element of matrix C.
|
||||
;
|
||||
; ZeroMode - Supplies true if the output matrix must be zero initialized,
|
||||
; else false if the output matrix is accumulated into.
|
||||
;
|
||||
; Return Value:
|
||||
;
|
||||
; Returns the number of rows handled.
|
||||
;
|
||||
;--
|
||||
|
||||
NESTED_ENTRY MlasGemm&Type&Kernel&Isa&, _TEXT
|
||||
|
||||
rex_push_reg rbp
|
||||
push_reg rbx
|
||||
push_reg rsi
|
||||
push_reg rdi
|
||||
push_reg r12
|
||||
push_reg r13
|
||||
push_reg r14
|
||||
alloc_stack (GemmU8X8KernelFrame.SavedR14)
|
||||
save_xmm128 xmm14,GemmU8X8KernelFrame.SavedXmm14
|
||||
save_xmm128 xmm15,GemmU8X8KernelFrame.SavedXmm15
|
||||
|
||||
END_PROLOGUE
|
||||
|
||||
mov rdi,rcx
|
||||
mov rbp,GemmU8X8KernelFrame.CountN[rsp]
|
||||
mov rax,GemmU8X8KernelFrame.ldc[rsp]
|
||||
shl rax,2 ; convert ldc to bytes
|
||||
shl r9,2 ; convert to row length
|
||||
movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp]
|
||||
mov r11,GemmU8X8KernelFrame.CountM[rsp]
|
||||
mov r12,GemmU8X8KernelFrame.RowSumVector[rsp]
|
||||
mov r13,GemmU8X8KernelFrame.ColumnSumVector[rsp]
|
||||
mov esi,-1
|
||||
kmovw k1,esi ; update mask to write all columns
|
||||
IFIDNI <Type>, <U8S8>
|
||||
IFIDNI <Isa>, <Avx512BW>
|
||||
neg esi
|
||||
vpbroadcastw zmm5,esi ; generate 512-bit word vector [0x0001]
|
||||
ENDIF
|
||||
mov r14,r9
|
||||
shl r14,4 ; compute matrix B packed stride
|
||||
ELSE
|
||||
lea r14,[r9*8] ; compute matrix B packed stride
|
||||
ENDIF
|
||||
|
||||
;
|
||||
; Process CountM rows of the matrices.
|
||||
;
|
||||
|
||||
cmp r11,5
|
||||
ja ProcessCountM6
|
||||
je ProcessCountM5
|
||||
cmp r11,3
|
||||
ja ProcessCountM4
|
||||
je ProcessCountM3
|
||||
cmp r11,1
|
||||
je ProcessCountM1
|
||||
|
||||
ProcessCountM2:
|
||||
ProcessCountM 2
|
||||
|
||||
ProcessCountM4:
|
||||
ProcessCountM 4
|
||||
|
||||
ProcessCountM6:
|
||||
mov r11d,6 ; return 6 rows handled
|
||||
ProcessCountM 6
|
||||
|
||||
;
|
||||
; Restore non-volatile registers and return.
|
||||
;
|
||||
|
||||
ExitKernel:
|
||||
mov eax,r11d
|
||||
vzeroupper
|
||||
movaps xmm14,GemmU8X8KernelFrame.SavedXmm14[rsp]
|
||||
movaps xmm15,GemmU8X8KernelFrame.SavedXmm15[rsp]
|
||||
add rsp,(GemmU8X8KernelFrame.SavedR14)
|
||||
|
||||
BEGIN_EPILOGUE
|
||||
|
||||
pop r14
|
||||
pop r13
|
||||
pop r12
|
||||
pop rdi
|
||||
pop rsi
|
||||
pop rbx
|
||||
pop rbp
|
||||
ret
|
||||
|
||||
ProcessCountM1:
|
||||
ProcessCountM 1
|
||||
|
||||
ProcessCountM3:
|
||||
ProcessCountM 3
|
||||
|
||||
ProcessCountM5:
|
||||
ProcessCountM 5
|
||||
|
||||
NESTED_END MlasGemm&Type&Kernel&Isa&, _TEXT
|
||||
|
||||
ENDM
|
||||
|
|
@ -194,6 +194,52 @@ void
|
|||
|
||||
typedef MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE* PMLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE;
|
||||
|
||||
typedef
|
||||
void
|
||||
(MLASCALL MLAS_GEMM_U8S8_COPY_PACKA_ROUTINE)(
|
||||
uint8_t* D,
|
||||
const uint8_t* A,
|
||||
size_t lda,
|
||||
size_t CountM,
|
||||
size_t CountK,
|
||||
int32_t* RowSumVector,
|
||||
int16_t offb
|
||||
);
|
||||
|
||||
typedef MLAS_GEMM_U8S8_COPY_PACKA_ROUTINE* PMLAS_GEMM_U8S8_COPY_PACKA_ROUTINE;
|
||||
|
||||
typedef
|
||||
void
|
||||
(MLASCALL MLAS_GEMM_U8S8_COPY_PACKB_ROUTINE)(
|
||||
int8_t* D,
|
||||
const int8_t* B,
|
||||
size_t ldb,
|
||||
size_t CountN,
|
||||
size_t CountK,
|
||||
int32_t* ColumnSumVector,
|
||||
int16_t offa
|
||||
);
|
||||
|
||||
typedef MLAS_GEMM_U8S8_COPY_PACKB_ROUTINE* PMLAS_GEMM_U8S8_COPY_PACKB_ROUTINE;
|
||||
|
||||
typedef
|
||||
size_t
|
||||
(MLASCALL MLAS_GEMM_U8S8_KERNEL)(
|
||||
const uint8_t* A,
|
||||
const int8_t* B,
|
||||
int32_t* C,
|
||||
size_t QuadCountK,
|
||||
size_t CountM,
|
||||
size_t CountN,
|
||||
size_t ldc,
|
||||
const int32_t* RowSumVector,
|
||||
const int32_t* ColumnSumVector,
|
||||
int32_t DepthValue,
|
||||
bool ZeroMode
|
||||
);
|
||||
|
||||
typedef MLAS_GEMM_U8S8_KERNEL* PMLAS_GEMM_U8S8_KERNEL;
|
||||
|
||||
typedef
|
||||
void
|
||||
(MLASCALL MLAS_GEMM_U8U8_COPY_PACKA_ROUTINE)(
|
||||
|
|
@ -228,7 +274,7 @@ size_t
|
|||
const int16_t* A,
|
||||
const uint8_t* B,
|
||||
int32_t* C,
|
||||
size_t PairedCountK,
|
||||
size_t PairCountK,
|
||||
size_t CountM,
|
||||
size_t CountN,
|
||||
size_t ldc,
|
||||
|
|
@ -364,10 +410,18 @@ extern "C" {
|
|||
#endif
|
||||
|
||||
#if defined(MLAS_TARGET_AMD64_IX86)
|
||||
MLAS_GEMM_U8S8_COPY_PACKA_ROUTINE MlasGemmU8S8CopyPackASse;
|
||||
MLAS_GEMM_U8S8_COPY_PACKB_ROUTINE MlasGemmU8S8CopyPackBSse;
|
||||
MLAS_GEMM_U8S8_KERNEL MlasGemmU8S8KernelSse;
|
||||
MLAS_GEMM_U8U8_COPY_PACKA_ROUTINE MlasGemmU8U8CopyPackASse;
|
||||
MLAS_GEMM_U8U8_COPY_PACKB_ROUTINE MlasGemmU8U8CopyPackBSse;
|
||||
MLAS_GEMM_U8U8_KERNEL MlasGemmU8U8KernelSse;
|
||||
#if defined(MLAS_TARGET_AMD64)
|
||||
MLAS_GEMM_U8S8_COPY_PACKA_ROUTINE MlasGemmU8S8CopyPackAAvx2;
|
||||
MLAS_GEMM_U8S8_COPY_PACKB_ROUTINE MlasGemmU8S8CopyPackBAvx2;
|
||||
MLAS_GEMM_U8S8_KERNEL MlasGemmU8S8KernelAvx2;
|
||||
MLAS_GEMM_U8S8_KERNEL MlasGemmU8S8KernelAvx512BW;
|
||||
MLAS_GEMM_U8S8_KERNEL MlasGemmU8S8KernelAvx512Vnni;
|
||||
MLAS_GEMM_U8U8_COPY_PACKA_ROUTINE MlasGemmU8U8CopyPackAAvx2;
|
||||
MLAS_GEMM_U8U8_COPY_PACKB_ROUTINE MlasGemmU8U8CopyPackBAvx2;
|
||||
MLAS_GEMM_U8U8_KERNEL MlasGemmU8U8KernelAvx2;
|
||||
|
|
@ -490,6 +544,9 @@ struct MLAS_PLATFORM {
|
|||
|
||||
#if defined(MLAS_TARGET_AMD64_IX86)
|
||||
PMLAS_GEMM_FLOAT_KERNEL GemmFloatKernel;
|
||||
PMLAS_GEMM_U8S8_COPY_PACKA_ROUTINE GemmU8S8CopyPackARoutine;
|
||||
PMLAS_GEMM_U8S8_COPY_PACKB_ROUTINE GemmU8S8CopyPackBRoutine;
|
||||
PMLAS_GEMM_U8S8_KERNEL GemmU8S8Kernel;
|
||||
PMLAS_GEMM_U8U8_COPY_PACKA_ROUTINE GemmU8U8CopyPackARoutine;
|
||||
PMLAS_GEMM_U8U8_COPY_PACKB_ROUTINE GemmU8U8CopyPackBRoutine;
|
||||
PMLAS_GEMM_U8U8_KERNEL GemmU8U8Kernel;
|
||||
|
|
|
|||
|
|
@ -85,6 +85,9 @@ Return Value:
|
|||
//
|
||||
|
||||
this->GemmFloatKernel = MlasGemmFloatKernelSse;
|
||||
this->GemmU8S8CopyPackARoutine = MlasGemmU8S8CopyPackASse;
|
||||
this->GemmU8S8CopyPackBRoutine = MlasGemmU8S8CopyPackBSse;
|
||||
this->GemmU8S8Kernel = MlasGemmU8S8KernelSse;
|
||||
this->GemmU8U8CopyPackARoutine = MlasGemmU8U8CopyPackASse;
|
||||
this->GemmU8U8CopyPackBRoutine = MlasGemmU8U8CopyPackBSse;
|
||||
this->GemmU8U8Kernel = MlasGemmU8U8KernelSse;
|
||||
|
|
@ -157,6 +160,9 @@ Return Value:
|
|||
|
||||
if (((Cpuid1[2] & 0x1000) != 0) && ((Cpuid7[1] & 0x20) != 0)) {
|
||||
|
||||
this->GemmU8S8CopyPackARoutine = MlasGemmU8S8CopyPackAAvx2;
|
||||
this->GemmU8S8CopyPackBRoutine = MlasGemmU8S8CopyPackBAvx2;
|
||||
this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx2;
|
||||
this->GemmU8U8CopyPackARoutine = MlasGemmU8U8CopyPackAAvx2;
|
||||
this->GemmU8U8CopyPackBRoutine = MlasGemmU8U8CopyPackBAvx2;
|
||||
this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx2;
|
||||
|
|
@ -180,6 +186,7 @@ Return Value:
|
|||
|
||||
if ((Cpuid7[1] & 0x40000000) != 0) {
|
||||
|
||||
this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx512BW;
|
||||
this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx512BW;
|
||||
|
||||
//
|
||||
|
|
@ -187,6 +194,8 @@ Return Value:
|
|||
//
|
||||
|
||||
if ((Cpuid7[2] & 0x800) != 0) {
|
||||
|
||||
this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx512Vnni;
|
||||
this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx512Vnni;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,12 +21,577 @@ Abstract:
|
|||
// Define the default strides to step through slices of the input matrices.
|
||||
//
|
||||
|
||||
#define MLAS_GEMM_U8U8_STRIDEM 12
|
||||
#define MLAS_GEMM_U8U8_STRIDEN 128
|
||||
#define MLAS_GEMM_U8S8_STRIDEM 24
|
||||
#define MLAS_GEMM_U8S8_STRIDEN 256
|
||||
#define MLAS_GEMM_U8S8_STRIDEK 128
|
||||
|
||||
#define MLAS_GEMM_U8U8_STRIDEM 24
|
||||
#define MLAS_GEMM_U8U8_STRIDEN 256
|
||||
#define MLAS_GEMM_U8U8_STRIDEK 128
|
||||
|
||||
#ifdef MLAS_TARGET_AMD64_IX86
|
||||
|
||||
//
|
||||
// U8S8 implementation using SSE2 intrinsics.
|
||||
//
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasGemmU8S8CopyPackASse(
|
||||
uint8_t* D,
|
||||
const uint8_t* A,
|
||||
size_t lda,
|
||||
size_t CountM,
|
||||
size_t CountK,
|
||||
int32_t* RowSumVector,
|
||||
int16_t offb
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine copies elements from the source matrix to the destination
|
||||
packed buffer.
|
||||
|
||||
Arguments:
|
||||
|
||||
D - Supplies the address of the destination packed buffer.
|
||||
|
||||
A - Supplies the address of the source matrix.
|
||||
|
||||
lda - Supplies the number of elements per row of the source matrix.
|
||||
|
||||
CountM - Supplies the number of rows of the source matrix to copy.
|
||||
|
||||
CountK - Supplies the number of columns of the source matrix to copy.
|
||||
|
||||
RowSumVector - Supplies the address of the buffer to receive the sums of
|
||||
the elements from each of the rows. Each sum has also been multiplied
|
||||
by the zero point offset.
|
||||
|
||||
offb - Supplies the zero point offset for the other source matrix of the
|
||||
matrix multiplication.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
{
|
||||
const __m128i ZeroVector = _mm_setzero_si128();
|
||||
const __m128i OffsetBroadcast = _mm_set1_epi16(offb);
|
||||
uint8_t PaddedMatrixAData[8] = { 0 };
|
||||
|
||||
//
|
||||
// Process a single row of matrix A in a loop.
|
||||
//
|
||||
|
||||
while (CountM > 0) {
|
||||
|
||||
const uint8_t* a = A;
|
||||
size_t k = CountK;
|
||||
__m128i RowSum = ZeroVector;
|
||||
|
||||
//
|
||||
// Copy the source bytes to the packed buffer.
|
||||
//
|
||||
// The packed buffer has the same data ordering as the source bytes,
|
||||
// but CountK is aligned up to a multiple of 4 to maintain 32-bit
|
||||
// alignment. All extra bytes are zero-padded.
|
||||
//
|
||||
// These values are also zero-extended and accumulated into an
|
||||
// intermediate per-row accumulator. CountK cannot be greater than 128
|
||||
// to avoid overflowing these signed 16-bit accumulators.
|
||||
//
|
||||
|
||||
while (k >= 8) {
|
||||
|
||||
__m128i Bytes = _mm_loadl_epi64((__m128i*)&a[0]);
|
||||
_mm_storel_epi64((__m128i*)&D[0], Bytes);
|
||||
|
||||
RowSum = _mm_add_epi16(RowSum, _mm_unpacklo_epi8(Bytes, ZeroVector));
|
||||
|
||||
D += 8;
|
||||
a += 8;
|
||||
k -= 8;
|
||||
}
|
||||
|
||||
if (k > 0) {
|
||||
|
||||
//
|
||||
// Copy the remaining bytes to the zero padded stack buffer.
|
||||
//
|
||||
|
||||
uint8_t* padded = PaddedMatrixAData;
|
||||
uint8_t* padded_end = padded + k;
|
||||
|
||||
do {
|
||||
padded[0] = a[0];
|
||||
padded++;
|
||||
a++;
|
||||
} while (padded < padded_end);
|
||||
|
||||
__m128i Bytes = _mm_loadl_epi64((__m128i*)PaddedMatrixAData);
|
||||
_mm_storel_epi64((__m128i*)&D[0], Bytes);
|
||||
|
||||
RowSum = _mm_add_epi16(RowSum, _mm_unpacklo_epi8(Bytes, ZeroVector));
|
||||
|
||||
//
|
||||
// Copy quads of 8-bit values from the vector to the packed
|
||||
// buffer and rotate the vector for the next iteration.
|
||||
//
|
||||
|
||||
for (size_t quads = (k + 3) / 4; quads > 0; quads--) {
|
||||
*((int32_t*)D) = _mm_cvtsi128_si32(Bytes);
|
||||
D += 4;
|
||||
Bytes = _mm_shuffle_epi32(Bytes, _MM_SHUFFLE(0, 3, 2, 1));
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Reduce the sum for the single row of output and multiply by the
|
||||
// zero point offset of the other source matrix.
|
||||
//
|
||||
|
||||
RowSum = _mm_madd_epi16(RowSum, OffsetBroadcast);
|
||||
RowSum = _mm_add_epi32(RowSum, _mm_shuffle_epi32(RowSum, _MM_SHUFFLE(3, 2, 3, 2)));
|
||||
RowSum = _mm_add_epi32(RowSum, _mm_shuffle_epi32(RowSum, _MM_SHUFFLE(0, 1, 0, 1)));
|
||||
|
||||
*RowSumVector++ = _mm_cvtsi128_si32(RowSum);
|
||||
|
||||
A += lda;
|
||||
CountM -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasGemmU8S8CopyPackBSse(
|
||||
int8_t* D,
|
||||
const int8_t* B,
|
||||
size_t ldb,
|
||||
size_t CountN,
|
||||
size_t CountK,
|
||||
int32_t* ColumnSumVector,
|
||||
int16_t offa
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine copies elements from the source matrix to the destination
|
||||
packed buffer.
|
||||
|
||||
Arguments:
|
||||
|
||||
D - Supplies the address of the destination packed buffer.
|
||||
|
||||
B - Supplies the address of the source matrix.
|
||||
|
||||
ldb - Supplies the number of elements per row of the source matrix.
|
||||
|
||||
CountN - Supplies the number of columns of the source matrix to copy.
|
||||
|
||||
CountK - Supplies the number of rows of the source matrix to copy.
|
||||
|
||||
ColumnSumVector - Supplies the address of the buffer to receive the sums of
|
||||
the elements from each of the columns. Each sum has also been multiplied
|
||||
by the zero point offset.
|
||||
|
||||
offa - Supplies the zero point offset for the other source matrix of the
|
||||
matrix multiplication.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
{
|
||||
const __m128i ZeroVector = _mm_setzero_si128();
|
||||
const __m128i OffsetBroadcast = _mm_set1_epi16(offa);
|
||||
int8_t PaddedMatrixBData[16] = { 0 };
|
||||
|
||||
//
|
||||
// Process 8 columns of matrix B in a loop.
|
||||
//
|
||||
|
||||
while (CountN >= 8) {
|
||||
|
||||
const int8_t* b = B;
|
||||
size_t k = CountK;
|
||||
__m128i ColumnSum0 = ZeroVector;
|
||||
__m128i ColumnSum1 = ZeroVector;
|
||||
|
||||
//
|
||||
// Interleave 2 rows of matrix B and write to the packed buffer.
|
||||
//
|
||||
// These values are also sign-extended and accumulated into an
|
||||
// intermediate per-column accumulator. CountK cannot be greater than
|
||||
// 128 to avoid overflowing these signed 16-bit accumulators.
|
||||
//
|
||||
|
||||
while (k >= 2) {
|
||||
|
||||
__m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&b[0]);
|
||||
__m128i BytesRow1 = _mm_loadl_epi64((__m128i*)&b[ldb]);
|
||||
__m128i BytesInterleaved = _mm_unpacklo_epi8(BytesRow0, BytesRow1);
|
||||
|
||||
_mm_storeu_si128((__m128i*)&D[0], BytesInterleaved);
|
||||
|
||||
__m128i WordsLow = _mm_srai_epi16(_mm_unpacklo_epi8(ZeroVector, BytesInterleaved), 8);
|
||||
ColumnSum0 = _mm_add_epi16(ColumnSum0, WordsLow);
|
||||
__m128i WordsHigh = _mm_srai_epi16(_mm_unpackhi_epi8(ZeroVector, BytesInterleaved), 8);
|
||||
ColumnSum1 = _mm_add_epi16(ColumnSum1, WordsHigh);
|
||||
|
||||
b += ldb * 2;
|
||||
D += 16;
|
||||
k -= 2;
|
||||
}
|
||||
|
||||
if (k > 0) {
|
||||
|
||||
//
|
||||
// Process the remaining row of matrix B.
|
||||
//
|
||||
|
||||
__m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&b[0]);
|
||||
__m128i BytesInterleaved = _mm_unpacklo_epi8(BytesRow0, ZeroVector);
|
||||
|
||||
_mm_storeu_si128((__m128i*)&D[0], BytesInterleaved);
|
||||
|
||||
__m128i WordsLow = _mm_srai_epi16(_mm_unpacklo_epi8(ZeroVector, BytesInterleaved), 8);
|
||||
ColumnSum0 = _mm_add_epi16(ColumnSum0, WordsLow);
|
||||
__m128i WordsHigh = _mm_srai_epi16(_mm_unpackhi_epi8(ZeroVector, BytesInterleaved), 8);
|
||||
ColumnSum1 = _mm_add_epi16(ColumnSum1, WordsHigh);
|
||||
|
||||
D += 16;
|
||||
}
|
||||
|
||||
//
|
||||
// The number of rows written to the packed buffer should be a multiple
|
||||
// of 4. Zero pad the packed buffer if the block is not complete.
|
||||
//
|
||||
|
||||
if (((CountK - 1) & 2) == 0) {
|
||||
|
||||
_mm_storeu_si128((__m128i*)&D[0], ZeroVector);
|
||||
|
||||
D += 16;
|
||||
}
|
||||
|
||||
ColumnSum0 = _mm_madd_epi16(ColumnSum0, OffsetBroadcast);
|
||||
ColumnSum1 = _mm_madd_epi16(ColumnSum1, OffsetBroadcast);
|
||||
|
||||
_mm_storeu_si128((__m128i*)&ColumnSumVector[0], ColumnSum0);
|
||||
_mm_storeu_si128((__m128i*)&ColumnSumVector[4], ColumnSum1);
|
||||
|
||||
ColumnSumVector += 8;
|
||||
|
||||
B += 8;
|
||||
CountN -= 8;
|
||||
}
|
||||
|
||||
//
|
||||
// Process the remaining columns of matrix B.
|
||||
//
|
||||
|
||||
if (CountN > 0) {
|
||||
|
||||
const int8_t* b = B;
|
||||
size_t k = CountK;
|
||||
__m128i ColumnSum0 = ZeroVector;
|
||||
__m128i ColumnSum1 = ZeroVector;
|
||||
|
||||
//
|
||||
// Interleave 2 rows of matrix B and write to the packed buffer.
|
||||
//
|
||||
// These values are also sign-extended and accumulated into an
|
||||
// intermediate per-column accumulator. CountK cannot be greater than
|
||||
// 128 to avoid overflowing these signed 16-bit accumulators.
|
||||
//
|
||||
|
||||
while (k >= 2) {
|
||||
|
||||
//
|
||||
// Copy the remaining columns to the zero padded stack buffer.
|
||||
//
|
||||
|
||||
const int8_t* bcopy = b;
|
||||
int8_t* padded = PaddedMatrixBData;
|
||||
int8_t* padded_end = padded + CountN;
|
||||
|
||||
do {
|
||||
padded[0] = bcopy[0];
|
||||
padded[8] = bcopy[ldb];
|
||||
padded++;
|
||||
bcopy++;
|
||||
} while (padded < padded_end);
|
||||
|
||||
__m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&PaddedMatrixBData[0]);
|
||||
__m128i BytesRow1 = _mm_loadl_epi64((__m128i*)&PaddedMatrixBData[8]);
|
||||
__m128i BytesInterleaved = _mm_unpacklo_epi8(BytesRow0, BytesRow1);
|
||||
|
||||
_mm_storeu_si128((__m128i*)&D[0], BytesInterleaved);
|
||||
|
||||
__m128i WordsLow = _mm_srai_epi16(_mm_unpacklo_epi8(ZeroVector, BytesInterleaved), 8);
|
||||
ColumnSum0 = _mm_add_epi16(ColumnSum0, WordsLow);
|
||||
__m128i WordsHigh = _mm_srai_epi16(_mm_unpackhi_epi8(ZeroVector, BytesInterleaved), 8);
|
||||
ColumnSum1 = _mm_add_epi16(ColumnSum1, WordsHigh);
|
||||
|
||||
b += ldb * 2;
|
||||
D += 16;
|
||||
k -= 2;
|
||||
}
|
||||
|
||||
if (k > 0) {
|
||||
|
||||
//
|
||||
// Copy the remaining columns to the zero padded stack buffer.
|
||||
//
|
||||
|
||||
const int8_t* bcopy = b;
|
||||
int8_t* padded = PaddedMatrixBData;
|
||||
int8_t* padded_end = padded + CountN;
|
||||
|
||||
do {
|
||||
padded[0] = bcopy[0];
|
||||
padded++;
|
||||
bcopy++;
|
||||
} while (padded < padded_end);
|
||||
|
||||
__m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&PaddedMatrixBData[0]);
|
||||
__m128i BytesInterleaved = _mm_unpacklo_epi8(BytesRow0, ZeroVector);
|
||||
|
||||
_mm_storeu_si128((__m128i*)&D[0], BytesInterleaved);
|
||||
|
||||
__m128i WordsLow = _mm_srai_epi16(_mm_unpacklo_epi8(ZeroVector, BytesInterleaved), 8);
|
||||
ColumnSum0 = _mm_add_epi16(ColumnSum0, WordsLow);
|
||||
__m128i WordsHigh = _mm_srai_epi16(_mm_unpackhi_epi8(ZeroVector, BytesInterleaved), 8);
|
||||
ColumnSum1 = _mm_add_epi16(ColumnSum1, WordsHigh);
|
||||
|
||||
D += 16;
|
||||
}
|
||||
|
||||
//
|
||||
// The number of rows written to the packed buffer should be a multiple
|
||||
// of 4. Zero pad the packed buffer if the block is not complete.
|
||||
//
|
||||
|
||||
if (((CountK - 1) & 2) == 0) {
|
||||
|
||||
_mm_storeu_si128((__m128i*)&D[0], ZeroVector);
|
||||
|
||||
D += 16;
|
||||
}
|
||||
|
||||
ColumnSum0 = _mm_madd_epi16(ColumnSum0, OffsetBroadcast);
|
||||
ColumnSum1 = _mm_madd_epi16(ColumnSum1, OffsetBroadcast);
|
||||
|
||||
_mm_storeu_si128((__m128i*)&ColumnSumVector[0], ColumnSum0);
|
||||
_mm_storeu_si128((__m128i*)&ColumnSumVector[4], ColumnSum1);
|
||||
}
|
||||
}
|
||||
|
||||
size_t
|
||||
MLASCALL
|
||||
MlasGemmU8S8KernelSse(
|
||||
const uint8_t* A,
|
||||
const int8_t* B,
|
||||
int32_t* C,
|
||||
size_t PairCountK,
|
||||
size_t CountM,
|
||||
size_t CountN,
|
||||
size_t ldc,
|
||||
const int32_t* RowSumVector,
|
||||
const int32_t* ColumnSumVector,
|
||||
int32_t DepthValue,
|
||||
bool ZeroMode
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine is an inner kernel to compute matrix multiplication for a
|
||||
set of rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
A - Supplies the address of matrix A. The matrix data has been packed
|
||||
using MlasGemmU8S8CopyPackASse.
|
||||
|
||||
B - Supplies the address of matrix B. The matrix data has been packed
|
||||
using MlasGemmU8S8CopyPackBSse.
|
||||
|
||||
C - Supplies the address of matrix C.
|
||||
|
||||
PairCountK - Supplies the number of paired columns from matrix A and the
|
||||
number of paired rows from matrix B to iterate over.
|
||||
|
||||
CountM - Supplies the maximum number of rows that can be processed for
|
||||
matrix A and matrix C. The actual number of rows handled for this
|
||||
invocation depends on the kernel implementation.
|
||||
|
||||
CountN - Supplies the number of columns from matrix B and matrix C to iterate
|
||||
over.
|
||||
|
||||
ldc - Supplies the first dimension of matrix C.
|
||||
|
||||
RowSumVector - Supplies the sum of each row from matrix A multiplied by the
|
||||
zero point offset of matrix B. These values are accumulated into every
|
||||
row of matrix C.
|
||||
|
||||
ColumnSumVector - Supplies the sum of each column from matrix B multiplied
|
||||
by the zero point offset of matrix A. These values are accumulated into
|
||||
every column of matrix C.
|
||||
|
||||
DepthValue - Supplies the value CountK multiplied by the zero point offset
|
||||
of matrixA multplied by the zero point offset of matrix B. This value is
|
||||
accumulated into every element of matrix C.
|
||||
|
||||
ZeroMode - Supplies true if the output matrix must be zero initialized,
|
||||
else false if the output matrix is accumulated into.
|
||||
|
||||
Return Value:
|
||||
|
||||
Returns the number of rows handled.
|
||||
|
||||
--*/
|
||||
{
|
||||
const __m128i ZeroVector = _mm_setzero_si128();
|
||||
|
||||
MLAS_UNREFERENCED_PARAMETER(CountM);
|
||||
MLAS_UNREFERENCED_PARAMETER(ldc);
|
||||
|
||||
while (CountN > 0) {
|
||||
|
||||
//
|
||||
// Initialize the accumulators with the sum of the global depth value
|
||||
// constant, the column sums, and the row sums.
|
||||
//
|
||||
|
||||
__m128i Accumulator0 = _mm_set1_epi32(DepthValue);
|
||||
Accumulator0 = _mm_add_epi32(Accumulator0, _mm_set1_epi32(RowSumVector[0]));
|
||||
__m128i Accumulator1 = Accumulator0;
|
||||
Accumulator0 = _mm_add_epi32(Accumulator0, _mm_loadu_si128((__m128i*)&ColumnSumVector[0]));
|
||||
Accumulator1 = _mm_add_epi32(Accumulator1, _mm_loadu_si128((__m128i*)&ColumnSumVector[4]));
|
||||
ColumnSumVector += 8;
|
||||
|
||||
//
|
||||
// Broadcast each pair of 16-bit values from the matrix A and multiply
|
||||
// with the zero-extended pair of 16-bit values from matrix B, and add
|
||||
// the 32-bit intermediate into the accumulator registers.
|
||||
//
|
||||
|
||||
const uint8_t* a = A;
|
||||
size_t k = PairCountK;
|
||||
|
||||
while (k > 0) {
|
||||
|
||||
__m128i AElements = _mm_unpacklo_epi8(_mm_cvtsi32_si128(*((int32_t*)a)), ZeroVector);
|
||||
|
||||
__m128i BElements;
|
||||
__m128i Intermediate0;
|
||||
__m128i Intermediate1;
|
||||
|
||||
BElements = _mm_loadu_si128((__m128i*)&B[0]);
|
||||
Intermediate0 = _mm_srai_epi16(_mm_unpacklo_epi8(ZeroVector, BElements), 8);
|
||||
Intermediate1 = _mm_srai_epi16(_mm_unpackhi_epi8(ZeroVector, BElements), 8);
|
||||
|
||||
__m128i AElements0 = _mm_shuffle_epi32(AElements, _MM_SHUFFLE(0, 0, 0, 0));
|
||||
|
||||
Intermediate0 = _mm_madd_epi16(Intermediate0, AElements0);
|
||||
Intermediate1 = _mm_madd_epi16(Intermediate1, AElements0);
|
||||
|
||||
Accumulator0 = _mm_add_epi32(Accumulator0, Intermediate0);
|
||||
Accumulator1 = _mm_add_epi32(Accumulator1, Intermediate1);
|
||||
|
||||
BElements = _mm_loadu_si128((__m128i*)&B[16]);
|
||||
Intermediate0 = _mm_srai_epi16(_mm_unpacklo_epi8(ZeroVector, BElements), 8);
|
||||
Intermediate1 = _mm_srai_epi16(_mm_unpackhi_epi8(ZeroVector, BElements), 8);
|
||||
|
||||
__m128i AElements1 = _mm_shuffle_epi32(AElements, _MM_SHUFFLE(1, 1, 1, 1));
|
||||
|
||||
Intermediate0 = _mm_madd_epi16(Intermediate0, AElements1);
|
||||
Intermediate1 = _mm_madd_epi16(Intermediate1, AElements1);
|
||||
|
||||
Accumulator0 = _mm_add_epi32(Accumulator0, Intermediate0);
|
||||
Accumulator1 = _mm_add_epi32(Accumulator1, Intermediate1);
|
||||
|
||||
a += 4;
|
||||
B += 32;
|
||||
k -= 1;
|
||||
}
|
||||
|
||||
//
|
||||
// Output the accumulator block after optionally accumulating the values
|
||||
// from matrix C.
|
||||
//
|
||||
|
||||
if (CountN >= 8) {
|
||||
|
||||
if (!ZeroMode) {
|
||||
Accumulator0 = _mm_add_epi32(Accumulator0, _mm_loadu_si128((__m128i*)&C[0]));
|
||||
Accumulator1 = _mm_add_epi32(Accumulator1, _mm_loadu_si128((__m128i*)&C[4]));
|
||||
}
|
||||
|
||||
_mm_storeu_si128((__m128i*)&C[0], Accumulator0);
|
||||
_mm_storeu_si128((__m128i*)&C[4], Accumulator1);
|
||||
|
||||
C += 8;
|
||||
CountN -= 8;
|
||||
|
||||
} else {
|
||||
|
||||
//
|
||||
// Output the remaining partial output block.
|
||||
//
|
||||
|
||||
if ((CountN & 4) != 0) {
|
||||
|
||||
if (!ZeroMode) {
|
||||
Accumulator0 = _mm_add_epi32(Accumulator0, _mm_loadu_si128((__m128i*)&C[0]));
|
||||
}
|
||||
|
||||
_mm_storeu_si128((__m128i*)&C[0], Accumulator0);
|
||||
C += 4;
|
||||
|
||||
Accumulator0 = Accumulator1;
|
||||
}
|
||||
|
||||
if ((CountN & 2) != 0) {
|
||||
|
||||
if (!ZeroMode) {
|
||||
Accumulator0 = _mm_add_epi32(Accumulator0, _mm_loadl_epi64((__m128i*)&C[0]));
|
||||
}
|
||||
|
||||
_mm_storel_epi64((__m128i*)&C[0], Accumulator0);
|
||||
C += 2;
|
||||
|
||||
Accumulator0 = _mm_shuffle_epi32(Accumulator0, _MM_SHUFFLE(1, 0, 3, 2));
|
||||
}
|
||||
|
||||
if ((CountN & 1) != 0) {
|
||||
|
||||
int32_t AccumulatorValue = _mm_cvtsi128_si32(Accumulator0);
|
||||
|
||||
if (!ZeroMode) {
|
||||
AccumulatorValue += C[0];
|
||||
}
|
||||
|
||||
C[0] = AccumulatorValue;
|
||||
}
|
||||
|
||||
CountN = 0;
|
||||
}
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
//
|
||||
// U8U8 implementation using SSE2 intrinsics.
|
||||
//
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasGemmU8U8CopyPackASse(
|
||||
|
|
@ -86,13 +651,15 @@ Return Value:
|
|||
|
||||
//
|
||||
// Zero extend the source bytes to 16-bits and write to the packed
|
||||
// buffer. The packed buffer has the same data ordering as the source
|
||||
// bytes, but the stride is CountK aligned up to a multiple of 8
|
||||
// values.
|
||||
// buffer.
|
||||
//
|
||||
// The packed buffer has the same data ordering as the source bytes,
|
||||
// but CountK is aligned up to a multiple of 2 to maintain 32-bit
|
||||
// alignment. All extra bytes are zero-padded.
|
||||
//
|
||||
// These 16-bit values are also accumulated into an intermediate per-row
|
||||
// accumulator. CountK cannot be greater than 256 to avoid overflowing
|
||||
// these 16-bit accumulators.
|
||||
// accumulator. CountK cannot be greater than 128 to avoid overflowing
|
||||
// these signed 16-bit accumulators.
|
||||
//
|
||||
|
||||
while (k >= 8) {
|
||||
|
|
@ -130,8 +697,8 @@ Return Value:
|
|||
RowSum = _mm_add_epi16(RowSum, Words);
|
||||
|
||||
//
|
||||
// Copy the 16-bit pairs from the vector to the destination packed
|
||||
// buffer. Rotate the vector at each iteration.
|
||||
// Copy pairs of 16-bit values from the vector to the packed
|
||||
// buffer and rotate the vector for the next iteration.
|
||||
//
|
||||
|
||||
for (size_t pairs = (k + 1) / 2; pairs > 0; pairs--) {
|
||||
|
|
@ -142,7 +709,8 @@ Return Value:
|
|||
}
|
||||
|
||||
//
|
||||
// Reduce the sum for the single row of output.
|
||||
// Reduce the sum for the single row of output and multiply by the
|
||||
// zero point offset of the other source matrix.
|
||||
//
|
||||
|
||||
RowSum = _mm_madd_epi16(RowSum, OffsetBroadcast);
|
||||
|
|
@ -176,13 +744,13 @@ Routine Description:
|
|||
|
||||
Arguments:
|
||||
|
||||
D (rcx) - Supplies the address of the destination packed buffer.
|
||||
D - Supplies the address of the destination packed buffer.
|
||||
|
||||
B (rdx) - Supplies the address of the source matrix.
|
||||
B - Supplies the address of the source matrix.
|
||||
|
||||
ldb (r8) - Supplies the number of elements per row of the source matrix.
|
||||
ldb - Supplies the number of elements per row of the source matrix.
|
||||
|
||||
CountN (r9) - Supplies the number of columns of the source matrix to copy.
|
||||
CountN - Supplies the number of columns of the source matrix to copy.
|
||||
|
||||
CountK - Supplies the number of rows of the source matrix to copy.
|
||||
|
||||
|
|
@ -214,6 +782,14 @@ Return Value:
|
|||
__m128i ColumnSum0 = ZeroVector;
|
||||
__m128i ColumnSum1 = ZeroVector;
|
||||
|
||||
//
|
||||
// Interleave 2 rows of matrix B and write to the packed buffer.
|
||||
//
|
||||
// These values are also zero-extended and accumulated into an
|
||||
// intermediate per-column accumulator. CountK cannot be greater than
|
||||
// 128 to avoid overflowing these signed 16-bit accumulators.
|
||||
//
|
||||
|
||||
while (k >= 2) {
|
||||
|
||||
__m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&b[0]);
|
||||
|
|
@ -232,6 +808,10 @@ Return Value:
|
|||
|
||||
if (k > 0) {
|
||||
|
||||
//
|
||||
// Process the remaining row of matrix B.
|
||||
//
|
||||
|
||||
__m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&b[0]);
|
||||
__m128i BytesInterleaved = _mm_unpacklo_epi8(BytesRow0, ZeroVector);
|
||||
|
||||
|
|
@ -240,7 +820,6 @@ Return Value:
|
|||
ColumnSum0 = _mm_add_epi16(ColumnSum0, _mm_unpacklo_epi8(BytesInterleaved, ZeroVector));
|
||||
ColumnSum1 = _mm_add_epi16(ColumnSum1, _mm_unpackhi_epi8(BytesInterleaved, ZeroVector));
|
||||
|
||||
b += ldb * 2;
|
||||
D += 16;
|
||||
}
|
||||
|
||||
|
|
@ -323,6 +902,11 @@ Return Value:
|
|||
ColumnSum1 = _mm_add_epi16(ColumnSum1, _mm_unpackhi_epi8(BytesInterleaved, ZeroVector));
|
||||
}
|
||||
|
||||
//
|
||||
// Reduce the sum for the packed columns and multiply by the zero point
|
||||
// offset of the other source matrix.
|
||||
//
|
||||
|
||||
ColumnSum0 = _mm_madd_epi16(ColumnSum0, OffsetBroadcast);
|
||||
ColumnSum1 = _mm_madd_epi16(ColumnSum1, OffsetBroadcast);
|
||||
|
||||
|
|
@ -337,7 +921,7 @@ MlasGemmU8U8KernelSse(
|
|||
const int16_t* A,
|
||||
const uint8_t* B,
|
||||
int32_t* C,
|
||||
size_t PairedCountK,
|
||||
size_t PairCountK,
|
||||
size_t CountM,
|
||||
size_t CountN,
|
||||
size_t ldc,
|
||||
|
|
@ -363,8 +947,8 @@ Arguments:
|
|||
|
||||
C - Supplies the address of matrix C.
|
||||
|
||||
PairedCountK - Supplies the number of paired columns from matrix A and
|
||||
the number of paired rows from matrix B to iterate over.
|
||||
PairCountK - Supplies the number of paired columns from matrix A and the
|
||||
number of paired rows from matrix B to iterate over.
|
||||
|
||||
CountM - Supplies the maximum number of rows that can be processed for
|
||||
matrix A and matrix C. The actual number of rows handled for this
|
||||
|
|
@ -422,18 +1006,18 @@ Return Value:
|
|||
//
|
||||
|
||||
const int16_t* a = A;
|
||||
size_t k = PairedCountK;
|
||||
size_t k = PairCountK;
|
||||
|
||||
while (k > 0) {
|
||||
|
||||
__m128i AElements0 = _mm_set1_epi32(*((int32_t*)a));
|
||||
__m128i AElements = _mm_set1_epi32(*((int32_t*)a));
|
||||
__m128i BElements0 = _mm_loadu_si128((__m128i*)&B[0]);
|
||||
|
||||
__m128i Intermediate0 = _mm_unpacklo_epi8(BElements0, ZeroVector);
|
||||
__m128i Intermediate1 = _mm_unpackhi_epi8(BElements0, ZeroVector);
|
||||
|
||||
Intermediate0 = _mm_madd_epi16(Intermediate0, AElements0);
|
||||
Intermediate1 = _mm_madd_epi16(Intermediate1, AElements0);
|
||||
Intermediate0 = _mm_madd_epi16(Intermediate0, AElements);
|
||||
Intermediate1 = _mm_madd_epi16(Intermediate1, AElements);
|
||||
|
||||
Accumulator0 = _mm_add_epi32(Accumulator0, Intermediate0);
|
||||
Accumulator1 = _mm_add_epi32(Accumulator1, Intermediate1);
|
||||
|
|
@ -502,7 +1086,7 @@ Return Value:
|
|||
C[0] = AccumulatorValue;
|
||||
}
|
||||
|
||||
break;
|
||||
CountN = 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -511,7 +1095,94 @@ Return Value:
|
|||
|
||||
void
|
||||
MLASCALL
|
||||
MlasQgemm(
|
||||
MlasGemm(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const uint8_t* A,
|
||||
size_t lda,
|
||||
uint8_t offa,
|
||||
const int8_t* B,
|
||||
size_t ldb,
|
||||
int8_t offb,
|
||||
int32_t* C,
|
||||
size_t ldc,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
)
|
||||
{
|
||||
MLAS_DECLSPEC_ALIGN(uint8_t PanelA[MLAS_GEMM_U8S8_STRIDEM * MLAS_GEMM_U8S8_STRIDEK], 64);
|
||||
MLAS_DECLSPEC_ALIGN(int8_t PanelB[MLAS_GEMM_U8S8_STRIDEN * MLAS_GEMM_U8S8_STRIDEK], 64);
|
||||
|
||||
MLAS_DECLSPEC_ALIGN(int32_t RowSumVector[MLAS_GEMM_U8S8_STRIDEM], 16);
|
||||
MLAS_DECLSPEC_ALIGN(int32_t ColumnSumVector[MLAS_GEMM_U8S8_STRIDEN], 16);
|
||||
|
||||
size_t StrideM = MLAS_GEMM_U8S8_STRIDEM;
|
||||
size_t StrideN = MLAS_GEMM_U8S8_STRIDEN;
|
||||
size_t StrideK = MLAS_GEMM_U8S8_STRIDEK;
|
||||
|
||||
MLAS_UNREFERENCED_PARAMETER(ThreadPool);
|
||||
|
||||
size_t CountK;
|
||||
|
||||
for (size_t k = 0; k < K; k += CountK) {
|
||||
|
||||
CountK = StrideK;
|
||||
|
||||
if (CountK > (K - k)) {
|
||||
CountK = K - k;
|
||||
}
|
||||
|
||||
size_t CountN;
|
||||
|
||||
for (size_t n = 0; n < N; n += CountN) {
|
||||
|
||||
CountN = StrideN;
|
||||
|
||||
if (CountN > (N - n)) {
|
||||
CountN = N - n;
|
||||
}
|
||||
|
||||
MlasPlatform.GemmU8S8CopyPackBRoutine(PanelB, B + n + k * ldb, ldb, CountN, CountK, ColumnSumVector, -int16_t(offa));
|
||||
|
||||
size_t CountM;
|
||||
|
||||
for (size_t m = 0; m < M; m += CountM) {
|
||||
|
||||
CountM = StrideM;
|
||||
|
||||
if (CountM > (M - m)) {
|
||||
CountM = M - m;
|
||||
}
|
||||
|
||||
MlasPlatform.GemmU8S8CopyPackARoutine(PanelA, A + k + m * lda, lda, CountM, CountK, RowSumVector, -int16_t(offb));
|
||||
|
||||
uint8_t* pa = PanelA;
|
||||
int32_t* c = C + n + m * ldc;
|
||||
|
||||
int32_t* RowSums = RowSumVector;
|
||||
|
||||
size_t RowsRemaining = CountM;
|
||||
size_t RowsHandled;
|
||||
|
||||
size_t QuadCountK = (CountK + 3) / 4;
|
||||
|
||||
while (RowsRemaining > 0) {
|
||||
|
||||
RowsHandled = MlasPlatform.GemmU8S8Kernel(pa, PanelB, c, QuadCountK, RowsRemaining, CountN, ldc, RowSums, ColumnSumVector, int32_t(CountK) * offa * offb, k == 0);
|
||||
|
||||
RowsRemaining -= RowsHandled;
|
||||
c += ldc * RowsHandled;
|
||||
pa += 4 * QuadCountK * RowsHandled;
|
||||
RowSums += RowsHandled;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasGemm(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
|
|
@ -580,15 +1251,15 @@ MlasQgemm(
|
|||
size_t RowsRemaining = CountM;
|
||||
size_t RowsHandled;
|
||||
|
||||
size_t PairedCountK = (CountK + 1) / 2;
|
||||
size_t PairCountK = (CountK + 1) / 2;
|
||||
|
||||
while (RowsRemaining > 0) {
|
||||
|
||||
RowsHandled = MlasPlatform.GemmU8U8Kernel(pa, PanelB, c, PairedCountK, RowsRemaining, CountN, ldc, RowSums, ColumnSumVector, int32_t(CountK) * offa * offb, k == 0);
|
||||
RowsHandled = MlasPlatform.GemmU8U8Kernel(pa, PanelB, c, PairCountK, RowsRemaining, CountN, ldc, RowSums, ColumnSumVector, int32_t(CountK) * offa * offb, k == 0);
|
||||
|
||||
RowsRemaining -= RowsHandled;
|
||||
c += ldc * RowsHandled;
|
||||
pa += 2 * PairedCountK * RowsHandled;
|
||||
pa += 2 * PairCountK * RowsHandled;
|
||||
RowSums += RowsHandled;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -140,7 +140,7 @@ Macro Description:
|
|||
|
||||
This macro builds a VNNI instruction of the form:
|
||||
|
||||
instr zmm1,zmm2,DWORD PTR [BaseReg+IndexReg*Scale]{1to16}
|
||||
instr zmm1,zmm2,DWORD PTR [BaseReg+IndexReg*Scale+ByteOffset]{1to16}
|
||||
|
||||
Arguments:
|
||||
|
||||
|
|
@ -152,13 +152,16 @@ Arguments:
|
|||
|
||||
BaseReg - Specifies the base register of the broadcast operand.
|
||||
|
||||
ByteOffset - Specifies the DWORD aligned byte offset for the broadcast
|
||||
operand.
|
||||
|
||||
IndexReg - Specifies the optional index register of the broadcast operand.
|
||||
|
||||
Scale - Specifies the scaling factor of the optional index register.
|
||||
|
||||
--*/
|
||||
|
||||
.macro VnniZmmZmmBroadcast Opcode, DestReg, Src1Reg, BaseReg, IndexReg, Scale
|
||||
.macro VnniZmmZmmBroadcast Opcode, DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
|
||||
|
||||
.set Payload0, 0x02 # "0F 38" prefix
|
||||
.set Payload0, Payload0 + ((((.LZmmIndex_\DestReg\() >> 3) & 1) ^ 1) << 7)
|
||||
|
|
@ -183,6 +186,9 @@ Arguments:
|
|||
.else
|
||||
.set ModRMByte, ModRMByte + (.LGprIndex_\BaseReg\() & 7)
|
||||
.endif
|
||||
.if \ByteOffset\() != 0
|
||||
.set ModRMByte, ModRMByte + 0x40 # indicate disp8 byte offset
|
||||
.endif
|
||||
|
||||
.ifnes "\IndexReg\()", ""
|
||||
.set SibByte, 0
|
||||
|
|
@ -205,34 +211,36 @@ Arguments:
|
|||
.set SibByte, SibByte + (.LGprIndex_\BaseReg\() & 7)
|
||||
.endif
|
||||
|
||||
.ifnes "\IndexReg\()", ""
|
||||
.byte 0x62, Payload0, Payload1, Payload2, \Opcode\(), ModRMByte, SibByte
|
||||
.else
|
||||
.byte 0x62, Payload0, Payload1, Payload2, \Opcode\(), ModRMByte
|
||||
.ifnes "\IndexReg\()", ""
|
||||
.byte SibByte
|
||||
.endif
|
||||
.if \ByteOffset\() != 0
|
||||
.byte (\ByteOffset\() >> 2)
|
||||
.endif
|
||||
|
||||
.endm
|
||||
|
||||
.macro VpdpbusdZmmZmmBroadcast DestReg, Src1Reg, BaseReg, IndexReg, Scale
|
||||
.macro VpdpbusdZmmZmmBroadcast DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
|
||||
|
||||
VnniZmmZmmBroadcast 0x50, \DestReg\(), \Src1Reg\(), \BaseReg\(), \IndexReg\(), \Scale\()
|
||||
VnniZmmZmmBroadcast 0x50, \DestReg\(), \Src1Reg\(), \BaseReg\(), \ByteOffset\(), \IndexReg\(), \Scale\()
|
||||
|
||||
.endm
|
||||
|
||||
.macro VpdpbusdsZmmZmmBroadcast DestReg, Src1Reg, BaseReg, IndexReg, Scale
|
||||
.macro VpdpbusdsZmmZmmBroadcast DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
|
||||
|
||||
VnniZmmZmmBroadcast 0x51, \DestReg\(), \Src1Reg\(), \BaseReg\(), \IndexReg\(), \Scale\()
|
||||
VnniZmmZmmBroadcast 0x51, \DestReg\(), \Src1Reg\(), \BaseReg\(), \ByteOffset\(), \IndexReg\(), \Scale\()
|
||||
|
||||
.endm
|
||||
|
||||
.macro VpdpwssdZmmZmmBroadcast DestReg, Src1Reg, BaseReg, IndexReg, Scale
|
||||
.macro VpdpwssdZmmZmmBroadcast DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
|
||||
|
||||
VnniZmmZmmBroadcast 0x52, \DestReg\(), \Src1Reg\(), \BaseReg\(), \IndexReg\(), \Scale\()
|
||||
VnniZmmZmmBroadcast 0x52, \DestReg\(), \Src1Reg\(), \BaseReg\(), \ByteOffset\(), \IndexReg\(), \Scale\()
|
||||
|
||||
.endm
|
||||
|
||||
.macro VpdpwssdsZmmZmmBroadcast DestReg, Src1Reg, BaseReg, IndexReg, Scale
|
||||
.macro VpdpwssdsZmmZmmBroadcast DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale
|
||||
|
||||
VnniZmmZmmBroadcast 0x53, \DestReg\(), \Src1Reg\(), \BaseReg\(), \IndexReg\(), \Scale\()
|
||||
VnniZmmZmmBroadcast 0x53, \DestReg\(), \Src1Reg\(), \BaseReg\(), \ByteOffset\(), \IndexReg\(), \Scale\()
|
||||
|
||||
.endm
|
||||
|
|
|
|||
955
onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S
Normal file
955
onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S
Normal file
|
|
@ -0,0 +1,955 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
QgemmU8S8KernelAvx2.s
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements the kernels for the quantized integer matrix/matrix
|
||||
multiply operation (QGEMM).
|
||||
|
||||
This implementation uses AVX2 instructions.
|
||||
|
||||
--*/
|
||||
|
||||
#include "asmmacro.h"
|
||||
#include "QgemmU8X8KernelAvx2Common.h"
|
||||
|
||||
.intel_syntax noprefix
|
||||
|
||||
//
|
||||
// Stack frame layout for the U8S8 CopyPackA routine.
|
||||
//
|
||||
|
||||
.equ .LGemmU8S8CopyPackAFrame_PaddedMatrixAData, -72
|
||||
.equ .LGemmU8S8CopyPackAFrame_mask, -8
|
||||
.equ .LGemmU8S8CopyPackAFrame_SavedR13, 0
|
||||
.equ .LGemmU8S8CopyPackAFrame_SavedR12, 8
|
||||
.equ .LGemmU8S8CopyPackAFrame_SavedRbx, 16
|
||||
.equ .LGemmU8S8CopyPackAFrame_SavedRbp, 24
|
||||
.equ .LGemmU8S8CopyPackAFrame_ReturnAddress, 32
|
||||
.equ .LGemmU8S8CopyPackAFrame_offb, 40
|
||||
|
||||
//
|
||||
// Stack frame layout for the U8S8 CopyPackB routine.
|
||||
//
|
||||
|
||||
.equ .LGemmU8S8CopyPackBFrame_PaddedMatrixBData, -72
|
||||
.equ .LGemmU8S8CopyPackBFrame_Padding, -8
|
||||
.equ .LGemmU8S8CopyPackBFrame_SavedRbx, 0
|
||||
.equ .LGemmU8S8CopyPackBFrame_SavedRbp, 8
|
||||
.equ .LGemmU8S8CopyPackBFrame_ReturnAddress, 16
|
||||
.equ .LGemmU8S8CopyPackBFrame_offa, 24
|
||||
|
||||
.text
|
||||
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine copies elements from the source matrix to the destination
|
||||
packed buffer.
|
||||
|
||||
Arguments:
|
||||
|
||||
D (rdi) - Supplies the address of the destination packed buffer.
|
||||
|
||||
A (rsi) - Supplies the address of the source matrix.
|
||||
|
||||
lda (rdx) - Supplies the number of elements per row of the source matrix.
|
||||
|
||||
CountM (rcx) - Supplies the number of rows of the source matrix to copy.
|
||||
|
||||
CountK (r8) - Supplies the number of columns of the source matrix to copy.
|
||||
|
||||
RowSumVector (r9) - Supplies the address of the buffer to receive the sums
|
||||
of the elements from each of the rows. Each sum has also been multiplied
|
||||
by the zero point offset.
|
||||
|
||||
offb - Supplies the zero point offset for the other source matrix of the
|
||||
matrix multiplication.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
|
||||
.globl C_UNDERSCORE(MlasGemmU8S8CopyPackAAvx2)
|
||||
C_UNDERSCORE(MlasGemmU8S8CopyPackAAvx2):
|
||||
|
||||
push rbp
|
||||
push rbx
|
||||
push r12
|
||||
push r13
|
||||
|
||||
mov r10,rdx
|
||||
mov r11,rcx
|
||||
lea r12,[r8+3]
|
||||
and r12,NOT 3 # align CountK up to quad count
|
||||
vpbroadcastw xmm8,WORD PTR .LGemmU8S8CopyPackAFrame_offb[rsp]
|
||||
vpcmpeqw ymm9,ymm9,ymm9 # generate word vector [0xFFFF]
|
||||
vpsrlw ymm9,ymm9,15 # generate word vector [0x0001]
|
||||
vpsllw ymm0,ymm9,8 # generate word vector [0x0100]
|
||||
vpor ymm9,ymm9,ymm0 # generate word vector [0x0101]
|
||||
|
||||
//
|
||||
// Compute the conditional load/store mask for an unaligned CountK.
|
||||
//
|
||||
|
||||
mov eax,r8d
|
||||
and eax,15 # isolate unaligned count
|
||||
add eax,3
|
||||
shr eax,2 # align unaligned count to quad count
|
||||
mov DWORD PTR .LGemmU8S8CopyPackAFrame_mask[rsp],eax
|
||||
vpbroadcastd xmm10,DWORD PTR .LGemmU8S8CopyPackAFrame_mask[rsp]
|
||||
vpcmpgtd xmm10,xmm10,XMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip]
|
||||
|
||||
//
|
||||
// Zero initialize the padded stack buffers.
|
||||
//
|
||||
|
||||
vpxor xmm0,xmm0,xmm0
|
||||
vmovdqu YMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp],ymm0
|
||||
vmovdqu YMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp+32],ymm0
|
||||
|
||||
//
|
||||
// Process 4 rows of matrix A in a loop.
|
||||
//
|
||||
|
||||
sub r11,4
|
||||
jb .LCopyPackA.ProcessRemainingRows
|
||||
|
||||
.LCopyPackA.ProcessNextRowM4:
|
||||
vpxor xmm0,xmm0,xmm0 # clear row accumulators
|
||||
vpxor xmm1,xmm1,xmm1
|
||||
vpxor xmm2,xmm2,xmm2
|
||||
vpxor xmm3,xmm3,xmm3
|
||||
mov rdx,rsi
|
||||
mov rcx,rdi
|
||||
lea rsi,[rsi+r10*4] # advance next matrix A by 4 rows
|
||||
lea rdi,[rdi+r12*4] # advance next matrix D by 4 rows
|
||||
mov rbx,r8 # reload columns remaining
|
||||
sub rbx,32
|
||||
jb .LCopyPackA.ProcessRemainingColumnsM4
|
||||
|
||||
.LCopyPackA.ProcessNextColumnLoopM4:
|
||||
lea rax,[rdx+r10*2] # compute matrix A plus 2 rows
|
||||
vmovdqu ymm4,YMMWORD PTR [rdx]
|
||||
vmovdqu ymm5,YMMWORD PTR [rdx+r10]
|
||||
vmovdqu ymm6,YMMWORD PTR [rax]
|
||||
vmovdqu ymm7,YMMWORD PTR [rax+r10]
|
||||
lea rax,[rcx+r12*2] # compute matrix D plus 2 rows
|
||||
vmovdqu YMMWORD PTR [rcx],ymm4
|
||||
vmovdqu YMMWORD PTR [rcx+r12],ymm5
|
||||
vmovdqu YMMWORD PTR [rax],ymm6
|
||||
vmovdqu YMMWORD PTR [rax+r12],ymm7
|
||||
vpmaddubsw ymm4,ymm4,ymm9 # horizontal byte+byte=word per row
|
||||
vpaddw ymm0,ymm0,ymm4 # add words to row accumulators
|
||||
vpmaddubsw ymm5,ymm5,ymm9
|
||||
vpaddw ymm1,ymm1,ymm5
|
||||
vpmaddubsw ymm6,ymm6,ymm9
|
||||
vpaddw ymm2,ymm2,ymm6
|
||||
vpmaddubsw ymm7,ymm7,ymm9
|
||||
vpaddw ymm3,ymm3,ymm7
|
||||
add rdx,32 # advance matrix A by 32 bytes
|
||||
add rcx,32 # advance matrix D by 32 bytes
|
||||
sub rbx,32 # subtract columns remaining
|
||||
jae .LCopyPackA.ProcessNextColumnLoopM4
|
||||
|
||||
.LCopyPackA.ProcessRemainingColumnsM4:
|
||||
add rbx,32 # correct for over-subtract above
|
||||
jz .LCopyPackA.ReduceRowSumVectorM4
|
||||
test bl,16 # (CountK & 16) != 0?
|
||||
jz .LCopyPackA.CopyRemainingCountKLessThan16M4
|
||||
lea rax,[rdx+r10*2] # compute matrix A plus 2 rows
|
||||
vmovdqu xmm4,XMMWORD PTR [rdx]
|
||||
vmovdqu xmm5,XMMWORD PTR [rdx+r10]
|
||||
vmovdqu xmm6,XMMWORD PTR [rax]
|
||||
vmovdqu xmm7,XMMWORD PTR [rax+r10]
|
||||
lea rax,[rcx+r12*2] # compute matrix D plus 2 rows
|
||||
vmovdqu XMMWORD PTR [rcx],xmm4
|
||||
vmovdqu XMMWORD PTR [rcx+r12],xmm5
|
||||
vmovdqu XMMWORD PTR [rax],xmm6
|
||||
vmovdqu XMMWORD PTR [rax+r12],xmm7
|
||||
vpmaddubsw xmm4,xmm4,xmm9 # horizontal byte+byte=word per row
|
||||
vpaddw ymm0,ymm0,ymm4 # add words to row accumulators
|
||||
vpmaddubsw xmm5,xmm5,xmm9
|
||||
vpaddw ymm1,ymm1,ymm5
|
||||
vpmaddubsw xmm6,xmm6,xmm9
|
||||
vpaddw ymm2,ymm2,ymm6
|
||||
vpmaddubsw xmm7,xmm7,xmm9
|
||||
vpaddw ymm3,ymm3,ymm7
|
||||
add rdx,16 # advance matrix A by 16 bytes
|
||||
add rcx,16 # advance matrix D by 16 bytes
|
||||
test bl,15 # test for unaligned columns
|
||||
jz .LCopyPackA.ReduceRowSumVectorM4
|
||||
|
||||
//
|
||||
// Copy the unaligned CountK columns to a zero padded stack buffer.
|
||||
//
|
||||
|
||||
.LCopyPackA.CopyRemainingCountKLessThan16M4:
|
||||
lea rbp,.LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp]
|
||||
test bl,8 # (CountK & 8) != 0?
|
||||
jz .LCopyPackA.CopyRemainingCountKLessThan8M4
|
||||
lea r13,[rdx+r10*2] # compute matrix A plus 2 rows
|
||||
mov rax,QWORD PTR [rdx]
|
||||
mov QWORD PTR [rbp],rax
|
||||
mov rax,QWORD PTR [rdx+r10]
|
||||
mov QWORD PTR [rbp+16],rax
|
||||
mov rax,QWORD PTR [r13]
|
||||
mov QWORD PTR [rbp+32],rax
|
||||
mov rax,QWORD PTR [r13+r10]
|
||||
mov QWORD PTR [rbp+48],rax
|
||||
add rdx,8
|
||||
add rbp,8 # advance padded buffer destination
|
||||
|
||||
.LCopyPackA.CopyRemainingCountKLessThan8M4:
|
||||
test bl,4 # (CountK & 4) != 0?
|
||||
jz .LCopyPackA.CopyRemainingCountKLessThan4M4
|
||||
lea r13,[rdx+r10*2] # compute matrix A plus 2 rows
|
||||
mov eax,DWORD PTR [rdx]
|
||||
mov DWORD PTR [rbp],eax
|
||||
mov eax,DWORD PTR [rdx+r10]
|
||||
mov DWORD PTR [rbp+16],eax
|
||||
mov eax,DWORD PTR [r13]
|
||||
mov DWORD PTR [rbp+32],eax
|
||||
mov eax,DWORD PTR [r13+r10]
|
||||
mov DWORD PTR [rbp+48],eax
|
||||
add rdx,4
|
||||
add rbp,4 # advance padded buffer destination
|
||||
|
||||
.LCopyPackA.CopyRemainingCountKLessThan4M4:
|
||||
test bl,2 # (CountK & 2) != 0?
|
||||
jz .LCopyPackA.CopyRemainingCountKLessThan2M4
|
||||
lea r13,[rdx+r10*2] # compute matrix A plus 2 rows
|
||||
movzx eax,WORD PTR [rdx]
|
||||
mov WORD PTR [rbp],ax
|
||||
movzx eax,WORD PTR [rdx+r10]
|
||||
mov WORD PTR [rbp+16],ax
|
||||
movzx eax,WORD PTR [r13]
|
||||
mov WORD PTR [rbp+32],ax
|
||||
movzx eax,WORD PTR [r13+r10]
|
||||
mov WORD PTR [rbp+48],ax
|
||||
add rdx,2
|
||||
add rbp,2 # advance padded buffer destination
|
||||
|
||||
.LCopyPackA.CopyRemainingCountKLessThan2M4:
|
||||
test bl,1 # (CountK & 1) != 0?
|
||||
jz .LCopyPackA.ProcessPaddedMatrixADataM4
|
||||
lea r13,[rdx+r10*2] # compute matrix A plus 2 rows
|
||||
movzx eax,BYTE PTR [rdx]
|
||||
mov BYTE PTR [rbp],al
|
||||
movzx eax,BYTE PTR [rdx+r10]
|
||||
mov BYTE PTR [rbp+16],al
|
||||
movzx eax,BYTE PTR [r13]
|
||||
mov BYTE PTR [rbp+32],al
|
||||
movzx eax,BYTE PTR [r13+r10]
|
||||
mov BYTE PTR [rbp+48],al
|
||||
|
||||
//
|
||||
// Process the remaining CountK columns using the zero padded stack buffer.
|
||||
//
|
||||
|
||||
.LCopyPackA.ProcessPaddedMatrixADataM4:
|
||||
vmovdqu xmm4,XMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp]
|
||||
vmovdqu xmm5,XMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp+16]
|
||||
vmovdqu xmm6,XMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp+32]
|
||||
vmovdqu xmm7,XMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp+48]
|
||||
lea rax,[rcx+r12*2] # compute matrix D plus 2 rows
|
||||
vpmaskmovd XMMWORD PTR [rcx],xmm10,xmm4
|
||||
vpmaskmovd XMMWORD PTR [rcx+r12],xmm10,xmm5
|
||||
vpmaskmovd XMMWORD PTR [rax],xmm10,xmm6
|
||||
vpmaskmovd XMMWORD PTR [rax+r12],xmm10,xmm7
|
||||
vpmaddubsw xmm4,xmm4,xmm9 # horizontal byte+byte=word per row
|
||||
vpaddw ymm0,ymm0,ymm4 # add words to row accumulators
|
||||
vpmaddubsw xmm5,xmm5,xmm9
|
||||
vpaddw ymm1,ymm1,ymm5
|
||||
vpmaddubsw xmm6,xmm6,xmm9
|
||||
vpaddw ymm2,ymm2,ymm6
|
||||
vpmaddubsw xmm7,xmm7,xmm9
|
||||
vpaddw ymm3,ymm3,ymm7
|
||||
|
||||
//
|
||||
// Reduce the sums for the four rows of output.
|
||||
//
|
||||
|
||||
.LCopyPackA.ReduceRowSumVectorM4:
|
||||
vphaddw ymm0,ymm0,ymm1 # reduce and interleave Sum1/Sum0
|
||||
vphaddw ymm1,ymm2,ymm3 # reduce and interleave Sum3/Sum2
|
||||
vphaddw ymm0,ymm0,ymm1 # reduce and interleave Sum3/Sum2/Sum1/Sum0
|
||||
vextracti128 xmm1,ymm0,1 # extract high pairs
|
||||
vpaddw xmm0,xmm0,xmm1 # reduce low/high pairs
|
||||
vpmaddwd xmm0,xmm0,xmm8 # multiply by offset and reduce 32-bit sum
|
||||
vmovdqu XMMWORD PTR [r9],xmm0
|
||||
add r9,4*4 # advance row sum vector by 4 DWORDs
|
||||
sub r11,4 # subtract rows remaining
|
||||
jae .LCopyPackA.ProcessNextRowM4
|
||||
|
||||
.LCopyPackA.ProcessRemainingRows:
|
||||
add r11,4 # correct for over-subtract above
|
||||
jz .LCopyPackA.ExitRoutine
|
||||
|
||||
//
|
||||
// Process a single row of matrix A in a loop.
|
||||
//
|
||||
|
||||
.LCopyPackA.ProcessNextRowM1:
|
||||
vpxor xmm0,xmm0,xmm0 # clear row accumulator
|
||||
mov rdx,rsi
|
||||
mov rcx,rdi
|
||||
add rsi,r10
|
||||
add rdi,r12
|
||||
mov rbx,r8 # reload columns remaining
|
||||
sub rbx,32
|
||||
jb .LCopyPackA.ProcessRemainingColumnsM1
|
||||
|
||||
.LCopyPackA.ProcessNextColumnLoopM1:
|
||||
vmovdqu ymm4,YMMWORD PTR [rdx]
|
||||
vmovdqu YMMWORD PTR [rcx],ymm4
|
||||
vpmaddubsw ymm4,ymm4,ymm9 # horizontal byte+byte=word per row
|
||||
vpaddw ymm0,ymm0,ymm4 # add words to row accumulators
|
||||
add rdx,32 # advance matrix A by 32 bytes
|
||||
add rcx,32 # advance matrix D by 32 bytes
|
||||
sub rbx,32 # subtract columns remaining
|
||||
jae .LCopyPackA.ProcessNextColumnLoopM1
|
||||
|
||||
.LCopyPackA.ProcessRemainingColumnsM1:
|
||||
add rbx,32 # correct for over-subtract above
|
||||
jz .LCopyPackA.ReduceRowSumVectorM1
|
||||
test bl,16 # (CountK & 16) != 0?
|
||||
jz .LCopyPackA.CopyRemainingCountKLessThan16M1
|
||||
vmovdqu xmm4,XMMWORD PTR [rdx]
|
||||
vmovdqu XMMWORD PTR [rcx],xmm4
|
||||
vpmaddubsw xmm4,xmm4,xmm9 # horizontal byte+byte=word per row
|
||||
vpaddw ymm0,ymm0,ymm4 # add words to row accumulators
|
||||
add rdx,16 # advance matrix A by 16 bytes
|
||||
add rcx,16 # advance matrix D by 16 bytes
|
||||
test bl,15 # test for unaligned columns
|
||||
jz .LCopyPackA.ReduceRowSumVectorM1
|
||||
|
||||
//
|
||||
// Copy the unaligned CountK columns to a zero padded stack buffer.
|
||||
//
|
||||
|
||||
.LCopyPackA.CopyRemainingCountKLessThan16M1:
|
||||
lea rbp,.LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp]
|
||||
test bl,8 # (CountK & 8) != 0?
|
||||
jz .LCopyPackA.CopyRemainingCountKLessThan8M1
|
||||
mov rax,QWORD PTR [rdx]
|
||||
mov QWORD PTR [rbp],rax
|
||||
add rdx,8
|
||||
add rbp,8 # advance padded buffer destination
|
||||
|
||||
.LCopyPackA.CopyRemainingCountKLessThan8M1:
|
||||
test bl,4 # (CountK & 4) != 0?
|
||||
jz .LCopyPackA.CopyRemainingCountKLessThan4M1
|
||||
mov eax,DWORD PTR [rdx]
|
||||
mov DWORD PTR [rbp],eax
|
||||
add rdx,4
|
||||
add rbp,4 # advance padded buffer destination
|
||||
|
||||
.LCopyPackA.CopyRemainingCountKLessThan4M1:
|
||||
test bl,2 # (CountK & 2) != 0?
|
||||
jz .LCopyPackA.CopyRemainingCountKLessThan2M1
|
||||
movzx eax,WORD PTR [rdx]
|
||||
mov WORD PTR [rbp],ax
|
||||
add rdx,2
|
||||
add rbp,2 # advance padded buffer destination
|
||||
|
||||
.LCopyPackA.CopyRemainingCountKLessThan2M1:
|
||||
test bl,1 # (CountK & 1) != 0?
|
||||
jz .LCopyPackA.ProcessPaddedMatrixADataM1
|
||||
movzx eax,BYTE PTR [rdx]
|
||||
mov BYTE PTR [rbp],al
|
||||
|
||||
//
|
||||
// Process the remaining CountK columns using the zero padded stack buffer.
|
||||
//
|
||||
|
||||
.LCopyPackA.ProcessPaddedMatrixADataM1:
|
||||
vmovdqu xmm4,XMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp]
|
||||
vpmaskmovd XMMWORD PTR [rcx],xmm10,xmm4
|
||||
vpmaddubsw ymm4,ymm4,ymm9 # horizontal byte+byte=word per row
|
||||
vpaddw ymm0,ymm0,ymm4 # accumulate per row along columns
|
||||
|
||||
//
|
||||
// Reduce the sum for the single row of output.
|
||||
//
|
||||
|
||||
.LCopyPackA.ReduceRowSumVectorM1:
|
||||
vextracti128 xmm1,ymm0,1 # extract high pairs
|
||||
vpaddw xmm0,xmm0,xmm1 # reduction
|
||||
vphaddw xmm0,xmm0,xmm0
|
||||
vphaddw xmm0,xmm0,xmm0
|
||||
vpmaddwd xmm0,xmm0,xmm8 # multiply by offset and reduce
|
||||
vmovd DWORD PTR [r9],xmm0
|
||||
add r9,4 # advance row sum vector by 1 DWORD
|
||||
dec r11 # decrement rows remaining
|
||||
jnz .LCopyPackA.ProcessNextRowM1
|
||||
|
||||
//
|
||||
// Restore non-volatile registers and return.
|
||||
//
|
||||
|
||||
.LCopyPackA.ExitRoutine:
|
||||
vzeroupper
|
||||
|
||||
pop r13
|
||||
pop r12
|
||||
pop rbx
|
||||
pop rbp
|
||||
ret
|
||||
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine copies elements from the source matrix to the destination
|
||||
packed buffer.
|
||||
|
||||
Arguments:
|
||||
|
||||
D (rdi) - Supplies the address of the destination packed buffer.
|
||||
|
||||
B (rsi) - Supplies the address of the source matrix.
|
||||
|
||||
ldb (rdx) - Supplies the number of elements per row of the source matrix.
|
||||
|
||||
CountN (rcx) - Supplies the number of columns of the source matrix to copy.
|
||||
|
||||
CountK (r8) - Supplies the number of rows of the source matrix to copy.
|
||||
|
||||
ColumnSumVector (r9) - Supplies the address of the buffer to receive the sums
|
||||
of the elements from each of the columns. Each sum has also been
|
||||
multiplied by the zero point offset.
|
||||
|
||||
offa - Supplies the zero point offset for the other source matrix of the
|
||||
matrix multiplication.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
|
||||
.globl C_UNDERSCORE(MlasGemmU8S8CopyPackBAvx2)
|
||||
C_UNDERSCORE(MlasGemmU8S8CopyPackBAvx2):
|
||||
|
||||
push rbp
|
||||
push rbx
|
||||
|
||||
mov r10,rdx
|
||||
vpbroadcastw ymm7,WORD PTR .LGemmU8S8CopyPackBFrame_offa[rsp]
|
||||
vpcmpeqw ymm8,ymm8,ymm8 # generate word vector [0xFFFF]
|
||||
vpsrlw ymm8,ymm8,15 # generate word vector [0x0001]
|
||||
vpsllw ymm0,ymm8,8 # generate word vector [0x0100]
|
||||
vpor ymm8,ymm8,ymm0 # generate word vector [0x0101]
|
||||
|
||||
//
|
||||
// Process 16 columns of matrix B in a loop.
|
||||
//
|
||||
|
||||
sub rcx,16
|
||||
jb .LCopyPackB.ProcessRemainingColumns
|
||||
|
||||
.LCopyPackB.ProcessNextColumnN16:
|
||||
vpxor xmm0,xmm0,xmm0 # clear column accumulators
|
||||
vpxor xmm1,xmm1,xmm1
|
||||
mov rdx,rsi
|
||||
add rsi,16 # advance next matrix B by 16 columns
|
||||
mov rbx,r8 # reload rows remaining
|
||||
sub rbx,4
|
||||
jb .LCopyPackB.ProcessRemainingRowsN16
|
||||
|
||||
.LCopyPackB.ProcessNextRowLoopN16:
|
||||
lea rax,[rdx+r10*2] # compute matrix B plus 2 rows
|
||||
vmovdqu xmm2,XMMWORD PTR [rdx] # load 4 rows
|
||||
vmovdqu xmm3,XMMWORD PTR [rdx+r10]
|
||||
vmovdqu xmm4,XMMWORD PTR [rax]
|
||||
vmovdqu xmm5,XMMWORD PTR [rax+r10]
|
||||
lea rdx,[rdx+r10*4] # advance matrix B by 4 rows
|
||||
|
||||
.LCopyPackB.InterleaveRowDataN16:
|
||||
vpunpcklbw xmm6,xmm2,xmm3 # interleave row data
|
||||
vpunpckhbw xmm3,xmm2,xmm3
|
||||
vpunpcklbw xmm2,xmm4,xmm5
|
||||
vpunpckhbw xmm5,xmm4,xmm5
|
||||
vpunpcklwd xmm4,xmm6,xmm2
|
||||
vpunpckhwd xmm6,xmm6,xmm2
|
||||
vpunpcklwd xmm2,xmm3,xmm5
|
||||
vpunpckhwd xmm3,xmm3,xmm5
|
||||
vinsertf128 ymm4,ymm4,xmm6,1
|
||||
vinsertf128 ymm2,ymm2,xmm3,1
|
||||
vmovdqu YMMWORD PTR [rdi],ymm4 # store interleaved rows
|
||||
vmovdqu YMMWORD PTR [rdi+32],ymm2
|
||||
vpmaddubsw ymm4,ymm8,ymm4 # horizontal byte+byte=word per row
|
||||
vpaddw ymm0,ymm0,ymm4 # add words to row accumulators
|
||||
vpmaddubsw ymm2,ymm8,ymm2
|
||||
vpaddw ymm1,ymm1,ymm2
|
||||
add rdi,64 # advance matrix D by 64 bytes
|
||||
sub rbx,4 # subtract rows remaining
|
||||
jae .LCopyPackB.ProcessNextRowLoopN16
|
||||
|
||||
//
|
||||
// Process the less than 4 remaining rows where the row has 16 columns.
|
||||
//
|
||||
|
||||
.LCopyPackB.ProcessRemainingRowsN16:
|
||||
add rbx,4 # correct for over-subtract above
|
||||
jz .LCopyPackB.ReduceColumnSumVectorN16
|
||||
vmovdqu xmm2,XMMWORD PTR [rdx]
|
||||
vpxor xmm3,xmm3,xmm3
|
||||
vpxor xmm4,xmm4,xmm4
|
||||
vpxor xmm5,xmm5,xmm5
|
||||
xor ebx,ebx # no more rows remaining
|
||||
test r8b,2 # (CountK & 2) != 0?
|
||||
jz .LCopyPackB.InterleaveRowDataN16
|
||||
vmovdqu xmm3,XMMWORD PTR [rdx+r10]
|
||||
test r8b,1 # (CountK & 1) != 0?
|
||||
jz .LCopyPackB.InterleaveRowDataN16
|
||||
vmovdqu xmm4,XMMWORD PTR [rdx+r10*2]
|
||||
jmp .LCopyPackB.InterleaveRowDataN16
|
||||
|
||||
.LCopyPackB.ReduceColumnSumVectorN16:
|
||||
vpmaddwd ymm0,ymm0,ymm7 # multiply by offset and reduce
|
||||
vpmaddwd ymm1,ymm1,ymm7 # multiply by offset and reduce
|
||||
vmovdqu YMMWORD PTR [r9],ymm0
|
||||
vmovdqu YMMWORD PTR [r9+32],ymm1
|
||||
add r9,16*4 # advance column sum vector by 16 DWORDs
|
||||
sub rcx,16 # subtract columns remaining
|
||||
jae .LCopyPackB.ProcessNextColumnN16
|
||||
|
||||
.LCopyPackB.ProcessRemainingColumns:
|
||||
add rcx,16 # correct for over-subtract above
|
||||
jnz .LCopyPackB.ProcessColumnNUnaligned
|
||||
|
||||
//
|
||||
// Restore non-volatile registers and return.
|
||||
//
|
||||
|
||||
.LCopyPackB.ExitRoutine:
|
||||
vzeroupper
|
||||
|
||||
pop rbx
|
||||
pop rbp
|
||||
ret
|
||||
|
||||
//
|
||||
// Process the remaining columns of matrix B.
|
||||
//
|
||||
|
||||
.LCopyPackB.ProcessColumnNUnaligned:
|
||||
vpxor xmm0,xmm0,xmm0 # clear column accumulators
|
||||
vpxor xmm1,xmm1,xmm1
|
||||
vmovdqu YMMWORD PTR .LGemmU8S8CopyPackBFrame_PaddedMatrixBData[rsp],ymm0
|
||||
vmovdqu YMMWORD PTR .LGemmU8S8CopyPackBFrame_PaddedMatrixBData[rsp+32],ymm0
|
||||
sub r8,4
|
||||
jb .LCopyPackB.ProcessRemainingRowsNUnaligned
|
||||
|
||||
.LCopyPackB.ProcessNextRowLoopNUnaligned:
|
||||
mov rdx,rsi
|
||||
lea rbp,.LGemmU8S8CopyPackBFrame_PaddedMatrixBData[rsp]
|
||||
test cl,8 # (CountN & 8) != 0?
|
||||
jz .LCopyPackB.CopyRemainingCountNLessThan8K4
|
||||
lea r11,[rdx+r10*2] # compute matrix B plus 2 rows
|
||||
mov rax,QWORD PTR [rdx]
|
||||
mov QWORD PTR [rbp],rax
|
||||
mov rax,QWORD PTR [rdx+r10]
|
||||
mov QWORD PTR [rbp+16],rax
|
||||
mov rax,QWORD PTR [r11]
|
||||
mov QWORD PTR [rbp+32],rax
|
||||
mov rax,QWORD PTR [r11+r10]
|
||||
mov QWORD PTR [rbp+48],rax
|
||||
add rdx,8 # advance matrix B
|
||||
add rbp,8 # advance padded buffer destination
|
||||
|
||||
.LCopyPackB.CopyRemainingCountNLessThan8K4:
|
||||
test cl,4 # (CountN & 4) != 0?
|
||||
jz .LCopyPackB.CopyRemainingCountNLessThan4K4
|
||||
lea r11,[rdx+r10*2] # compute matrix B plus 2 rows
|
||||
mov eax,DWORD PTR [rdx]
|
||||
mov DWORD PTR [rbp],eax
|
||||
mov eax,DWORD PTR [rdx+r10]
|
||||
mov DWORD PTR [rbp+16],eax
|
||||
mov eax,DWORD PTR [r11]
|
||||
mov DWORD PTR [rbp+32],eax
|
||||
mov eax,DWORD PTR [r11+r10]
|
||||
mov DWORD PTR [rbp+48],eax
|
||||
add rdx,4 # advance matrix B
|
||||
add rbp,4 # advance padded buffer destination
|
||||
|
||||
.LCopyPackB.CopyRemainingCountNLessThan4K4:
|
||||
test cl,2 # (CountN & 2) != 0?
|
||||
jz .LCopyPackB.CopyRemainingCountNLessThan2K4
|
||||
lea r11,[rdx+r10*2] # compute matrix B plus 2 rows
|
||||
movzx eax,WORD PTR [rdx]
|
||||
mov WORD PTR [rbp],ax
|
||||
movzx eax,WORD PTR [rdx+r10]
|
||||
mov WORD PTR [rbp+16],ax
|
||||
movzx eax,WORD PTR [r11]
|
||||
mov WORD PTR [rbp+32],ax
|
||||
movzx eax,WORD PTR [r11+r10]
|
||||
mov WORD PTR [rbp+48],ax
|
||||
add rdx,2 # advance matrix B
|
||||
add rbp,2 # advance padded buffer destination
|
||||
|
||||
.LCopyPackB.CopyRemainingCountNLessThan2K4:
|
||||
test cl,1 # (CountN & 1) != 0?
|
||||
jz .LCopyPackB.ProcessPaddedMatrixBData
|
||||
lea r11,[rdx+r10*2] # compute matrix B plus 2 rows
|
||||
movzx eax,BYTE PTR [rdx]
|
||||
mov BYTE PTR [rbp],al
|
||||
movzx eax,BYTE PTR [rdx+r10]
|
||||
mov BYTE PTR [rbp+16],al
|
||||
movzx eax,BYTE PTR [r11]
|
||||
mov BYTE PTR [rbp+32],al
|
||||
movzx eax,BYTE PTR [r11+r10]
|
||||
mov BYTE PTR [rbp+48],al
|
||||
|
||||
.LCopyPackB.ProcessPaddedMatrixBData:
|
||||
vmovdqu xmm2,XMMWORD PTR .LGemmU8S8CopyPackBFrame_PaddedMatrixBData[rsp]
|
||||
vmovdqu xmm3,XMMWORD PTR .LGemmU8S8CopyPackBFrame_PaddedMatrixBData[rsp+16]
|
||||
vmovdqu xmm4,XMMWORD PTR .LGemmU8S8CopyPackBFrame_PaddedMatrixBData[rsp+32]
|
||||
vmovdqu xmm5,XMMWORD PTR .LGemmU8S8CopyPackBFrame_PaddedMatrixBData[rsp+48]
|
||||
vpunpcklbw xmm6,xmm2,xmm3 # interleave row data
|
||||
vpunpckhbw xmm3,xmm2,xmm3
|
||||
vpunpcklbw xmm2,xmm4,xmm5
|
||||
vpunpckhbw xmm5,xmm4,xmm5
|
||||
vpunpcklwd xmm4,xmm6,xmm2
|
||||
vpunpckhwd xmm6,xmm6,xmm2
|
||||
vpunpcklwd xmm2,xmm3,xmm5
|
||||
vpunpckhwd xmm3,xmm3,xmm5
|
||||
vinsertf128 ymm4,ymm4,xmm6,1
|
||||
vinsertf128 ymm2,ymm2,xmm3,1
|
||||
vmovdqu YMMWORD PTR [rdi],ymm4 # store interleaved rows
|
||||
vmovdqu YMMWORD PTR [rdi+32],ymm2
|
||||
vpmaddubsw ymm4,ymm8,ymm4 # horizontal byte+byte=word per row
|
||||
vpaddw ymm0,ymm0,ymm4 # add words to row accumulators
|
||||
vpmaddubsw ymm2,ymm8,ymm2
|
||||
vpaddw ymm1,ymm1,ymm2
|
||||
lea rsi,[rsi+r10*4] # advance next matrix B by 4 rows
|
||||
add rdi,64 # advance matrix D by 64 bytes
|
||||
sub r8,4 # subtract rows remaining
|
||||
jae .LCopyPackB.ProcessNextRowLoopNUnaligned
|
||||
|
||||
.LCopyPackB.ProcessRemainingRowsNUnaligned:
|
||||
add r8,4
|
||||
jz .LCopyPackB.ReduceColumnSumVectorNUnaligned
|
||||
|
||||
//
|
||||
// Process the less than 4 remaining rows where the row has less than 16 columns.
|
||||
//
|
||||
|
||||
lea rbp,.LGemmU8S8CopyPackBFrame_PaddedMatrixBData[rsp]
|
||||
vpxor xmm6,xmm6,xmm6
|
||||
vmovdqu YMMWORD PTR [rbp],ymm6
|
||||
vmovdqu YMMWORD PTR [rbp+32],ymm6
|
||||
|
||||
.LCopyPackB.CopyUnalignedRowLoop:
|
||||
lea r11,[rbp+16] # advance next padded buffer by 16 bytes
|
||||
mov rdx,rsi
|
||||
test cl,8 # (CountN & 8) != 0?
|
||||
jz .LCopyPackB.CopyRemainingCountNLessThan8KSmall
|
||||
mov rax,QWORD PTR [rdx]
|
||||
mov QWORD PTR [rbp],rax
|
||||
add rdx,8 # advance matrix B
|
||||
add rbp,8 # advance padded buffer destination
|
||||
|
||||
.LCopyPackB.CopyRemainingCountNLessThan8KSmall:
|
||||
test cl,4 # (CountN & 4) != 0?
|
||||
jz .LCopyPackB.CopyRemainingCountNLessThan4KSmall
|
||||
mov eax,DWORD PTR [rdx]
|
||||
mov DWORD PTR [rbp],eax
|
||||
add rdx,4 # advance matrix B
|
||||
add rbp,4 # advance padded buffer destination
|
||||
|
||||
.LCopyPackB.CopyRemainingCountNLessThan4KSmall:
|
||||
test cl,2 # (CountN & 2) != 0?
|
||||
jz .LCopyPackB.CopyRemainingCountNLessThan2KSmall
|
||||
movzx eax,WORD PTR [rdx]
|
||||
mov WORD PTR [rbp],ax
|
||||
add rdx,2 # advance matrix B
|
||||
add rbp,2 # advance padded buffer destination
|
||||
|
||||
.LCopyPackB.CopyRemainingCountNLessThan2KSmall:
|
||||
test cl,1 # (CountN & 1) != 0?
|
||||
jz .LCopyPackB.DoneCopyRemainingCountNKSmall
|
||||
movzx eax,BYTE PTR [rdx]
|
||||
mov BYTE PTR [rbp],al
|
||||
|
||||
.LCopyPackB.DoneCopyRemainingCountNKSmall:
|
||||
dec r8
|
||||
jz .LCopyPackB.ProcessPaddedMatrixBData
|
||||
add rsi,r10 # advance next matrix B by 1 row
|
||||
mov rbp,r11
|
||||
jmp .LCopyPackB.CopyUnalignedRowLoop
|
||||
|
||||
.LCopyPackB.ReduceColumnSumVectorNUnaligned:
|
||||
vpmaddwd ymm0,ymm0,ymm7 # multiply by offset and reduce
|
||||
vpmaddwd ymm1,ymm1,ymm7 # multiply by offset and reduce
|
||||
vmovdqu YMMWORD PTR [r9],ymm0
|
||||
vmovdqu YMMWORD PTR [r9+32],ymm1
|
||||
jmp .LCopyPackB.ExitRoutine
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to multiply and accumulator a single row of the
|
||||
output block.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
Vec1Reg - Supplies the high block accumulator register (when ColumnCount
|
||||
is 16).
|
||||
|
||||
Vec2Reg - Supplies the low block accumulator register.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
ymm0 - Supplies the first vector loaded from matrix B.
|
||||
|
||||
ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount
|
||||
is 16).
|
||||
|
||||
ymm2 - Supplies the broadcast value loaded from matrix A.
|
||||
|
||||
ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001.
|
||||
|
||||
--*/
|
||||
|
||||
.macro MultiplyAccumulateRow ColumnCount, Vec1Reg, Vec2Reg
|
||||
|
||||
vpmaddubsw ymm3,ymm2,ymm0
|
||||
vpmaddwd ymm3,ymm3,ymm12
|
||||
.if \ColumnCount\() == 16
|
||||
vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3
|
||||
vpmaddubsw ymm2,ymm2,ymm1
|
||||
vpmaddwd ymm2,ymm2,ymm12
|
||||
vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2
|
||||
.else
|
||||
vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm3
|
||||
.endif
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to multiply and accumulate each row of the output
|
||||
block.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
|
||||
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rcx - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
ymm4-ymm11 - Supplies the block accumulators.
|
||||
|
||||
ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
.if \RowCount\() == 1
|
||||
vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]
|
||||
vpmaddubsw ymm3,ymm2,YMMWORD PTR [rsi+\VectorOffset\()]
|
||||
vpmaddwd ymm3,ymm3,ymm12
|
||||
.if \ColumnCount\() == 16
|
||||
vpaddd ymm4,ymm4,ymm3
|
||||
vpmaddubsw ymm2,ymm2,YMMWORD PTR [rsi+\VectorOffset\()+32]
|
||||
vpmaddwd ymm2,ymm2,ymm12
|
||||
vpaddd ymm5,ymm5,ymm2
|
||||
.else
|
||||
vpaddd ymm5,ymm5,ymm3
|
||||
.endif
|
||||
.else
|
||||
vmovdqu ymm0,YMMWORD PTR [rsi+\VectorOffset\()]
|
||||
EmitIfCountGE \ColumnCount\(), 16, "vmovdqu ymm1,YMMWORD PTR [rsi+\VectorOffset\()+32]"
|
||||
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRow \ColumnCount\(), ymm4, ymm5"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRow \ColumnCount\(), ymm6, ymm7"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRow \ColumnCount\(), ymm8, ymm9"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [rbx+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRow \ColumnCount\(), ymm10, ymm11"
|
||||
.endif
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to execute the block compute macro multiple
|
||||
times and advancing the matrix A and matrix B data pointers.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
ymm4-ymm11 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ComputeBlockLoop ColumnCount, RowCount
|
||||
|
||||
mov rbp,rcx # reload row length remaining
|
||||
|
||||
.LComputeBlockBy1Loop\@:
|
||||
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
|
||||
add rdi,4 # advance matrix A by 1 quad
|
||||
.if \RowCount\() > 3
|
||||
add rbx,4 # advance matrix A plus 3 rows by 1 quad
|
||||
.endif
|
||||
add rsi,64 # advance matrix B
|
||||
sub rbp,4
|
||||
jnz .LComputeBlockBy1Loop\@
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine is an inner kernel to compute matrix multiplication for a
|
||||
set of rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
A (rdi) - Supplies the address of matrix A. The matrix data has been packed
|
||||
using MlasGemmU8S8CopyPackAAvx2.
|
||||
|
||||
B (rsi) - Supplies the address of matrix B. The matrix data has been packed
|
||||
using MlasGemmU8S8CopyPackBAvx2.
|
||||
|
||||
C (rdx) - Supplies the address of matrix C.
|
||||
|
||||
QuadCountK (rcx) - Supplies the number of quad columns from matrix A and
|
||||
the number of quad rows from matrix B to iterate over.
|
||||
|
||||
CountM (r8) - Supplies the maximum number of rows that can be processed for
|
||||
matrix A and matrix C. The actual number of rows handled for this
|
||||
invocation depends on the kernel implementation.
|
||||
|
||||
CountN (r9) - Supplies the number of columns from matrix B and matrix C to
|
||||
iterate over.
|
||||
|
||||
ldc - Supplies the first dimension of matrix C.
|
||||
|
||||
RowSumVector - Supplies the sum of each row from matrix A multiplied by the
|
||||
zero point offset of matrix B. These values are accumulated into every
|
||||
row of matrix C.
|
||||
|
||||
ColumnSumVector - Supplies the sum of each column from matrix B multiplied
|
||||
by the zero point offset of matrix A. These values are accumulated into
|
||||
every column of matrix C.
|
||||
|
||||
DepthValue - Supplies the value CountK multiplied by the zero point offset
|
||||
of matrix A multplied by the zero point offset of matrix B. This value
|
||||
is accumulated into every element of matrix C.
|
||||
|
||||
ZeroMode - Supplies true if the output matrix must be zero initialized,
|
||||
else false if the output matrix is accumulated into.
|
||||
|
||||
Return Value:
|
||||
|
||||
Returns the number of rows handled.
|
||||
|
||||
--*/
|
||||
|
||||
.globl C_UNDERSCORE(MlasGemmU8S8KernelAvx2)
|
||||
C_UNDERSCORE(MlasGemmU8S8KernelAvx2):
|
||||
|
||||
push rbp
|
||||
push rbx
|
||||
push r12
|
||||
push r13
|
||||
|
||||
mov rax,.LGemmU8X8KernelFrame_ldc[rsp]
|
||||
shl rax,2 # convert ldc to bytes
|
||||
shl rcx,2 # convert to row length
|
||||
movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp]
|
||||
mov r11,rdi
|
||||
mov r12,.LGemmU8X8KernelFrame_RowSumVector[rsp]
|
||||
mov r13,.LGemmU8X8KernelFrame_ColumnSumVector[rsp]
|
||||
vpcmpeqw ymm12,ymm12,ymm12 # generate 256-bit word vector [0xFFFF]
|
||||
vpsrlw ymm12,ymm12,15 # generate 256-bit word vector [0x0001]
|
||||
|
||||
//
|
||||
// Process CountM rows of the matrices.
|
||||
//
|
||||
|
||||
cmp r8,3
|
||||
ja .LProcessCountM4
|
||||
je .LProcessCountM3
|
||||
cmp r8,1
|
||||
je .LProcessCountM1
|
||||
|
||||
.LProcessCountM2:
|
||||
ProcessCountM 2
|
||||
|
||||
.LProcessCountM4:
|
||||
mov r8d,4 # return 4 rows handled
|
||||
ProcessCountM 4, Fallthrough
|
||||
|
||||
//
|
||||
// Restore non-volatile registers and return.
|
||||
//
|
||||
|
||||
.LExitKernel:
|
||||
mov eax,r8d
|
||||
vzeroupper
|
||||
|
||||
pop r13
|
||||
pop r12
|
||||
pop rbx
|
||||
pop rbp
|
||||
ret
|
||||
|
||||
.LProcessCountM1:
|
||||
ProcessCountM 1
|
||||
|
||||
.LProcessCountM3:
|
||||
ProcessCountM 3
|
||||
|
||||
.end
|
||||
136
onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512BW.S
Normal file
136
onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512BW.S
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
QgemmU8S8KernelAvx512BW.s
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements the kernels for the quantized integer matrix/matrix
|
||||
multiply operation (QGEMM).
|
||||
|
||||
This implementation uses AVX512BW instructions.
|
||||
|
||||
--*/
|
||||
|
||||
#include "asmmacro.h"
|
||||
#include "QgemmU8S8KernelAvx512Common.h"
|
||||
|
||||
.intel_syntax noprefix
|
||||
|
||||
.text
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to multiply and accumulator a single cell of the
|
||||
output block.
|
||||
|
||||
Arguments:
|
||||
|
||||
AccumReg - Supplies the register to accumulate into.
|
||||
|
||||
Mult1Reg - Supplies the first multiplication operand register.
|
||||
|
||||
Mult2Reg - Supplies the second multiplication operand register.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
zmm4 - Supplies a scratch register for intermediate results.
|
||||
|
||||
zmm5 - Supplies a 512-bit with the broadcasted word value 0x0001.
|
||||
|
||||
--*/
|
||||
|
||||
.macro MultiplyAccumulateCell AccumReg, Mult1Reg, Mult2Reg
|
||||
|
||||
vpmaddubsw zmm4,\Mult1Reg\(),\Mult2Reg\()
|
||||
vpmaddwd zmm4,zmm4,zmm5
|
||||
vpaddd \AccumReg\(),\AccumReg\(),zmm4
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to multiply and accumulate each row of the output
|
||||
block.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
|
||||
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
|
||||
zmm14-zmm31 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
.if \ColumnCount\() >= 48
|
||||
vmovdqu32 zmm0,ZMMWORD PTR [rsi+\VectorOffset\()]
|
||||
vmovdqu32 zmm1,ZMMWORD PTR [rsi+r14+\VectorOffset\()]
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [rsi+r14*2+\VectorOffset\()]
|
||||
.elseif \ColumnCount\() >= 32
|
||||
vmovdqu32 zmm1,ZMMWORD PTR [rsi+\VectorOffset\()]
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [rsi+r14+\VectorOffset\()]
|
||||
.else
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [rsi+\VectorOffset\()]
|
||||
.endif
|
||||
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm3,DWORD PTR [rdi+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm26,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm20,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm14,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm27,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm21,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm15,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm28,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm22,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm16,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm3,DWORD PTR [rbx+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm29,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm23,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm17,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm30,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm24,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm18,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm31,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm25,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm19,zmm3,zmm2"
|
||||
|
||||
.endm
|
||||
|
||||
//
|
||||
// Generate the GEMM kernel.
|
||||
//
|
||||
|
||||
GemmU8X8KernelAvx512Function U8S8, Avx512BW
|
||||
|
||||
.end
|
||||
|
|
@ -0,0 +1,88 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
QgemmU8S8KernelAvx512Common.h
|
||||
|
||||
Abstract:
|
||||
|
||||
This module contains common kernel macros and structures for the quantized
|
||||
integer matrix/matrix multiply operation (QGEMM) for the AVX512BW and
|
||||
AVX512VNNI kernels.
|
||||
|
||||
--*/
|
||||
|
||||
#include "QgemmU8X8KernelAvx512Common.h"
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to execute the block compute macro multiple
|
||||
times and advancing the matrix A and matrix B data pointers.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
|
||||
zmm14-zmm31 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ComputeBlockLoop ColumnCount, RowCount
|
||||
|
||||
mov rbp,rcx # reload row length remaining
|
||||
|
||||
.if ((\RowCount\() & 1) == 0)
|
||||
sub rbp,4*4
|
||||
jb .LProcessRemainingBlocks\@
|
||||
|
||||
.LComputeBlockBy4Loop\@:
|
||||
ComputeBlock \ColumnCount\(), \RowCount\(), 0*64, 0
|
||||
ComputeBlock \ColumnCount\(), \RowCount\(), 1*64, 4
|
||||
ComputeBlock \ColumnCount\(), \RowCount\(), 2*64, 8
|
||||
ComputeBlock \ColumnCount\(), \RowCount\(), 3*64, 12
|
||||
add rdi,4*4 # advance matrix A by 1 quad
|
||||
.if \RowCount\() > 3
|
||||
add rbx,4*4 # advance matrix A plus 3 rows by 1 quad
|
||||
.endif
|
||||
add rsi,4*64 # advance matrix B
|
||||
sub rbp,4*4 # decrement quads remaining
|
||||
jae .LComputeBlockBy4Loop\@
|
||||
|
||||
.LProcessRemainingBlocks\@:
|
||||
add rbp,4*4 # correct for over-subtract above
|
||||
jz .LComputeBlockLoopExit\@
|
||||
.endif
|
||||
|
||||
.LComputeBlockBy1Loop\@:
|
||||
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
|
||||
add rdi,4 # advance matrix A by 1 quad
|
||||
.if \RowCount\() > 3
|
||||
add rbx,4 # advance matrix A plus 3 rows by 1 quad
|
||||
.endif
|
||||
add rsi,64 # advance matrix B
|
||||
sub rbp,4 # decrement quads remaining
|
||||
jnz .LComputeBlockBy1Loop\@
|
||||
|
||||
.LComputeBlockLoopExit\@:
|
||||
|
||||
.endm
|
||||
106
onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Vnni.S
Normal file
106
onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx512Vnni.S
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
QgemmU8S8KernelAvx512Vnni.s
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements the kernels for the quantized integer matrix/matrix
|
||||
multiply operation (QGEMM).
|
||||
|
||||
This implementation uses AVX512VNNI instructions.
|
||||
|
||||
--*/
|
||||
|
||||
#include "asmmacro.h"
|
||||
#include "QgemmU8S8KernelAvx512Common.h"
|
||||
#include "AssembleAvx512Vnni.h"
|
||||
|
||||
.intel_syntax noprefix
|
||||
|
||||
.text
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to multiply and accumulate each row of the output
|
||||
block.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
|
||||
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
|
||||
zmm14-zmm31 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
.if \ColumnCount\() >= 48
|
||||
vmovdqu32 zmm0,ZMMWORD PTR [rsi+\VectorOffset\()]
|
||||
vmovdqu32 zmm1,ZMMWORD PTR [rsi+r14+\VectorOffset\()]
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [rsi+r14*2+\VectorOffset\()]
|
||||
.elseif \ColumnCount\() >= 32
|
||||
vmovdqu32 zmm1,ZMMWORD PTR [rsi+\VectorOffset\()]
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [rsi+r14+\VectorOffset\()]
|
||||
.else
|
||||
vmovdqu32 zmm2,ZMMWORD PTR [rsi+\VectorOffset\()]
|
||||
.endif
|
||||
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm3,DWORD PTR [rdi+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm26,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm20,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm14,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm27,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm21,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm15,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm28,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm22,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm16,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm3,DWORD PTR [rbx+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm29,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm23,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm17,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm30,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm24,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm18,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "VpdpbusdsZmmZmmZmm zmm31,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "VpdpbusdsZmmZmmZmm zmm25,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "VpdpbusdsZmmZmmZmm zmm19,zmm3,zmm2"
|
||||
|
||||
.endm
|
||||
|
||||
//
|
||||
// Generate the GEMM kernel.
|
||||
//
|
||||
|
||||
GemmU8X8KernelAvx512Function U8S8, Avx512Vnni
|
||||
|
||||
.end
|
||||
|
|
@ -18,11 +18,10 @@ Abstract:
|
|||
--*/
|
||||
|
||||
#include "asmmacro.h"
|
||||
#include "QgemmU8X8KernelAvx2Common.h"
|
||||
|
||||
.intel_syntax noprefix
|
||||
|
||||
.text
|
||||
|
||||
//
|
||||
// Stack frame layout for the U8U8 CopyPackA routine.
|
||||
//
|
||||
|
|
@ -47,22 +46,7 @@ Abstract:
|
|||
.equ .LGemmU8U8CopyPackBFrame_ReturnAddress, 16
|
||||
.equ .LGemmU8U8CopyPackBFrame_offa, 24
|
||||
|
||||
//
|
||||
// Stack frame layout for the U8U8 kernel.
|
||||
//
|
||||
|
||||
.equ .LGemmU8U8KernelFrame_mask, -8
|
||||
.equ .LGemmU8U8KernelFrame_SavedR14, 0
|
||||
.equ .LGemmU8U8KernelFrame_SavedR13, 8
|
||||
.equ .LGemmU8U8KernelFrame_SavedR12, 16
|
||||
.equ .LGemmU8U8KernelFrame_SavedRbx, 24
|
||||
.equ .LGemmU8U8KernelFrame_SavedRbp, 32
|
||||
.equ .LGemmU8U8KernelFrame_ReturnAddress, 40
|
||||
.equ .LGemmU8U8KernelFrame_ldc, 48
|
||||
.equ .LGemmU8U8KernelFrame_RowSumVector, 56
|
||||
.equ .LGemmU8U8KernelFrame_ColumnSumVector, 64
|
||||
.equ .LGemmU8U8KernelFrame_DepthValue, 72
|
||||
.equ .LGemmU8U8KernelFrame_ZeroMode, 80
|
||||
.text
|
||||
|
||||
/*++
|
||||
|
||||
|
|
@ -138,13 +122,15 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2):
|
|||
//
|
||||
// Process 4 rows of matrix A in a loop.
|
||||
//
|
||||
// For each row, zero extend the source bytes to 16-bits and write to the packed
|
||||
// buffer. The packed buffer has the same data ordering as the source bytes, but
|
||||
// the stride is CountK aligned up to an even number of 16-bit values.
|
||||
// Zero extend the source bytes to 16-bits and write to the packed buffer.
|
||||
//
|
||||
// The packed buffer has the same data ordering as the source bytes, but CountK
|
||||
// is aligned up to a multiple of 2 to maintain 32-bit alignment. All padding
|
||||
// bytes are zero filled.
|
||||
//
|
||||
// These 16-bit values are also accumulated into an intermediate per-row
|
||||
// accumulator. CountK cannot be greater than 256 to avoid overflowing these
|
||||
// 16-bit accumulators.
|
||||
// accumulator. CountK cannot be greater than 128 to avoid overflowing these
|
||||
// signed 16-bit accumulators.
|
||||
//
|
||||
|
||||
sub r11,4
|
||||
|
|
@ -164,12 +150,12 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2):
|
|||
jb .LCopyPackA.ProcessRemainingColumnsM4
|
||||
|
||||
.LCopyPackA.ProcessNextColumnLoopM4:
|
||||
lea rax,[rdx+r10*2] # compute matrix A plus two rows
|
||||
lea rax,[rdx+r10*2] # compute matrix A plus 2 rows
|
||||
vpmovzxbw ymm4,XMMWORD PTR [rdx]
|
||||
vpmovzxbw ymm5,XMMWORD PTR [rdx+r10]
|
||||
vpmovzxbw ymm6,XMMWORD PTR [rax]
|
||||
vpmovzxbw ymm7,XMMWORD PTR [rax+r10]
|
||||
lea rax,[rcx+r12*4] # compute matrix D plus two rows
|
||||
lea rax,[rcx+r12*4] # compute matrix D plus 2 rows
|
||||
vmovdqu YMMWORD PTR [rcx],ymm4
|
||||
vmovdqu YMMWORD PTR [rcx+r12*2],ymm5
|
||||
vmovdqu YMMWORD PTR [rax],ymm6
|
||||
|
|
@ -194,7 +180,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2):
|
|||
lea rbp,.LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp]
|
||||
test bl,8 # (CountK & 8) != 0?
|
||||
jz .LCopyPackA.CopyRemainingCountKLessThan8M4
|
||||
lea r13,[rdx+r10*2] # compute matrix A plus two rows
|
||||
lea r13,[rdx+r10*2] # compute matrix A plus 2 rows
|
||||
mov rax,QWORD PTR [rdx]
|
||||
mov QWORD PTR [rbp],rax
|
||||
mov rax,QWORD PTR [rdx+r10]
|
||||
|
|
@ -209,7 +195,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2):
|
|||
.LCopyPackA.CopyRemainingCountKLessThan8M4:
|
||||
test bl,4 # (CountK & 4) != 0?
|
||||
jz .LCopyPackA.CopyRemainingCountKLessThan4M4
|
||||
lea r13,[rdx+r10*2] # compute matrix A plus two rows
|
||||
lea r13,[rdx+r10*2] # compute matrix A plus 2 rows
|
||||
mov eax,DWORD PTR [rdx]
|
||||
mov DWORD PTR [rbp],eax
|
||||
mov eax,DWORD PTR [rdx+r10]
|
||||
|
|
@ -224,7 +210,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2):
|
|||
.LCopyPackA.CopyRemainingCountKLessThan4M4:
|
||||
test bl,2 # (CountK & 2) != 0?
|
||||
jz .LCopyPackA.CopyRemainingCountKLessThan2M4
|
||||
lea r13,[rdx+r10*2] # compute matrix A plus two rows
|
||||
lea r13,[rdx+r10*2] # compute matrix A plus 2 rows
|
||||
movzx eax,WORD PTR [rdx]
|
||||
mov WORD PTR [rbp],ax
|
||||
movzx eax,WORD PTR [rdx+r10]
|
||||
|
|
@ -239,7 +225,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2):
|
|||
.LCopyPackA.CopyRemainingCountKLessThan2M4:
|
||||
test bl,1 # (CountK & 1) != 0?
|
||||
jz .LCopyPackA.ProcessPaddedMatrixADataM4
|
||||
lea r13,[rdx+r10*2] # compute matrix A plus two rows
|
||||
lea r13,[rdx+r10*2] # compute matrix A plus 2 rows
|
||||
movzx eax,BYTE PTR [rdx]
|
||||
mov BYTE PTR [rbp],al
|
||||
movzx eax,BYTE PTR [rdx+r10]
|
||||
|
|
@ -258,7 +244,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2):
|
|||
vpmovzxbw ymm5,XMMWORD PTR .LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp+16]
|
||||
vpmovzxbw ymm6,XMMWORD PTR .LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp+32]
|
||||
vpmovzxbw ymm7,XMMWORD PTR .LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp+48]
|
||||
lea rax,[rcx+r12*4] # compute matrix D plus two rows
|
||||
lea rax,[rcx+r12*4] # compute matrix D plus 2 rows
|
||||
vpmaskmovd YMMWORD PTR [rcx],ymm9,ymm4
|
||||
vpmaskmovd YMMWORD PTR [rcx+r12*2],ymm9,ymm5
|
||||
vpmaskmovd YMMWORD PTR [rax],ymm9,ymm6
|
||||
|
|
@ -276,20 +262,12 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackAAvx2):
|
|||
//
|
||||
|
||||
.LCopyPackA.ReduceRowSumVectorM4:
|
||||
vpunpckldq ymm4,ymm0,ymm1 # [A5 B5 A4 B4 A1 B1 A0 B0]
|
||||
vpunpckhdq ymm5,ymm0,ymm1 # [A7 B7 A6 B6 A3 B3 A2 B2]
|
||||
vpunpckldq ymm6,ymm2,ymm3 # [C5 D5 C4 D4 C1 D1 C0 D0]
|
||||
vpunpckhdq ymm7,ymm2,ymm3 # [C7 D7 C6 D6 C3 D3 C2 D2]
|
||||
vpunpcklqdq ymm0,ymm4,ymm6 # [A4 B4 C4 D4 A0 B0 C0 D0]
|
||||
vpunpckhqdq ymm1,ymm4,ymm6 # [A5 B5 C5 D5 A1 B1 C1 D1]
|
||||
vpunpcklqdq ymm2,ymm5,ymm7 # [A6 B6 C6 D6 A2 B2 C2 D2]
|
||||
vpunpckhqdq ymm3,ymm5,ymm7 # [A7 B7 C7 D7 A3 B3 C3 D3]
|
||||
vpaddw ymm0,ymm0,ymm1 # reduction
|
||||
vpaddw ymm0,ymm0,ymm2
|
||||
vpaddw ymm0,ymm0,ymm3
|
||||
vphaddw ymm0,ymm0,ymm1 # reduce and interleave Sum1/Sum0
|
||||
vphaddw ymm1,ymm2,ymm3 # reduce and interleave Sum3/Sum2
|
||||
vphaddw ymm0,ymm0,ymm1 # reduce and interleave Sum3/Sum2/Sum1/Sum0
|
||||
vextracti128 xmm1,ymm0,1 # extract high pairs
|
||||
vpaddw xmm0,xmm0,xmm1 # reduction
|
||||
vpmaddwd xmm0,xmm0,xmm8 # multiply by offset and reduce
|
||||
vpaddw xmm0,xmm0,xmm1 # reduce low/high pairs
|
||||
vpmaddwd xmm0,xmm0,xmm8 # multiply by offset and reduce 32-bit sum
|
||||
vmovdqu XMMWORD PTR [r9],xmm0
|
||||
add r9,4*4 # advance row sum vector by 4 dwords
|
||||
sub r11,4 # subtract rows remaining
|
||||
|
|
@ -436,7 +414,6 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
|
|||
push rbx
|
||||
|
||||
mov r10,rdx
|
||||
mov r11,rcx
|
||||
vpbroadcastw ymm5,WORD PTR .LGemmU8U8CopyPackBFrame_offa[rsp]
|
||||
|
||||
//
|
||||
|
|
@ -450,7 +427,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
|
|||
// Process 16 columns of matrix B in a loop.
|
||||
//
|
||||
|
||||
sub r11,16
|
||||
sub rcx,16
|
||||
jb .LCopyPackB.ProcessRemainingColumns
|
||||
|
||||
.LCopyPackB.ProcessNextColumnN16:
|
||||
|
|
@ -463,9 +440,9 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
|
|||
jb .LCopyPackB.ProcessRemainingRowsN16
|
||||
|
||||
.LCopyPackB.ProcessNextRowLoopN16:
|
||||
vmovdqu xmm2,XMMWORD PTR [rdx] # load two rows
|
||||
vmovdqu xmm2,XMMWORD PTR [rdx] # load 2 rows
|
||||
vmovdqu xmm3,XMMWORD PTR [rdx+r10]
|
||||
lea rdx,[rdx+r10*2] # advance matrix B by two rows
|
||||
lea rdx,[rdx+r10*2] # advance matrix B by 2 rows
|
||||
vpunpcklbw xmm4,xmm2,xmm3 # interleave row data
|
||||
vpunpckhbw xmm3,xmm2,xmm3
|
||||
vmovdqu XMMWORD PTR [rdi],xmm4 # store interleaved rows
|
||||
|
|
@ -475,7 +452,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
|
|||
add rdi,32 # advance matrix D by 32 bytes
|
||||
vpaddw ymm0,ymm0,ymm4 # accumulate per column
|
||||
vpaddw ymm1,ymm1,ymm3
|
||||
sub rbx,2 # subtract columns remaining
|
||||
sub rbx,2 # subtract rows remaining
|
||||
jae .LCopyPackB.ProcessNextRowLoopN16
|
||||
|
||||
.LCopyPackB.ProcessRemainingRowsN16:
|
||||
|
|
@ -495,12 +472,12 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
|
|||
vpmaddwd ymm1,ymm1,ymm5 # multiply by offset and reduce
|
||||
vmovdqu YMMWORD PTR [r9],ymm0
|
||||
vmovdqu YMMWORD PTR [r9+32],ymm1
|
||||
add r9,64 # advance column sum vector by 16 dwords
|
||||
sub r11,16 # subtract columns remaining
|
||||
add r9,64 # advance column sum vector by 16 DWORDs
|
||||
sub rcx,16 # subtract columns remaining
|
||||
jae .LCopyPackB.ProcessNextColumnN16
|
||||
|
||||
.LCopyPackB.ProcessRemainingColumns:
|
||||
add r11,16 # correct for over-subtract above
|
||||
add rcx,16 # correct for over-subtract above
|
||||
jnz .LCopyPackB.ProcessColumnNUnaligned
|
||||
|
||||
//
|
||||
|
|
@ -527,7 +504,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
|
|||
.LCopyPackB.ProcessNextRowLoopNUnaligned:
|
||||
mov rdx,rsi
|
||||
lea rbp,.LGemmU8U8CopyPackBFrame_PaddedMatrixBData[rsp]
|
||||
test r11b,8 # (CountN & 8) != 0?
|
||||
test cl,8 # (CountN & 8) != 0?
|
||||
jz .LCopyPackB.CopyRemainingCountNLessThan8K2
|
||||
mov rax,QWORD PTR [rdx]
|
||||
mov QWORD PTR [rbp],rax
|
||||
|
|
@ -537,7 +514,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
|
|||
add rbp,8 # advance padded buffer destination
|
||||
|
||||
.LCopyPackB.CopyRemainingCountNLessThan8K2:
|
||||
test r11b,4 # (CountN & 4) != 0?
|
||||
test cl,4 # (CountN & 4) != 0?
|
||||
jz .LCopyPackB.CopyRemainingCountNLessThan4K2
|
||||
mov eax,DWORD PTR [rdx]
|
||||
mov DWORD PTR [rbp],eax
|
||||
|
|
@ -547,7 +524,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
|
|||
add rbp,4 # advance padded buffer destination
|
||||
|
||||
.LCopyPackB.CopyRemainingCountNLessThan4K2:
|
||||
test r11b,2 # (CountN & 2) != 0?
|
||||
test cl,2 # (CountN & 2) != 0?
|
||||
jz .LCopyPackB.CopyRemainingCountNLessThan2K2
|
||||
movzx eax,WORD PTR [rdx]
|
||||
mov WORD PTR [rbp],ax
|
||||
|
|
@ -557,7 +534,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
|
|||
add rbp,2 # advance padded buffer destination
|
||||
|
||||
.LCopyPackB.CopyRemainingCountNLessThan2K2:
|
||||
test r11b,1 # (CountN & 1) != 0?
|
||||
test cl,1 # (CountN & 1) != 0?
|
||||
jz .LCopyPackB.ProcessPaddedMatrixBDataK2
|
||||
movzx eax,BYTE PTR [rdx]
|
||||
mov BYTE PTR [rbp],al
|
||||
|
|
@ -575,9 +552,9 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
|
|||
vpmovzxbw ymm3,xmm3
|
||||
vpaddw ymm0,ymm0,ymm4 # accumulate per column
|
||||
vpaddw ymm1,ymm1,ymm3
|
||||
lea rsi,[rsi+r10*2] # advance next matrix B by two rows
|
||||
lea rsi,[rsi+r10*2] # advance next matrix B by 2 rows
|
||||
add rdi,32 # advance matrix D by 32 bytes
|
||||
sub r8,2 # subtract columns remaining
|
||||
sub r8,2 # subtract rows remaining
|
||||
jae .LCopyPackB.ProcessNextRowLoopNUnaligned
|
||||
|
||||
.LCopyPackB.ProcessRemainingRowsNUnaligned:
|
||||
|
|
@ -585,7 +562,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
|
|||
jz .LCopyPackB.ReduceColumnSumVectorNUnaligned
|
||||
mov rdx,rsi
|
||||
lea rbp,.LGemmU8U8CopyPackBFrame_PaddedMatrixBData[rsp]
|
||||
test r11b,8 # (CountN & 8) != 0?
|
||||
test cl,8 # (CountN & 8) != 0?
|
||||
jz .LCopyPackB.CopyRemainingCountNLessThan8K1
|
||||
mov rax,QWORD PTR [rdx]
|
||||
mov QWORD PTR [rbp],rax
|
||||
|
|
@ -593,7 +570,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
|
|||
add rbp,8 # advance padded buffer destination
|
||||
|
||||
.LCopyPackB.CopyRemainingCountNLessThan8K1:
|
||||
test r11b,4 # (CountN & 4) != 0?
|
||||
test cl,4 # (CountN & 4) != 0?
|
||||
jz .LCopyPackB.CopyRemainingCountNLessThan4K1
|
||||
mov eax,DWORD PTR [rdx]
|
||||
mov DWORD PTR [rbp],eax
|
||||
|
|
@ -601,7 +578,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
|
|||
add rbp,4 # advance padded buffer destination
|
||||
|
||||
.LCopyPackB.CopyRemainingCountNLessThan4K1:
|
||||
test r11b,2 # (CountN & 2) != 0?
|
||||
test cl,2 # (CountN & 2) != 0?
|
||||
jz .LCopyPackB.CopyRemainingCountNLessThan2K1
|
||||
movzx eax,WORD PTR [rdx]
|
||||
mov WORD PTR [rbp],ax
|
||||
|
|
@ -609,7 +586,7 @@ C_UNDERSCORE(MlasGemmU8U8CopyPackBAvx2):
|
|||
add rbp,2 # advance padded buffer destination
|
||||
|
||||
.LCopyPackB.CopyRemainingCountNLessThan2K1:
|
||||
test r11b,1 # (CountN & 1) != 0?
|
||||
test cl,1 # (CountN & 1) != 0?
|
||||
jz .LCopyPackB.ProcessPaddedMatrixBDataK1
|
||||
movzx eax,BYTE PTR [rdx]
|
||||
mov BYTE PTR [rbp],al
|
||||
|
|
@ -659,13 +636,12 @@ Implicit Arguments:
|
|||
|
||||
.macro MultiplyAccumulateRow ColumnCount, Vec1Reg, Vec2Reg
|
||||
|
||||
.if \ColumnCount\() == 16
|
||||
vpmaddwd ymm3,ymm2,ymm0
|
||||
.if \ColumnCount\() == 16
|
||||
vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3
|
||||
vpmaddwd ymm2,ymm2,ymm1
|
||||
vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2
|
||||
.else
|
||||
vpmaddwd ymm3,ymm2,ymm0
|
||||
vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm3
|
||||
.endif
|
||||
|
||||
|
|
@ -696,7 +672,7 @@ Implicit Arguments:
|
|||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
r10 - Supplies the length in bytes of a row from matrix A.
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
ymm4-ymm15 - Supplies the block accumulators.
|
||||
|
||||
|
|
@ -708,15 +684,15 @@ Implicit Arguments:
|
|||
EmitIfCountGE \ColumnCount\(), 16, "vpmovzxbw ymm1,XMMWORD PTR [rsi+\VectorOffset\()+16]"
|
||||
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRow \ColumnCount\(), ymm4, ymm5"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+r10+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRow \ColumnCount\(), ymm6, ymm7"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+r10*2+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRow \ColumnCount\(), ymm8, ymm9"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [rbx+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRow \ColumnCount\(), ymm10, ymm11"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm2,DWORD PTR [rbx+r10+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm2,DWORD PTR [rbx+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRow \ColumnCount\(), ymm12, ymm13"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm2,DWORD PTR [rbx+r10*2+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm2,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRow \ColumnCount\(), ymm14, ymm15"
|
||||
|
||||
.endm
|
||||
|
|
@ -725,8 +701,8 @@ Implicit Arguments:
|
|||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to produce an output block for a set of columns
|
||||
and rows.
|
||||
This macro generates code to execute the block compute macro multiple
|
||||
times and advancing the matrix A and matrix B data pointers.
|
||||
|
||||
Arguments:
|
||||
|
||||
|
|
@ -736,271 +712,55 @@ Arguments:
|
|||
|
||||
Implicit Arguments:
|
||||
|
||||
rax - Supplies the length in bytes of a row from matrix C.
|
||||
rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the number of paired columns from matrix A and the number of
|
||||
paired rows from matrix B to iterate over.
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r10 - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r12 - Supplies the address of the row sum vector.
|
||||
|
||||
r13 - Supplies the address of the column sum vector.
|
||||
ymm4-ymm15 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ProduceOutputBlock ColumnCount, RowCount
|
||||
.macro ComputeBlockLoop ColumnCount, RowCount
|
||||
|
||||
//
|
||||
// Initialize the accumulators with the sum of the global depth value constant,
|
||||
// the column sums, and the row sums.
|
||||
//
|
||||
|
||||
vpbroadcastd ymm1,DWORD PTR .LGemmU8U8KernelFrame_DepthValue[rsp]
|
||||
.if \ColumnCount\() == 16
|
||||
vpaddd ymm0,ymm1,YMMWORD PTR [r13]
|
||||
vpaddd ymm1,ymm1,YMMWORD PTR [r13+32]
|
||||
add r13,16*4 # advance ColumnSumVector by 16 columns
|
||||
.else
|
||||
vpaddd ymm1,ymm1,YMMWORD PTR [r13]
|
||||
.endif
|
||||
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm5,DWORD PTR [r12]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm7,DWORD PTR [r12+4]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm9,DWORD PTR [r12+8]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm11,DWORD PTR [r12+12]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm13,DWORD PTR [r12+16]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm15,DWORD PTR [r12+20]"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd ymm4,ymm5,ymm0"
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm1"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd ymm6,ymm7,ymm0"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm1"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd ymm8,ymm9,ymm0"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm1"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd ymm10,ymm11,ymm0"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm1"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd ymm12,ymm13,ymm0"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm1"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd ymm14,ymm15,ymm0"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm1"
|
||||
|
||||
//
|
||||
// Iterate over PairedCountK elements from matrix A and matrix B.
|
||||
//
|
||||
// Unrolling the loop to do two iterations improves performance slightly at the
|
||||
// cost of larger code size. Balance this by only unrolling for the common case
|
||||
// of computing 16 columns for an even number of rows.
|
||||
//
|
||||
|
||||
mov rbp,rcx # reload PairedCountK
|
||||
.if \RowCount\() > 3
|
||||
lea rbx,[r10*2+r10]
|
||||
add rbx,rdi # compute matrix A plus 3 rows
|
||||
.endif
|
||||
mov rbp,rcx # reload row length remaining
|
||||
|
||||
.if (\ColumnCount\() == 16) && ((\RowCount\() & 1) == 0)
|
||||
sub rbp,2
|
||||
jb .LProcessRemainingBlocks.\ColumnCount\().\RowCount\()
|
||||
sub rbp,2*4
|
||||
jb .LProcessRemainingBlocks\@
|
||||
|
||||
.LComputeBlockLoop.\ColumnCount\().\RowCount\():
|
||||
.LComputeBlockBy2Loop\@:
|
||||
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
|
||||
ComputeBlock \ColumnCount\(), \RowCount\(), 32, 4
|
||||
add rdi,2*4 # advance matrix A by 2 pairs
|
||||
.if \RowCount\() > 3
|
||||
add rbx,2*4 # advance matrix A plus 3 rows by 2 pairs
|
||||
.endif
|
||||
add rsi,2*32 # advance matrix B by 64 columns
|
||||
sub rbp,2 # subtract pairs remaining
|
||||
jae .LComputeBlockLoop.\ColumnCount\().\RowCount\()
|
||||
add rsi,2*32 # advance matrix B
|
||||
sub rbp,2*4
|
||||
jae .LComputeBlockBy2Loop\@
|
||||
|
||||
.LProcessRemainingBlocks.\ColumnCount\().\RowCount\():
|
||||
add rbp,2 # correct for over-subtract above
|
||||
jz .LComputeBlockLoopExit.\ColumnCount\().\RowCount\()
|
||||
.LProcessRemainingBlocks\@:
|
||||
add rbp,2*4 # correct for over-subtract above
|
||||
jz .LComputeBlockLoopExit\@
|
||||
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
|
||||
add rsi,32 # advance matrix B by 32 columns
|
||||
add rsi,32 # advance matrix B
|
||||
.else
|
||||
.LComputeBlockLoop.\ColumnCount\().\RowCount\():
|
||||
.LComputeBlockBy1Loop\@:
|
||||
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
|
||||
add rdi,4 # advance matrix A by 1 pair
|
||||
.if \RowCount\() > 3
|
||||
add rbx,4 # advance matrix A plus 3 rows by 1 pair
|
||||
.endif
|
||||
add rsi,32
|
||||
dec rbp # decrement pairs remaining
|
||||
jnz .LComputeBlockLoop.\ColumnCount\().\RowCount\()
|
||||
add rsi,32 # advance matrix B
|
||||
sub rbp,4
|
||||
jnz .LComputeBlockBy1Loop\@
|
||||
.endif
|
||||
|
||||
.LComputeBlockLoopExit.\ColumnCount\().\RowCount\():
|
||||
.if \RowCount\() > 3
|
||||
lea rbx,[rdx+rax*2] # compute matrix C plus 3 rows
|
||||
add rbx,rax
|
||||
.endif
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to compute matrix multiplication for a fixed set
|
||||
of rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
RowCount - Supplies the number of rows to process.
|
||||
|
||||
Fallthrough - Supplies a non-blank value if the macro may fall through to
|
||||
the ExitKernel label.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rax - Supplies the length in bytes of a row from matrix C.
|
||||
|
||||
rdi - Supplies the address of matrix A.
|
||||
|
||||
rsi - Supplies the address of matrix B.
|
||||
|
||||
rdx - Supplies the address of matrix C.
|
||||
|
||||
r11 - Supplies the address of matrix A.
|
||||
|
||||
r9 - Supplies the number of columns from matrix B and matrix C to iterate
|
||||
over.
|
||||
|
||||
rcx - Supplies the number of paired columns from matrix A and the number of
|
||||
paired rows from matrix B to iterate over.
|
||||
|
||||
r10 - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r12 - Supplies the address of the row sum vector.
|
||||
|
||||
r13 - Supplies the address of the column sum vector.
|
||||
|
||||
r14b - Supplies the zero mode flag.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ProcessCountM RowCount, Fallthrough
|
||||
|
||||
cmp r9,8
|
||||
jbe .LProcessRemainingCountN.\RowCount\()
|
||||
|
||||
.LProcessNextColumnLoop16xN.\RowCount\():
|
||||
ProduceOutputBlock 16, \RowCount\()
|
||||
sub r9,16
|
||||
jb .LOutputMasked16xNBlock.\RowCount\()
|
||||
test r14b,r14b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutput16xNBlock.\RowCount\()
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx+32]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax+32]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2+32]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [rbx]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [rbx+32]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax+32]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2+32]"
|
||||
|
||||
.LSkipAccumulateOutput16xNBlock.\RowCount\():
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4"
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx+32],ymm5"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax+32],ymm7"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2+32],ymm9"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm10"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx+32],ymm11"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm12"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax+32],ymm13"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm14"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2+32],ymm15"
|
||||
add rdx,16*4 # advance matrix C by 16 columns
|
||||
mov rdi,r11 # reload matrix A
|
||||
cmp r9,8
|
||||
ja .LProcessNextColumnLoop16xN.\RowCount\()
|
||||
test r9,r9
|
||||
jz .LExitKernel
|
||||
|
||||
.LProcessRemainingCountN.\RowCount\():
|
||||
ProduceOutputBlock 8, \RowCount\()
|
||||
cmp r9,8
|
||||
jb .LOutputMasked8xNBlock.\RowCount\()
|
||||
test r14b,r14b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutput8xNBlock.\RowCount\()
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [rbx]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2]"
|
||||
|
||||
.LSkipAccumulateOutput8xNBlock.\RowCount\():
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm5"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm7"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm9"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm11"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm13"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm15"
|
||||
jmp .LExitKernel
|
||||
|
||||
.LOutputMasked16xNBlock.\RowCount\():
|
||||
test r14b,r14b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutputMasked16xNBlock.\RowCount\()
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [rbx]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]"
|
||||
|
||||
.LSkipAccumulateOutputMasked16xNBlock.\RowCount\():
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm10"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm12"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm14"
|
||||
add rdx,8*4 # advance matrix C by 8 columns
|
||||
.if \RowCount\() > 3
|
||||
add rbx,8*4 # advance matrix C plus 3 rows by 8 columns
|
||||
.endif
|
||||
add r9,8 # correct for over-subtract above
|
||||
|
||||
.LOutputMasked8xNBlock.\RowCount\():
|
||||
mov DWORD PTR .LGemmU8U8KernelFrame_mask[rsp],r9d
|
||||
vpbroadcastd ymm0,DWORD PTR .LGemmU8U8KernelFrame_mask[rsp]
|
||||
vpcmpgtd ymm0,ymm0,YMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip]
|
||||
test r14b,r14b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutputMasked8xNBlock.\RowCount\()
|
||||
EmitIfCountGE \RowCount\(), 1, "vpmaskmovd ymm4,ymm0,YMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpmaskmovd ymm6,ymm0,YMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpmaskmovd ymm8,ymm0,YMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpmaskmovd ymm10,ymm0,YMMWORD PTR [rbx]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpmaskmovd ymm12,ymm0,YMMWORD PTR [rbx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpmaskmovd ymm14,ymm0,YMMWORD PTR [rbx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm4"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm6"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm8"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm10"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm12"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm14"
|
||||
|
||||
.LSkipAccumulateOutputMasked8xNBlock.\RowCount\():
|
||||
EmitIfCountGE \RowCount\(), 1, "vpmaskmovd YMMWORD PTR [rdx],ymm0,ymm5"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpmaskmovd YMMWORD PTR [rdx+rax],ymm0,ymm7"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpmaskmovd YMMWORD PTR [rdx+rax*2],ymm0,ymm9"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpmaskmovd YMMWORD PTR [rbx],ymm0,ymm11"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpmaskmovd YMMWORD PTR [rbx+rax],ymm0,ymm13"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpmaskmovd YMMWORD PTR [rbx+rax*2],ymm0,ymm15"
|
||||
.ifb \Fallthrough\()
|
||||
jmp .LExitKernel
|
||||
.endif
|
||||
.LComputeBlockLoopExit\@:
|
||||
|
||||
.endm
|
||||
|
||||
|
|
@ -1021,8 +781,8 @@ Arguments:
|
|||
|
||||
C (rdx) - Supplies the address of matrix C.
|
||||
|
||||
PairedCountK (rcx) - Supplies the number of paired columns from matrix A and
|
||||
the number of paired rows from matrix B to iterate over.
|
||||
PairCountK (rcx) - Supplies the number of pair columns from matrix A and
|
||||
the number of pair rows from matrix B to iterate over.
|
||||
|
||||
CountM (r8) - Supplies the maximum number of rows that can be processed for
|
||||
matrix A and matrix C. The actual number of rows handled for this
|
||||
|
|
@ -1042,8 +802,8 @@ Arguments:
|
|||
every column of matrix C.
|
||||
|
||||
DepthValue - Supplies the value CountK multiplied by the zero point offset
|
||||
of matrixA multplied by the zero point offset of matrix B. This value is
|
||||
accumulated into every element of matrix C.
|
||||
of matrix A multplied by the zero point offset of matrix B. This value
|
||||
is accumulated into every element of matrix C.
|
||||
|
||||
ZeroMode - Supplies true if the output matrix must be zero initialized,
|
||||
else false if the output matrix is accumulated into.
|
||||
|
|
@ -1061,15 +821,14 @@ C_UNDERSCORE(MlasGemmU8U8KernelAvx2):
|
|||
push rbx
|
||||
push r12
|
||||
push r13
|
||||
push r14
|
||||
|
||||
mov rax,.LGemmU8U8KernelFrame_ldc[rsp]
|
||||
mov rax,.LGemmU8X8KernelFrame_ldc[rsp]
|
||||
shl rax,2 # convert ldc to bytes
|
||||
lea r10,[rcx*4]
|
||||
shl rcx,2 # convert to row length
|
||||
movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp]
|
||||
mov r11,rdi
|
||||
mov r12,.LGemmU8U8KernelFrame_RowSumVector[rsp]
|
||||
mov r13,.LGemmU8U8KernelFrame_ColumnSumVector[rsp]
|
||||
movzx r14,BYTE PTR .LGemmU8U8KernelFrame_ZeroMode[rsp]
|
||||
mov r12,.LGemmU8X8KernelFrame_RowSumVector[rsp]
|
||||
mov r13,.LGemmU8X8KernelFrame_ColumnSumVector[rsp]
|
||||
|
||||
//
|
||||
// Process CountM rows of the matrices.
|
||||
|
|
@ -1102,7 +861,6 @@ C_UNDERSCORE(MlasGemmU8U8KernelAvx2):
|
|||
mov eax,r8d
|
||||
vzeroupper
|
||||
|
||||
pop r14
|
||||
pop r13
|
||||
pop r12
|
||||
pop rbx
|
||||
|
|
|
|||
|
|
@ -28,40 +28,27 @@ Abstract:
|
|||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to multiply and accumulator a single row of the
|
||||
This macro generates code to multiply and accumulator a single cell of the
|
||||
output block.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
AccumReg - Supplies the register to accumulate into.
|
||||
|
||||
Vec1Reg - Supplies the high block accumulator register (when ColumnCount
|
||||
is 32).
|
||||
Mult1Reg - Supplies the first multiplication operand register.
|
||||
|
||||
Vec2Reg - Supplies the low block accumulator register.
|
||||
Mult2Reg - Supplies the second multiplication operand register.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
zmm28 - Supplies the first vector loaded from matrix B.
|
||||
|
||||
zmm29 - Supplies the second vector loaded from matrix B (when ColumnCount
|
||||
is 32).
|
||||
|
||||
zmm30 - Supplies the broadcast value loaded from matrix A.
|
||||
zmm4 - Supplies a scratch register for intermediate results.
|
||||
|
||||
--*/
|
||||
|
||||
.macro MultiplyAccumulateRow ColumnCount, Vec1Reg, Vec2Reg
|
||||
.macro MultiplyAccumulateCell AccumReg, Mult1Reg, Mult2Reg
|
||||
|
||||
.if \ColumnCount\() == 32
|
||||
vpmaddwd zmm31,zmm30,zmm28
|
||||
vpaddd \Vec1Reg\(),\Vec1Reg\(),zmm31
|
||||
vpmaddwd zmm30,zmm30,zmm29
|
||||
vpaddd \Vec2Reg\(),\Vec2Reg\(),zmm30
|
||||
.else
|
||||
vpmaddwd zmm31,zmm30,zmm28
|
||||
vpaddd \Vec2Reg\(),\Vec2Reg\(),zmm31
|
||||
.endif
|
||||
vpmaddwd zmm4,\Mult1Reg\(),\Mult2Reg\()
|
||||
vpaddd \AccumReg\(),\AccumReg\(),zmm4
|
||||
|
||||
.endm
|
||||
|
||||
|
|
@ -78,36 +65,62 @@ Arguments:
|
|||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
Implicit Arguments:
|
||||
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
r10 - Supplies the length in bytes of a row from matrix A.
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
zmm16-zmm27 - Supplies the block accumulators.
|
||||
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
|
||||
zmm14-zmm31 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ComputeBlock ColumnCount, RowCount
|
||||
.macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
vpmovzxbw zmm28,YMMWORD PTR [rsi]
|
||||
EmitIfCountGE \ColumnCount\(), 32, "vpmovzxbw zmm29,YMMWORD PTR [rsi+r10*8]"
|
||||
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm30,DWORD PTR [rdi]"
|
||||
EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRow \ColumnCount\(), zmm16, zmm17"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm30,DWORD PTR [rdi+r10]"
|
||||
EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRow \ColumnCount\(), zmm18, zmm19"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm30,DWORD PTR [rdi+r10*2]"
|
||||
EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRow \ColumnCount\(), zmm20, zmm21"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm30,DWORD PTR [rbx]"
|
||||
EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRow \ColumnCount\(), zmm22, zmm23"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm30,DWORD PTR [rbx+r10]"
|
||||
EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRow \ColumnCount\(), zmm24, zmm25"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm30,DWORD PTR [rbx+r10*2]"
|
||||
EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRow \ColumnCount\(), zmm26, zmm27"
|
||||
.if \ColumnCount\() >= 48
|
||||
vpmovzxbw zmm0,YMMWORD PTR [rsi+\VectorOffset\()]
|
||||
vpmovzxbw zmm1,YMMWORD PTR [rsi+r14+\VectorOffset\()]
|
||||
vpmovzxbw zmm2,YMMWORD PTR [rsi+r14*2+\VectorOffset\()]
|
||||
.elseif \ColumnCount\() >= 32
|
||||
vpmovzxbw zmm1,YMMWORD PTR [rsi+\VectorOffset\()]
|
||||
vpmovzxbw zmm2,YMMWORD PTR [rsi+r14+\VectorOffset\()]
|
||||
.else
|
||||
vpmovzxbw zmm2,YMMWORD PTR [rsi+\VectorOffset\()]
|
||||
.endif
|
||||
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm3,DWORD PTR [rdi+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm26,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm20,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm14,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm27,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm21,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm15,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm28,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm22,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm16,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm3,DWORD PTR [rbx+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm29,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm23,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm17,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm30,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm24,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm18,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "MultiplyAccumulateCell zmm31,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "MultiplyAccumulateCell zmm25,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "MultiplyAccumulateCell zmm19,zmm3,zmm2"
|
||||
|
||||
.endm
|
||||
|
||||
|
|
@ -115,6 +128,6 @@ Implicit Arguments:
|
|||
// Generate the GEMM kernel.
|
||||
//
|
||||
|
||||
GemmU8U8KernelAvx512Function Avx512BW
|
||||
GemmU8X8KernelAvx512Function U8U8, Avx512BW
|
||||
|
||||
.end
|
||||
|
|
|
|||
|
|
@ -16,28 +16,14 @@ Abstract:
|
|||
|
||||
--*/
|
||||
|
||||
//
|
||||
// Stack frame layout for the U8U8 kernel.
|
||||
//
|
||||
|
||||
.equ .LGemmU8U8KernelFrame_SavedR14, 0
|
||||
.equ .LGemmU8U8KernelFrame_SavedR13, 8
|
||||
.equ .LGemmU8U8KernelFrame_SavedR12, 16
|
||||
.equ .LGemmU8U8KernelFrame_SavedRbx, 24
|
||||
.equ .LGemmU8U8KernelFrame_SavedRbp, 32
|
||||
.equ .LGemmU8U8KernelFrame_ReturnAddress, 40
|
||||
.equ .LGemmU8U8KernelFrame_ldc, 48
|
||||
.equ .LGemmU8U8KernelFrame_RowSumVector, 56
|
||||
.equ .LGemmU8U8KernelFrame_ColumnSumVector, 64
|
||||
.equ .LGemmU8U8KernelFrame_DepthValue, 72
|
||||
.equ .LGemmU8U8KernelFrame_ZeroMode, 80
|
||||
#include "QgemmU8X8KernelAvx512Common.h"
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to produce an output block for a set of columns
|
||||
and rows.
|
||||
This macro generates code to execute the block compute macro multiple
|
||||
times and advancing the matrix A and matrix B data pointers.
|
||||
|
||||
Arguments:
|
||||
|
||||
|
|
@ -47,315 +33,32 @@ Arguments:
|
|||
|
||||
Implicit Arguments:
|
||||
|
||||
rax - Supplies the length in bytes of a row from matrix C.
|
||||
rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the number of paired columns from matrix A and the number of
|
||||
paired rows from matrix B to iterate over.
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r10 - Supplies the length in bytes of a row from matrix A.
|
||||
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
|
||||
r12 - Supplies the address of the row sum vector.
|
||||
|
||||
r13 - Supplies the address of the column sum vector.
|
||||
zmm14-zmm31 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ProduceOutputBlock ColumnCount, RowCount
|
||||
.macro ComputeBlockLoop ColumnCount, RowCount
|
||||
|
||||
//
|
||||
// Initialize the accumulators with the sum of the global depth value constant,
|
||||
// the column sums, and the row sums.
|
||||
//
|
||||
mov rbp,rcx # reload row length remaining
|
||||
|
||||
vpbroadcastd zmm31,DWORD PTR .LGemmU8U8KernelFrame_DepthValue[rsp]
|
||||
.if \ColumnCount\() == 32
|
||||
vpaddd zmm30,zmm31,ZMMWORD PTR [r13]
|
||||
vpaddd zmm31,zmm31,ZMMWORD PTR [r13+64]
|
||||
add r13,32*4 # advance ColumnSumVector by 32 columns
|
||||
.else
|
||||
vpaddd zmm31,zmm31,ZMMWORD PTR [r13]
|
||||
.endif
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd zmm16,zmm30,DWORD PTR [r12]{1to16}"
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd zmm17,zmm31,DWORD PTR [r12]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "vpaddd zmm18,zmm30,DWORD PTR [r12+4]{1to16}"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd zmm19,zmm31,DWORD PTR [r12+4]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "vpaddd zmm20,zmm30,DWORD PTR [r12+8]{1to16}"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd zmm21,zmm31,DWORD PTR [r12+8]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "vpaddd zmm22,zmm30,DWORD PTR [r12+12]{1to16}"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd zmm23,zmm31,DWORD PTR [r12+12]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "vpaddd zmm24,zmm30,DWORD PTR [r12+16]{1to16}"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd zmm25,zmm31,DWORD PTR [r12+16]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "vpaddd zmm26,zmm30,DWORD PTR [r12+20]{1to16}"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd zmm27,zmm31,DWORD PTR [r12+20]{1to16}"
|
||||
|
||||
//
|
||||
// Iterate over PairedCountK elements from matrix A and matrix B.
|
||||
//
|
||||
|
||||
mov rbp,rcx # reload PairedCountK
|
||||
.if \RowCount\() > 3
|
||||
lea rbx,[r10*2+r10]
|
||||
add rbx,rdi # compute matrix A plus 3 rows
|
||||
.endif
|
||||
|
||||
.LComputeBlockLoop.\ColumnCount\().\RowCount\():
|
||||
ComputeBlock \ColumnCount\(), \RowCount\()
|
||||
.LComputeBlockBy1Loop\@:
|
||||
ComputeBlock \ColumnCount\(), \RowCount\(), 0, 0
|
||||
add rdi,4 # advance matrix A by 1 pair
|
||||
.if \RowCount\() > 3
|
||||
add rbx,4 # advance matrix A plus 3 rows by 1 pair
|
||||
.endif
|
||||
add rsi,32
|
||||
dec rbp # decrement pairs remaining
|
||||
jnz .LComputeBlockLoop.\ColumnCount\().\RowCount\()
|
||||
|
||||
.if \RowCount\() > 3
|
||||
lea rbx,[rdx+rax*2] # compute matrix C plus 3 rows
|
||||
add rbx,rax
|
||||
.endif
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to compute matrix multiplication for a fixed set
|
||||
of rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
RowCount - Supplies the number of rows to process.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rax - Supplies the length in bytes of a row from matrix C.
|
||||
|
||||
rdi - Supplies the address of matrix A.
|
||||
|
||||
rsi - Supplies the address of matrix B.
|
||||
|
||||
rdx - Supplies the address of matrix C.
|
||||
|
||||
r11 - Supplies the address of matrix A.
|
||||
|
||||
r9 - Supplies the number of columns from matrix B and matrix C to iterate
|
||||
over.
|
||||
|
||||
rcx - Supplies the number of paired columns from matrix A and the number of
|
||||
paired rows from matrix B to iterate over.
|
||||
|
||||
r10 - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r12 - Supplies the address of the row sum vector.
|
||||
|
||||
r13 - Supplies the address of the column sum vector.
|
||||
|
||||
r14b - Supplies the zero mode flag.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ProcessCountM RowCount
|
||||
|
||||
cmp r9,16
|
||||
jbe .LProcessRemainingCountN.\RowCount\()
|
||||
|
||||
.LProcessNextColumnLoop32xN.\RowCount\():
|
||||
ProduceOutputBlock 32, \RowCount\()
|
||||
lea rsi,[rsi+r10*8] # advance matrix B by 8*PairedCountK
|
||||
test r14b,r14b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutput32xNBlock.\RowCount\()
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd zmm16,zmm16,ZMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd zmm18,zmm18,ZMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd zmm20,zmm20,ZMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd zmm22,zmm22,ZMMWORD PTR [rbx]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd zmm24,zmm24,ZMMWORD PTR [rbx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd zmm26,zmm26,ZMMWORD PTR [rbx+rax*2]"
|
||||
|
||||
.LSkipAccumulateOutput32xNBlock.\RowCount\():
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx],zmm16"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax],zmm18"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2],zmm20"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx],zmm22"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax],zmm24"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm26"
|
||||
add rdx,16*4 # advance matrix C by 16 columns
|
||||
.if \RowCount\() > 3
|
||||
add rbx,16*4 # advance matrix C plus 3 rows by 16 columns
|
||||
.endif
|
||||
sub r9,16
|
||||
|
||||
.LOutput16xNBlock.\RowCount\():
|
||||
sub r9,16
|
||||
jae .LOutput16xNBlockWithMask.\RowCount\()
|
||||
lea rcx,[r9+16] # correct for over-subtract above
|
||||
mov ebp,1
|
||||
shl ebp,cl
|
||||
dec ebp
|
||||
kmovw k1,ebp # update mask for remaining columns
|
||||
xor r9,r9 # no more columns remaining
|
||||
|
||||
.LOutput16xNBlockWithMask.\RowCount\():
|
||||
test r14b,r14b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutput16xNBlockWithMask.\RowCount\()
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd zmm17{k1},zmm17,ZMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd zmm19{k1},zmm19,ZMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd zmm21{k1},zmm21,ZMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd zmm23{k1},zmm23,ZMMWORD PTR [rbx]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd zmm25{k1},zmm25,ZMMWORD PTR [rbx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd zmm27{k1},zmm27,ZMMWORD PTR [rbx+rax*2]"
|
||||
|
||||
.LSkipAccumulateOutput16xNBlockWithMask.\RowCount\():
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx]{k1},zmm17"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax]{k1},zmm19"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2]{k1},zmm21"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx]{k1},zmm23"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax]{k1},zmm25"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2]{k1},zmm27"
|
||||
add rdx,16*4 # advance matrix C by 16 columns
|
||||
mov rdi,r11 # reload matrix A
|
||||
cmp r9,16
|
||||
ja .LProcessNextColumnLoop32xN.\RowCount\()
|
||||
test r9,r9
|
||||
jz .LExitKernel
|
||||
|
||||
.LProcessRemainingCountN.\RowCount\():
|
||||
ProduceOutputBlock 16, \RowCount\()
|
||||
jmp .LOutput16xNBlock.\RowCount\()
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates the common AVX512 code for the inner kernel to compute
|
||||
matrix multiplication.
|
||||
|
||||
Arguments:
|
||||
|
||||
Isa - Supplies the instruction set architecture string for function tags.
|
||||
|
||||
--*/
|
||||
|
||||
.macro GemmU8U8KernelAvx512Function Isa
|
||||
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine is an inner kernel to compute matrix multiplication for a
|
||||
set of rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
A (rdi) - Supplies the address of matrix A. The matrix data has been packed
|
||||
using MlasGemmU8U8CopyPackAAvx2.
|
||||
|
||||
B (rsi) - Supplies the address of matrix B. The matrix data has been packed
|
||||
using MlasGemmU8U8CopyPackBAvx2.
|
||||
|
||||
C (rdx) - Supplies the address of matrix C.
|
||||
|
||||
PairedCountK (rcx) - Supplies the number of paired columns from matrix A and
|
||||
the number of paired rows from matrix B to iterate over.
|
||||
|
||||
CountM (r8) - Supplies the maximum number of rows that can be processed for
|
||||
matrix A and matrix C. The actual number of rows handled for this
|
||||
invocation depends on the kernel implementation.
|
||||
|
||||
CountN (r9) - Supplies the number of columns from matrix B and matrix C to
|
||||
iterate over.
|
||||
|
||||
ldc - Supplies the first dimension of matrix C.
|
||||
|
||||
RowSumVector - Supplies the sum of each row from matrix A multiplied by the
|
||||
zero point offset of matrix B. These values are accumulated into every
|
||||
row of matrix C.
|
||||
|
||||
ColumnSumVector - Supplies the sum of each column from matrix B multiplied
|
||||
by the zero point offset of matrix A. These values are accumulated into
|
||||
every column of matrix C.
|
||||
|
||||
DepthValue - Supplies the value CountK multiplied by the zero point offset
|
||||
of matrixA multplied by the zero point offset of matrix B. This value is
|
||||
accumulated into every element of matrix C.
|
||||
|
||||
ZeroMode - Supplies true if the output matrix must be zero initialized,
|
||||
else false if the output matrix is accumulated into.
|
||||
|
||||
Return Value:
|
||||
|
||||
Returns the number of rows handled.
|
||||
|
||||
--*/
|
||||
|
||||
.globl C_UNDERSCORE(MlasGemmU8U8Kernel\Isa\())
|
||||
C_UNDERSCORE(MlasGemmU8U8Kernel\Isa\()):
|
||||
|
||||
push rbp
|
||||
push rbx
|
||||
push r12
|
||||
push r13
|
||||
push r14
|
||||
|
||||
mov rax,.LGemmU8U8KernelFrame_ldc[rsp]
|
||||
shl rax,2 # convert ldc to bytes
|
||||
lea r10,[rcx*4]
|
||||
mov r11,rdi
|
||||
mov r12,.LGemmU8U8KernelFrame_RowSumVector[rsp]
|
||||
mov r13,.LGemmU8U8KernelFrame_ColumnSumVector[rsp]
|
||||
movzx r14,BYTE PTR .LGemmU8U8KernelFrame_ZeroMode[rsp]
|
||||
mov ebp,-1
|
||||
kmovw k1,ebp # update mask to write all columns
|
||||
|
||||
//
|
||||
// Process CountM rows of the matrices.
|
||||
//
|
||||
|
||||
cmp r8,5
|
||||
ja .LProcessCountM6
|
||||
je .LProcessCountM5
|
||||
cmp r8,3
|
||||
ja .LProcessCountM4
|
||||
je .LProcessCountM3
|
||||
cmp r8,1
|
||||
je .LProcessCountM1
|
||||
|
||||
.LProcessCountM2:
|
||||
ProcessCountM 2
|
||||
|
||||
.LProcessCountM4:
|
||||
ProcessCountM 4
|
||||
|
||||
.LProcessCountM6:
|
||||
mov r8d,6 # return 6 rows handled
|
||||
ProcessCountM 6
|
||||
|
||||
//
|
||||
// Restore non-volatile registers and return.
|
||||
//
|
||||
|
||||
.LExitKernel:
|
||||
mov eax,r8d
|
||||
|
||||
pop r14
|
||||
pop r13
|
||||
pop r12
|
||||
pop rbx
|
||||
pop rbp
|
||||
ret
|
||||
|
||||
.LProcessCountM1:
|
||||
ProcessCountM 1
|
||||
|
||||
.LProcessCountM3:
|
||||
ProcessCountM 3
|
||||
|
||||
.LProcessCountM5:
|
||||
ProcessCountM 5
|
||||
add rsi,32 # advance matrix B
|
||||
sub rbp,4
|
||||
jnz .LComputeBlockBy1Loop\@
|
||||
|
||||
.endm
|
||||
|
|
|
|||
|
|
@ -38,50 +38,69 @@ Arguments:
|
|||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
Implicit Arguments:
|
||||
VectorOffset - Supplies the byte offset from matrix B to fetch elements.
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
BroadcastOffset - Supplies the byte offset from matrix A to fetch elements.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rbx - Supplies the address into the matrix A data plus 3 rows.
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
r10 - Supplies the length in bytes of a row from matrix A.
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
zmm16-zmm27 - Supplies the block accumulators.
|
||||
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
|
||||
zmm14-zmm31 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ComputeBlock ColumnCount, RowCount
|
||||
.macro ComputeBlock ColumnCount, RowCount, VectorOffset, BroadcastOffset
|
||||
|
||||
vpmovzxbw zmm28,YMMWORD PTR [rsi]
|
||||
.if \ColumnCount\() == 32
|
||||
vpmovzxbw zmm29,YMMWORD PTR [rsi+r10*8]
|
||||
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm30,DWORD PTR [rdi]"
|
||||
EmitIfCountGE \RowCount\(), 1, "VpdpwssdZmmZmmZmm zmm16,zmm28,zmm30"
|
||||
EmitIfCountGE \RowCount\(), 1, "VpdpwssdZmmZmmZmm zmm17,zmm29,zmm30"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm30,DWORD PTR [rdi+r10]"
|
||||
EmitIfCountGE \RowCount\(), 2, "VpdpwssdZmmZmmZmm zmm18,zmm28,zmm30"
|
||||
EmitIfCountGE \RowCount\(), 2, "VpdpwssdZmmZmmZmm zmm19,zmm29,zmm30"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm30,DWORD PTR [rdi+r10*2]"
|
||||
EmitIfCountGE \RowCount\(), 3, "VpdpwssdZmmZmmZmm zmm20,zmm28,zmm30"
|
||||
EmitIfCountGE \RowCount\(), 3, "VpdpwssdZmmZmmZmm zmm21,zmm29,zmm30"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm30,DWORD PTR [rbx]"
|
||||
EmitIfCountGE \RowCount\(), 4, "VpdpwssdZmmZmmZmm zmm22,zmm28,zmm30"
|
||||
EmitIfCountGE \RowCount\(), 4, "VpdpwssdZmmZmmZmm zmm23,zmm29,zmm30"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm30,DWORD PTR [rbx+r10]"
|
||||
EmitIfCountGE \RowCount\(), 5, "VpdpwssdZmmZmmZmm zmm24,zmm28,zmm30"
|
||||
EmitIfCountGE \RowCount\(), 5, "VpdpwssdZmmZmmZmm zmm25,zmm29,zmm30"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm30,DWORD PTR [rbx+r10*2]"
|
||||
EmitIfCountGE \RowCount\(), 6, "VpdpwssdZmmZmmZmm zmm26,zmm28,zmm30"
|
||||
EmitIfCountGE \RowCount\(), 6, "VpdpwssdZmmZmmZmm zmm27,zmm29,zmm30"
|
||||
.if \ColumnCount\() >= 32
|
||||
.if \ColumnCount\() >= 48
|
||||
vpmovzxbw zmm0,YMMWORD PTR [rsi+\VectorOffset\()]
|
||||
vpmovzxbw zmm1,YMMWORD PTR [rsi+r14+\VectorOffset\()]
|
||||
vpmovzxbw zmm2,YMMWORD PTR [rsi+r14*2+\VectorOffset\()]
|
||||
.else
|
||||
EmitIfCountGE \RowCount\(), 1, "VpdpwssdZmmZmmBroadcast zmm17,zmm28,rdi"
|
||||
EmitIfCountGE \RowCount\(), 2, "VpdpwssdZmmZmmBroadcast zmm19,zmm28,rdi,r10,1"
|
||||
EmitIfCountGE \RowCount\(), 3, "VpdpwssdZmmZmmBroadcast zmm21,zmm28,rdi,r10,2"
|
||||
EmitIfCountGE \RowCount\(), 4, "VpdpwssdZmmZmmBroadcast zmm23,zmm28,rbx"
|
||||
EmitIfCountGE \RowCount\(), 5, "VpdpwssdZmmZmmBroadcast zmm25,zmm28,rbx,r10,1"
|
||||
EmitIfCountGE \RowCount\(), 6, "VpdpwssdZmmZmmBroadcast zmm27,zmm28,rbx,r10,2"
|
||||
vpmovzxbw zmm1,YMMWORD PTR [rsi+\VectorOffset\()]
|
||||
vpmovzxbw zmm2,YMMWORD PTR [rsi+r14+\VectorOffset\()]
|
||||
.endif
|
||||
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm3,DWORD PTR [rdi+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "VpdpwssdZmmZmmZmm zmm26,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "VpdpwssdZmmZmmZmm zmm20,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "VpdpwssdZmmZmmZmm zmm14,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "VpdpwssdZmmZmmZmm zmm27,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "VpdpwssdZmmZmmZmm zmm21,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "VpdpwssdZmmZmmZmm zmm15,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "VpdpwssdZmmZmmZmm zmm28,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "VpdpwssdZmmZmmZmm zmm22,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "VpdpwssdZmmZmmZmm zmm16,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm3,DWORD PTR [rbx+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "VpdpwssdZmmZmmZmm zmm29,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "VpdpwssdZmmZmmZmm zmm23,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "VpdpwssdZmmZmmZmm zmm17,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "VpdpwssdZmmZmmZmm zmm30,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "VpdpwssdZmmZmmZmm zmm24,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "VpdpwssdZmmZmmZmm zmm18,zmm3,zmm2"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm3,DWORD PTR [rbx+rcx*2+\BroadcastOffset\()]"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "VpdpwssdZmmZmmZmm zmm31,zmm3,zmm0"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "VpdpwssdZmmZmmZmm zmm25,zmm3,zmm1"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "VpdpwssdZmmZmmZmm zmm19,zmm3,zmm2"
|
||||
.else
|
||||
vpmovzxbw zmm2,YMMWORD PTR [rsi+\VectorOffset\()]
|
||||
EmitIfCountGE \RowCount\(), 1, "VpdpwssdZmmZmmBroadcast zmm14,zmm2,rdi,\BroadcastOffset\()"
|
||||
EmitIfCountGE \RowCount\(), 2, "VpdpwssdZmmZmmBroadcast zmm15,zmm2,rdi,\BroadcastOffset\(),rcx,1"
|
||||
EmitIfCountGE \RowCount\(), 3, "VpdpwssdZmmZmmBroadcast zmm16,zmm2,rdi,\BroadcastOffset\(),rcx,2"
|
||||
EmitIfCountGE \RowCount\(), 4, "VpdpwssdZmmZmmBroadcast zmm17,zmm2,rbx,\BroadcastOffset\()"
|
||||
EmitIfCountGE \RowCount\(), 5, "VpdpwssdZmmZmmBroadcast zmm18,zmm2,rbx,\BroadcastOffset\(),rcx,1"
|
||||
EmitIfCountGE \RowCount\(), 6, "VpdpwssdZmmZmmBroadcast zmm19,zmm2,rbx,\BroadcastOffset\(),rcx,2"
|
||||
.endif
|
||||
|
||||
.endm
|
||||
|
|
@ -90,6 +109,6 @@ Implicit Arguments:
|
|||
// Generate the GEMM kernel.
|
||||
//
|
||||
|
||||
GemmU8U8KernelAvx512Function Avx512Vnni
|
||||
GemmU8X8KernelAvx512Function U8U8, Avx512Vnni
|
||||
|
||||
.end
|
||||
|
|
|
|||
273
onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2Common.h
Normal file
273
onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2Common.h
Normal file
|
|
@ -0,0 +1,273 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
QgemmU8X8KernelAvx2Common.h
|
||||
|
||||
Abstract:
|
||||
|
||||
This module contains common kernel macros and structures for the quantized
|
||||
integer matrix/matrix multiply operation (QGEMM) for the AVX2 kernels.
|
||||
|
||||
--*/
|
||||
|
||||
//
|
||||
// Stack frame layout for the U8S8 and U8U8 kernels.
|
||||
//
|
||||
|
||||
.equ .LGemmU8X8KernelFrame_mask, -8
|
||||
.equ .LGemmU8X8KernelFrame_SavedR13, 0
|
||||
.equ .LGemmU8X8KernelFrame_SavedR12, 8
|
||||
.equ .LGemmU8X8KernelFrame_SavedRbx, 16
|
||||
.equ .LGemmU8X8KernelFrame_SavedRbp, 24
|
||||
.equ .LGemmU8X8KernelFrame_ReturnAddress, 32
|
||||
.equ .LGemmU8X8KernelFrame_ldc, 40
|
||||
.equ .LGemmU8X8KernelFrame_RowSumVector, 48
|
||||
.equ .LGemmU8X8KernelFrame_ColumnSumVector, 56
|
||||
.equ .LGemmU8X8KernelFrame_DepthValue, 64
|
||||
.equ .LGemmU8X8KernelFrame_ZeroMode, 72
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to produce an output block for a set of columns
|
||||
and rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rax - Supplies the length in bytes of a row from matrix C.
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r12 - Supplies the address of the row sum vector.
|
||||
|
||||
r13 - Supplies the address of the column sum vector.
|
||||
|
||||
ymm4-ymm15 - Supplies the block accumulators.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ProduceOutputBlock ColumnCount, RowCount
|
||||
|
||||
//
|
||||
// Initialize the accumulators with the sum of the global depth value constant,
|
||||
// the column sums, and the row sums.
|
||||
//
|
||||
|
||||
vpbroadcastd ymm1,DWORD PTR .LGemmU8X8KernelFrame_DepthValue[rsp]
|
||||
.if \ColumnCount\() == 16
|
||||
vpaddd ymm0,ymm1,YMMWORD PTR [r13]
|
||||
vpaddd ymm1,ymm1,YMMWORD PTR [r13+32]
|
||||
add r13,16*4 # advance ColumnSumVector by 16 columns
|
||||
.else
|
||||
vpaddd ymm1,ymm1,YMMWORD PTR [r13]
|
||||
.endif
|
||||
EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm5,DWORD PTR [r12]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm7,DWORD PTR [r12+4]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm9,DWORD PTR [r12+8]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm11,DWORD PTR [r12+12]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm13,DWORD PTR [r12+16]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm15,DWORD PTR [r12+20]"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd ymm4,ymm5,ymm0"
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm1"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd ymm6,ymm7,ymm0"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm1"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd ymm8,ymm9,ymm0"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm1"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd ymm10,ymm11,ymm0"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm1"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd ymm12,ymm13,ymm0"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm1"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd ymm14,ymm15,ymm0"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm1"
|
||||
|
||||
//
|
||||
// Iterate over the length of a matrix A row to produce the output accumulators.
|
||||
//
|
||||
|
||||
.if \RowCount\() > 3
|
||||
lea rbx,[rcx*2+rcx]
|
||||
add rbx,rdi # compute matrix A plus 3 rows
|
||||
.endif
|
||||
ComputeBlockLoop \ColumnCount\(), \RowCount\()
|
||||
.if \RowCount\() > 3
|
||||
lea rbx,[rdx+rax*2] # compute matrix C plus 3 rows
|
||||
add rbx,rax
|
||||
.endif
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to compute matrix multiplication for a fixed set
|
||||
of rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
RowCount - Supplies the number of rows to process.
|
||||
|
||||
Fallthrough - Supplies a non-blank value if the macro may fall through to
|
||||
the ExitKernel label.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rax - Supplies the length in bytes of a row from matrix C.
|
||||
|
||||
rdi - Supplies the address of matrix A.
|
||||
|
||||
rsi - Supplies the address of matrix B.
|
||||
|
||||
rdx - Supplies the address of matrix C.
|
||||
|
||||
r11 - Supplies the address of matrix A.
|
||||
|
||||
r9 - Supplies the number of columns from matrix B and matrix C to iterate
|
||||
over.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r10b - Supplies the zero mode flag.
|
||||
|
||||
r12 - Supplies the address of the row sum vector.
|
||||
|
||||
r13 - Supplies the address of the column sum vector.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ProcessCountM RowCount, Fallthrough
|
||||
|
||||
cmp r9,8
|
||||
jbe .LProcessRemainingCountN\@
|
||||
|
||||
.LProcessNextColumnLoop16xN\@:
|
||||
ProduceOutputBlock 16, \RowCount\()
|
||||
sub r9,16
|
||||
jb .LOutputMasked16xNBlock\@
|
||||
test r10b,r10b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutput16xNBlock\@
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx+32]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax+32]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2+32]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [rbx]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [rbx+32]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax+32]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2+32]"
|
||||
|
||||
.LSkipAccumulateOutput16xNBlock\@:
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4"
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx+32],ymm5"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax+32],ymm7"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2+32],ymm9"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm10"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx+32],ymm11"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm12"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax+32],ymm13"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm14"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2+32],ymm15"
|
||||
add rdx,16*4 # advance matrix C by 16 columns
|
||||
mov rdi,r11 # reload matrix A
|
||||
cmp r9,8
|
||||
ja .LProcessNextColumnLoop16xN\@
|
||||
test r9,r9
|
||||
jz .LExitKernel
|
||||
|
||||
.LProcessRemainingCountN\@:
|
||||
ProduceOutputBlock 8, \RowCount\()
|
||||
cmp r9,8
|
||||
jb .LOutputMasked8xNBlock\@
|
||||
test r10b,r10b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutput8xNBlock\@
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [rbx]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [rbx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [rbx+rax*2]"
|
||||
|
||||
.LSkipAccumulateOutput8xNBlock\@:
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm5"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm7"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm9"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm11"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm13"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm15"
|
||||
jmp .LExitKernel
|
||||
|
||||
.LOutputMasked16xNBlock\@:
|
||||
test r10b,r10b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutputMasked16xNBlock\@
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [rbx]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [rbx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [rbx+rax*2]"
|
||||
|
||||
.LSkipAccumulateOutputMasked16xNBlock\@:
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [rbx],ymm10"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [rbx+rax],ymm12"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [rbx+rax*2],ymm14"
|
||||
add rdx,8*4 # advance matrix C by 8 columns
|
||||
.if \RowCount\() > 3
|
||||
add rbx,8*4 # advance matrix C plus 3 rows by 8 columns
|
||||
.endif
|
||||
add r9,8 # correct for over-subtract above
|
||||
|
||||
.LOutputMasked8xNBlock\@:
|
||||
mov DWORD PTR .LGemmU8X8KernelFrame_mask[rsp],r9d
|
||||
vpbroadcastd ymm0,DWORD PTR .LGemmU8X8KernelFrame_mask[rsp]
|
||||
vpcmpgtd ymm0,ymm0,YMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip]
|
||||
test r10b,r10b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutputMasked8xNBlock\@
|
||||
EmitIfCountGE \RowCount\(), 1, "vpmaskmovd ymm4,ymm0,YMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpmaskmovd ymm6,ymm0,YMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpmaskmovd ymm8,ymm0,YMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpmaskmovd ymm10,ymm0,YMMWORD PTR [rbx]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpmaskmovd ymm12,ymm0,YMMWORD PTR [rbx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpmaskmovd ymm14,ymm0,YMMWORD PTR [rbx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm4"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm6"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm8"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm10"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm12"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm14"
|
||||
|
||||
.LSkipAccumulateOutputMasked8xNBlock\@:
|
||||
EmitIfCountGE \RowCount\(), 1, "vpmaskmovd YMMWORD PTR [rdx],ymm0,ymm5"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpmaskmovd YMMWORD PTR [rdx+rax],ymm0,ymm7"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpmaskmovd YMMWORD PTR [rdx+rax*2],ymm0,ymm9"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpmaskmovd YMMWORD PTR [rbx],ymm0,ymm11"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpmaskmovd YMMWORD PTR [rbx+rax],ymm0,ymm13"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpmaskmovd YMMWORD PTR [rbx+rax*2],ymm0,ymm15"
|
||||
.ifb \Fallthrough\()
|
||||
jmp .LExitKernel
|
||||
.endif
|
||||
|
||||
.endm
|
||||
403
onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Common.h
Normal file
403
onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Common.h
Normal file
|
|
@ -0,0 +1,403 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
QgemmU8X8KernelAvx512Common.h
|
||||
|
||||
Abstract:
|
||||
|
||||
This module contains common kernel macros and structures for the quantized
|
||||
integer matrix/matrix multiply operation (QGEMM) for the AVX512BW and
|
||||
AVX512VNNI kernels.
|
||||
|
||||
--*/
|
||||
|
||||
//
|
||||
// Stack frame layout for the U8S8 and U8U8 kernels.
|
||||
//
|
||||
|
||||
.equ .LGemmU8X8KernelFrame_SavedR14, 0
|
||||
.equ .LGemmU8X8KernelFrame_SavedR13, 8
|
||||
.equ .LGemmU8X8KernelFrame_SavedR12, 16
|
||||
.equ .LGemmU8X8KernelFrame_SavedRbx, 24
|
||||
.equ .LGemmU8X8KernelFrame_SavedRbp, 32
|
||||
.equ .LGemmU8X8KernelFrame_ReturnAddress, 40
|
||||
.equ .LGemmU8X8KernelFrame_ldc, 48
|
||||
.equ .LGemmU8X8KernelFrame_RowSumVector, 56
|
||||
.equ .LGemmU8X8KernelFrame_ColumnSumVector, 64
|
||||
.equ .LGemmU8X8KernelFrame_DepthValue, 72
|
||||
.equ .LGemmU8X8KernelFrame_ZeroMode, 80
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to produce an output block for a set of columns
|
||||
and rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
ColumnCount - Supplies the number of columns to produce.
|
||||
|
||||
RowCount - Supplies the number of rows to produce.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rax - Supplies the length in bytes of a row from matrix C.
|
||||
|
||||
rdi - Supplies the address into the matrix A data.
|
||||
|
||||
rsi - Supplies the address into the matrix B data.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r12 - Supplies the address of the row sum vector.
|
||||
|
||||
r13 - Supplies the address of the column sum vector.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ProduceOutputBlock ColumnCount, RowCount
|
||||
|
||||
//
|
||||
// Initialize the accumulators with the sum of the global depth value constant,
|
||||
// the column sums, and the row sums.
|
||||
//
|
||||
|
||||
vpbroadcastd zmm3,DWORD PTR .LGemmU8X8KernelFrame_DepthValue[rsp]
|
||||
.if \ColumnCount\() >= 32
|
||||
.if \ColumnCount\() >= 48
|
||||
vpaddd zmm2,zmm3,ZMMWORD PTR [r13]
|
||||
vpaddd zmm1,zmm3,ZMMWORD PTR [r13+64]
|
||||
vpaddd zmm0,zmm3,ZMMWORD PTR [r13+128]
|
||||
.else
|
||||
vpaddd zmm1,zmm3,ZMMWORD PTR [r13]
|
||||
vpaddd zmm0,zmm3,ZMMWORD PTR [r13+64]
|
||||
.endif
|
||||
add_immed r13,\ColumnCount\()*4 # advance ColumnSumVector by N columns
|
||||
.else
|
||||
vpaddd zmm0,zmm3,ZMMWORD PTR [r13]
|
||||
.endif
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd zmm14,zmm0,DWORD PTR [r12]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd zmm20,zmm1,DWORD PTR [r12]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "vpaddd zmm26,zmm2,DWORD PTR [r12]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd zmm15,zmm0,DWORD PTR [r12+4]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "vpaddd zmm21,zmm1,DWORD PTR [r12+4]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "vpaddd zmm27,zmm2,DWORD PTR [r12+4]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd zmm16,zmm0,DWORD PTR [r12+8]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "vpaddd zmm22,zmm1,DWORD PTR [r12+8]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "vpaddd zmm28,zmm2,DWORD PTR [r12+8]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd zmm17,zmm0,DWORD PTR [r12+12]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "vpaddd zmm23,zmm1,DWORD PTR [r12+12]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "vpaddd zmm29,zmm2,DWORD PTR [r12+12]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd zmm18,zmm0,DWORD PTR [r12+16]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "vpaddd zmm24,zmm1,DWORD PTR [r12+16]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "vpaddd zmm30,zmm2,DWORD PTR [r12+16]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd zmm19,zmm0,DWORD PTR [r12+20]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "vpaddd zmm25,zmm1,DWORD PTR [r12+20]{1to16}"
|
||||
EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "vpaddd zmm31,zmm2,DWORD PTR [r12+20]{1to16}"
|
||||
|
||||
//
|
||||
// Iterate over the length of a matrix A row to produce the output accumulators.
|
||||
//
|
||||
|
||||
.if \RowCount\() > 3
|
||||
lea rbx,[rcx*2+rcx]
|
||||
add rbx,rdi # compute matrix A plus 3 rows
|
||||
.endif
|
||||
ComputeBlockLoop \ColumnCount\(), \RowCount\()
|
||||
.if \RowCount\() > 3
|
||||
lea rbx,[rdx+rax*2] # compute matrix C plus 3 rows
|
||||
add rbx,rax
|
||||
.endif
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates code to compute matrix multiplication for a fixed set
|
||||
of rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
RowCount - Supplies the number of rows to process.
|
||||
|
||||
Implicit Arguments:
|
||||
|
||||
rax - Supplies the length in bytes of a row from matrix C.
|
||||
|
||||
rdi - Supplies the address of matrix A.
|
||||
|
||||
rsi - Supplies the address of matrix B.
|
||||
|
||||
rdx - Supplies the address of matrix C.
|
||||
|
||||
r11 - Supplies the address of matrix A.
|
||||
|
||||
r9 - Supplies the number of columns from matrix B and matrix C to iterate
|
||||
over.
|
||||
|
||||
rcx - Supplies the length in bytes of a row from matrix A.
|
||||
|
||||
r10b - Supplies the zero mode flag.
|
||||
|
||||
r12 - Supplies the address of the row sum vector.
|
||||
|
||||
r13 - Supplies the address of the column sum vector.
|
||||
|
||||
r14 - Supplies the stride in bytes of between packed blocks of matrix B.
|
||||
|
||||
--*/
|
||||
|
||||
.macro ProcessCountM RowCount
|
||||
|
||||
cmp r9,32
|
||||
ja .LProcessNextColumnLoop48xN\@
|
||||
cmp r9,16
|
||||
jbe .LProcessRemainingCountN\@
|
||||
|
||||
.LProcessNextColumnLoop32xN\@:
|
||||
ProduceOutputBlock 32, \RowCount\()
|
||||
add rsi,r14 # advance matrix B by packed block stride
|
||||
|
||||
.LOutput32xNBlock\@:
|
||||
test r10b,r10b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutput32xNBlock\@
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd zmm20,zmm20,ZMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd zmm21,zmm21,ZMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd zmm22,zmm22,ZMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd zmm23,zmm23,ZMMWORD PTR [rbx]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd zmm24,zmm24,ZMMWORD PTR [rbx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd zmm25,zmm25,ZMMWORD PTR [rbx+rax*2]"
|
||||
|
||||
.LSkipAccumulateOutput32xNBlock\@:
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx],zmm20"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax],zmm21"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2],zmm22"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx],zmm23"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax],zmm24"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm25"
|
||||
add rdx,16*4 # advance matrix C by 16 columns
|
||||
.if \RowCount\() > 3
|
||||
add rbx,16*4 # advance matrix C plus 3 rows by 16 columns
|
||||
.endif
|
||||
sub r9,16
|
||||
|
||||
.LOutput16xNBlock\@:
|
||||
sub r9,16
|
||||
jae .LOutput16xNBlockWithMask\@
|
||||
lea rcx,[r9+16] # correct for over-subtract above
|
||||
mov ebp,1
|
||||
shl ebp,cl
|
||||
dec ebp
|
||||
kmovw k1,ebp # update mask for remaining columns
|
||||
xor r9,r9 # no more columns remaining
|
||||
|
||||
.LOutput16xNBlockWithMask\@:
|
||||
test r10b,r10b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutput16xNBlockWithMask\@
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd zmm14{k1},zmm14,ZMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd zmm15{k1},zmm15,ZMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd zmm16{k1},zmm16,ZMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd zmm17{k1},zmm17,ZMMWORD PTR [rbx]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd zmm18{k1},zmm18,ZMMWORD PTR [rbx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd zmm19{k1},zmm19,ZMMWORD PTR [rbx+rax*2]"
|
||||
|
||||
.LSkipAccumulateOutput16xNBlockWithMask\@:
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx]{k1},zmm14"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax]{k1},zmm15"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2]{k1},zmm16"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx]{k1},zmm17"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax]{k1},zmm18"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2]{k1},zmm19"
|
||||
add rdx,16*4 # advance matrix C by 16 columns
|
||||
mov rdi,r11 # reload matrix A
|
||||
cmp r9,32
|
||||
ja .LProcessNextColumnLoop48xN\@
|
||||
cmp r9,16
|
||||
ja .LProcessNextColumnLoop32xN\@
|
||||
test r9,r9
|
||||
jz .LExitKernel
|
||||
|
||||
.LProcessRemainingCountN\@:
|
||||
ProduceOutputBlock 16, \RowCount\()
|
||||
jmp .LOutput16xNBlock\@
|
||||
|
||||
.LProcessNextColumnLoop48xN\@:
|
||||
ProduceOutputBlock 48, \RowCount\()
|
||||
lea rsi,[rsi+r14*2] # advance matrix B by packed block stride
|
||||
test r10b,r10b # ZeroMode?
|
||||
jnz .LSkipAccumulateOutput48xNBlock\@
|
||||
EmitIfCountGE \RowCount\(), 1, "vpaddd zmm26,zmm26,ZMMWORD PTR [rdx]"
|
||||
EmitIfCountGE \RowCount\(), 2, "vpaddd zmm27,zmm27,ZMMWORD PTR [rdx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 3, "vpaddd zmm28,zmm28,ZMMWORD PTR [rdx+rax*2]"
|
||||
EmitIfCountGE \RowCount\(), 4, "vpaddd zmm29,zmm29,ZMMWORD PTR [rbx]"
|
||||
EmitIfCountGE \RowCount\(), 5, "vpaddd zmm30,zmm30,ZMMWORD PTR [rbx+rax]"
|
||||
EmitIfCountGE \RowCount\(), 6, "vpaddd zmm31,zmm31,ZMMWORD PTR [rbx+rax*2]"
|
||||
|
||||
.LSkipAccumulateOutput48xNBlock\@:
|
||||
EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx],zmm26"
|
||||
EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax],zmm27"
|
||||
EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2],zmm28"
|
||||
EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [rbx],zmm29"
|
||||
EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [rbx+rax],zmm30"
|
||||
EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [rbx+rax*2],zmm31"
|
||||
add rdx,16*4 # advance matrix C by 16 columns
|
||||
.if \RowCount\() > 3
|
||||
add rbx,16*4 # advance matrix C plus 3 rows by 16 columns
|
||||
.endif
|
||||
sub r9,16
|
||||
jmp .LOutput32xNBlock\@
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro generates the common AVX512 code for the inner kernel to compute
|
||||
matrix multiplication.
|
||||
|
||||
Arguments:
|
||||
|
||||
Type - Supplies the kernel type string for function tags.
|
||||
|
||||
Isa - Supplies the instruction set architecture string for function tags.
|
||||
|
||||
--*/
|
||||
|
||||
.macro GemmU8X8KernelAvx512Function Type, Isa
|
||||
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine is an inner kernel to compute matrix multiplication for a
|
||||
set of rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
A (rdi) - Supplies the address of matrix A. The matrix data has been packed
|
||||
using MlasGemmU8X8CopyPackAAvx2.
|
||||
|
||||
B (rsi) - Supplies the address of matrix B. The matrix data has been packed
|
||||
using MlasGemmU8X8CopyPackBAvx2.
|
||||
|
||||
C (rdx) - Supplies the address of matrix C.
|
||||
|
||||
PairedCountK (rcx) - Supplies the number of paired columns from matrix A and
|
||||
the number of paired rows from matrix B to iterate over.
|
||||
|
||||
CountM (r8) - Supplies the maximum number of rows that can be processed for
|
||||
matrix A and matrix C. The actual number of rows handled for this
|
||||
invocation depends on the kernel implementation.
|
||||
|
||||
CountN (r9) - Supplies the number of columns from matrix B and matrix C to
|
||||
iterate over.
|
||||
|
||||
ldc - Supplies the first dimension of matrix C.
|
||||
|
||||
RowSumVector - Supplies the sum of each row from matrix A multiplied by the
|
||||
zero point offset of matrix B. These values are accumulated into every
|
||||
row of matrix C.
|
||||
|
||||
ColumnSumVector - Supplies the sum of each column from matrix B multiplied
|
||||
by the zero point offset of matrix A. These values are accumulated into
|
||||
every column of matrix C.
|
||||
|
||||
DepthValue - Supplies the value CountK multiplied by the zero point offset
|
||||
of matrixA multplied by the zero point offset of matrix B. This value is
|
||||
accumulated into every element of matrix C.
|
||||
|
||||
ZeroMode - Supplies true if the output matrix must be zero initialized,
|
||||
else false if the output matrix is accumulated into.
|
||||
|
||||
Return Value:
|
||||
|
||||
Returns the number of rows handled.
|
||||
|
||||
--*/
|
||||
|
||||
.globl C_UNDERSCORE(MlasGemm\Type\()Kernel\Isa\())
|
||||
C_UNDERSCORE(MlasGemm\Type\()Kernel\Isa\()):
|
||||
|
||||
push rbp
|
||||
push rbx
|
||||
push r12
|
||||
push r13
|
||||
push r14
|
||||
|
||||
mov rax,.LGemmU8X8KernelFrame_ldc[rsp]
|
||||
shl rax,2 # convert ldc to bytes
|
||||
shl rcx,2 # convert to row length
|
||||
movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp]
|
||||
mov r11,rdi
|
||||
mov r12,.LGemmU8X8KernelFrame_RowSumVector[rsp]
|
||||
mov r13,.LGemmU8X8KernelFrame_ColumnSumVector[rsp]
|
||||
mov ebp,-1
|
||||
kmovw k1,ebp # update mask to write all columns
|
||||
.ifeqs "\Type\()", "U8S8"
|
||||
.ifeqs "\Isa\()", "Avx512BW"
|
||||
neg ebp
|
||||
vpbroadcastw zmm5,ebp # generate 512-bit word vector [0x0001]
|
||||
.endif
|
||||
mov r14,rcx
|
||||
shl r14,4 # compute matrix B packed stride
|
||||
.else
|
||||
lea r14,[rcx*8] # compute matrix B packed stride
|
||||
.endif
|
||||
|
||||
//
|
||||
// Process CountM rows of the matrices.
|
||||
//
|
||||
|
||||
cmp r8,5
|
||||
ja .LProcessCountM6
|
||||
je .LProcessCountM5
|
||||
cmp r8,3
|
||||
ja .LProcessCountM4
|
||||
je .LProcessCountM3
|
||||
cmp r8,1
|
||||
je .LProcessCountM1
|
||||
|
||||
.LProcessCountM2:
|
||||
ProcessCountM 2
|
||||
|
||||
.LProcessCountM4:
|
||||
ProcessCountM 4
|
||||
|
||||
.LProcessCountM6:
|
||||
mov r8d,6 # return 6 rows handled
|
||||
ProcessCountM 6
|
||||
|
||||
//
|
||||
// Restore non-volatile registers and return.
|
||||
//
|
||||
|
||||
.LExitKernel:
|
||||
mov eax,r8d
|
||||
vzeroupper
|
||||
|
||||
pop r14
|
||||
pop r13
|
||||
pop r12
|
||||
pop rbx
|
||||
pop rbp
|
||||
ret
|
||||
|
||||
.LProcessCountM1:
|
||||
ProcessCountM 1
|
||||
|
||||
.LProcessCountM3:
|
||||
ProcessCountM 3
|
||||
|
||||
.LProcessCountM5:
|
||||
ProcessCountM 5
|
||||
|
||||
.endm
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/framework/op_kernel_context_internal.h"
|
||||
#include "core/providers/cpu/math/matmul_integer.h"
|
||||
#include "core/providers/cpu/math/matmul_helper.h"
|
||||
#include "core/util/qmath.h"
|
||||
|
|
@ -35,6 +36,9 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
|
|||
|
||||
template <>
|
||||
Status MatMulInteger<uint8_t, uint8_t>::Compute(OpKernelContext* ctx) const {
|
||||
auto ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
|
||||
concurrency::ThreadPool* thread_pool = ctx_internal->GetOperatorThreadPool();
|
||||
|
||||
auto a = ctx->Input<Tensor>(0);
|
||||
auto b = ctx->Input<Tensor>(1);
|
||||
ORT_ENFORCE(a != nullptr && b != nullptr);
|
||||
|
|
@ -71,13 +75,16 @@ Status MatMulInteger<uint8_t, uint8_t>::Compute(OpKernelContext* ctx) const {
|
|||
b_offset,
|
||||
y->template MutableData<int32_t>() + helper.OutputOffsets()[i],
|
||||
static_cast<int>(helper.N()),
|
||||
nullptr);
|
||||
thread_pool);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <>
|
||||
Status MatMulInteger<uint8_t, int8_t>::Compute(OpKernelContext* ctx) const {
|
||||
auto ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
|
||||
concurrency::ThreadPool* thread_pool = ctx_internal->GetOperatorThreadPool();
|
||||
|
||||
auto a = ctx->Input<Tensor>(0);
|
||||
auto b = ctx->Input<Tensor>(1);
|
||||
ORT_ENFORCE(a != nullptr && b != nullptr);
|
||||
|
|
@ -107,15 +114,19 @@ Status MatMulInteger<uint8_t, int8_t>::Compute(OpKernelContext* ctx) const {
|
|||
}
|
||||
}
|
||||
|
||||
// NOTE: Eigen based implementation is a reference implementation for accuracy only
|
||||
for (int i = 0; i < static_cast<int>(helper.OutputOffsets().size()); i++) {
|
||||
EigenCastGEMM<uint8_t, int8_t, int32_t>(
|
||||
a->template Data<uint8_t>() + helper.LeftOffsets()[i],
|
||||
b->template Data<int8_t>() + helper.RightOffsets()[i],
|
||||
y->template MutableData<int32_t>() + helper.OutputOffsets()[i],
|
||||
static_cast<int>(helper.M()),
|
||||
static_cast<int>(helper.N()),
|
||||
static_cast<int>(helper.K()));
|
||||
QGemmu8s8_s32(static_cast<int>(helper.M()),
|
||||
static_cast<int>(helper.N()),
|
||||
static_cast<int>(helper.K()),
|
||||
a->template Data<uint8_t>() + helper.LeftOffsets()[i],
|
||||
static_cast<int>(helper.K()),
|
||||
0,
|
||||
b->template Data<int8_t>() + helper.RightOffsets()[i],
|
||||
static_cast<int>(helper.N()),
|
||||
0,
|
||||
y->template MutableData<int32_t>() + helper.OutputOffsets()[i],
|
||||
static_cast<int>(helper.N()),
|
||||
thread_pool);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,9 +3,50 @@
|
|||
|
||||
#include "core/util/qmath.h"
|
||||
#include "core/common/common.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
|
||||
#if defined(_M_AMD64) || defined(__x86_64__) || defined(_M_IX86) || defined(__i386__)
|
||||
#define MLAS_SUPPORTS_GEMM_U8X8
|
||||
#else
|
||||
// default to gemmlowp when building for arm devices
|
||||
#ifndef USE_GEMMLOWP
|
||||
#define USE_GEMMLOWP
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef USE_GEMMLOWP
|
||||
#include "core/util/gemmlowp_common.h"
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
void QGemmu8s8_s32(
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
const uint8_t* lhs_data,
|
||||
int lda,
|
||||
const uint8_t lhs_offset,
|
||||
const int8_t* rhs_data,
|
||||
int ldb,
|
||||
const int8_t rhs_offset,
|
||||
int32_t* result_data,
|
||||
int ldc,
|
||||
concurrency::ThreadPool* thread_pool) {
|
||||
#ifdef MLAS_SUPPORTS_GEMM_U8X8
|
||||
|
||||
MlasGemm(M, N, K, lhs_data, lda, lhs_offset, rhs_data, ldb, rhs_offset, result_data, ldc, thread_pool);
|
||||
|
||||
#else
|
||||
ORT_ENFORCE(lhs_offset == 0 && rhs_offset == 0, "For Eigen, zero point must be zero");
|
||||
ORT_ENFORCE(lda == K && ldb == N && ldc == N, "For Eigen only RowMajor*RowMajor=RowMajor format is supported");
|
||||
|
||||
EigenCastGEMM<uint8_t, int8_t, int32_t>(lhs_data, rhs_data, result_data, M, N, K);
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
void QGemmu8u8_s32(
|
||||
int M,
|
||||
int N,
|
||||
|
|
@ -21,13 +62,13 @@ void QGemmu8u8_s32(
|
|||
concurrency::ThreadPool* thread_pool) {
|
||||
#ifdef USE_GEMMLOWP
|
||||
|
||||
ORT_ENFORCE(lda == K && ldb == N && ldc == N, "For gemmlowp only RowMajor*RowMajor=RowMajor format is supported");
|
||||
ORT_ENFORCE(lda == K && ldb == N && ldc == N, "For gemmlowp only RowMajor*RowMajor=RowMajor format is supported");
|
||||
|
||||
GemmlowpMultiplyu8u8_s32(lhs_data, rhs_data, result_data, lhs_offset, rhs_offset, M, N, K, thread_pool);
|
||||
GemmlowpMultiplyu8u8_s32(lhs_data, rhs_data, result_data, lhs_offset, rhs_offset, M, N, K, thread_pool);
|
||||
|
||||
#else
|
||||
MlasQgemm(M, N, K, lhs_data, lda, lhs_offset, rhs_data, ldb, rhs_offset, result_data, ldc, thread_pool);
|
||||
MlasGemm(M, N, K, lhs_data, lda, lhs_offset, rhs_data, ldb, rhs_offset, result_data, ldc, thread_pool);
|
||||
|
||||
#endif
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,26 +1,26 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
// default to gemmlowp when building for arm devices
|
||||
#ifndef USE_GEMMLOWP
|
||||
#if defined(_M_ARM64) || defined(__aarch64__)
|
||||
#define USE_GEMMLOWP
|
||||
#endif
|
||||
#if defined(_M_ARM) || defined(__arm__)
|
||||
#define USE_GEMMLOWP
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef USE_GEMMLOWP
|
||||
#include "core/util/gemmlowp_common.h"
|
||||
#else
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
#endif
|
||||
#include "core/platform/threadpool.h"
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
void QGemmu8s8_s32(
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
const uint8_t* lhs_data,
|
||||
int lda,
|
||||
const uint8_t lhs_offset,
|
||||
const int8_t* rhs_data,
|
||||
int ldb,
|
||||
const int8_t rhs_offset,
|
||||
int32_t* result_data,
|
||||
int ldc,
|
||||
concurrency::ThreadPool* thread_pool);
|
||||
|
||||
void QGemmu8u8_s32(
|
||||
int M,
|
||||
int N,
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ Abstract:
|
|||
#endif
|
||||
|
||||
#if defined(_M_IX86) || defined(__i386__) || defined(_M_AMD64) || defined(__x86_64__)
|
||||
#define MLAS_HAS_QGEMM_U8U8
|
||||
#define MLAS_HAS_QGEMM_U8X8
|
||||
#endif
|
||||
|
||||
MLAS_THREADPOOL* threadpool = nullptr;
|
||||
|
|
@ -452,9 +452,10 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
#ifdef MLAS_HAS_QGEMM_U8U8
|
||||
#ifdef MLAS_HAS_QGEMM_U8X8
|
||||
|
||||
class MlasQgemmU8U8Test : public MlasTestBase
|
||||
template <typename xint8_t>
|
||||
class MlasQgemmU8X8Test : public MlasTestBase
|
||||
{
|
||||
private:
|
||||
void
|
||||
|
|
@ -467,11 +468,11 @@ private:
|
|||
)
|
||||
{
|
||||
const uint8_t* A = BufferA.GetBuffer(K * M);
|
||||
const uint8_t* B = BufferB.GetBuffer(N * K);
|
||||
const xint8_t* B = BufferB.GetBuffer(N * K);
|
||||
int32_t* C = BufferC.GetBuffer(N * M);
|
||||
int32_t* CReference = BufferCReference.GetBuffer(N * M);
|
||||
|
||||
Test(M, N, K, A, K, offa, B, N, offb, C, CReference, N);
|
||||
Test(M, N, K, A, K, offa, B, N, xint8_t(offb), C, CReference, N);
|
||||
}
|
||||
|
||||
void
|
||||
|
|
@ -482,9 +483,9 @@ private:
|
|||
const uint8_t* A,
|
||||
size_t lda,
|
||||
uint8_t offa,
|
||||
const uint8_t* B,
|
||||
const xint8_t* B,
|
||||
size_t ldb,
|
||||
uint8_t offb,
|
||||
xint8_t offb,
|
||||
int32_t* C,
|
||||
int32_t* CReference,
|
||||
size_t ldc
|
||||
|
|
@ -493,7 +494,7 @@ private:
|
|||
std::fill_n(C, M * N, -1);
|
||||
std::fill_n(CReference, M * N, -1);
|
||||
|
||||
MlasQgemm(M, N, K, A, lda, offa, B, ldb, offb, C, ldc, threadpool);
|
||||
MlasGemm(M, N, K, A, lda, offa, B, ldb, offb, C, ldc, threadpool);
|
||||
ReferenceQgemm(M, N, K, A, lda, offa, B, ldb, offb, CReference, ldc);
|
||||
|
||||
for (size_t f = 0; f < M * N; f++) {
|
||||
|
|
@ -511,9 +512,9 @@ private:
|
|||
const uint8_t* A,
|
||||
size_t lda,
|
||||
uint8_t offa,
|
||||
const uint8_t* B,
|
||||
const xint8_t* B,
|
||||
size_t ldb,
|
||||
uint8_t offb,
|
||||
xint8_t offb,
|
||||
int32_t* C,
|
||||
size_t ldc
|
||||
)
|
||||
|
|
@ -523,7 +524,7 @@ private:
|
|||
for (size_t n = 0; n < N; n++) {
|
||||
|
||||
const uint8_t* a = A + (m * lda);
|
||||
const uint8_t* b = B + n;
|
||||
const xint8_t* b = B + n;
|
||||
int32_t* c = C + (m * ldc) + n;
|
||||
int32_t sum = 0;
|
||||
|
||||
|
|
@ -539,7 +540,7 @@ private:
|
|||
}
|
||||
|
||||
MatrixGuardBuffer<uint8_t> BufferA;
|
||||
MatrixGuardBuffer<uint8_t> BufferB;
|
||||
MatrixGuardBuffer<xint8_t> BufferB;
|
||||
MatrixGuardBuffer<int32_t> BufferC;
|
||||
MatrixGuardBuffer<int32_t> BufferCReference;
|
||||
|
||||
|
|
@ -565,7 +566,7 @@ public:
|
|||
void
|
||||
) override
|
||||
{
|
||||
static const uint8_t zero_points[] = { 0, 18, 128, 157, 231, 255 };
|
||||
static const uint8_t zero_points[] = { 0, 18, 75, 128, 157, 231, 255 };
|
||||
|
||||
for (size_t a = 0; a < _countof(zero_points); a++) {
|
||||
uint8_t offa = zero_points[a];
|
||||
|
|
@ -1970,9 +1971,10 @@ main(
|
|||
printf("SGEMM tests.\n");
|
||||
std::make_unique<MlasSgemmTest>()->ExecuteShort();
|
||||
|
||||
#ifdef MLAS_HAS_QGEMM_U8U8
|
||||
#ifdef MLAS_HAS_QGEMM_U8X8
|
||||
printf("QGEMM tests.\n");
|
||||
std::make_unique<MlasQgemmU8U8Test>()->ExecuteShort();
|
||||
std::make_unique<MlasQgemmU8X8Test<int8_t>>()->ExecuteShort();
|
||||
std::make_unique<MlasQgemmU8X8Test<uint8_t>>()->ExecuteShort();
|
||||
#endif
|
||||
|
||||
printf("Conv2D tests.\n");
|
||||
|
|
|
|||
Loading…
Reference in a new issue