mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-10 00:38:54 +00:00
POWER10: Add optimized dgemm kernel (#9652)
* POWER10: Add optimized dgemm kernel This patch makes use of POWER10 matrix multiply assist feature and adds new DGEMM kernel. * Indentation update Co-authored-by: Rajalakshmi Srinivasaraghavan <rajis@linux.ibm.com>
This commit is contained in:
parent
bf5e9a5044
commit
8564fc1933
8 changed files with 799 additions and 599 deletions
|
|
@ -291,6 +291,7 @@ else()
|
|||
${MLAS_SRC_DIR}/dgemm.cpp
|
||||
${MLAS_SRC_DIR}/power/DgemmKernelPower.cpp
|
||||
)
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/power/SgemmKernelPower.cpp PROPERTIES COMPILE_FLAGS "-DSINGLE")
|
||||
check_cxx_compiler_flag("-mcpu=power10" HAS_POWER10)
|
||||
if(HAS_POWER10)
|
||||
set(CMAKE_REQUIRED_FLAGS "-mcpu=power10")
|
||||
|
|
@ -318,8 +319,10 @@ else()
|
|||
endif()
|
||||
set(mlas_platform_srcs_power10
|
||||
${MLAS_SRC_DIR}/power/SgemmKernelPOWER10.cpp
|
||||
${MLAS_SRC_DIR}/power/DgemmKernelPOWER10.cpp
|
||||
)
|
||||
set_source_files_properties(${mlas_platform_srcs_power10} PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10")
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/power/SgemmKernelPOWER10.cpp PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10 -DSINGLE")
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/power/DgemmKernelPOWER10.cpp PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10")
|
||||
set(mlas_platform_srcs
|
||||
${mlas_platform_srcs}
|
||||
${mlas_platform_srcs_power10}
|
||||
|
|
|
|||
|
|
@ -499,6 +499,7 @@ extern "C" {
|
|||
MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernel;
|
||||
MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelPOWER10;
|
||||
MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernel;
|
||||
MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelPOWER10;
|
||||
#else
|
||||
MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelZero;
|
||||
MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelAdd;
|
||||
|
|
@ -1886,7 +1887,7 @@ MlasStoreAlignedFloat64x2(double* Buffer, MLAS_FLOAT64X2 Vector)
|
|||
#if defined(MLAS_SSE2_INTRINSICS)
|
||||
_mm_store_pd(Buffer, Vector);
|
||||
#elif defined(MLAS_VSX_INTRINSICS)
|
||||
vec_st(Vector, 0, Buffer);
|
||||
*((MLAS_FLOAT64X2*)Buffer) = Vector;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -379,6 +379,7 @@ Return Value:
|
|||
bool HasP10Instructions = ((hwcap2 & PPC_FEATURE2_MMA) && (hwcap2 & PPC_FEATURE2_ARCH_3_1));
|
||||
if (HasP10Instructions) {
|
||||
this->GemmFloatKernel = MlasSgemmKernelPOWER10;
|
||||
this->GemmDoubleKernel = MlasDgemmKernelPOWER10;
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
|
|
|||
418
onnxruntime/core/mlas/lib/power/DgemmKernelPOWER10.cpp
Normal file
418
onnxruntime/core/mlas/lib/power/DgemmKernelPOWER10.cpp
Normal file
|
|
@ -0,0 +1,418 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
DgemmKernelPower.cpp
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements the kernels for the double precision matrix/matrix
|
||||
multiply operation (DGEMM).
|
||||
|
||||
--*/
|
||||
|
||||
#include "DgemmKernelpower.h"
|
||||
struct MlasDgemmBroadcastAElementsMMA
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
double ARow[RowCount],
|
||||
const double* A,
|
||||
size_t lda
|
||||
)
|
||||
{
|
||||
ARow[Row] = A [Row * lda];
|
||||
}
|
||||
};
|
||||
|
||||
template<size_t RowCount>
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
MlasDgemmComputeAElements(
|
||||
MLAS_FLOAT64X2 AElements[RowCount],
|
||||
MLAS_FLOAT64X2 ABroadcast[RowCount]
|
||||
)
|
||||
{
|
||||
ABroadcast[0] = vec_mergee (AElements[0], AElements[1]);
|
||||
ABroadcast[1] = vec_mergee (AElements[2], AElements[3]);
|
||||
ABroadcast[2] = vec_mergeo (AElements[0], AElements[1]);
|
||||
ABroadcast[3] = vec_mergeo (AElements[2], AElements[3]);
|
||||
}
|
||||
|
||||
template<size_t RowCount>
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
MlasDgemmComputeBlockMMA(
|
||||
__vector_quad acc[8],
|
||||
MLAS_FLOAT64X2 ABroadcast[RowCount],
|
||||
MLAS_FLOAT64X2 A2Broadcast[RowCount],
|
||||
const double* B,
|
||||
size_t CountM
|
||||
)
|
||||
{
|
||||
MLAS_FLOAT64X2 BElements[4];
|
||||
typedef __vector unsigned char vec_t;
|
||||
__vector_pair A2pair, Apair;
|
||||
#if (defined(__GNUC__) && (__GNUC__ == 10 && __GNUC_MINOR__ <= 2))
|
||||
__builtin_mma_assemble_pair (&Apair, (vec_t)ABroadcast[1], (vec_t)ABroadcast[0]);
|
||||
if (CountM == 8) {
|
||||
__builtin_mma_assemble_pair (&A2pair, (vec_t)A2Broadcast[1], (vec_t)A2Broadcast[0]);
|
||||
}
|
||||
#elif (defined(__GNUC__) && (__GNUC__ == 11 && __GNUC_MINOR__ <= 2))
|
||||
Apair = *((__vector_pair *)((void *)&ABroadcast[0]));
|
||||
if (CountM == 8) {
|
||||
A2pair = *((__vector_pair *)((void *)&A2Broadcast[0]));
|
||||
}
|
||||
#else
|
||||
__builtin_vsx_assemble_pair (&Apair, (vec_t)ABroadcast[1], (vec_t)ABroadcast[0]);
|
||||
if (CountM == 8) {
|
||||
__builtin_vsx_assemble_pair (&A2pair, (vec_t)A2Broadcast[1], (vec_t)A2Broadcast[0]);
|
||||
}
|
||||
#endif
|
||||
BElements[0] = MlasLoadFloat64x2(B);
|
||||
BElements[1] = MlasLoadFloat64x2(B + 2);
|
||||
BElements[2] = MlasLoadFloat64x2(B + 4);
|
||||
BElements[3] = MlasLoadFloat64x2(B + 6);
|
||||
__builtin_mma_xvf64gerpp (&acc[0], Apair, (vec_t)BElements[0]);
|
||||
__builtin_mma_xvf64gerpp (&acc[1], Apair, (vec_t)BElements[1]);
|
||||
__builtin_mma_xvf64gerpp (&acc[2], Apair, (vec_t)BElements[2]);
|
||||
__builtin_mma_xvf64gerpp (&acc[3], Apair, (vec_t)BElements[3]);
|
||||
if (CountM == 8) {
|
||||
__builtin_mma_xvf64gerpp (&acc[4], A2pair, (vec_t)BElements[0]);
|
||||
__builtin_mma_xvf64gerpp (&acc[5], A2pair, (vec_t)BElements[1]);
|
||||
__builtin_mma_xvf64gerpp (&acc[6], A2pair, (vec_t)BElements[2]);
|
||||
__builtin_mma_xvf64gerpp (&acc[7], A2pair, (vec_t)BElements[3]);
|
||||
}
|
||||
}
|
||||
template<size_t VectorCount>
|
||||
struct MlasDgemmStoreVectorMMA
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT64X2 Result[4],
|
||||
double* C,
|
||||
size_t ldc,
|
||||
MLAS_FLOAT64X2 AlphaBroadcast,
|
||||
bool ZeroMode
|
||||
)
|
||||
{
|
||||
MLAS_FLOAT64X2 *rowC;
|
||||
if (ZeroMode) {
|
||||
rowC = (MLAS_FLOAT64X2 *) &C[Row * ldc + VectorCount];
|
||||
rowC[0] = Result[Row] * AlphaBroadcast;
|
||||
} else {
|
||||
rowC = (MLAS_FLOAT64X2 *) &C[Row * ldc + VectorCount];
|
||||
rowC[0] += Result[Row] * AlphaBroadcast;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct MlasDgemmMultiplyAlphaTrailingMMA
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT64X2 Accumulators[RowCount],
|
||||
MLAS_FLOAT64X2 AlphaBroadcast
|
||||
)
|
||||
{
|
||||
Accumulators[Row] = MlasMultiplyFloat64x2(Accumulators[Row], AlphaBroadcast);
|
||||
}
|
||||
};
|
||||
template<unsigned Lane>
|
||||
struct MlasDgemmStoreScalarMMA
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT64X2 Accumulators[RowCount],
|
||||
double* C,
|
||||
size_t ldc,
|
||||
bool ZeroMode
|
||||
)
|
||||
{
|
||||
double* c = C + Row * ldc + Lane;
|
||||
double Value = Accumulators[Row][Lane];
|
||||
if (!ZeroMode) {
|
||||
Value += *c;
|
||||
}
|
||||
|
||||
*c = Value;
|
||||
}
|
||||
};
|
||||
|
||||
template<size_t RowCount>
|
||||
MLAS_FORCEINLINE
|
||||
size_t
|
||||
MlasDgemmMMAProcessCount(
|
||||
const double* A,
|
||||
const double* B,
|
||||
double* C,
|
||||
size_t CountM,
|
||||
size_t CountK,
|
||||
size_t CountN,
|
||||
size_t lda,
|
||||
size_t ldc,
|
||||
MLAS_FLOAT64X2 AlphaBroadcast,
|
||||
bool ZeroMode
|
||||
)
|
||||
{
|
||||
do {
|
||||
|
||||
const double* a = A;
|
||||
size_t k = CountK;
|
||||
|
||||
MLAS_FLOAT64X2 Accumulators[2][RowCount] = {{ 0 }};
|
||||
MLAS_FLOAT64X2 Result[RowCount];
|
||||
MLAS_FLOAT64X2 AElements[RowCount];
|
||||
MLAS_FLOAT64X2 ABroadcast[RowCount] = { 0 };
|
||||
MLAS_FLOAT64X2 A2Broadcast[RowCount] = { 0 };
|
||||
MLAS_FLOAT64X2 A3Broadcast[RowCount] = { 0 };
|
||||
MLAS_FLOAT64X2 A4Broadcast[RowCount] = { 0 };
|
||||
double ARow[RowCount] = { 0 };
|
||||
double A2Row[RowCount] = { 0 };
|
||||
__vector_quad acc[8];
|
||||
|
||||
//
|
||||
// Clear the block accumulators.
|
||||
//
|
||||
__builtin_mma_xxsetaccz(&acc[0]);
|
||||
__builtin_mma_xxsetaccz(&acc[1]);
|
||||
__builtin_mma_xxsetaccz(&acc[2]);
|
||||
__builtin_mma_xxsetaccz(&acc[3]);
|
||||
__builtin_mma_xxsetaccz(&acc[4]);
|
||||
__builtin_mma_xxsetaccz(&acc[5]);
|
||||
__builtin_mma_xxsetaccz(&acc[6]);
|
||||
__builtin_mma_xxsetaccz(&acc[7]);
|
||||
|
||||
//
|
||||
// Compute the output block.
|
||||
//
|
||||
while (k >= 4) {
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasFgemmLoadAElements>()(AElements, a, lda);
|
||||
MlasDgemmComputeAElements<RowCount>(AElements, ABroadcast);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmLoadAElements>()(AElements, a+2, lda);
|
||||
MlasDgemmComputeAElements<RowCount>(AElements, A3Broadcast);
|
||||
if (CountM == 8) {
|
||||
MlasLoopUnroll<RowCount, MlasFgemmLoadAElements>()(AElements, a + ( lda * 4), lda);
|
||||
MlasDgemmComputeAElements<RowCount>(AElements, A2Broadcast);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmLoadAElements>()(AElements, (a+2) + ( lda * 4), lda);
|
||||
MlasDgemmComputeAElements<RowCount>(AElements, A4Broadcast);
|
||||
}
|
||||
MlasDgemmComputeBlockMMA<RowCount>(&acc[0], &ABroadcast[0], &A2Broadcast[0], B, CountM);
|
||||
MlasDgemmComputeBlockMMA<RowCount>(&acc[0], &ABroadcast[2], &A2Broadcast[2], B+8, CountM);
|
||||
MlasDgemmComputeBlockMMA<RowCount>(&acc[0], &A3Broadcast[0], &A4Broadcast[0], B+16, CountM);
|
||||
MlasDgemmComputeBlockMMA<RowCount>(&acc[0], &A3Broadcast[2], &A4Broadcast[2], B+24, CountM);
|
||||
B += 8 * 4;
|
||||
a += 4;
|
||||
k -= 4;
|
||||
}
|
||||
while (k > 0) {
|
||||
MlasLoopUnroll<RowCount, MlasDgemmBroadcastAElementsMMA>()(ARow, a, lda);
|
||||
if (CountM == 8) {
|
||||
MlasLoopUnroll<RowCount, MlasDgemmBroadcastAElementsMMA>()(A2Row, a + (lda * 4), lda);
|
||||
}
|
||||
|
||||
MlasDgemmComputeBlockMMA<RowCount>(&acc[0], (MLAS_FLOAT64X2 *)ARow, (MLAS_FLOAT64X2 *)A2Row, B, CountM);
|
||||
a += 1;
|
||||
B += 8;
|
||||
k -= 1;
|
||||
}
|
||||
if (CountN >= 8) {
|
||||
|
||||
//
|
||||
// Store the entire output block.
|
||||
//
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[0]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<0>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[1]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<2>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[2]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<4>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[3]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<6>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
if (CountM == 8) {
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[4]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<0>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[5]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<2>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[6]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<4>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[7]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<6>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
}
|
||||
} else {
|
||||
|
||||
//
|
||||
// Store the partial output block.
|
||||
//
|
||||
|
||||
if (CountN >= 6) {
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[0]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<0>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[1]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<2>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[2]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<4>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
if (CountM == 8) {
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[4]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<0>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[5]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<2>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[6]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<4>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
if (CountN - 6 > 0) {
|
||||
__builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[7]);
|
||||
}
|
||||
}
|
||||
if (CountN - 6 > 0) {
|
||||
__builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[3]);
|
||||
}
|
||||
} else if (CountN >= 4) {
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[0]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<0>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[1]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<2>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
if (CountM == 8) {
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[4]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<0>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[5]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<2>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
if (CountN - 4 > 0) {
|
||||
__builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[6]);
|
||||
}
|
||||
}
|
||||
if (CountN - 4 > 0) {
|
||||
__builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[2]);
|
||||
}
|
||||
} else if (CountN >= 2) {
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[0]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<0>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
if (CountM == 8) {
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[4]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<0>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
if (CountN - 2 > 0) {
|
||||
__builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[5]);
|
||||
}
|
||||
}
|
||||
if (CountN - 2 > 0) {
|
||||
__builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[1]);
|
||||
}
|
||||
} else {
|
||||
__builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[0]);
|
||||
if (CountM == 8) {
|
||||
__builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[4]);
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Store the remaining unaligned columns.
|
||||
//
|
||||
C += (CountN & ~1);
|
||||
CountN &= 1;
|
||||
|
||||
if (CountN > 0) {
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasDgemmMultiplyAlphaTrailingMMA>()(Accumulators[0], AlphaBroadcast);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreScalarMMA<0>>()(Accumulators[0], C, ldc, ZeroMode);
|
||||
if (CountM == 8) {
|
||||
MlasLoopUnroll<RowCount, MlasDgemmMultiplyAlphaTrailingMMA>()(Accumulators[1], AlphaBroadcast);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreScalarMMA<0>>()(Accumulators[1], C + (ldc*4), ldc, ZeroMode);
|
||||
}
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
C += 8;
|
||||
CountN -= 8;
|
||||
|
||||
} while (CountN > 0);
|
||||
|
||||
return CountM;
|
||||
}
|
||||
|
||||
size_t
|
||||
MLASCALL
|
||||
MlasDgemmKernelPOWER10(
|
||||
const double* A,
|
||||
const double* B,
|
||||
double* C,
|
||||
size_t CountK,
|
||||
size_t CountM,
|
||||
size_t CountN,
|
||||
size_t lda,
|
||||
size_t ldc,
|
||||
double alpha,
|
||||
bool ZeroMode
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine is an inner kernel to compute matrix multiplication for a
|
||||
set of rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
A - Supplies the address of matrix A.
|
||||
|
||||
B - Supplies the address of matrix B. The matrix data has been packed using
|
||||
MlasDgemmCopyPackB or MlasDgemmTransposePackB.
|
||||
|
||||
C - Supplies the address of matrix C.
|
||||
|
||||
CountK - Supplies the number of columns from matrix A and the number of rows
|
||||
from matrix B to iterate over.
|
||||
|
||||
CountM - Supplies the maximum number of rows that can be processed for
|
||||
matrix A and matrix C. The actual number of rows handled for this
|
||||
invocation depends on the kernel implementation.
|
||||
|
||||
CountN - Supplies the number of columns from matrix B and matrix C to
|
||||
iterate over.
|
||||
|
||||
lda - Supplies the first dimension of matrix A.
|
||||
|
||||
ldc - Supplies the first dimension of matrix C.
|
||||
|
||||
alpha - Supplies the scalar multiplier (see DGEMM definition).
|
||||
|
||||
ZeroMode - Supplies true if the output matrix must be zero initialized,
|
||||
else false if the output matrix is accumulated into.
|
||||
|
||||
Return Value:
|
||||
|
||||
Returns the number of rows handled.
|
||||
|
||||
--*/
|
||||
{
|
||||
size_t RowsHandled;
|
||||
MLAS_FLOAT64X2 AlphaBroadcast = MlasBroadcastFloat64x2(alpha);
|
||||
if (CountM >= 8) {
|
||||
RowsHandled = MlasDgemmMMAProcessCount<4>(A, B, C, 8 ,CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode);
|
||||
} else if (CountM >= 4) {
|
||||
RowsHandled = MlasDgemmMMAProcessCount<4>(A, B, C, 4, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode);
|
||||
} else if (CountM >= 2) {
|
||||
RowsHandled = MlasDgemmProcessCount<2>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode);
|
||||
} else {
|
||||
RowsHandled = MlasDgemmProcessCount<1>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode);
|
||||
}
|
||||
|
||||
return RowsHandled;
|
||||
}
|
||||
|
|
@ -6,293 +6,16 @@ Licensed under the MIT License.
|
|||
|
||||
Module Name:
|
||||
|
||||
DgemmKernelPower.cpp
|
||||
DgemmKernelpower.h
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements the kernels for the single precision matrix/matrix
|
||||
This module implements the kernels for the double precision matrix/matrix
|
||||
multiply operation (DGEMM).
|
||||
|
||||
--*/
|
||||
|
||||
#include "mlasi.h"
|
||||
|
||||
//
|
||||
// Templates to ensure that a loop is unrolled.
|
||||
//
|
||||
|
||||
template<size_t Count, size_t Index>
|
||||
struct MlasLoopUnrollStep
|
||||
{
|
||||
template<typename IterationType, typename... IterationArgs>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Step(
|
||||
IterationArgs&&... Arguments
|
||||
)
|
||||
{
|
||||
IterationType::template Iteration<Count, Index>(Arguments...);
|
||||
MlasLoopUnrollStep<Count, Index + 1>::template Step<IterationType>(Arguments...);
|
||||
}
|
||||
};
|
||||
|
||||
template<size_t Count>
|
||||
struct MlasLoopUnrollStep<Count, Count>
|
||||
{
|
||||
template<typename IterationType, typename... IterationArgs>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Step(
|
||||
IterationArgs&&...
|
||||
)
|
||||
{
|
||||
// Terminate the loop.
|
||||
}
|
||||
};
|
||||
|
||||
template<size_t Count, typename IteratorType>
|
||||
struct MlasLoopUnroll
|
||||
{
|
||||
template<typename... IterationArgs>
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
operator()(
|
||||
IterationArgs&&... Arguments
|
||||
)
|
||||
{
|
||||
MlasLoopUnrollStep<Count, 0>::template Step<IteratorType>(Arguments...);
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Templates used with loop unrolling to perform an action on one row of the
|
||||
// output.
|
||||
//
|
||||
|
||||
struct MlasDgemmZeroAccumulators
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT64X2 Accumulators[RowCount][4]
|
||||
)
|
||||
{
|
||||
Accumulators[Row][0] = MlasZeroFloat64x2();
|
||||
Accumulators[Row][1] = MlasZeroFloat64x2();
|
||||
Accumulators[Row][2] = MlasZeroFloat64x2();
|
||||
Accumulators[Row][3] = MlasZeroFloat64x2();
|
||||
}
|
||||
};
|
||||
|
||||
struct MlasDgemmLoadAElements
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT64X2 AElements[RowCount],
|
||||
const double* A,
|
||||
size_t lda
|
||||
)
|
||||
{
|
||||
AElements[Row] = MlasLoadFloat64x2(A + Row * lda);
|
||||
}
|
||||
};
|
||||
|
||||
struct MlasDgemmBroadcastAElements
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT64X2 ABroadcast[RowCount],
|
||||
const double* A,
|
||||
size_t lda
|
||||
)
|
||||
{
|
||||
ABroadcast[Row] = MlasBroadcastFloat64x2(A + Row * lda);
|
||||
}
|
||||
};
|
||||
|
||||
template<unsigned Lane>
|
||||
struct MlasDgemmSplatAElements
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT64X2 AElements[RowCount],
|
||||
MLAS_FLOAT64X2 ABroadcast[RowCount]
|
||||
)
|
||||
{
|
||||
ABroadcast[Row] = vec_splat(AElements[Row], Lane);
|
||||
}
|
||||
};
|
||||
|
||||
struct MlasDgemmMultiplyAddRow
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT64X2 Accumulators[RowCount][4],
|
||||
MLAS_FLOAT64X2 ABroadcast[RowCount],
|
||||
MLAS_FLOAT64X2 BElements[4]
|
||||
)
|
||||
{
|
||||
Accumulators[Row][0] = MlasMultiplyAddFloat64x2(ABroadcast[Row], BElements[0], Accumulators[Row][0]);
|
||||
Accumulators[Row][1] = MlasMultiplyAddFloat64x2(ABroadcast[Row], BElements[1], Accumulators[Row][1]);
|
||||
Accumulators[Row][2] = MlasMultiplyAddFloat64x2(ABroadcast[Row], BElements[2], Accumulators[Row][2]);
|
||||
Accumulators[Row][3] = MlasMultiplyAddFloat64x2(ABroadcast[Row], BElements[3], Accumulators[Row][3]);
|
||||
}
|
||||
};
|
||||
|
||||
template<size_t RowCount>
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
MlasDgemmComputeBlock(
|
||||
MLAS_FLOAT64X2 Accumulators[RowCount][4],
|
||||
MLAS_FLOAT64X2 ABroadcast[RowCount],
|
||||
const double* B
|
||||
)
|
||||
{
|
||||
MLAS_FLOAT64X2 BElements[4];
|
||||
|
||||
BElements[0] = MlasLoadFloat64x2(B);
|
||||
BElements[1] = MlasLoadFloat64x2(B + 2);
|
||||
BElements[2] = MlasLoadFloat64x2(B + 4);
|
||||
BElements[3] = MlasLoadFloat64x2(B + 6);
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasDgemmMultiplyAddRow>()(Accumulators, ABroadcast, BElements);
|
||||
}
|
||||
|
||||
struct MlasDgemmMultiplyAlphaRow
|
||||
{
|
||||
template<size_t Count, size_t Index>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT64X2 Accumulators[4],
|
||||
MLAS_FLOAT64X2 AlphaBroadcast
|
||||
)
|
||||
{
|
||||
Accumulators[Index] = MlasMultiplyFloat64x2(Accumulators[Index], AlphaBroadcast);
|
||||
}
|
||||
};
|
||||
|
||||
struct MlasDgemmMultiplyAlphaAddRow
|
||||
{
|
||||
template<size_t Count, size_t Index>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT64X2 Accumulators[4],
|
||||
MLAS_FLOAT64X2 AlphaBroadcast,
|
||||
const double* C
|
||||
)
|
||||
{
|
||||
Accumulators[Index] = MlasMultiplyAddFloat64x2(Accumulators[Index],
|
||||
AlphaBroadcast, MlasLoadFloat64x2(C + Index * 2));
|
||||
}
|
||||
};
|
||||
|
||||
struct MlasDgemmStoreRow
|
||||
{
|
||||
template<size_t Count, size_t Index>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT64X2 Accumulators[4],
|
||||
double* C
|
||||
)
|
||||
{
|
||||
MlasStoreFloat64x2(C + Index * 2, Accumulators[Index]);
|
||||
}
|
||||
};
|
||||
|
||||
template<size_t VectorCount>
|
||||
struct MlasDgemmStoreVector
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT64X2 Accumulators[RowCount][4],
|
||||
double* C,
|
||||
size_t ldc,
|
||||
MLAS_FLOAT64X2 AlphaBroadcast,
|
||||
bool ZeroMode
|
||||
)
|
||||
{
|
||||
double* c = C + Row * ldc;
|
||||
if (ZeroMode) {
|
||||
MlasLoopUnroll<VectorCount, MlasDgemmMultiplyAlphaRow>()(Accumulators[Row], AlphaBroadcast);
|
||||
} else {
|
||||
MlasLoopUnroll<VectorCount, MlasDgemmMultiplyAlphaAddRow>()(Accumulators[Row], AlphaBroadcast, c);
|
||||
}
|
||||
MlasLoopUnroll<VectorCount, MlasDgemmStoreRow>()(Accumulators[Row], c);
|
||||
|
||||
//
|
||||
// Shift down any unaligned elements to the bottom for further processing.
|
||||
//
|
||||
|
||||
if (VectorCount < 4) {
|
||||
Accumulators[Row][0] = Accumulators[Row][VectorCount];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct MlasDgemmMultiplyAlphaTrailing
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT64X2 Accumulators[RowCount][4],
|
||||
MLAS_FLOAT64X2 AlphaBroadcast
|
||||
)
|
||||
{
|
||||
Accumulators[Row][0] = MlasMultiplyFloat64x2(Accumulators[Row][0], AlphaBroadcast);
|
||||
}
|
||||
};
|
||||
|
||||
template<unsigned Lane>
|
||||
struct MlasDgemmStoreScalar
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT64X2 Accumulators[RowCount][4],
|
||||
double* C,
|
||||
size_t ldc,
|
||||
bool ZeroMode
|
||||
)
|
||||
{
|
||||
double* c = C + Row * ldc + Lane;
|
||||
double Value = MlasExtractLaneFloat64x2<Lane>(Accumulators[Row][0]);
|
||||
|
||||
if (!ZeroMode) {
|
||||
Value += *c;
|
||||
}
|
||||
|
||||
*c = Value;
|
||||
}
|
||||
};
|
||||
#include "FgemmKernelpower.h"
|
||||
|
||||
template<size_t RowCount>
|
||||
MLAS_FORCEINLINE
|
||||
|
|
@ -322,20 +45,20 @@ MlasDgemmProcessCount(
|
|||
// Clear the block accumulators.
|
||||
//
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasDgemmZeroAccumulators>()(Accumulators);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmZeroAccumulators>()(Accumulators);
|
||||
|
||||
//
|
||||
// Compute the output block.
|
||||
//
|
||||
while (k >= 2) {
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasDgemmLoadAElements>()(AElements, a, lda);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmLoadAElements>()(AElements, a, lda);
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasDgemmSplatAElements<0>>()(AElements, ABroadcast);
|
||||
MlasDgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmSplatAElements<0>>()(AElements, ABroadcast);
|
||||
MlasFgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B);
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasDgemmSplatAElements<1>>()(AElements, ABroadcast);
|
||||
MlasDgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B + 8);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmSplatAElements<1>>()(AElements, ABroadcast);
|
||||
MlasFgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B + 8);
|
||||
|
||||
a += 2;
|
||||
B += 8 * 2;
|
||||
|
|
@ -343,8 +66,8 @@ MlasDgemmProcessCount(
|
|||
}
|
||||
if (k > 0) {
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasDgemmBroadcastAElements>()(ABroadcast, a, lda);
|
||||
MlasDgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmBroadcastAElements>()(ABroadcast, a, lda);
|
||||
MlasFgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B);
|
||||
|
||||
a += 1;
|
||||
B += 8;
|
||||
|
|
@ -357,7 +80,7 @@ MlasDgemmProcessCount(
|
|||
// Store the entire output block.
|
||||
//
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVector<4>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmStoreVector<4>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
|
||||
} else {
|
||||
|
||||
|
|
@ -367,11 +90,11 @@ MlasDgemmProcessCount(
|
|||
|
||||
//
|
||||
if (CountN >= 6) {
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVector<3>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmStoreVector<3>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
} else if (CountN >= 4) {
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVector<2>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmStoreVector<2>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
} else if (CountN >= 2) {
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVector<1>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmStoreVector<1>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
}
|
||||
//
|
||||
// Store the remaining unaligned columns.
|
||||
|
|
@ -381,9 +104,9 @@ MlasDgemmProcessCount(
|
|||
|
||||
if (CountN > 0) {
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasDgemmMultiplyAlphaTrailing>()(Accumulators, AlphaBroadcast);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmMultiplyAlphaTrailing>()(Accumulators, AlphaBroadcast);
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreScalar<0>>()(Accumulators, C, ldc, ZeroMode);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmStoreScalar<0>>()(Accumulators, C, ldc, ZeroMode);
|
||||
}
|
||||
|
||||
break;
|
||||
|
|
|
|||
333
onnxruntime/core/mlas/lib/power/FgemmKernelpower.h
Normal file
333
onnxruntime/core/mlas/lib/power/FgemmKernelpower.h
Normal file
|
|
@ -0,0 +1,333 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
FgemmKernelPower.h
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements the kernels for the single/double precision matrix/matrix
|
||||
multiply operation (DGEMM/SGEMM).
|
||||
|
||||
--*/
|
||||
|
||||
#include "mlasi.h"
|
||||
#if defined(SINGLE)
|
||||
#define MLAS_FLOATTYPE MLAS_FLOAT32X4
|
||||
#define MLAS_GEMMTYPE float
|
||||
#define MLAS_LOAD_FLOAT MlasLoadFloat32x4
|
||||
#define MLAS_ZERO_FLOAT MlasZeroFloat32x4
|
||||
#define MLAS_STORE_FLOAT MlasStoreFloat32x4
|
||||
#define MLAS_EXTRACT_FLOAT MlasExtractLaneFloat32x4
|
||||
#define MLAS_MUL_FLOAT MlasMultiplyFloat32x4
|
||||
#define MLAS_MULADD_FLOAT MlasMultiplyAddFloat32x4
|
||||
#define MLAS_BROADCAST_FLOAT MlasBroadcastFloat32x4
|
||||
#else
|
||||
#define MLAS_FLOATTYPE MLAS_FLOAT64X2
|
||||
#define MLAS_GEMMTYPE double
|
||||
#define MLAS_LOAD_FLOAT MlasLoadFloat64x2
|
||||
#define MLAS_ZERO_FLOAT MlasZeroFloat64x2
|
||||
#define MLAS_STORE_FLOAT MlasStoreFloat64x2
|
||||
#define MLAS_EXTRACT_FLOAT MlasExtractLaneFloat64x2
|
||||
#define MLAS_MUL_FLOAT MlasMultiplyFloat64x2
|
||||
#define MLAS_MULADD_FLOAT MlasMultiplyAddFloat64x2
|
||||
#define MLAS_BROADCAST_FLOAT MlasBroadcastFloat64x2
|
||||
#endif
|
||||
//
|
||||
// Templates to ensure that a loop is unrolled.
|
||||
//
|
||||
|
||||
template<size_t Count, size_t Index>
|
||||
struct MlasLoopUnrollStep
|
||||
{
|
||||
template<typename IterationType, typename... IterationArgs>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Step(
|
||||
IterationArgs&&... Arguments
|
||||
)
|
||||
{
|
||||
IterationType::template Iteration<Count, Index>(Arguments...);
|
||||
MlasLoopUnrollStep<Count, Index + 1>::template Step<IterationType>(Arguments...);
|
||||
}
|
||||
};
|
||||
|
||||
template<size_t Count>
|
||||
struct MlasLoopUnrollStep<Count, Count>
|
||||
{
|
||||
template<typename IterationType, typename... IterationArgs>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Step(
|
||||
IterationArgs&&...
|
||||
)
|
||||
{
|
||||
// Terminate the loop.
|
||||
}
|
||||
};
|
||||
|
||||
template<size_t Count, typename IteratorType>
|
||||
struct MlasLoopUnroll
|
||||
{
|
||||
template<typename... IterationArgs>
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
operator()(
|
||||
IterationArgs&&... Arguments
|
||||
)
|
||||
{
|
||||
MlasLoopUnrollStep<Count, 0>::template Step<IteratorType>(Arguments...);
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Templates used with loop unrolling to perform an action on one row of the
|
||||
// output.
|
||||
//
|
||||
|
||||
struct MlasFgemmZeroAccumulators
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOATTYPE Accumulators[RowCount][4]
|
||||
)
|
||||
{
|
||||
Accumulators[Row][0] = MLAS_ZERO_FLOAT();
|
||||
Accumulators[Row][1] = MLAS_ZERO_FLOAT();
|
||||
Accumulators[Row][2] = MLAS_ZERO_FLOAT();
|
||||
Accumulators[Row][3] = MLAS_ZERO_FLOAT();
|
||||
}
|
||||
};
|
||||
|
||||
struct MlasFgemmLoadAElements
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOATTYPE AElements[RowCount],
|
||||
const MLAS_GEMMTYPE* A,
|
||||
size_t lda
|
||||
)
|
||||
{
|
||||
AElements[Row] = MLAS_LOAD_FLOAT(A + Row * lda);
|
||||
}
|
||||
};
|
||||
|
||||
struct MlasFgemmBroadcastAElements
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOATTYPE ABroadcast[RowCount],
|
||||
const MLAS_GEMMTYPE* A,
|
||||
size_t lda
|
||||
)
|
||||
{
|
||||
ABroadcast[Row] = MLAS_BROADCAST_FLOAT(A + Row * lda);
|
||||
}
|
||||
};
|
||||
|
||||
template<unsigned Lane>
|
||||
struct MlasFgemmSplatAElements
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOATTYPE AElements[RowCount],
|
||||
MLAS_FLOATTYPE ABroadcast[RowCount]
|
||||
)
|
||||
{
|
||||
ABroadcast[Row] = vec_splat(AElements[Row], Lane);
|
||||
}
|
||||
};
|
||||
|
||||
struct MlasFgemmMultiplyAddRow
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOATTYPE Accumulators[RowCount][4],
|
||||
MLAS_FLOATTYPE ABroadcast[RowCount],
|
||||
MLAS_FLOATTYPE BElements[4]
|
||||
)
|
||||
{
|
||||
Accumulators[Row][0] = MLAS_MULADD_FLOAT(ABroadcast[Row], BElements[0], Accumulators[Row][0]);
|
||||
Accumulators[Row][1] = MLAS_MULADD_FLOAT(ABroadcast[Row], BElements[1], Accumulators[Row][1]);
|
||||
Accumulators[Row][2] = MLAS_MULADD_FLOAT(ABroadcast[Row], BElements[2], Accumulators[Row][2]);
|
||||
Accumulators[Row][3] = MLAS_MULADD_FLOAT(ABroadcast[Row], BElements[3], Accumulators[Row][3]);
|
||||
}
|
||||
};
|
||||
|
||||
template<size_t RowCount>
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
MlasFgemmComputeBlock(
|
||||
MLAS_FLOATTYPE Accumulators[RowCount][4],
|
||||
MLAS_FLOATTYPE ABroadcast[RowCount],
|
||||
const MLAS_GEMMTYPE* B
|
||||
)
|
||||
{
|
||||
MLAS_FLOATTYPE BElements[4];
|
||||
#if defined(SINGLE)
|
||||
BElements[0] = MLAS_LOAD_FLOAT(B);
|
||||
BElements[1] = MLAS_LOAD_FLOAT(B + 4);
|
||||
BElements[2] = MLAS_LOAD_FLOAT(B + 8);
|
||||
BElements[3] = MLAS_LOAD_FLOAT(B + 12);
|
||||
#else
|
||||
BElements[0] = MLAS_LOAD_FLOAT(B);
|
||||
BElements[1] = MLAS_LOAD_FLOAT(B + 2);
|
||||
BElements[2] = MLAS_LOAD_FLOAT(B + 4);
|
||||
BElements[3] = MLAS_LOAD_FLOAT(B + 6);
|
||||
#endif
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasFgemmMultiplyAddRow>()(Accumulators, ABroadcast, BElements);
|
||||
}
|
||||
|
||||
struct MlasFgemmMultiplyAlphaRow
|
||||
{
|
||||
template<size_t Count, size_t Index>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOATTYPE Accumulators[4],
|
||||
MLAS_FLOATTYPE AlphaBroadcast
|
||||
)
|
||||
{
|
||||
Accumulators[Index] = MLAS_MUL_FLOAT(Accumulators[Index], AlphaBroadcast);
|
||||
}
|
||||
};
|
||||
|
||||
struct MlasFgemmMultiplyAlphaAddRow
|
||||
{
|
||||
template<size_t Count, size_t Index>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOATTYPE Accumulators[4],
|
||||
MLAS_FLOATTYPE AlphaBroadcast,
|
||||
const MLAS_GEMMTYPE* C
|
||||
)
|
||||
{
|
||||
#if defined(SINGLE)
|
||||
Accumulators[Index] = MLAS_MULADD_FLOAT(Accumulators[Index],
|
||||
AlphaBroadcast, MLAS_LOAD_FLOAT(C + Index * 4));
|
||||
#else
|
||||
Accumulators[Index] = MLAS_MULADD_FLOAT(Accumulators[Index],
|
||||
AlphaBroadcast, MLAS_LOAD_FLOAT(C + Index * 2));
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct MlasFgemmStoreRow
|
||||
{
|
||||
template<size_t Count, size_t Index>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOATTYPE Accumulators[4],
|
||||
MLAS_GEMMTYPE* C
|
||||
)
|
||||
{
|
||||
#if defined(SINGLE)
|
||||
MLAS_STORE_FLOAT(C + Index * 4, Accumulators[Index]);
|
||||
#else
|
||||
MLAS_STORE_FLOAT(C + Index * 2, Accumulators[Index]);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template<size_t VectorCount>
|
||||
struct MlasFgemmStoreVector
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOATTYPE Accumulators[RowCount][4],
|
||||
MLAS_GEMMTYPE* C,
|
||||
size_t ldc,
|
||||
MLAS_FLOATTYPE AlphaBroadcast,
|
||||
bool ZeroMode
|
||||
)
|
||||
{
|
||||
MLAS_GEMMTYPE* c = C + Row * ldc;
|
||||
|
||||
if (ZeroMode) {
|
||||
MlasLoopUnroll<VectorCount, MlasFgemmMultiplyAlphaRow>()(Accumulators[Row], AlphaBroadcast);
|
||||
} else {
|
||||
MlasLoopUnroll<VectorCount, MlasFgemmMultiplyAlphaAddRow>()(Accumulators[Row], AlphaBroadcast, c);
|
||||
}
|
||||
|
||||
MlasLoopUnroll<VectorCount, MlasFgemmStoreRow>()(Accumulators[Row], c);
|
||||
|
||||
//
|
||||
// Shift down any unaligned elements to the bottom for further processing.
|
||||
//
|
||||
|
||||
if (VectorCount < 4) {
|
||||
Accumulators[Row][0] = Accumulators[Row][VectorCount];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct MlasFgemmMultiplyAlphaTrailing
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOATTYPE Accumulators[RowCount][4],
|
||||
MLAS_FLOATTYPE AlphaBroadcast
|
||||
)
|
||||
{
|
||||
Accumulators[Row][0] = MLAS_MUL_FLOAT(Accumulators[Row][0], AlphaBroadcast);
|
||||
}
|
||||
};
|
||||
|
||||
template<unsigned Lane>
|
||||
struct MlasFgemmStoreScalar
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOATTYPE Accumulators[RowCount][4],
|
||||
MLAS_GEMMTYPE* C,
|
||||
size_t ldc,
|
||||
bool ZeroMode
|
||||
)
|
||||
{
|
||||
MLAS_GEMMTYPE* c = C + Row * ldc + Lane;
|
||||
MLAS_GEMMTYPE Value = MLAS_EXTRACT_FLOAT<Lane>(Accumulators[Row][0]);
|
||||
|
||||
if (!ZeroMode) {
|
||||
Value += *c;
|
||||
}
|
||||
|
||||
*c = Value;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -188,10 +188,10 @@ MlasSgemmMMAProcessCount(
|
|||
//
|
||||
while (k >= 4) {
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasSgemmLoadAElements>()(AElements, a, lda);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmLoadAElements>()(AElements, a, lda);
|
||||
MlasSgemmComputeAElements<RowCount>(AElements, ABroadcast);
|
||||
if (CountM == 8) {
|
||||
MlasLoopUnroll<RowCount, MlasSgemmLoadAElements>()(AElements, a + ( lda * 4), lda);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmLoadAElements>()(AElements, a + ( lda * 4), lda);
|
||||
MlasSgemmComputeAElements<RowCount>(AElements, A2Broadcast);
|
||||
}
|
||||
MlasSgemmComputeBlockMMA<RowCount>(&acc[0], ABroadcast[0], A2Broadcast[0], B, CountM);
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ Licensed under the MIT License.
|
|||
|
||||
Module Name:
|
||||
|
||||
SgemmKernelPower.cpp
|
||||
SgemmKernelpower.h
|
||||
|
||||
Abstract:
|
||||
|
||||
|
|
@ -15,286 +15,7 @@ Abstract:
|
|||
|
||||
--*/
|
||||
|
||||
#include "mlasi.h"
|
||||
|
||||
//
|
||||
// Templates to ensure that a loop is unrolled.
|
||||
//
|
||||
|
||||
template<size_t Count, size_t Index>
|
||||
struct MlasLoopUnrollStep
|
||||
{
|
||||
template<typename IterationType, typename... IterationArgs>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Step(
|
||||
IterationArgs&&... Arguments
|
||||
)
|
||||
{
|
||||
IterationType::template Iteration<Count, Index>(Arguments...);
|
||||
MlasLoopUnrollStep<Count, Index + 1>::template Step<IterationType>(Arguments...);
|
||||
}
|
||||
};
|
||||
|
||||
template<size_t Count>
|
||||
struct MlasLoopUnrollStep<Count, Count>
|
||||
{
|
||||
template<typename IterationType, typename... IterationArgs>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Step(
|
||||
IterationArgs&&...
|
||||
)
|
||||
{
|
||||
// Terminate the loop.
|
||||
}
|
||||
};
|
||||
|
||||
template<size_t Count, typename IteratorType>
|
||||
struct MlasLoopUnroll
|
||||
{
|
||||
template<typename... IterationArgs>
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
operator()(
|
||||
IterationArgs&&... Arguments
|
||||
)
|
||||
{
|
||||
MlasLoopUnrollStep<Count, 0>::template Step<IteratorType>(Arguments...);
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Templates used with loop unrolling to perform an action on one row of the
|
||||
// output.
|
||||
//
|
||||
|
||||
struct MlasSgemmZeroAccumulators
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT32X4 Accumulators[RowCount][4]
|
||||
)
|
||||
{
|
||||
Accumulators[Row][0] = MlasZeroFloat32x4();
|
||||
Accumulators[Row][1] = MlasZeroFloat32x4();
|
||||
Accumulators[Row][2] = MlasZeroFloat32x4();
|
||||
Accumulators[Row][3] = MlasZeroFloat32x4();
|
||||
}
|
||||
};
|
||||
|
||||
struct MlasSgemmLoadAElements
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT32X4 AElements[RowCount],
|
||||
const float* A,
|
||||
size_t lda
|
||||
)
|
||||
{
|
||||
AElements[Row] = MlasLoadFloat32x4(A + Row * lda);
|
||||
}
|
||||
};
|
||||
|
||||
struct MlasSgemmBroadcastAElements
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT32X4 ABroadcast[RowCount],
|
||||
const float* A,
|
||||
size_t lda
|
||||
)
|
||||
{
|
||||
ABroadcast[Row] = MlasBroadcastFloat32x4(A + Row * lda);
|
||||
}
|
||||
};
|
||||
|
||||
template<unsigned Lane>
|
||||
struct MlasSgemmSplatAElements
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT32X4 AElements[RowCount],
|
||||
MLAS_FLOAT32X4 ABroadcast[RowCount]
|
||||
)
|
||||
{
|
||||
ABroadcast[Row] = vec_splat(AElements[Row], Lane);
|
||||
}
|
||||
};
|
||||
|
||||
struct MlasSgemmMultiplyAddRow
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT32X4 Accumulators[RowCount][4],
|
||||
MLAS_FLOAT32X4 ABroadcast[RowCount],
|
||||
MLAS_FLOAT32X4 BElements[4]
|
||||
)
|
||||
{
|
||||
Accumulators[Row][0] = MlasMultiplyAddFloat32x4(ABroadcast[Row], BElements[0], Accumulators[Row][0]);
|
||||
Accumulators[Row][1] = MlasMultiplyAddFloat32x4(ABroadcast[Row], BElements[1], Accumulators[Row][1]);
|
||||
Accumulators[Row][2] = MlasMultiplyAddFloat32x4(ABroadcast[Row], BElements[2], Accumulators[Row][2]);
|
||||
Accumulators[Row][3] = MlasMultiplyAddFloat32x4(ABroadcast[Row], BElements[3], Accumulators[Row][3]);
|
||||
}
|
||||
};
|
||||
|
||||
template<size_t RowCount>
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
MlasSgemmComputeBlock(
|
||||
MLAS_FLOAT32X4 Accumulators[RowCount][4],
|
||||
MLAS_FLOAT32X4 ABroadcast[RowCount],
|
||||
const float* B
|
||||
)
|
||||
{
|
||||
MLAS_FLOAT32X4 BElements[4];
|
||||
|
||||
BElements[0] = MlasLoadFloat32x4(B);
|
||||
BElements[1] = MlasLoadFloat32x4(B + 4);
|
||||
BElements[2] = MlasLoadFloat32x4(B + 8);
|
||||
BElements[3] = MlasLoadFloat32x4(B + 12);
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasSgemmMultiplyAddRow>()(Accumulators, ABroadcast, BElements);
|
||||
}
|
||||
|
||||
struct MlasSgemmMultiplyAlphaRow
|
||||
{
|
||||
template<size_t Count, size_t Index>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT32X4 Accumulators[4],
|
||||
MLAS_FLOAT32X4 AlphaBroadcast
|
||||
)
|
||||
{
|
||||
Accumulators[Index] = MlasMultiplyFloat32x4(Accumulators[Index], AlphaBroadcast);
|
||||
}
|
||||
};
|
||||
|
||||
struct MlasSgemmMultiplyAlphaAddRow
|
||||
{
|
||||
template<size_t Count, size_t Index>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT32X4 Accumulators[4],
|
||||
MLAS_FLOAT32X4 AlphaBroadcast,
|
||||
const float* C
|
||||
)
|
||||
{
|
||||
Accumulators[Index] = MlasMultiplyAddFloat32x4(Accumulators[Index],
|
||||
AlphaBroadcast, MlasLoadFloat32x4(C + Index * 4));
|
||||
}
|
||||
};
|
||||
|
||||
struct MlasSgemmStoreRow
|
||||
{
|
||||
template<size_t Count, size_t Index>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT32X4 Accumulators[4],
|
||||
float* C
|
||||
)
|
||||
{
|
||||
MlasStoreFloat32x4(C + Index * 4, Accumulators[Index]);
|
||||
}
|
||||
};
|
||||
|
||||
template<size_t VectorCount>
|
||||
struct MlasSgemmStoreVector
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT32X4 Accumulators[RowCount][4],
|
||||
float* C,
|
||||
size_t ldc,
|
||||
MLAS_FLOAT32X4 AlphaBroadcast,
|
||||
bool ZeroMode
|
||||
)
|
||||
{
|
||||
float* c = C + Row * ldc;
|
||||
|
||||
if (ZeroMode) {
|
||||
MlasLoopUnroll<VectorCount, MlasSgemmMultiplyAlphaRow>()(Accumulators[Row], AlphaBroadcast);
|
||||
} else {
|
||||
MlasLoopUnroll<VectorCount, MlasSgemmMultiplyAlphaAddRow>()(Accumulators[Row], AlphaBroadcast, c);
|
||||
}
|
||||
|
||||
MlasLoopUnroll<VectorCount, MlasSgemmStoreRow>()(Accumulators[Row], c);
|
||||
|
||||
//
|
||||
// Shift down any unaligned elements to the bottom for further processing.
|
||||
//
|
||||
|
||||
if (VectorCount < 4) {
|
||||
Accumulators[Row][0] = Accumulators[Row][VectorCount];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct MlasSgemmMultiplyAlphaTrailing
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT32X4 Accumulators[RowCount][4],
|
||||
MLAS_FLOAT32X4 AlphaBroadcast
|
||||
)
|
||||
{
|
||||
Accumulators[Row][0] = MlasMultiplyFloat32x4(Accumulators[Row][0], AlphaBroadcast);
|
||||
}
|
||||
};
|
||||
|
||||
template<unsigned Lane>
|
||||
struct MlasSgemmStoreScalar
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT32X4 Accumulators[RowCount][4],
|
||||
float* C,
|
||||
size_t ldc,
|
||||
bool ZeroMode
|
||||
)
|
||||
{
|
||||
float* c = C + Row * ldc + Lane;
|
||||
float Value = MlasExtractLaneFloat32x4<Lane>(Accumulators[Row][0]);
|
||||
|
||||
if (!ZeroMode) {
|
||||
Value += *c;
|
||||
}
|
||||
|
||||
*c = Value;
|
||||
}
|
||||
};
|
||||
#include "FgemmKernelpower.h"
|
||||
|
||||
template<size_t RowCount>
|
||||
MLAS_FORCEINLINE
|
||||
|
|
@ -324,7 +45,7 @@ MlasSgemmProcessCount(
|
|||
// Clear the block accumulators.
|
||||
//
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasSgemmZeroAccumulators>()(Accumulators);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmZeroAccumulators>()(Accumulators);
|
||||
|
||||
//
|
||||
// Compute the output block.
|
||||
|
|
@ -332,19 +53,19 @@ MlasSgemmProcessCount(
|
|||
|
||||
while (k >= 4) {
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasSgemmLoadAElements>()(AElements, a, lda);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmLoadAElements>()(AElements, a, lda);
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasSgemmSplatAElements<0>>()(AElements, ABroadcast);
|
||||
MlasSgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmSplatAElements<0>>()(AElements, ABroadcast);
|
||||
MlasFgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B);
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasSgemmSplatAElements<1>>()(AElements, ABroadcast);
|
||||
MlasSgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B + 16);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmSplatAElements<1>>()(AElements, ABroadcast);
|
||||
MlasFgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B + 16);
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasSgemmSplatAElements<2>>()(AElements, ABroadcast);
|
||||
MlasSgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B + 32);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmSplatAElements<2>>()(AElements, ABroadcast);
|
||||
MlasFgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B + 32);
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasSgemmSplatAElements<3>>()(AElements, ABroadcast);
|
||||
MlasSgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B + 48);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmSplatAElements<3>>()(AElements, ABroadcast);
|
||||
MlasFgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B + 48);
|
||||
|
||||
a += 4;
|
||||
B += 16 * 4;
|
||||
|
|
@ -353,8 +74,8 @@ MlasSgemmProcessCount(
|
|||
|
||||
while (k > 0) {
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasSgemmBroadcastAElements>()(ABroadcast, a, lda);
|
||||
MlasSgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmBroadcastAElements>()(ABroadcast, a, lda);
|
||||
MlasFgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B);
|
||||
|
||||
a += 1;
|
||||
B += 16;
|
||||
|
|
@ -367,7 +88,7 @@ MlasSgemmProcessCount(
|
|||
// Store the entire output block.
|
||||
//
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVector<4>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmStoreVector<4>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
|
||||
} else {
|
||||
|
||||
|
|
@ -376,11 +97,11 @@ MlasSgemmProcessCount(
|
|||
//
|
||||
|
||||
if (CountN >= 12) {
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVector<3>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmStoreVector<3>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
} else if (CountN >= 8) {
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVector<2>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmStoreVector<2>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
} else if (CountN >= 4) {
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVector<1>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmStoreVector<1>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
}
|
||||
|
||||
//
|
||||
|
|
@ -392,16 +113,16 @@ MlasSgemmProcessCount(
|
|||
|
||||
if (CountN > 0) {
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasSgemmMultiplyAlphaTrailing>()(Accumulators, AlphaBroadcast);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmMultiplyAlphaTrailing>()(Accumulators, AlphaBroadcast);
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreScalar<0>>()(Accumulators, C, ldc, ZeroMode);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmStoreScalar<0>>()(Accumulators, C, ldc, ZeroMode);
|
||||
|
||||
if (CountN >= 2) {
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreScalar<1>>()(Accumulators, C, ldc, ZeroMode);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmStoreScalar<1>>()(Accumulators, C, ldc, ZeroMode);
|
||||
}
|
||||
|
||||
if (CountN >= 3) {
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreScalar<2>>()(Accumulators, C, ldc, ZeroMode);
|
||||
MlasLoopUnroll<RowCount, MlasFgemmStoreScalar<2>>()(Accumulators, C, ldc, ZeroMode);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue