POWER10: Add optimized dgemm kernel (#9652)

* POWER10: Add optimized dgemm kernel

This patch makes use of POWER10 matrix multiply assist feature and
adds new DGEMM kernel.

* Indentation update

Co-authored-by: Rajalakshmi Srinivasaraghavan <rajis@linux.ibm.com>
This commit is contained in:
RajalakshmiSR 2021-11-22 22:28:21 -06:00 committed by GitHub
parent bf5e9a5044
commit 8564fc1933
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 799 additions and 599 deletions

View file

@ -291,6 +291,7 @@ else()
${MLAS_SRC_DIR}/dgemm.cpp
${MLAS_SRC_DIR}/power/DgemmKernelPower.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/power/SgemmKernelPower.cpp PROPERTIES COMPILE_FLAGS "-DSINGLE")
check_cxx_compiler_flag("-mcpu=power10" HAS_POWER10)
if(HAS_POWER10)
set(CMAKE_REQUIRED_FLAGS "-mcpu=power10")
@ -318,8 +319,10 @@ else()
endif()
set(mlas_platform_srcs_power10
${MLAS_SRC_DIR}/power/SgemmKernelPOWER10.cpp
${MLAS_SRC_DIR}/power/DgemmKernelPOWER10.cpp
)
set_source_files_properties(${mlas_platform_srcs_power10} PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10")
set_source_files_properties(${MLAS_SRC_DIR}/power/SgemmKernelPOWER10.cpp PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10 -DSINGLE")
set_source_files_properties(${MLAS_SRC_DIR}/power/DgemmKernelPOWER10.cpp PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10")
set(mlas_platform_srcs
${mlas_platform_srcs}
${mlas_platform_srcs_power10}

View file

@ -499,6 +499,7 @@ extern "C" {
MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernel;
MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelPOWER10;
MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernel;
MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelPOWER10;
#else
MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelZero;
MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelAdd;
@ -1886,7 +1887,7 @@ MlasStoreAlignedFloat64x2(double* Buffer, MLAS_FLOAT64X2 Vector)
#if defined(MLAS_SSE2_INTRINSICS)
_mm_store_pd(Buffer, Vector);
#elif defined(MLAS_VSX_INTRINSICS)
vec_st(Vector, 0, Buffer);
*((MLAS_FLOAT64X2*)Buffer) = Vector;
#endif
}

View file

@ -379,6 +379,7 @@ Return Value:
bool HasP10Instructions = ((hwcap2 & PPC_FEATURE2_MMA) && (hwcap2 & PPC_FEATURE2_ARCH_3_1));
if (HasP10Instructions) {
this->GemmFloatKernel = MlasSgemmKernelPOWER10;
this->GemmDoubleKernel = MlasDgemmKernelPOWER10;
}
#endif
#endif

View file

@ -0,0 +1,418 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
DgemmKernelPower.cpp
Abstract:
This module implements the kernels for the double precision matrix/matrix
multiply operation (DGEMM).
--*/
#include "DgemmKernelpower.h"
struct MlasDgemmBroadcastAElementsMMA
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
double ARow[RowCount],
const double* A,
size_t lda
)
{
ARow[Row] = A [Row * lda];
}
};
template<size_t RowCount>
MLAS_FORCEINLINE
void
MlasDgemmComputeAElements(
MLAS_FLOAT64X2 AElements[RowCount],
MLAS_FLOAT64X2 ABroadcast[RowCount]
)
{
ABroadcast[0] = vec_mergee (AElements[0], AElements[1]);
ABroadcast[1] = vec_mergee (AElements[2], AElements[3]);
ABroadcast[2] = vec_mergeo (AElements[0], AElements[1]);
ABroadcast[3] = vec_mergeo (AElements[2], AElements[3]);
}
template<size_t RowCount>
MLAS_FORCEINLINE
void
MlasDgemmComputeBlockMMA(
__vector_quad acc[8],
MLAS_FLOAT64X2 ABroadcast[RowCount],
MLAS_FLOAT64X2 A2Broadcast[RowCount],
const double* B,
size_t CountM
)
{
MLAS_FLOAT64X2 BElements[4];
typedef __vector unsigned char vec_t;
__vector_pair A2pair, Apair;
#if (defined(__GNUC__) && (__GNUC__ == 10 && __GNUC_MINOR__ <= 2))
__builtin_mma_assemble_pair (&Apair, (vec_t)ABroadcast[1], (vec_t)ABroadcast[0]);
if (CountM == 8) {
__builtin_mma_assemble_pair (&A2pair, (vec_t)A2Broadcast[1], (vec_t)A2Broadcast[0]);
}
#elif (defined(__GNUC__) && (__GNUC__ == 11 && __GNUC_MINOR__ <= 2))
Apair = *((__vector_pair *)((void *)&ABroadcast[0]));
if (CountM == 8) {
A2pair = *((__vector_pair *)((void *)&A2Broadcast[0]));
}
#else
__builtin_vsx_assemble_pair (&Apair, (vec_t)ABroadcast[1], (vec_t)ABroadcast[0]);
if (CountM == 8) {
__builtin_vsx_assemble_pair (&A2pair, (vec_t)A2Broadcast[1], (vec_t)A2Broadcast[0]);
}
#endif
BElements[0] = MlasLoadFloat64x2(B);
BElements[1] = MlasLoadFloat64x2(B + 2);
BElements[2] = MlasLoadFloat64x2(B + 4);
BElements[3] = MlasLoadFloat64x2(B + 6);
__builtin_mma_xvf64gerpp (&acc[0], Apair, (vec_t)BElements[0]);
__builtin_mma_xvf64gerpp (&acc[1], Apair, (vec_t)BElements[1]);
__builtin_mma_xvf64gerpp (&acc[2], Apair, (vec_t)BElements[2]);
__builtin_mma_xvf64gerpp (&acc[3], Apair, (vec_t)BElements[3]);
if (CountM == 8) {
__builtin_mma_xvf64gerpp (&acc[4], A2pair, (vec_t)BElements[0]);
__builtin_mma_xvf64gerpp (&acc[5], A2pair, (vec_t)BElements[1]);
__builtin_mma_xvf64gerpp (&acc[6], A2pair, (vec_t)BElements[2]);
__builtin_mma_xvf64gerpp (&acc[7], A2pair, (vec_t)BElements[3]);
}
}
template<size_t VectorCount>
struct MlasDgemmStoreVectorMMA
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOAT64X2 Result[4],
double* C,
size_t ldc,
MLAS_FLOAT64X2 AlphaBroadcast,
bool ZeroMode
)
{
MLAS_FLOAT64X2 *rowC;
if (ZeroMode) {
rowC = (MLAS_FLOAT64X2 *) &C[Row * ldc + VectorCount];
rowC[0] = Result[Row] * AlphaBroadcast;
} else {
rowC = (MLAS_FLOAT64X2 *) &C[Row * ldc + VectorCount];
rowC[0] += Result[Row] * AlphaBroadcast;
}
}
};
struct MlasDgemmMultiplyAlphaTrailingMMA
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOAT64X2 Accumulators[RowCount],
MLAS_FLOAT64X2 AlphaBroadcast
)
{
Accumulators[Row] = MlasMultiplyFloat64x2(Accumulators[Row], AlphaBroadcast);
}
};
template<unsigned Lane>
struct MlasDgemmStoreScalarMMA
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOAT64X2 Accumulators[RowCount],
double* C,
size_t ldc,
bool ZeroMode
)
{
double* c = C + Row * ldc + Lane;
double Value = Accumulators[Row][Lane];
if (!ZeroMode) {
Value += *c;
}
*c = Value;
}
};
template<size_t RowCount>
MLAS_FORCEINLINE
size_t
MlasDgemmMMAProcessCount(
const double* A,
const double* B,
double* C,
size_t CountM,
size_t CountK,
size_t CountN,
size_t lda,
size_t ldc,
MLAS_FLOAT64X2 AlphaBroadcast,
bool ZeroMode
)
{
do {
const double* a = A;
size_t k = CountK;
MLAS_FLOAT64X2 Accumulators[2][RowCount] = {{ 0 }};
MLAS_FLOAT64X2 Result[RowCount];
MLAS_FLOAT64X2 AElements[RowCount];
MLAS_FLOAT64X2 ABroadcast[RowCount] = { 0 };
MLAS_FLOAT64X2 A2Broadcast[RowCount] = { 0 };
MLAS_FLOAT64X2 A3Broadcast[RowCount] = { 0 };
MLAS_FLOAT64X2 A4Broadcast[RowCount] = { 0 };
double ARow[RowCount] = { 0 };
double A2Row[RowCount] = { 0 };
__vector_quad acc[8];
//
// Clear the block accumulators.
//
__builtin_mma_xxsetaccz(&acc[0]);
__builtin_mma_xxsetaccz(&acc[1]);
__builtin_mma_xxsetaccz(&acc[2]);
__builtin_mma_xxsetaccz(&acc[3]);
__builtin_mma_xxsetaccz(&acc[4]);
__builtin_mma_xxsetaccz(&acc[5]);
__builtin_mma_xxsetaccz(&acc[6]);
__builtin_mma_xxsetaccz(&acc[7]);
//
// Compute the output block.
//
while (k >= 4) {
MlasLoopUnroll<RowCount, MlasFgemmLoadAElements>()(AElements, a, lda);
MlasDgemmComputeAElements<RowCount>(AElements, ABroadcast);
MlasLoopUnroll<RowCount, MlasFgemmLoadAElements>()(AElements, a+2, lda);
MlasDgemmComputeAElements<RowCount>(AElements, A3Broadcast);
if (CountM == 8) {
MlasLoopUnroll<RowCount, MlasFgemmLoadAElements>()(AElements, a + ( lda * 4), lda);
MlasDgemmComputeAElements<RowCount>(AElements, A2Broadcast);
MlasLoopUnroll<RowCount, MlasFgemmLoadAElements>()(AElements, (a+2) + ( lda * 4), lda);
MlasDgemmComputeAElements<RowCount>(AElements, A4Broadcast);
}
MlasDgemmComputeBlockMMA<RowCount>(&acc[0], &ABroadcast[0], &A2Broadcast[0], B, CountM);
MlasDgemmComputeBlockMMA<RowCount>(&acc[0], &ABroadcast[2], &A2Broadcast[2], B+8, CountM);
MlasDgemmComputeBlockMMA<RowCount>(&acc[0], &A3Broadcast[0], &A4Broadcast[0], B+16, CountM);
MlasDgemmComputeBlockMMA<RowCount>(&acc[0], &A3Broadcast[2], &A4Broadcast[2], B+24, CountM);
B += 8 * 4;
a += 4;
k -= 4;
}
while (k > 0) {
MlasLoopUnroll<RowCount, MlasDgemmBroadcastAElementsMMA>()(ARow, a, lda);
if (CountM == 8) {
MlasLoopUnroll<RowCount, MlasDgemmBroadcastAElementsMMA>()(A2Row, a + (lda * 4), lda);
}
MlasDgemmComputeBlockMMA<RowCount>(&acc[0], (MLAS_FLOAT64X2 *)ARow, (MLAS_FLOAT64X2 *)A2Row, B, CountM);
a += 1;
B += 8;
k -= 1;
}
if (CountN >= 8) {
//
// Store the entire output block.
//
__builtin_mma_disassemble_acc ((void *)Result, &acc[0]);
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<0>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
__builtin_mma_disassemble_acc ((void *)Result, &acc[1]);
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<2>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
__builtin_mma_disassemble_acc ((void *)Result, &acc[2]);
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<4>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
__builtin_mma_disassemble_acc ((void *)Result, &acc[3]);
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<6>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
if (CountM == 8) {
__builtin_mma_disassemble_acc ((void *)Result, &acc[4]);
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<0>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
__builtin_mma_disassemble_acc ((void *)Result, &acc[5]);
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<2>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
__builtin_mma_disassemble_acc ((void *)Result, &acc[6]);
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<4>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
__builtin_mma_disassemble_acc ((void *)Result, &acc[7]);
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<6>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
}
} else {
//
// Store the partial output block.
//
if (CountN >= 6) {
__builtin_mma_disassemble_acc ((void *)Result, &acc[0]);
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<0>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
__builtin_mma_disassemble_acc ((void *)Result, &acc[1]);
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<2>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
__builtin_mma_disassemble_acc ((void *)Result, &acc[2]);
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<4>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
if (CountM == 8) {
__builtin_mma_disassemble_acc ((void *)Result, &acc[4]);
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<0>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
__builtin_mma_disassemble_acc ((void *)Result, &acc[5]);
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<2>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
__builtin_mma_disassemble_acc ((void *)Result, &acc[6]);
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<4>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
if (CountN - 6 > 0) {
__builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[7]);
}
}
if (CountN - 6 > 0) {
__builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[3]);
}
} else if (CountN >= 4) {
__builtin_mma_disassemble_acc ((void *)Result, &acc[0]);
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<0>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
__builtin_mma_disassemble_acc ((void *)Result, &acc[1]);
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<2>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
if (CountM == 8) {
__builtin_mma_disassemble_acc ((void *)Result, &acc[4]);
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<0>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
__builtin_mma_disassemble_acc ((void *)Result, &acc[5]);
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<2>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
if (CountN - 4 > 0) {
__builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[6]);
}
}
if (CountN - 4 > 0) {
__builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[2]);
}
} else if (CountN >= 2) {
__builtin_mma_disassemble_acc ((void *)Result, &acc[0]);
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<0>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
if (CountM == 8) {
__builtin_mma_disassemble_acc ((void *)Result, &acc[4]);
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<0>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
if (CountN - 2 > 0) {
__builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[5]);
}
}
if (CountN - 2 > 0) {
__builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[1]);
}
} else {
__builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[0]);
if (CountM == 8) {
__builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[4]);
}
}
//
// Store the remaining unaligned columns.
//
C += (CountN & ~1);
CountN &= 1;
if (CountN > 0) {
MlasLoopUnroll<RowCount, MlasDgemmMultiplyAlphaTrailingMMA>()(Accumulators[0], AlphaBroadcast);
MlasLoopUnroll<RowCount, MlasDgemmStoreScalarMMA<0>>()(Accumulators[0], C, ldc, ZeroMode);
if (CountM == 8) {
MlasLoopUnroll<RowCount, MlasDgemmMultiplyAlphaTrailingMMA>()(Accumulators[1], AlphaBroadcast);
MlasLoopUnroll<RowCount, MlasDgemmStoreScalarMMA<0>>()(Accumulators[1], C + (ldc*4), ldc, ZeroMode);
}
}
break;
}
C += 8;
CountN -= 8;
} while (CountN > 0);
return CountM;
}
size_t
MLASCALL
MlasDgemmKernelPOWER10(
const double* A,
const double* B,
double* C,
size_t CountK,
size_t CountM,
size_t CountN,
size_t lda,
size_t ldc,
double alpha,
bool ZeroMode
)
/*++
Routine Description:
This routine is an inner kernel to compute matrix multiplication for a
set of rows.
Arguments:
A - Supplies the address of matrix A.
B - Supplies the address of matrix B. The matrix data has been packed using
MlasDgemmCopyPackB or MlasDgemmTransposePackB.
C - Supplies the address of matrix C.
CountK - Supplies the number of columns from matrix A and the number of rows
from matrix B to iterate over.
CountM - Supplies the maximum number of rows that can be processed for
matrix A and matrix C. The actual number of rows handled for this
invocation depends on the kernel implementation.
CountN - Supplies the number of columns from matrix B and matrix C to
iterate over.
lda - Supplies the first dimension of matrix A.
ldc - Supplies the first dimension of matrix C.
alpha - Supplies the scalar multiplier (see DGEMM definition).
ZeroMode - Supplies true if the output matrix must be zero initialized,
else false if the output matrix is accumulated into.
Return Value:
Returns the number of rows handled.
--*/
{
size_t RowsHandled;
MLAS_FLOAT64X2 AlphaBroadcast = MlasBroadcastFloat64x2(alpha);
if (CountM >= 8) {
RowsHandled = MlasDgemmMMAProcessCount<4>(A, B, C, 8 ,CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode);
} else if (CountM >= 4) {
RowsHandled = MlasDgemmMMAProcessCount<4>(A, B, C, 4, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode);
} else if (CountM >= 2) {
RowsHandled = MlasDgemmProcessCount<2>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode);
} else {
RowsHandled = MlasDgemmProcessCount<1>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode);
}
return RowsHandled;
}

View file

@ -6,293 +6,16 @@ Licensed under the MIT License.
Module Name:
DgemmKernelPower.cpp
DgemmKernelpower.h
Abstract:
This module implements the kernels for the single precision matrix/matrix
This module implements the kernels for the double precision matrix/matrix
multiply operation (DGEMM).
--*/
#include "mlasi.h"
//
// Templates to ensure that a loop is unrolled.
//
template<size_t Count, size_t Index>
struct MlasLoopUnrollStep
{
template<typename IterationType, typename... IterationArgs>
MLAS_FORCEINLINE
static
void
Step(
IterationArgs&&... Arguments
)
{
IterationType::template Iteration<Count, Index>(Arguments...);
MlasLoopUnrollStep<Count, Index + 1>::template Step<IterationType>(Arguments...);
}
};
template<size_t Count>
struct MlasLoopUnrollStep<Count, Count>
{
template<typename IterationType, typename... IterationArgs>
MLAS_FORCEINLINE
static
void
Step(
IterationArgs&&...
)
{
// Terminate the loop.
}
};
template<size_t Count, typename IteratorType>
struct MlasLoopUnroll
{
template<typename... IterationArgs>
MLAS_FORCEINLINE
void
operator()(
IterationArgs&&... Arguments
)
{
MlasLoopUnrollStep<Count, 0>::template Step<IteratorType>(Arguments...);
}
};
//
// Templates used with loop unrolling to perform an action on one row of the
// output.
//
struct MlasDgemmZeroAccumulators
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOAT64X2 Accumulators[RowCount][4]
)
{
Accumulators[Row][0] = MlasZeroFloat64x2();
Accumulators[Row][1] = MlasZeroFloat64x2();
Accumulators[Row][2] = MlasZeroFloat64x2();
Accumulators[Row][3] = MlasZeroFloat64x2();
}
};
struct MlasDgemmLoadAElements
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOAT64X2 AElements[RowCount],
const double* A,
size_t lda
)
{
AElements[Row] = MlasLoadFloat64x2(A + Row * lda);
}
};
struct MlasDgemmBroadcastAElements
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOAT64X2 ABroadcast[RowCount],
const double* A,
size_t lda
)
{
ABroadcast[Row] = MlasBroadcastFloat64x2(A + Row * lda);
}
};
template<unsigned Lane>
struct MlasDgemmSplatAElements
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOAT64X2 AElements[RowCount],
MLAS_FLOAT64X2 ABroadcast[RowCount]
)
{
ABroadcast[Row] = vec_splat(AElements[Row], Lane);
}
};
struct MlasDgemmMultiplyAddRow
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOAT64X2 Accumulators[RowCount][4],
MLAS_FLOAT64X2 ABroadcast[RowCount],
MLAS_FLOAT64X2 BElements[4]
)
{
Accumulators[Row][0] = MlasMultiplyAddFloat64x2(ABroadcast[Row], BElements[0], Accumulators[Row][0]);
Accumulators[Row][1] = MlasMultiplyAddFloat64x2(ABroadcast[Row], BElements[1], Accumulators[Row][1]);
Accumulators[Row][2] = MlasMultiplyAddFloat64x2(ABroadcast[Row], BElements[2], Accumulators[Row][2]);
Accumulators[Row][3] = MlasMultiplyAddFloat64x2(ABroadcast[Row], BElements[3], Accumulators[Row][3]);
}
};
template<size_t RowCount>
MLAS_FORCEINLINE
void
MlasDgemmComputeBlock(
MLAS_FLOAT64X2 Accumulators[RowCount][4],
MLAS_FLOAT64X2 ABroadcast[RowCount],
const double* B
)
{
MLAS_FLOAT64X2 BElements[4];
BElements[0] = MlasLoadFloat64x2(B);
BElements[1] = MlasLoadFloat64x2(B + 2);
BElements[2] = MlasLoadFloat64x2(B + 4);
BElements[3] = MlasLoadFloat64x2(B + 6);
MlasLoopUnroll<RowCount, MlasDgemmMultiplyAddRow>()(Accumulators, ABroadcast, BElements);
}
struct MlasDgemmMultiplyAlphaRow
{
template<size_t Count, size_t Index>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOAT64X2 Accumulators[4],
MLAS_FLOAT64X2 AlphaBroadcast
)
{
Accumulators[Index] = MlasMultiplyFloat64x2(Accumulators[Index], AlphaBroadcast);
}
};
struct MlasDgemmMultiplyAlphaAddRow
{
template<size_t Count, size_t Index>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOAT64X2 Accumulators[4],
MLAS_FLOAT64X2 AlphaBroadcast,
const double* C
)
{
Accumulators[Index] = MlasMultiplyAddFloat64x2(Accumulators[Index],
AlphaBroadcast, MlasLoadFloat64x2(C + Index * 2));
}
};
struct MlasDgemmStoreRow
{
template<size_t Count, size_t Index>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOAT64X2 Accumulators[4],
double* C
)
{
MlasStoreFloat64x2(C + Index * 2, Accumulators[Index]);
}
};
template<size_t VectorCount>
struct MlasDgemmStoreVector
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOAT64X2 Accumulators[RowCount][4],
double* C,
size_t ldc,
MLAS_FLOAT64X2 AlphaBroadcast,
bool ZeroMode
)
{
double* c = C + Row * ldc;
if (ZeroMode) {
MlasLoopUnroll<VectorCount, MlasDgemmMultiplyAlphaRow>()(Accumulators[Row], AlphaBroadcast);
} else {
MlasLoopUnroll<VectorCount, MlasDgemmMultiplyAlphaAddRow>()(Accumulators[Row], AlphaBroadcast, c);
}
MlasLoopUnroll<VectorCount, MlasDgemmStoreRow>()(Accumulators[Row], c);
//
// Shift down any unaligned elements to the bottom for further processing.
//
if (VectorCount < 4) {
Accumulators[Row][0] = Accumulators[Row][VectorCount];
}
}
};
struct MlasDgemmMultiplyAlphaTrailing
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOAT64X2 Accumulators[RowCount][4],
MLAS_FLOAT64X2 AlphaBroadcast
)
{
Accumulators[Row][0] = MlasMultiplyFloat64x2(Accumulators[Row][0], AlphaBroadcast);
}
};
template<unsigned Lane>
struct MlasDgemmStoreScalar
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOAT64X2 Accumulators[RowCount][4],
double* C,
size_t ldc,
bool ZeroMode
)
{
double* c = C + Row * ldc + Lane;
double Value = MlasExtractLaneFloat64x2<Lane>(Accumulators[Row][0]);
if (!ZeroMode) {
Value += *c;
}
*c = Value;
}
};
#include "FgemmKernelpower.h"
template<size_t RowCount>
MLAS_FORCEINLINE
@ -322,20 +45,20 @@ MlasDgemmProcessCount(
// Clear the block accumulators.
//
MlasLoopUnroll<RowCount, MlasDgemmZeroAccumulators>()(Accumulators);
MlasLoopUnroll<RowCount, MlasFgemmZeroAccumulators>()(Accumulators);
//
// Compute the output block.
//
while (k >= 2) {
MlasLoopUnroll<RowCount, MlasDgemmLoadAElements>()(AElements, a, lda);
MlasLoopUnroll<RowCount, MlasFgemmLoadAElements>()(AElements, a, lda);
MlasLoopUnroll<RowCount, MlasDgemmSplatAElements<0>>()(AElements, ABroadcast);
MlasDgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B);
MlasLoopUnroll<RowCount, MlasFgemmSplatAElements<0>>()(AElements, ABroadcast);
MlasFgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B);
MlasLoopUnroll<RowCount, MlasDgemmSplatAElements<1>>()(AElements, ABroadcast);
MlasDgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B + 8);
MlasLoopUnroll<RowCount, MlasFgemmSplatAElements<1>>()(AElements, ABroadcast);
MlasFgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B + 8);
a += 2;
B += 8 * 2;
@ -343,8 +66,8 @@ MlasDgemmProcessCount(
}
if (k > 0) {
MlasLoopUnroll<RowCount, MlasDgemmBroadcastAElements>()(ABroadcast, a, lda);
MlasDgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B);
MlasLoopUnroll<RowCount, MlasFgemmBroadcastAElements>()(ABroadcast, a, lda);
MlasFgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B);
a += 1;
B += 8;
@ -357,7 +80,7 @@ MlasDgemmProcessCount(
// Store the entire output block.
//
MlasLoopUnroll<RowCount, MlasDgemmStoreVector<4>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
MlasLoopUnroll<RowCount, MlasFgemmStoreVector<4>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
} else {
@ -367,11 +90,11 @@ MlasDgemmProcessCount(
//
if (CountN >= 6) {
MlasLoopUnroll<RowCount, MlasDgemmStoreVector<3>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
MlasLoopUnroll<RowCount, MlasFgemmStoreVector<3>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
} else if (CountN >= 4) {
MlasLoopUnroll<RowCount, MlasDgemmStoreVector<2>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
MlasLoopUnroll<RowCount, MlasFgemmStoreVector<2>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
} else if (CountN >= 2) {
MlasLoopUnroll<RowCount, MlasDgemmStoreVector<1>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
MlasLoopUnroll<RowCount, MlasFgemmStoreVector<1>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
}
//
// Store the remaining unaligned columns.
@ -381,9 +104,9 @@ MlasDgemmProcessCount(
if (CountN > 0) {
MlasLoopUnroll<RowCount, MlasDgemmMultiplyAlphaTrailing>()(Accumulators, AlphaBroadcast);
MlasLoopUnroll<RowCount, MlasFgemmMultiplyAlphaTrailing>()(Accumulators, AlphaBroadcast);
MlasLoopUnroll<RowCount, MlasDgemmStoreScalar<0>>()(Accumulators, C, ldc, ZeroMode);
MlasLoopUnroll<RowCount, MlasFgemmStoreScalar<0>>()(Accumulators, C, ldc, ZeroMode);
}
break;

View file

@ -0,0 +1,333 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
FgemmKernelPower.h
Abstract:
This module implements the kernels for the single/double precision matrix/matrix
multiply operation (DGEMM/SGEMM).
--*/
#include "mlasi.h"
#if defined(SINGLE)
#define MLAS_FLOATTYPE MLAS_FLOAT32X4
#define MLAS_GEMMTYPE float
#define MLAS_LOAD_FLOAT MlasLoadFloat32x4
#define MLAS_ZERO_FLOAT MlasZeroFloat32x4
#define MLAS_STORE_FLOAT MlasStoreFloat32x4
#define MLAS_EXTRACT_FLOAT MlasExtractLaneFloat32x4
#define MLAS_MUL_FLOAT MlasMultiplyFloat32x4
#define MLAS_MULADD_FLOAT MlasMultiplyAddFloat32x4
#define MLAS_BROADCAST_FLOAT MlasBroadcastFloat32x4
#else
#define MLAS_FLOATTYPE MLAS_FLOAT64X2
#define MLAS_GEMMTYPE double
#define MLAS_LOAD_FLOAT MlasLoadFloat64x2
#define MLAS_ZERO_FLOAT MlasZeroFloat64x2
#define MLAS_STORE_FLOAT MlasStoreFloat64x2
#define MLAS_EXTRACT_FLOAT MlasExtractLaneFloat64x2
#define MLAS_MUL_FLOAT MlasMultiplyFloat64x2
#define MLAS_MULADD_FLOAT MlasMultiplyAddFloat64x2
#define MLAS_BROADCAST_FLOAT MlasBroadcastFloat64x2
#endif
//
// Templates to ensure that a loop is unrolled.
//
template<size_t Count, size_t Index>
struct MlasLoopUnrollStep
{
template<typename IterationType, typename... IterationArgs>
MLAS_FORCEINLINE
static
void
Step(
IterationArgs&&... Arguments
)
{
IterationType::template Iteration<Count, Index>(Arguments...);
MlasLoopUnrollStep<Count, Index + 1>::template Step<IterationType>(Arguments...);
}
};
template<size_t Count>
struct MlasLoopUnrollStep<Count, Count>
{
template<typename IterationType, typename... IterationArgs>
MLAS_FORCEINLINE
static
void
Step(
IterationArgs&&...
)
{
// Terminate the loop.
}
};
template<size_t Count, typename IteratorType>
struct MlasLoopUnroll
{
template<typename... IterationArgs>
MLAS_FORCEINLINE
void
operator()(
IterationArgs&&... Arguments
)
{
MlasLoopUnrollStep<Count, 0>::template Step<IteratorType>(Arguments...);
}
};
//
// Templates used with loop unrolling to perform an action on one row of the
// output.
//
struct MlasFgemmZeroAccumulators
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOATTYPE Accumulators[RowCount][4]
)
{
Accumulators[Row][0] = MLAS_ZERO_FLOAT();
Accumulators[Row][1] = MLAS_ZERO_FLOAT();
Accumulators[Row][2] = MLAS_ZERO_FLOAT();
Accumulators[Row][3] = MLAS_ZERO_FLOAT();
}
};
struct MlasFgemmLoadAElements
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOATTYPE AElements[RowCount],
const MLAS_GEMMTYPE* A,
size_t lda
)
{
AElements[Row] = MLAS_LOAD_FLOAT(A + Row * lda);
}
};
struct MlasFgemmBroadcastAElements
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOATTYPE ABroadcast[RowCount],
const MLAS_GEMMTYPE* A,
size_t lda
)
{
ABroadcast[Row] = MLAS_BROADCAST_FLOAT(A + Row * lda);
}
};
template<unsigned Lane>
struct MlasFgemmSplatAElements
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOATTYPE AElements[RowCount],
MLAS_FLOATTYPE ABroadcast[RowCount]
)
{
ABroadcast[Row] = vec_splat(AElements[Row], Lane);
}
};
struct MlasFgemmMultiplyAddRow
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOATTYPE Accumulators[RowCount][4],
MLAS_FLOATTYPE ABroadcast[RowCount],
MLAS_FLOATTYPE BElements[4]
)
{
Accumulators[Row][0] = MLAS_MULADD_FLOAT(ABroadcast[Row], BElements[0], Accumulators[Row][0]);
Accumulators[Row][1] = MLAS_MULADD_FLOAT(ABroadcast[Row], BElements[1], Accumulators[Row][1]);
Accumulators[Row][2] = MLAS_MULADD_FLOAT(ABroadcast[Row], BElements[2], Accumulators[Row][2]);
Accumulators[Row][3] = MLAS_MULADD_FLOAT(ABroadcast[Row], BElements[3], Accumulators[Row][3]);
}
};
template<size_t RowCount>
MLAS_FORCEINLINE
void
MlasFgemmComputeBlock(
MLAS_FLOATTYPE Accumulators[RowCount][4],
MLAS_FLOATTYPE ABroadcast[RowCount],
const MLAS_GEMMTYPE* B
)
{
MLAS_FLOATTYPE BElements[4];
#if defined(SINGLE)
BElements[0] = MLAS_LOAD_FLOAT(B);
BElements[1] = MLAS_LOAD_FLOAT(B + 4);
BElements[2] = MLAS_LOAD_FLOAT(B + 8);
BElements[3] = MLAS_LOAD_FLOAT(B + 12);
#else
BElements[0] = MLAS_LOAD_FLOAT(B);
BElements[1] = MLAS_LOAD_FLOAT(B + 2);
BElements[2] = MLAS_LOAD_FLOAT(B + 4);
BElements[3] = MLAS_LOAD_FLOAT(B + 6);
#endif
MlasLoopUnroll<RowCount, MlasFgemmMultiplyAddRow>()(Accumulators, ABroadcast, BElements);
}
struct MlasFgemmMultiplyAlphaRow
{
template<size_t Count, size_t Index>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOATTYPE Accumulators[4],
MLAS_FLOATTYPE AlphaBroadcast
)
{
Accumulators[Index] = MLAS_MUL_FLOAT(Accumulators[Index], AlphaBroadcast);
}
};
struct MlasFgemmMultiplyAlphaAddRow
{
template<size_t Count, size_t Index>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOATTYPE Accumulators[4],
MLAS_FLOATTYPE AlphaBroadcast,
const MLAS_GEMMTYPE* C
)
{
#if defined(SINGLE)
Accumulators[Index] = MLAS_MULADD_FLOAT(Accumulators[Index],
AlphaBroadcast, MLAS_LOAD_FLOAT(C + Index * 4));
#else
Accumulators[Index] = MLAS_MULADD_FLOAT(Accumulators[Index],
AlphaBroadcast, MLAS_LOAD_FLOAT(C + Index * 2));
#endif
}
};
struct MlasFgemmStoreRow
{
template<size_t Count, size_t Index>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOATTYPE Accumulators[4],
MLAS_GEMMTYPE* C
)
{
#if defined(SINGLE)
MLAS_STORE_FLOAT(C + Index * 4, Accumulators[Index]);
#else
MLAS_STORE_FLOAT(C + Index * 2, Accumulators[Index]);
#endif
}
};
template<size_t VectorCount>
struct MlasFgemmStoreVector
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOATTYPE Accumulators[RowCount][4],
MLAS_GEMMTYPE* C,
size_t ldc,
MLAS_FLOATTYPE AlphaBroadcast,
bool ZeroMode
)
{
MLAS_GEMMTYPE* c = C + Row * ldc;
if (ZeroMode) {
MlasLoopUnroll<VectorCount, MlasFgemmMultiplyAlphaRow>()(Accumulators[Row], AlphaBroadcast);
} else {
MlasLoopUnroll<VectorCount, MlasFgemmMultiplyAlphaAddRow>()(Accumulators[Row], AlphaBroadcast, c);
}
MlasLoopUnroll<VectorCount, MlasFgemmStoreRow>()(Accumulators[Row], c);
//
// Shift down any unaligned elements to the bottom for further processing.
//
if (VectorCount < 4) {
Accumulators[Row][0] = Accumulators[Row][VectorCount];
}
}
};
struct MlasFgemmMultiplyAlphaTrailing
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOATTYPE Accumulators[RowCount][4],
MLAS_FLOATTYPE AlphaBroadcast
)
{
Accumulators[Row][0] = MLAS_MUL_FLOAT(Accumulators[Row][0], AlphaBroadcast);
}
};
template<unsigned Lane>
struct MlasFgemmStoreScalar
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOATTYPE Accumulators[RowCount][4],
MLAS_GEMMTYPE* C,
size_t ldc,
bool ZeroMode
)
{
MLAS_GEMMTYPE* c = C + Row * ldc + Lane;
MLAS_GEMMTYPE Value = MLAS_EXTRACT_FLOAT<Lane>(Accumulators[Row][0]);
if (!ZeroMode) {
Value += *c;
}
*c = Value;
}
};

View file

@ -188,10 +188,10 @@ MlasSgemmMMAProcessCount(
//
while (k >= 4) {
MlasLoopUnroll<RowCount, MlasSgemmLoadAElements>()(AElements, a, lda);
MlasLoopUnroll<RowCount, MlasFgemmLoadAElements>()(AElements, a, lda);
MlasSgemmComputeAElements<RowCount>(AElements, ABroadcast);
if (CountM == 8) {
MlasLoopUnroll<RowCount, MlasSgemmLoadAElements>()(AElements, a + ( lda * 4), lda);
MlasLoopUnroll<RowCount, MlasFgemmLoadAElements>()(AElements, a + ( lda * 4), lda);
MlasSgemmComputeAElements<RowCount>(AElements, A2Broadcast);
}
MlasSgemmComputeBlockMMA<RowCount>(&acc[0], ABroadcast[0], A2Broadcast[0], B, CountM);

View file

@ -6,7 +6,7 @@ Licensed under the MIT License.
Module Name:
SgemmKernelPower.cpp
SgemmKernelpower.h
Abstract:
@ -15,286 +15,7 @@ Abstract:
--*/
#include "mlasi.h"
//
// Templates to ensure that a loop is unrolled.
//
template<size_t Count, size_t Index>
struct MlasLoopUnrollStep
{
template<typename IterationType, typename... IterationArgs>
MLAS_FORCEINLINE
static
void
Step(
IterationArgs&&... Arguments
)
{
IterationType::template Iteration<Count, Index>(Arguments...);
MlasLoopUnrollStep<Count, Index + 1>::template Step<IterationType>(Arguments...);
}
};
template<size_t Count>
struct MlasLoopUnrollStep<Count, Count>
{
template<typename IterationType, typename... IterationArgs>
MLAS_FORCEINLINE
static
void
Step(
IterationArgs&&...
)
{
// Terminate the loop.
}
};
template<size_t Count, typename IteratorType>
struct MlasLoopUnroll
{
template<typename... IterationArgs>
MLAS_FORCEINLINE
void
operator()(
IterationArgs&&... Arguments
)
{
MlasLoopUnrollStep<Count, 0>::template Step<IteratorType>(Arguments...);
}
};
//
// Templates used with loop unrolling to perform an action on one row of the
// output.
//
struct MlasSgemmZeroAccumulators
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOAT32X4 Accumulators[RowCount][4]
)
{
Accumulators[Row][0] = MlasZeroFloat32x4();
Accumulators[Row][1] = MlasZeroFloat32x4();
Accumulators[Row][2] = MlasZeroFloat32x4();
Accumulators[Row][3] = MlasZeroFloat32x4();
}
};
struct MlasSgemmLoadAElements
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOAT32X4 AElements[RowCount],
const float* A,
size_t lda
)
{
AElements[Row] = MlasLoadFloat32x4(A + Row * lda);
}
};
struct MlasSgemmBroadcastAElements
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOAT32X4 ABroadcast[RowCount],
const float* A,
size_t lda
)
{
ABroadcast[Row] = MlasBroadcastFloat32x4(A + Row * lda);
}
};
template<unsigned Lane>
struct MlasSgemmSplatAElements
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOAT32X4 AElements[RowCount],
MLAS_FLOAT32X4 ABroadcast[RowCount]
)
{
ABroadcast[Row] = vec_splat(AElements[Row], Lane);
}
};
struct MlasSgemmMultiplyAddRow
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOAT32X4 Accumulators[RowCount][4],
MLAS_FLOAT32X4 ABroadcast[RowCount],
MLAS_FLOAT32X4 BElements[4]
)
{
Accumulators[Row][0] = MlasMultiplyAddFloat32x4(ABroadcast[Row], BElements[0], Accumulators[Row][0]);
Accumulators[Row][1] = MlasMultiplyAddFloat32x4(ABroadcast[Row], BElements[1], Accumulators[Row][1]);
Accumulators[Row][2] = MlasMultiplyAddFloat32x4(ABroadcast[Row], BElements[2], Accumulators[Row][2]);
Accumulators[Row][3] = MlasMultiplyAddFloat32x4(ABroadcast[Row], BElements[3], Accumulators[Row][3]);
}
};
template<size_t RowCount>
MLAS_FORCEINLINE
void
MlasSgemmComputeBlock(
MLAS_FLOAT32X4 Accumulators[RowCount][4],
MLAS_FLOAT32X4 ABroadcast[RowCount],
const float* B
)
{
MLAS_FLOAT32X4 BElements[4];
BElements[0] = MlasLoadFloat32x4(B);
BElements[1] = MlasLoadFloat32x4(B + 4);
BElements[2] = MlasLoadFloat32x4(B + 8);
BElements[3] = MlasLoadFloat32x4(B + 12);
MlasLoopUnroll<RowCount, MlasSgemmMultiplyAddRow>()(Accumulators, ABroadcast, BElements);
}
struct MlasSgemmMultiplyAlphaRow
{
template<size_t Count, size_t Index>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOAT32X4 Accumulators[4],
MLAS_FLOAT32X4 AlphaBroadcast
)
{
Accumulators[Index] = MlasMultiplyFloat32x4(Accumulators[Index], AlphaBroadcast);
}
};
struct MlasSgemmMultiplyAlphaAddRow
{
template<size_t Count, size_t Index>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOAT32X4 Accumulators[4],
MLAS_FLOAT32X4 AlphaBroadcast,
const float* C
)
{
Accumulators[Index] = MlasMultiplyAddFloat32x4(Accumulators[Index],
AlphaBroadcast, MlasLoadFloat32x4(C + Index * 4));
}
};
struct MlasSgemmStoreRow
{
template<size_t Count, size_t Index>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOAT32X4 Accumulators[4],
float* C
)
{
MlasStoreFloat32x4(C + Index * 4, Accumulators[Index]);
}
};
template<size_t VectorCount>
struct MlasSgemmStoreVector
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOAT32X4 Accumulators[RowCount][4],
float* C,
size_t ldc,
MLAS_FLOAT32X4 AlphaBroadcast,
bool ZeroMode
)
{
float* c = C + Row * ldc;
if (ZeroMode) {
MlasLoopUnroll<VectorCount, MlasSgemmMultiplyAlphaRow>()(Accumulators[Row], AlphaBroadcast);
} else {
MlasLoopUnroll<VectorCount, MlasSgemmMultiplyAlphaAddRow>()(Accumulators[Row], AlphaBroadcast, c);
}
MlasLoopUnroll<VectorCount, MlasSgemmStoreRow>()(Accumulators[Row], c);
//
// Shift down any unaligned elements to the bottom for further processing.
//
if (VectorCount < 4) {
Accumulators[Row][0] = Accumulators[Row][VectorCount];
}
}
};
struct MlasSgemmMultiplyAlphaTrailing
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOAT32X4 Accumulators[RowCount][4],
MLAS_FLOAT32X4 AlphaBroadcast
)
{
Accumulators[Row][0] = MlasMultiplyFloat32x4(Accumulators[Row][0], AlphaBroadcast);
}
};
template<unsigned Lane>
struct MlasSgemmStoreScalar
{
template<size_t RowCount, size_t Row>
MLAS_FORCEINLINE
static
void
Iteration(
MLAS_FLOAT32X4 Accumulators[RowCount][4],
float* C,
size_t ldc,
bool ZeroMode
)
{
float* c = C + Row * ldc + Lane;
float Value = MlasExtractLaneFloat32x4<Lane>(Accumulators[Row][0]);
if (!ZeroMode) {
Value += *c;
}
*c = Value;
}
};
#include "FgemmKernelpower.h"
template<size_t RowCount>
MLAS_FORCEINLINE
@ -324,7 +45,7 @@ MlasSgemmProcessCount(
// Clear the block accumulators.
//
MlasLoopUnroll<RowCount, MlasSgemmZeroAccumulators>()(Accumulators);
MlasLoopUnroll<RowCount, MlasFgemmZeroAccumulators>()(Accumulators);
//
// Compute the output block.
@ -332,19 +53,19 @@ MlasSgemmProcessCount(
while (k >= 4) {
MlasLoopUnroll<RowCount, MlasSgemmLoadAElements>()(AElements, a, lda);
MlasLoopUnroll<RowCount, MlasFgemmLoadAElements>()(AElements, a, lda);
MlasLoopUnroll<RowCount, MlasSgemmSplatAElements<0>>()(AElements, ABroadcast);
MlasSgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B);
MlasLoopUnroll<RowCount, MlasFgemmSplatAElements<0>>()(AElements, ABroadcast);
MlasFgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B);
MlasLoopUnroll<RowCount, MlasSgemmSplatAElements<1>>()(AElements, ABroadcast);
MlasSgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B + 16);
MlasLoopUnroll<RowCount, MlasFgemmSplatAElements<1>>()(AElements, ABroadcast);
MlasFgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B + 16);
MlasLoopUnroll<RowCount, MlasSgemmSplatAElements<2>>()(AElements, ABroadcast);
MlasSgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B + 32);
MlasLoopUnroll<RowCount, MlasFgemmSplatAElements<2>>()(AElements, ABroadcast);
MlasFgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B + 32);
MlasLoopUnroll<RowCount, MlasSgemmSplatAElements<3>>()(AElements, ABroadcast);
MlasSgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B + 48);
MlasLoopUnroll<RowCount, MlasFgemmSplatAElements<3>>()(AElements, ABroadcast);
MlasFgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B + 48);
a += 4;
B += 16 * 4;
@ -353,8 +74,8 @@ MlasSgemmProcessCount(
while (k > 0) {
MlasLoopUnroll<RowCount, MlasSgemmBroadcastAElements>()(ABroadcast, a, lda);
MlasSgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B);
MlasLoopUnroll<RowCount, MlasFgemmBroadcastAElements>()(ABroadcast, a, lda);
MlasFgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B);
a += 1;
B += 16;
@ -367,7 +88,7 @@ MlasSgemmProcessCount(
// Store the entire output block.
//
MlasLoopUnroll<RowCount, MlasSgemmStoreVector<4>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
MlasLoopUnroll<RowCount, MlasFgemmStoreVector<4>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
} else {
@ -376,11 +97,11 @@ MlasSgemmProcessCount(
//
if (CountN >= 12) {
MlasLoopUnroll<RowCount, MlasSgemmStoreVector<3>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
MlasLoopUnroll<RowCount, MlasFgemmStoreVector<3>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
} else if (CountN >= 8) {
MlasLoopUnroll<RowCount, MlasSgemmStoreVector<2>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
MlasLoopUnroll<RowCount, MlasFgemmStoreVector<2>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
} else if (CountN >= 4) {
MlasLoopUnroll<RowCount, MlasSgemmStoreVector<1>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
MlasLoopUnroll<RowCount, MlasFgemmStoreVector<1>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
}
//
@ -392,16 +113,16 @@ MlasSgemmProcessCount(
if (CountN > 0) {
MlasLoopUnroll<RowCount, MlasSgemmMultiplyAlphaTrailing>()(Accumulators, AlphaBroadcast);
MlasLoopUnroll<RowCount, MlasFgemmMultiplyAlphaTrailing>()(Accumulators, AlphaBroadcast);
MlasLoopUnroll<RowCount, MlasSgemmStoreScalar<0>>()(Accumulators, C, ldc, ZeroMode);
MlasLoopUnroll<RowCount, MlasFgemmStoreScalar<0>>()(Accumulators, C, ldc, ZeroMode);
if (CountN >= 2) {
MlasLoopUnroll<RowCount, MlasSgemmStoreScalar<1>>()(Accumulators, C, ldc, ZeroMode);
MlasLoopUnroll<RowCount, MlasFgemmStoreScalar<1>>()(Accumulators, C, ldc, ZeroMode);
}
if (CountN >= 3) {
MlasLoopUnroll<RowCount, MlasSgemmStoreScalar<2>>()(Accumulators, C, ldc, ZeroMode);
MlasLoopUnroll<RowCount, MlasFgemmStoreScalar<2>>()(Accumulators, C, ldc, ZeroMode);
}
}