mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
POWER10: Update builtins for DGEMM
This patch changes builtin names in DGEMM based on endianness order. Also changing some casting style in SGEMM and DGEMM code for POWER10.
This commit is contained in:
parent
5d821b5bd9
commit
ad99dff298
2 changed files with 93 additions and 86 deletions
|
|
@ -60,35 +60,42 @@ MlasDgemmComputeBlockMMA(
|
|||
MLAS_FLOAT64X2 BElements[4];
|
||||
typedef __vector unsigned char vec_t;
|
||||
__vector_pair A2pair, Apair;
|
||||
#if (defined(__GNUC__) && (__GNUC__ == 10 && __GNUC_MINOR__ <= 2))
|
||||
__builtin_mma_assemble_pair (&Apair, (vec_t)ABroadcast[1], (vec_t)ABroadcast[0]);
|
||||
#if (defined(__GNUC__) && (__GNUC__ == 10 && __GNUC_MINOR__ <= 3))
|
||||
#if (__BYTE_ORDER__ != __ORDER_BIG_ENDIAN__)
|
||||
__builtin_mma_assemble_pair (&Apair, reinterpret_cast<vec_t>(ABroadcast[1]), reinterpret_cast<vec_t>(ABroadcast[0]));
|
||||
if (CountM == 8) {
|
||||
__builtin_mma_assemble_pair (&A2pair, (vec_t)A2Broadcast[1], (vec_t)A2Broadcast[0]);
|
||||
}
|
||||
#elif (defined(__GNUC__) && (__GNUC__ == 11 && __GNUC_MINOR__ <= 2))
|
||||
Apair = *((__vector_pair *)((void *)&ABroadcast[0]));
|
||||
if (CountM == 8) {
|
||||
A2pair = *((__vector_pair *)((void *)&A2Broadcast[0]));
|
||||
__builtin_mma_assemble_pair (&A2pair, reinterpret_cast<vec_t>(A2Broadcast[1]), reinterpret_cast<vec_t>(A2Broadcast[0]));
|
||||
}
|
||||
#else
|
||||
__builtin_vsx_assemble_pair (&Apair, (vec_t)ABroadcast[1], (vec_t)ABroadcast[0]);
|
||||
__builtin_mma_assemble_pair (&Apair, reinterpret_cast<vec_t>(ABroadcast[0]), reinterpret_cast<vec_t>(ABroadcast[1]));
|
||||
if (CountM == 8) {
|
||||
__builtin_vsx_assemble_pair (&A2pair, (vec_t)A2Broadcast[1], (vec_t)A2Broadcast[0]);
|
||||
__builtin_mma_assemble_pair (&A2pair, reinterpret_cast<vec_t>(A2Broadcast[0]), reinterpret_cast<vec_t>(A2Broadcast[1]));
|
||||
}
|
||||
#endif
|
||||
#elif (defined(__GNUC__) && (__GNUC__ == 11 && __GNUC_MINOR__ <= 2))
|
||||
Apair = *reinterpret_cast<__vector_pair *>(&ABroadcast[0]);
|
||||
if (CountM == 8) {
|
||||
A2pair = *reinterpret_cast<__vector_pair *>(&A2Broadcast[0]);
|
||||
}
|
||||
#else
|
||||
__builtin_vsx_build_pair (&Apair, reinterpret_cast<vec_t>(ABroadcast[0]), reinterpret_cast<vec_t>(ABroadcast[1]));
|
||||
if (CountM == 8) {
|
||||
__builtin_vsx_build_pair (&A2pair, reinterpret_cast<vec_t>(A2Broadcast[0]), reinterpret_cast<vec_t>(A2Broadcast[1]));
|
||||
}
|
||||
#endif
|
||||
BElements[0] = MlasLoadFloat64x2(B);
|
||||
BElements[1] = MlasLoadFloat64x2(B + 2);
|
||||
BElements[2] = MlasLoadFloat64x2(B + 4);
|
||||
BElements[3] = MlasLoadFloat64x2(B + 6);
|
||||
__builtin_mma_xvf64gerpp (&acc[0], Apair, (vec_t)BElements[0]);
|
||||
__builtin_mma_xvf64gerpp (&acc[1], Apair, (vec_t)BElements[1]);
|
||||
__builtin_mma_xvf64gerpp (&acc[2], Apair, (vec_t)BElements[2]);
|
||||
__builtin_mma_xvf64gerpp (&acc[3], Apair, (vec_t)BElements[3]);
|
||||
__builtin_mma_xvf64gerpp (&acc[0], Apair, reinterpret_cast<vec_t>(BElements[0]));
|
||||
__builtin_mma_xvf64gerpp (&acc[1], Apair, reinterpret_cast<vec_t>(BElements[1]));
|
||||
__builtin_mma_xvf64gerpp (&acc[2], Apair, reinterpret_cast<vec_t>(BElements[2]));
|
||||
__builtin_mma_xvf64gerpp (&acc[3], Apair, reinterpret_cast<vec_t>(BElements[3]));
|
||||
if (CountM == 8) {
|
||||
__builtin_mma_xvf64gerpp (&acc[4], A2pair, (vec_t)BElements[0]);
|
||||
__builtin_mma_xvf64gerpp (&acc[5], A2pair, (vec_t)BElements[1]);
|
||||
__builtin_mma_xvf64gerpp (&acc[6], A2pair, (vec_t)BElements[2]);
|
||||
__builtin_mma_xvf64gerpp (&acc[7], A2pair, (vec_t)BElements[3]);
|
||||
__builtin_mma_xvf64gerpp (&acc[4], A2pair, reinterpret_cast<vec_t>(BElements[0]));
|
||||
__builtin_mma_xvf64gerpp (&acc[5], A2pair, reinterpret_cast<vec_t>(BElements[1]));
|
||||
__builtin_mma_xvf64gerpp (&acc[6], A2pair, reinterpret_cast<vec_t>(BElements[2]));
|
||||
__builtin_mma_xvf64gerpp (&acc[7], A2pair, reinterpret_cast<vec_t>(BElements[3]));
|
||||
}
|
||||
}
|
||||
template<size_t VectorCount>
|
||||
|
|
@ -108,10 +115,10 @@ struct MlasDgemmStoreVectorMMA
|
|||
{
|
||||
MLAS_FLOAT64X2 *rowC;
|
||||
if (ZeroMode) {
|
||||
rowC = (MLAS_FLOAT64X2 *) &C[Row * ldc + VectorCount];
|
||||
rowC = reinterpret_cast<MLAS_FLOAT64X2 *>(&C[Row * ldc + VectorCount]);
|
||||
rowC[0] = Result[Row] * AlphaBroadcast;
|
||||
} else {
|
||||
rowC = (MLAS_FLOAT64X2 *) &C[Row * ldc + VectorCount];
|
||||
rowC = reinterpret_cast<MLAS_FLOAT64X2 *>(&C[Row * ldc + VectorCount]);
|
||||
rowC[0] += Result[Row] * AlphaBroadcast;
|
||||
}
|
||||
}
|
||||
|
|
@ -238,22 +245,22 @@ MlasDgemmMMAProcessCount(
|
|||
//
|
||||
// Store the entire output block.
|
||||
//
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[0]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[0]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<0>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[1]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[1]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<2>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[2]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[2]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<4>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[3]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[3]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<6>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
if (CountM == 8) {
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[4]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[4]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<0>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[5]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[5]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<2>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[6]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[6]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<4>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[7]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[7]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<6>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
}
|
||||
} else {
|
||||
|
|
@ -263,60 +270,60 @@ MlasDgemmMMAProcessCount(
|
|||
//
|
||||
|
||||
if (CountN >= 6) {
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[0]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[0]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<0>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[1]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[1]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<2>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[2]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[2]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<4>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
if (CountM == 8) {
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[4]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[4]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<0>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[5]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[5]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<2>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[6]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[6]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<4>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
if (CountN - 6 > 0) {
|
||||
__builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[7]);
|
||||
__builtin_mma_disassemble_acc (Accumulators[1], &acc[7]);
|
||||
}
|
||||
}
|
||||
if (CountN - 6 > 0) {
|
||||
__builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[3]);
|
||||
__builtin_mma_disassemble_acc (Accumulators[0], &acc[3]);
|
||||
}
|
||||
} else if (CountN >= 4) {
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[0]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[0]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<0>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[1]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[1]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<2>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
if (CountM == 8) {
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[4]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[4]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<0>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[5]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[5]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<2>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
if (CountN - 4 > 0) {
|
||||
__builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[6]);
|
||||
__builtin_mma_disassemble_acc (Accumulators[1], &acc[6]);
|
||||
}
|
||||
}
|
||||
if (CountN - 4 > 0) {
|
||||
__builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[2]);
|
||||
__builtin_mma_disassemble_acc (Accumulators[0], &acc[2]);
|
||||
}
|
||||
} else if (CountN >= 2) {
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[0]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[0]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<0>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
if (CountM == 8) {
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[4]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[4]);
|
||||
MlasLoopUnroll<RowCount, MlasDgemmStoreVectorMMA<0>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
if (CountN - 2 > 0) {
|
||||
__builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[5]);
|
||||
__builtin_mma_disassemble_acc (Accumulators[1], &acc[5]);
|
||||
}
|
||||
}
|
||||
if (CountN - 2 > 0) {
|
||||
__builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[1]);
|
||||
__builtin_mma_disassemble_acc (Accumulators[0], &acc[1]);
|
||||
}
|
||||
} else {
|
||||
__builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[0]);
|
||||
__builtin_mma_disassemble_acc (Accumulators[0], &acc[0]);
|
||||
if (CountM == 8) {
|
||||
__builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[4]);
|
||||
__builtin_mma_disassemble_acc (Accumulators[1], &acc[4]);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -68,15 +68,15 @@ MlasSgemmComputeBlockMMA(
|
|||
BElements[1] = MlasLoadFloat32x4(B + 4);
|
||||
BElements[2] = MlasLoadFloat32x4(B + 8);
|
||||
BElements[3] = MlasLoadFloat32x4(B + 12);
|
||||
__builtin_mma_xvf32gerpp (&acc[0], (vec_t) ABroadcast, (vec_t )BElements[0]);
|
||||
__builtin_mma_xvf32gerpp (&acc[1], (vec_t) ABroadcast, (vec_t )BElements[1]);
|
||||
__builtin_mma_xvf32gerpp (&acc[2], (vec_t) ABroadcast, (vec_t )BElements[2]);
|
||||
__builtin_mma_xvf32gerpp (&acc[3], (vec_t) ABroadcast, (vec_t )BElements[3]);
|
||||
__builtin_mma_xvf32gerpp (&acc[0], reinterpret_cast<vec_t>(ABroadcast), reinterpret_cast<vec_t>(BElements[0]));
|
||||
__builtin_mma_xvf32gerpp (&acc[1], reinterpret_cast<vec_t>(ABroadcast), reinterpret_cast<vec_t>(BElements[1]));
|
||||
__builtin_mma_xvf32gerpp (&acc[2], reinterpret_cast<vec_t>(ABroadcast), reinterpret_cast<vec_t>(BElements[2]));
|
||||
__builtin_mma_xvf32gerpp (&acc[3], reinterpret_cast<vec_t>(ABroadcast), reinterpret_cast<vec_t>(BElements[3]));
|
||||
if (CountM == 8) {
|
||||
__builtin_mma_xvf32gerpp (&acc[4], (vec_t) A2Broadcast, (vec_t )BElements[0]);
|
||||
__builtin_mma_xvf32gerpp (&acc[5], (vec_t) A2Broadcast, (vec_t )BElements[1]);
|
||||
__builtin_mma_xvf32gerpp (&acc[6], (vec_t) A2Broadcast, (vec_t )BElements[2]);
|
||||
__builtin_mma_xvf32gerpp (&acc[7], (vec_t) A2Broadcast, (vec_t )BElements[3]);
|
||||
__builtin_mma_xvf32gerpp (&acc[4], reinterpret_cast<vec_t>(A2Broadcast), reinterpret_cast<vec_t>(BElements[0]));
|
||||
__builtin_mma_xvf32gerpp (&acc[5], reinterpret_cast<vec_t>(A2Broadcast), reinterpret_cast<vec_t>(BElements[1]));
|
||||
__builtin_mma_xvf32gerpp (&acc[6], reinterpret_cast<vec_t>(A2Broadcast), reinterpret_cast<vec_t>(BElements[2]));
|
||||
__builtin_mma_xvf32gerpp (&acc[7], reinterpret_cast<vec_t>(A2Broadcast), reinterpret_cast<vec_t>(BElements[3]));
|
||||
}
|
||||
}
|
||||
template<size_t VectorCount>
|
||||
|
|
@ -96,10 +96,10 @@ struct MlasSgemmStoreVectorMMA
|
|||
{
|
||||
MLAS_FLOAT32X4 *rowC;
|
||||
if (ZeroMode) {
|
||||
rowC = (MLAS_FLOAT32X4 *) &C[Row * ldc + VectorCount];
|
||||
rowC = reinterpret_cast<MLAS_FLOAT32X4 *>(&C[Row * ldc + VectorCount]);
|
||||
rowC[0] = Result[Row] * AlphaBroadcast;
|
||||
} else {
|
||||
rowC = (MLAS_FLOAT32X4 *) &C[Row * ldc + VectorCount];
|
||||
rowC = reinterpret_cast<MLAS_FLOAT32X4 *>(&C[Row * ldc + VectorCount]);
|
||||
rowC[0] += Result[Row] * AlphaBroadcast;
|
||||
}
|
||||
}
|
||||
|
|
@ -218,22 +218,22 @@ MlasSgemmMMAProcessCount(
|
|||
//
|
||||
// Store the entire output block.
|
||||
//
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[0]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[0]);
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVectorMMA<0>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[1]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[1]);
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVectorMMA<4>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[2]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[2]);
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVectorMMA<8>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[3]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[3]);
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVectorMMA<12>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
if (CountM == 8) {
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[4]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[4]);
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVectorMMA<0>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[5]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[5]);
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVectorMMA<4>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[6]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[6]);
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVectorMMA<8>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[7]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[7]);
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVectorMMA<12>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
}
|
||||
} else {
|
||||
|
|
@ -243,60 +243,60 @@ MlasSgemmMMAProcessCount(
|
|||
//
|
||||
|
||||
if (CountN >= 12) {
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[0]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[0]);
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVectorMMA<0>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[1]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[1]);
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVectorMMA<4>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[2]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[2]);
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVectorMMA<8>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
if (CountM == 8) {
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[4]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[4]);
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVectorMMA<0>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[5]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[5]);
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVectorMMA<4>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[6]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[6]);
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVectorMMA<8>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
if (CountN - 12 > 0) {
|
||||
__builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[7]);
|
||||
__builtin_mma_disassemble_acc (Accumulators[1], &acc[7]);
|
||||
}
|
||||
}
|
||||
if (CountN - 12 > 0) {
|
||||
__builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[3]);
|
||||
__builtin_mma_disassemble_acc (Accumulators[0], &acc[3]);
|
||||
}
|
||||
} else if (CountN >= 8) {
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[0]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[0]);
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVectorMMA<0>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[1]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[1]);
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVectorMMA<4>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
if (CountM == 8) {
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[4]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[4]);
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVectorMMA<0>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[5]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[5]);
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVectorMMA<4>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
if (CountN - 8 > 0) {
|
||||
__builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[6]);
|
||||
__builtin_mma_disassemble_acc (Accumulators[1], &acc[6]);
|
||||
}
|
||||
}
|
||||
if (CountN - 8 > 0) {
|
||||
__builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[2]);
|
||||
__builtin_mma_disassemble_acc (Accumulators[0], &acc[2]);
|
||||
}
|
||||
} else if (CountN >= 4) {
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[0]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[0]);
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVectorMMA<0>>()(Result, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
if (CountM == 8) {
|
||||
__builtin_mma_disassemble_acc ((void *)Result, &acc[4]);
|
||||
__builtin_mma_disassemble_acc (Result, &acc[4]);
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVectorMMA<0>>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode);
|
||||
if (CountN - 4 > 0) {
|
||||
__builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[5]);
|
||||
__builtin_mma_disassemble_acc (Accumulators[1], &acc[5]);
|
||||
}
|
||||
}
|
||||
if (CountN - 4 > 0) {
|
||||
__builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[1]);
|
||||
__builtin_mma_disassemble_acc (Accumulators[0], &acc[1]);
|
||||
}
|
||||
} else {
|
||||
__builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[0]);
|
||||
__builtin_mma_disassemble_acc (Accumulators[0], &acc[0]);
|
||||
if (CountM == 8) {
|
||||
__builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[4]);
|
||||
__builtin_mma_disassemble_acc (Accumulators[1], &acc[4]);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue