mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
MLAS: add sgemv path for aarch64 builds (#4254)
Implement a fast path for GEMMs where M=1 and TransB=CblasNoTrans.
This commit is contained in:
parent
5da849b414
commit
5d773ee57b
6 changed files with 406 additions and 11 deletions
|
|
@ -141,6 +141,7 @@ else()
|
|||
enable_language(ASM)
|
||||
set(mlas_platform_srcs
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/aarch64/SgemmKernelNeon.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/aarch64/SgemvKernelNeon.S
|
||||
)
|
||||
elseif(POWER)
|
||||
set(mlas_platform_srcs
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.\
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
|
|
@ -15,11 +15,9 @@ Abstract:
|
|||
|
||||
--*/
|
||||
|
||||
|
||||
#include "asmmacro.h"
|
||||
|
||||
.text
|
||||
.p2align 2
|
||||
|
||||
//
|
||||
// ClearRowAccumulators
|
||||
|
|
@ -420,7 +418,7 @@ Arguments:
|
|||
|
||||
ldc (x7) - Supplies the first dimension of matrix C.
|
||||
|
||||
Alpha (s0) - Supplies the scaler multiplier (see SGEMM definition).
|
||||
Alpha (s0) - Supplies the scalar multiplier (see SGEMM definition).
|
||||
|
||||
Return Value:
|
||||
|
||||
|
|
@ -430,11 +428,7 @@ Return Value:
|
|||
|
||||
.macro SgemmKernelNeonFunction Mode
|
||||
|
||||
.globl C_UNDERSCORE(MlasSgemmKernel\Mode\())
|
||||
#ifndef __APPLE__
|
||||
.type C_UNDERSCORE(MlasSgemmKernel\Mode\()),%function
|
||||
#endif
|
||||
C_UNDERSCORE(MlasSgemmKernel\Mode\()):
|
||||
FUNCTION_ENTRY MlasSgemmKernel\Mode\()
|
||||
|
||||
stp d8,d9,[sp,#-32]!
|
||||
stp d10,d11,[sp,#16]
|
||||
|
|
|
|||
303
onnxruntime/core/mlas/lib/aarch64/SgemvKernelNeon.S
Normal file
303
onnxruntime/core/mlas/lib/aarch64/SgemvKernelNeon.S
Normal file
|
|
@ -0,0 +1,303 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
SgemvKernelNeon.s
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements the kernels for the single precision matrix/vector
|
||||
multiply operation (SGEMV).
|
||||
|
||||
--*/
|
||||
|
||||
#include "asmmacro.h"
|
||||
|
||||
.text
|
||||
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine is an inner kernel to compute matrix multiplication for a
|
||||
set of rows. This handles the special case of M=1.
|
||||
|
||||
The elements in matrix B are not transposed.
|
||||
|
||||
Arguments:
|
||||
|
||||
A (x0) - Supplies the address of matrix A.
|
||||
|
||||
B (x1) - Supplies the address of matrix B.
|
||||
|
||||
C (x2) - Supplies the address of matrix C.
|
||||
|
||||
CountK (x3) - Supplies the number of columns from matrix A and the number
|
||||
of rows from matrix B to iterate over.
|
||||
|
||||
CountN (x4) - Supplies the number of columns from matrix B and matrix C to
|
||||
iterate over.
|
||||
|
||||
ldb (x5) - Supplies the first dimension of matrix B.
|
||||
|
||||
ZeroMode (x6) - Supplies true if the output matrix must be zero initialized,
|
||||
else false if the output matrix is accumulated into.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
|
||||
FUNCTION_ENTRY MlasGemvFloatKernel
|
||||
|
||||
cmp x4,#64
|
||||
blo .LSgemvN.ProcessRemainingCountN
|
||||
mov x14,x0 // preserve vector A
|
||||
|
||||
//
|
||||
// Process 64 columns at a time in a loop.
|
||||
//
|
||||
|
||||
.LSgemvN.ProcessColumnLoopBy64:
|
||||
ldr q4,[x1]
|
||||
add x15,x1,#256 // compute next matrix B
|
||||
ldr q5,[x1,#16]
|
||||
tst w6,0xFF // ZeroMode?
|
||||
mov x13,x3 // reload CountK
|
||||
ldr q6,[x1,#32]
|
||||
beq .LSgemvN.LoadOutputBy64
|
||||
movi v16.4s,#0
|
||||
movi v17.4s,#0
|
||||
movi v18.4s,#0
|
||||
movi v19.4s,#0
|
||||
movi v20.4s,#0
|
||||
movi v21.4s,#0
|
||||
movi v22.4s,#0
|
||||
movi v23.4s,#0
|
||||
movi v24.4s,#0
|
||||
movi v25.4s,#0
|
||||
movi v26.4s,#0
|
||||
movi v27.4s,#0
|
||||
movi v28.4s,#0
|
||||
movi v29.4s,#0
|
||||
movi v30.4s,#0
|
||||
movi v31.4s,#0
|
||||
b .LSgemvN.MultiplyAccumulateBy64
|
||||
|
||||
.LSgemvN.LoadOutputBy64:
|
||||
ldp q16,q17,[x2]
|
||||
ldp q18,q19,[x2,#32]
|
||||
ldp q20,q21,[x2,#64]
|
||||
ldp q22,q23,[x2,#96]
|
||||
ldp q24,q25,[x2,#128]
|
||||
ldp q26,q27,[x2,#160]
|
||||
ldp q28,q29,[x2,#192]
|
||||
ldp q30,q31,[x2,#224]
|
||||
|
||||
.LSgemvN.MultiplyAccumulateBy64:
|
||||
ld1r {v0.4s},[x0] // broadcast next vector A element
|
||||
add x0,x0,4 // advance vector A by 1 element
|
||||
sub x13,x13,#1 // decrement K remaining
|
||||
fmla v16.4s,v4.4s,v0.4s
|
||||
ldr q7,[x1,#48]
|
||||
fmla v17.4s,v5.4s,v0.4s
|
||||
ldr q4,[x1,#64]
|
||||
fmla v18.4s,v6.4s,v0.4s
|
||||
ldr q5,[x1,#80]
|
||||
fmla v19.4s,v7.4s,v0.4s
|
||||
ldr q6,[x1,#96]
|
||||
fmla v20.4s,v4.4s,v0.4s
|
||||
ldr q7,[x1,#112]
|
||||
fmla v21.4s,v5.4s,v0.4s
|
||||
ldr q4,[x1,#128]
|
||||
fmla v22.4s,v6.4s,v0.4s
|
||||
ldr q5,[x1,#144]
|
||||
fmla v23.4s,v7.4s,v0.4s
|
||||
ldr q6,[x1,#160]
|
||||
fmla v24.4s,v4.4s,v0.4s
|
||||
ldr q7,[x1,#176]
|
||||
fmla v25.4s,v5.4s,v0.4s
|
||||
ldr q4,[x1,#192]
|
||||
fmla v26.4s,v6.4s,v0.4s
|
||||
ldr q5,[x1,#208]
|
||||
fmla v27.4s,v7.4s,v0.4s
|
||||
ldr q6,[x1,#224]
|
||||
fmla v28.4s,v4.4s,v0.4s
|
||||
ldr q7,[x1,#240]
|
||||
add x1,x1,x5,lsl #2 // compute next matrix B row address
|
||||
cbz x13,.LSgemvN.StoreOutputBy64
|
||||
ldr q4,[x1] // load data for the next iteration
|
||||
fmla v29.4s,v5.4s,v0.4s
|
||||
ldr q5,[x1,#16]
|
||||
fmla v30.4s,v6.4s,v0.4s
|
||||
ldr q6,[x1,#32]
|
||||
fmla v31.4s,v7.4s,v0.4s
|
||||
b .LSgemvN.MultiplyAccumulateBy64
|
||||
|
||||
.LSgemvN.StoreOutputBy64:
|
||||
stp q16,q17,[x2]
|
||||
fmla v29.4s,v5.4s,v0.4s // finish computing the tail vectors
|
||||
stp q18,q19,[x2,#32]
|
||||
fmla v30.4s,v6.4s,v0.4s
|
||||
stp q20,q21,[x2,#64]
|
||||
fmla v31.4s,v7.4s,v0.4s
|
||||
stp q22,q23,[x2,#96]
|
||||
sub x4,x4,#64 // subtract 64 columns
|
||||
stp q24,q25,[x2,#128]
|
||||
mov x0,x14 // reload vector A
|
||||
stp q26,q27,[x2,#160]
|
||||
mov x1,x15 // load next matrix B
|
||||
stp q28,q29,[x2,#192]
|
||||
stp q30,q31,[x2,#224]
|
||||
add x2,x2,#256 // advance vector C by 64 columns
|
||||
cbz x4,.LSgemvN.ExitKernel
|
||||
cmp x4,#64
|
||||
bhs .LSgemvN.ProcessColumnLoopBy64
|
||||
|
||||
//
|
||||
// Process the remaining 1 to 63 columns.
|
||||
//
|
||||
|
||||
.LSgemvN.ProcessRemainingCountN:
|
||||
tst w6,0xFF // ZeroMode?
|
||||
beq .LSgemvN.LoadOutputPartial32
|
||||
movi v16.4s,#0
|
||||
movi v17.4s,#0
|
||||
movi v18.4s,#0
|
||||
movi v19.4s,#0
|
||||
movi v20.4s,#0
|
||||
movi v21.4s,#0
|
||||
movi v22.4s,#0
|
||||
movi v23.4s,#0
|
||||
movi v24.4s,#0
|
||||
movi v25.4s,#0
|
||||
movi v26.4s,#0
|
||||
movi v27.4s,#0
|
||||
movi v28.4s,#0
|
||||
movi v29.4s,#0
|
||||
movi v30.4s,#0
|
||||
movi v31.4s,#0 // trailing float[2]
|
||||
movi v1.4s,#0 // trailing float[1]
|
||||
b .LSgemvN.ProcessNextPartialRow
|
||||
|
||||
.LSgemvN.LoadOutputPartial32:
|
||||
mov x15,x2
|
||||
tbz x4,#5,.LSgemvN.LoadOutputPartial16
|
||||
ldp q16,q17,[x15],#128
|
||||
ldp q18,q19,[x15,#-96]
|
||||
ldp q20,q21,[x15,#-64]
|
||||
ldp q22,q23,[x15,#-32]
|
||||
|
||||
.LSgemvN.LoadOutputPartial16:
|
||||
tbz x4,#4,.LSgemvN.LoadOutputPartial8
|
||||
ldp q24,q25,[x15],#64
|
||||
ldp q26,q27,[x15,#-32]
|
||||
|
||||
.LSgemvN.LoadOutputPartial8:
|
||||
tbz x4,#3,.LSgemvN.LoadOutputPartial4
|
||||
ldp q28,q29,[x15],#32
|
||||
|
||||
.LSgemvN.LoadOutputPartial4:
|
||||
tbz x4,#2,.LSgemvN.LoadOutputPartial2
|
||||
ldr q30,[x15],#16
|
||||
|
||||
.LSgemvN.LoadOutputPartial2:
|
||||
tbz x4,#1,.LSgemvN.LoadOutputPartial1
|
||||
ldr d31,[x15],#8
|
||||
|
||||
.LSgemvN.LoadOutputPartial1:
|
||||
tbz x4,#0,.LSgemvN.ProcessNextPartialRow
|
||||
ldr s1,[x15]
|
||||
|
||||
.LSgemvN.ProcessNextPartialRow:
|
||||
ld1r {v0.4s},[x0]
|
||||
add x0,x0,4
|
||||
sub x3,x3,#1 // decrement K remaining
|
||||
mov x15,x1
|
||||
|
||||
.LSgemvN.MultiplyAccumulatePartial32:
|
||||
tbz x4,#5,.LSgemvN.MultiplyAccumulatePartial16
|
||||
ldp q4,q5,[x15],#128
|
||||
fmla v16.4s,v4.4s,v0.4s
|
||||
ldp q6,q7,[x15,#-96]
|
||||
fmla v17.4s,v5.4s,v0.4s
|
||||
ldp q4,q5,[x15,#-64]
|
||||
fmla v18.4s,v6.4s,v0.4s
|
||||
fmla v19.4s,v7.4s,v0.4s
|
||||
ldp q6,q7,[x15,#-32]
|
||||
fmla v20.4s,v4.4s,v0.4s
|
||||
fmla v21.4s,v5.4s,v0.4s
|
||||
fmla v22.4s,v6.4s,v0.4s
|
||||
fmla v23.4s,v7.4s,v0.4s
|
||||
|
||||
.LSgemvN.MultiplyAccumulatePartial16:
|
||||
tbz x4,#4,.LSgemvN.MultiplyAccumulatePartial8
|
||||
ldp q4,q5,[x15],#64
|
||||
fmla v24.4s,v4.4s,v0.4s
|
||||
ldp q6,q7,[x15,#-32]
|
||||
fmla v25.4s,v5.4s,v0.4s
|
||||
fmla v26.4s,v6.4s,v0.4s
|
||||
fmla v27.4s,v7.4s,v0.4s
|
||||
|
||||
.LSgemvN.MultiplyAccumulatePartial8:
|
||||
tbz x4,#3,.LSgemvN.MultiplyAccumulatePartial4
|
||||
ldp q4,q5,[x15],#32
|
||||
fmla v28.4s,v4.4s,v0.4s
|
||||
fmla v29.4s,v5.4s,v0.4s
|
||||
|
||||
.LSgemvN.MultiplyAccumulatePartial4:
|
||||
tbz x4,#2,.LSgemvN.MultiplyAccumulatePartial2
|
||||
ldr q4,[x15],#16
|
||||
fmla v30.4s,v4.4s,v0.4s
|
||||
|
||||
.LSgemvN.MultiplyAccumulatePartial2:
|
||||
tbz x4,#1,.LSgemvN.MultiplyAccumulatePartial1
|
||||
ldr d4,[x15],#8
|
||||
fmla v31.4s,v4.4s,v0.4s
|
||||
|
||||
.LSgemvN.MultiplyAccumulatePartial1:
|
||||
tbz x4,#0,.LSgemvN.AdvancePartialRow
|
||||
ldr s4,[x15]
|
||||
fmla v1.4s,v4.4s,v0.4s
|
||||
|
||||
.LSgemvN.AdvancePartialRow:
|
||||
add x1,x1,x5,lsl #2 // compute next matrix B row address
|
||||
cbnz x3,.LSgemvN.ProcessNextPartialRow
|
||||
|
||||
.LSgemvN.StoreOutputPartial32:
|
||||
tbz x4,#5,.LSgemvN.StoreOutputPartial16
|
||||
stp q16,q17,[x2],#128
|
||||
stp q18,q19,[x2,#-96]
|
||||
stp q20,q21,[x2,#-64]
|
||||
stp q22,q23,[x2,#-32]
|
||||
|
||||
.LSgemvN.StoreOutputPartial16:
|
||||
tbz x4,#4,.LSgemvN.StoreOutputPartial8
|
||||
stp q24,q25,[x2],#64
|
||||
stp q26,q27,[x2,#-32]
|
||||
|
||||
.LSgemvN.StoreOutputPartial8:
|
||||
tbz x4,#3,.LSgemvN.StoreOutputPartial4
|
||||
stp q28,q29,[x2],#32
|
||||
|
||||
.LSgemvN.StoreOutputPartial4:
|
||||
tbz x4,#2,.LSgemvN.StoreOutputPartial2
|
||||
str q30,[x2],#16
|
||||
|
||||
.LSgemvN.StoreOutputPartial2:
|
||||
tbz x4,#1,.LSgemvN.StoreOutputPartial1
|
||||
str d31,[x2],#8
|
||||
|
||||
.LSgemvN.StoreOutputPartial1:
|
||||
tbz x4,#0,.LSgemvN.ExitKernel
|
||||
str s1,[x2]
|
||||
|
||||
.LSgemvN.ExitKernel:
|
||||
ret
|
||||
|
||||
.end
|
||||
|
|
@ -14,8 +14,82 @@ Abstract:
|
|||
|
||||
--*/
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro emits the assembler directives to annotate a new function.
|
||||
|
||||
Arguments:
|
||||
|
||||
FunctionName - Supplies the name of the function.
|
||||
|
||||
--*/
|
||||
|
||||
.macro FUNCTION_ENTRY FunctionName
|
||||
|
||||
.p2align 2
|
||||
#if defined(__APPLE__)
|
||||
#define C_UNDERSCORE(symbol) _##symbol
|
||||
.globl _\FunctionName\()
|
||||
_\FunctionName\():
|
||||
#else
|
||||
#define C_UNDERSCORE(symbol) symbol
|
||||
.globl \FunctionName\()
|
||||
.type \FunctionName\(),%function
|
||||
\FunctionName\():
|
||||
#endif
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro conditionally emits the statement if Count is greater than or
|
||||
equal to Value.
|
||||
|
||||
Arguments:
|
||||
|
||||
Count - Supplies the variable used in the comparison.
|
||||
|
||||
Value - Supplies the static used in the comparison.
|
||||
|
||||
Statement - Supplies the statement to conditionally emit.
|
||||
|
||||
--*/
|
||||
|
||||
.macro EmitIfCountGE Count1, Value1, Statement
|
||||
|
||||
.if (\Count1\() >= \Value1\())
|
||||
\Statement\()
|
||||
.endif
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
||||
Macro Description:
|
||||
|
||||
This macro conditionally emits the statement if Count1 is greater than or
|
||||
equal to Value1 and Count2 is greater than or equal to Value2.
|
||||
|
||||
Arguments:
|
||||
|
||||
Count1 - Supplies the variable used in the comparison.
|
||||
|
||||
Value1 - Supplies the static used in the comparison.
|
||||
|
||||
Count2 - Supplies the variable used in the comparison.
|
||||
|
||||
Value2 - Supplies the static used in the comparison.
|
||||
|
||||
Statement - Supplies the statement to conditionally emit.
|
||||
|
||||
--*/
|
||||
|
||||
.macro EmitIfCount2GE Count1, Value1, Count2, Value2, Statement
|
||||
|
||||
.if (\Count1\() >= \Value1\()) && (\Count2\() >= \Value2\())
|
||||
\Statement\()
|
||||
.endif
|
||||
|
||||
.endm
|
||||
|
|
|
|||
|
|
@ -227,6 +227,20 @@ typedef MLAS_GEMM_FLOAT_KERNEL* PMLAS_GEMM_FLOAT_KERNEL;
|
|||
|
||||
typedef MLAS_GEMM_DOUBLE_KERNEL* PMLAS_GEMM_DOUBLE_KERNEL;
|
||||
|
||||
typedef
|
||||
size_t
|
||||
(MLASCALL MLAS_GEMV_FLOAT_KERNEL)(
|
||||
const float* A,
|
||||
const float* B,
|
||||
float* C,
|
||||
size_t CountK,
|
||||
size_t CountN,
|
||||
size_t ldb,
|
||||
bool ZeroMode
|
||||
);
|
||||
|
||||
typedef MLAS_GEMV_FLOAT_KERNEL* PMLAS_GEMV_FLOAT_KERNEL;
|
||||
|
||||
typedef
|
||||
void
|
||||
(MLASCALL MLAS_SGEMM_KERNEL_M1_ROUTINE)(
|
||||
|
|
@ -473,6 +487,8 @@ extern "C" {
|
|||
#if defined(MLAS_TARGET_AMD64)
|
||||
MLAS_SGEMM_KERNEL_M1_ROUTINE MlasSgemmKernelM1Avx;
|
||||
MLAS_SGEMM_KERNEL_M1_ROUTINE MlasSgemmKernelM1TransposeBAvx;
|
||||
#elif defined(MLAS_TARGET_ARM64)
|
||||
MLAS_GEMV_FLOAT_KERNEL MlasGemvFloatKernel;
|
||||
#endif
|
||||
|
||||
#if defined(MLAS_TARGET_AMD64)
|
||||
|
|
|
|||
|
|
@ -915,6 +915,13 @@ Return Value:
|
|||
return;
|
||||
}
|
||||
|
||||
#elif defined(MLAS_TARGET_ARM64) && !defined(_WIN32)
|
||||
|
||||
if (TransB == CblasNoTrans) {
|
||||
MlasGemvFloatKernel(A, B, C, K, N, ldb, (beta == 0.0f));
|
||||
return;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue