MLAS: add sgemv path for aarch64 builds (#4254)

Implement a fast path for GEMMs where M=1 and TransB=CblasNoTrans.
This commit is contained in:
Tracy Sharpe 2020-06-17 20:10:35 -07:00 committed by GitHub
parent 5da849b414
commit 5d773ee57b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 406 additions and 11 deletions

View file

@ -141,6 +141,7 @@ else()
enable_language(ASM)
set(mlas_platform_srcs
${ONNXRUNTIME_ROOT}/core/mlas/lib/aarch64/SgemmKernelNeon.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/aarch64/SgemvKernelNeon.S
)
elseif(POWER)
set(mlas_platform_srcs

View file

@ -1,6 +1,6 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.\
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
@ -15,11 +15,9 @@ Abstract:
--*/
#include "asmmacro.h"
.text
.p2align 2
//
// ClearRowAccumulators
@ -420,7 +418,7 @@ Arguments:
ldc (x7) - Supplies the first dimension of matrix C.
Alpha (s0) - Supplies the scaler multiplier (see SGEMM definition).
Alpha (s0) - Supplies the scalar multiplier (see SGEMM definition).
Return Value:
@ -430,11 +428,7 @@ Return Value:
.macro SgemmKernelNeonFunction Mode
.globl C_UNDERSCORE(MlasSgemmKernel\Mode\())
#ifndef __APPLE__
.type C_UNDERSCORE(MlasSgemmKernel\Mode\()),%function
#endif
C_UNDERSCORE(MlasSgemmKernel\Mode\()):
FUNCTION_ENTRY MlasSgemmKernel\Mode\()
stp d8,d9,[sp,#-32]!
stp d10,d11,[sp,#16]

View file

@ -0,0 +1,303 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
SgemvKernelNeon.s
Abstract:
This module implements the kernels for the single precision matrix/vector
multiply operation (SGEMV).
--*/
#include "asmmacro.h"
.text
/*++
Routine Description:
This routine is an inner kernel to compute matrix multiplication for a
set of rows. This handles the special case of M=1.
The elements in matrix B are not transposed.
Arguments:
A (x0) - Supplies the address of matrix A.
B (x1) - Supplies the address of matrix B.
C (x2) - Supplies the address of matrix C.
CountK (x3) - Supplies the number of columns from matrix A and the number
of rows from matrix B to iterate over.
CountN (x4) - Supplies the number of columns from matrix B and matrix C to
iterate over.
ldb (x5) - Supplies the first dimension of matrix B.
ZeroMode (x6) - Supplies true if the output matrix must be zero initialized,
else false if the output matrix is accumulated into.
Return Value:
None.
--*/
FUNCTION_ENTRY MlasGemvFloatKernel
cmp x4,#64
blo .LSgemvN.ProcessRemainingCountN
mov x14,x0 // preserve vector A
//
// Process 64 columns at a time in a loop.
//
.LSgemvN.ProcessColumnLoopBy64:
ldr q4,[x1]
add x15,x1,#256 // compute next matrix B
ldr q5,[x1,#16]
tst w6,0xFF // ZeroMode?
mov x13,x3 // reload CountK
ldr q6,[x1,#32]
beq .LSgemvN.LoadOutputBy64
movi v16.4s,#0
movi v17.4s,#0
movi v18.4s,#0
movi v19.4s,#0
movi v20.4s,#0
movi v21.4s,#0
movi v22.4s,#0
movi v23.4s,#0
movi v24.4s,#0
movi v25.4s,#0
movi v26.4s,#0
movi v27.4s,#0
movi v28.4s,#0
movi v29.4s,#0
movi v30.4s,#0
movi v31.4s,#0
b .LSgemvN.MultiplyAccumulateBy64
.LSgemvN.LoadOutputBy64:
ldp q16,q17,[x2]
ldp q18,q19,[x2,#32]
ldp q20,q21,[x2,#64]
ldp q22,q23,[x2,#96]
ldp q24,q25,[x2,#128]
ldp q26,q27,[x2,#160]
ldp q28,q29,[x2,#192]
ldp q30,q31,[x2,#224]
.LSgemvN.MultiplyAccumulateBy64:
ld1r {v0.4s},[x0] // broadcast next vector A element
add x0,x0,4 // advance vector A by 1 element
sub x13,x13,#1 // decrement K remaining
fmla v16.4s,v4.4s,v0.4s
ldr q7,[x1,#48]
fmla v17.4s,v5.4s,v0.4s
ldr q4,[x1,#64]
fmla v18.4s,v6.4s,v0.4s
ldr q5,[x1,#80]
fmla v19.4s,v7.4s,v0.4s
ldr q6,[x1,#96]
fmla v20.4s,v4.4s,v0.4s
ldr q7,[x1,#112]
fmla v21.4s,v5.4s,v0.4s
ldr q4,[x1,#128]
fmla v22.4s,v6.4s,v0.4s
ldr q5,[x1,#144]
fmla v23.4s,v7.4s,v0.4s
ldr q6,[x1,#160]
fmla v24.4s,v4.4s,v0.4s
ldr q7,[x1,#176]
fmla v25.4s,v5.4s,v0.4s
ldr q4,[x1,#192]
fmla v26.4s,v6.4s,v0.4s
ldr q5,[x1,#208]
fmla v27.4s,v7.4s,v0.4s
ldr q6,[x1,#224]
fmla v28.4s,v4.4s,v0.4s
ldr q7,[x1,#240]
add x1,x1,x5,lsl #2 // compute next matrix B row address
cbz x13,.LSgemvN.StoreOutputBy64
ldr q4,[x1] // load data for the next iteration
fmla v29.4s,v5.4s,v0.4s
ldr q5,[x1,#16]
fmla v30.4s,v6.4s,v0.4s
ldr q6,[x1,#32]
fmla v31.4s,v7.4s,v0.4s
b .LSgemvN.MultiplyAccumulateBy64
.LSgemvN.StoreOutputBy64:
stp q16,q17,[x2]
fmla v29.4s,v5.4s,v0.4s // finish computing the tail vectors
stp q18,q19,[x2,#32]
fmla v30.4s,v6.4s,v0.4s
stp q20,q21,[x2,#64]
fmla v31.4s,v7.4s,v0.4s
stp q22,q23,[x2,#96]
sub x4,x4,#64 // subtract 64 columns
stp q24,q25,[x2,#128]
mov x0,x14 // reload vector A
stp q26,q27,[x2,#160]
mov x1,x15 // load next matrix B
stp q28,q29,[x2,#192]
stp q30,q31,[x2,#224]
add x2,x2,#256 // advance vector C by 64 columns
cbz x4,.LSgemvN.ExitKernel
cmp x4,#64
bhs .LSgemvN.ProcessColumnLoopBy64
//
// Process the remaining 1 to 63 columns.
//
.LSgemvN.ProcessRemainingCountN:
tst w6,0xFF // ZeroMode?
beq .LSgemvN.LoadOutputPartial32
movi v16.4s,#0
movi v17.4s,#0
movi v18.4s,#0
movi v19.4s,#0
movi v20.4s,#0
movi v21.4s,#0
movi v22.4s,#0
movi v23.4s,#0
movi v24.4s,#0
movi v25.4s,#0
movi v26.4s,#0
movi v27.4s,#0
movi v28.4s,#0
movi v29.4s,#0
movi v30.4s,#0
movi v31.4s,#0 // trailing float[2]
movi v1.4s,#0 // trailing float[1]
b .LSgemvN.ProcessNextPartialRow
.LSgemvN.LoadOutputPartial32:
mov x15,x2
tbz x4,#5,.LSgemvN.LoadOutputPartial16
ldp q16,q17,[x15],#128
ldp q18,q19,[x15,#-96]
ldp q20,q21,[x15,#-64]
ldp q22,q23,[x15,#-32]
.LSgemvN.LoadOutputPartial16:
tbz x4,#4,.LSgemvN.LoadOutputPartial8
ldp q24,q25,[x15],#64
ldp q26,q27,[x15,#-32]
.LSgemvN.LoadOutputPartial8:
tbz x4,#3,.LSgemvN.LoadOutputPartial4
ldp q28,q29,[x15],#32
.LSgemvN.LoadOutputPartial4:
tbz x4,#2,.LSgemvN.LoadOutputPartial2
ldr q30,[x15],#16
.LSgemvN.LoadOutputPartial2:
tbz x4,#1,.LSgemvN.LoadOutputPartial1
ldr d31,[x15],#8
.LSgemvN.LoadOutputPartial1:
tbz x4,#0,.LSgemvN.ProcessNextPartialRow
ldr s1,[x15]
.LSgemvN.ProcessNextPartialRow:
ld1r {v0.4s},[x0]
add x0,x0,4
sub x3,x3,#1 // decrement K remaining
mov x15,x1
.LSgemvN.MultiplyAccumulatePartial32:
tbz x4,#5,.LSgemvN.MultiplyAccumulatePartial16
ldp q4,q5,[x15],#128
fmla v16.4s,v4.4s,v0.4s
ldp q6,q7,[x15,#-96]
fmla v17.4s,v5.4s,v0.4s
ldp q4,q5,[x15,#-64]
fmla v18.4s,v6.4s,v0.4s
fmla v19.4s,v7.4s,v0.4s
ldp q6,q7,[x15,#-32]
fmla v20.4s,v4.4s,v0.4s
fmla v21.4s,v5.4s,v0.4s
fmla v22.4s,v6.4s,v0.4s
fmla v23.4s,v7.4s,v0.4s
.LSgemvN.MultiplyAccumulatePartial16:
tbz x4,#4,.LSgemvN.MultiplyAccumulatePartial8
ldp q4,q5,[x15],#64
fmla v24.4s,v4.4s,v0.4s
ldp q6,q7,[x15,#-32]
fmla v25.4s,v5.4s,v0.4s
fmla v26.4s,v6.4s,v0.4s
fmla v27.4s,v7.4s,v0.4s
.LSgemvN.MultiplyAccumulatePartial8:
tbz x4,#3,.LSgemvN.MultiplyAccumulatePartial4
ldp q4,q5,[x15],#32
fmla v28.4s,v4.4s,v0.4s
fmla v29.4s,v5.4s,v0.4s
.LSgemvN.MultiplyAccumulatePartial4:
tbz x4,#2,.LSgemvN.MultiplyAccumulatePartial2
ldr q4,[x15],#16
fmla v30.4s,v4.4s,v0.4s
.LSgemvN.MultiplyAccumulatePartial2:
tbz x4,#1,.LSgemvN.MultiplyAccumulatePartial1
ldr d4,[x15],#8
fmla v31.4s,v4.4s,v0.4s
.LSgemvN.MultiplyAccumulatePartial1:
tbz x4,#0,.LSgemvN.AdvancePartialRow
ldr s4,[x15]
fmla v1.4s,v4.4s,v0.4s
.LSgemvN.AdvancePartialRow:
add x1,x1,x5,lsl #2 // compute next matrix B row address
cbnz x3,.LSgemvN.ProcessNextPartialRow
.LSgemvN.StoreOutputPartial32:
tbz x4,#5,.LSgemvN.StoreOutputPartial16
stp q16,q17,[x2],#128
stp q18,q19,[x2,#-96]
stp q20,q21,[x2,#-64]
stp q22,q23,[x2,#-32]
.LSgemvN.StoreOutputPartial16:
tbz x4,#4,.LSgemvN.StoreOutputPartial8
stp q24,q25,[x2],#64
stp q26,q27,[x2,#-32]
.LSgemvN.StoreOutputPartial8:
tbz x4,#3,.LSgemvN.StoreOutputPartial4
stp q28,q29,[x2],#32
.LSgemvN.StoreOutputPartial4:
tbz x4,#2,.LSgemvN.StoreOutputPartial2
str q30,[x2],#16
.LSgemvN.StoreOutputPartial2:
tbz x4,#1,.LSgemvN.StoreOutputPartial1
str d31,[x2],#8
.LSgemvN.StoreOutputPartial1:
tbz x4,#0,.LSgemvN.ExitKernel
str s1,[x2]
.LSgemvN.ExitKernel:
ret
.end

View file

@ -14,8 +14,82 @@ Abstract:
--*/
/*++
Macro Description:
This macro emits the assembler directives to annotate a new function.
Arguments:
FunctionName - Supplies the name of the function.
--*/
.macro FUNCTION_ENTRY FunctionName
.p2align 2
#if defined(__APPLE__)
#define C_UNDERSCORE(symbol) _##symbol
.globl _\FunctionName\()
_\FunctionName\():
#else
#define C_UNDERSCORE(symbol) symbol
.globl \FunctionName\()
.type \FunctionName\(),%function
\FunctionName\():
#endif
.endm
/*++
Macro Description:
This macro conditionally emits the statement if Count is greater than or
equal to Value.
Arguments:
Count - Supplies the variable used in the comparison.
Value - Supplies the static used in the comparison.
Statement - Supplies the statement to conditionally emit.
--*/
.macro EmitIfCountGE Count1, Value1, Statement
.if (\Count1\() >= \Value1\())
\Statement\()
.endif
.endm
/*++
Macro Description:
This macro conditionally emits the statement if Count1 is greater than or
equal to Value1 and Count2 is greater than or equal to Value2.
Arguments:
Count1 - Supplies the variable used in the comparison.
Value1 - Supplies the static used in the comparison.
Count2 - Supplies the variable used in the comparison.
Value2 - Supplies the static used in the comparison.
Statement - Supplies the statement to conditionally emit.
--*/
.macro EmitIfCount2GE Count1, Value1, Count2, Value2, Statement
.if (\Count1\() >= \Value1\()) && (\Count2\() >= \Value2\())
\Statement\()
.endif
.endm

View file

@ -227,6 +227,20 @@ typedef MLAS_GEMM_FLOAT_KERNEL* PMLAS_GEMM_FLOAT_KERNEL;
typedef MLAS_GEMM_DOUBLE_KERNEL* PMLAS_GEMM_DOUBLE_KERNEL;
typedef
size_t
(MLASCALL MLAS_GEMV_FLOAT_KERNEL)(
const float* A,
const float* B,
float* C,
size_t CountK,
size_t CountN,
size_t ldb,
bool ZeroMode
);
typedef MLAS_GEMV_FLOAT_KERNEL* PMLAS_GEMV_FLOAT_KERNEL;
typedef
void
(MLASCALL MLAS_SGEMM_KERNEL_M1_ROUTINE)(
@ -473,6 +487,8 @@ extern "C" {
#if defined(MLAS_TARGET_AMD64)
MLAS_SGEMM_KERNEL_M1_ROUTINE MlasSgemmKernelM1Avx;
MLAS_SGEMM_KERNEL_M1_ROUTINE MlasSgemmKernelM1TransposeBAvx;
#elif defined(MLAS_TARGET_ARM64)
MLAS_GEMV_FLOAT_KERNEL MlasGemvFloatKernel;
#endif
#if defined(MLAS_TARGET_AMD64)

View file

@ -915,6 +915,13 @@ Return Value:
return;
}
#elif defined(MLAS_TARGET_ARM64) && !defined(_WIN32)
if (TransB == CblasNoTrans) {
MlasGemvFloatKernel(A, B, C, K, N, ldb, (beta == 0.0f));
return;
}
#endif
}