From ad99dff2980285ff064d9c846b0ac3d47f14bbfb Mon Sep 17 00:00:00 2001 From: Rajalakshmi Srinivasaraghavan Date: Fri, 10 Dec 2021 22:08:44 -0600 Subject: [PATCH] POWER10: Update builtins for DGEMM This patch changes builtin names in DGEMM based on endianness order. Also changing some casting style in SGEMM and DGEMM code for POWER10. --- .../mlas/lib/power/DgemmKernelPOWER10.cpp | 103 ++++++++++-------- .../mlas/lib/power/SgemmKernelPOWER10.cpp | 76 ++++++------- 2 files changed, 93 insertions(+), 86 deletions(-) diff --git a/onnxruntime/core/mlas/lib/power/DgemmKernelPOWER10.cpp b/onnxruntime/core/mlas/lib/power/DgemmKernelPOWER10.cpp index 11638bc33f..560df1e618 100644 --- a/onnxruntime/core/mlas/lib/power/DgemmKernelPOWER10.cpp +++ b/onnxruntime/core/mlas/lib/power/DgemmKernelPOWER10.cpp @@ -60,35 +60,42 @@ MlasDgemmComputeBlockMMA( MLAS_FLOAT64X2 BElements[4]; typedef __vector unsigned char vec_t; __vector_pair A2pair, Apair; -#if (defined(__GNUC__) && (__GNUC__ == 10 && __GNUC_MINOR__ <= 2)) - __builtin_mma_assemble_pair (&Apair, (vec_t)ABroadcast[1], (vec_t)ABroadcast[0]); +#if (defined(__GNUC__) && (__GNUC__ == 10 && __GNUC_MINOR__ <= 3)) +#if (__BYTE_ORDER__ != __ORDER_BIG_ENDIAN__) + __builtin_mma_assemble_pair (&Apair, reinterpret_cast(ABroadcast[1]), reinterpret_cast(ABroadcast[0])); if (CountM == 8) { - __builtin_mma_assemble_pair (&A2pair, (vec_t)A2Broadcast[1], (vec_t)A2Broadcast[0]); - } -#elif (defined(__GNUC__) && (__GNUC__ == 11 && __GNUC_MINOR__ <= 2)) - Apair = *((__vector_pair *)((void *)&ABroadcast[0])); - if (CountM == 8) { - A2pair = *((__vector_pair *)((void *)&A2Broadcast[0])); + __builtin_mma_assemble_pair (&A2pair, reinterpret_cast(A2Broadcast[1]), reinterpret_cast(A2Broadcast[0])); } #else - __builtin_vsx_assemble_pair (&Apair, (vec_t)ABroadcast[1], (vec_t)ABroadcast[0]); + __builtin_mma_assemble_pair (&Apair, reinterpret_cast(ABroadcast[0]), reinterpret_cast(ABroadcast[1])); if (CountM == 8) { - __builtin_vsx_assemble_pair (&A2pair, (vec_t)A2Broadcast[1], (vec_t)A2Broadcast[0]); + __builtin_mma_assemble_pair (&A2pair, reinterpret_cast(A2Broadcast[0]), reinterpret_cast(A2Broadcast[1])); + } +#endif +#elif (defined(__GNUC__) && (__GNUC__ == 11 && __GNUC_MINOR__ <= 2)) + Apair = *reinterpret_cast<__vector_pair *>(&ABroadcast[0]); + if (CountM == 8) { + A2pair = *reinterpret_cast<__vector_pair *>(&A2Broadcast[0]); + } +#else + __builtin_vsx_build_pair (&Apair, reinterpret_cast(ABroadcast[0]), reinterpret_cast(ABroadcast[1])); + if (CountM == 8) { + __builtin_vsx_build_pair (&A2pair, reinterpret_cast(A2Broadcast[0]), reinterpret_cast(A2Broadcast[1])); } #endif BElements[0] = MlasLoadFloat64x2(B); BElements[1] = MlasLoadFloat64x2(B + 2); BElements[2] = MlasLoadFloat64x2(B + 4); BElements[3] = MlasLoadFloat64x2(B + 6); - __builtin_mma_xvf64gerpp (&acc[0], Apair, (vec_t)BElements[0]); - __builtin_mma_xvf64gerpp (&acc[1], Apair, (vec_t)BElements[1]); - __builtin_mma_xvf64gerpp (&acc[2], Apair, (vec_t)BElements[2]); - __builtin_mma_xvf64gerpp (&acc[3], Apair, (vec_t)BElements[3]); + __builtin_mma_xvf64gerpp (&acc[0], Apair, reinterpret_cast(BElements[0])); + __builtin_mma_xvf64gerpp (&acc[1], Apair, reinterpret_cast(BElements[1])); + __builtin_mma_xvf64gerpp (&acc[2], Apair, reinterpret_cast(BElements[2])); + __builtin_mma_xvf64gerpp (&acc[3], Apair, reinterpret_cast(BElements[3])); if (CountM == 8) { - __builtin_mma_xvf64gerpp (&acc[4], A2pair, (vec_t)BElements[0]); - __builtin_mma_xvf64gerpp (&acc[5], A2pair, (vec_t)BElements[1]); - __builtin_mma_xvf64gerpp (&acc[6], A2pair, (vec_t)BElements[2]); - __builtin_mma_xvf64gerpp (&acc[7], A2pair, (vec_t)BElements[3]); + __builtin_mma_xvf64gerpp (&acc[4], A2pair, reinterpret_cast(BElements[0])); + __builtin_mma_xvf64gerpp (&acc[5], A2pair, reinterpret_cast(BElements[1])); + __builtin_mma_xvf64gerpp (&acc[6], A2pair, reinterpret_cast(BElements[2])); + __builtin_mma_xvf64gerpp (&acc[7], A2pair, reinterpret_cast(BElements[3])); } } template @@ -108,10 +115,10 @@ struct MlasDgemmStoreVectorMMA { MLAS_FLOAT64X2 *rowC; if (ZeroMode) { - rowC = (MLAS_FLOAT64X2 *) &C[Row * ldc + VectorCount]; + rowC = reinterpret_cast(&C[Row * ldc + VectorCount]); rowC[0] = Result[Row] * AlphaBroadcast; } else { - rowC = (MLAS_FLOAT64X2 *) &C[Row * ldc + VectorCount]; + rowC = reinterpret_cast(&C[Row * ldc + VectorCount]); rowC[0] += Result[Row] * AlphaBroadcast; } } @@ -238,22 +245,22 @@ MlasDgemmMMAProcessCount( // // Store the entire output block. // - __builtin_mma_disassemble_acc ((void *)Result, &acc[0]); + __builtin_mma_disassemble_acc (Result, &acc[0]); MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc ((void *)Result, &acc[1]); + __builtin_mma_disassemble_acc (Result, &acc[1]); MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc ((void *)Result, &acc[2]); + __builtin_mma_disassemble_acc (Result, &acc[2]); MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc ((void *)Result, &acc[3]); + __builtin_mma_disassemble_acc (Result, &acc[3]); MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); if (CountM == 8) { - __builtin_mma_disassemble_acc ((void *)Result, &acc[4]); + __builtin_mma_disassemble_acc (Result, &acc[4]); MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc ((void *)Result, &acc[5]); + __builtin_mma_disassemble_acc (Result, &acc[5]); MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc ((void *)Result, &acc[6]); + __builtin_mma_disassemble_acc (Result, &acc[6]); MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc ((void *)Result, &acc[7]); + __builtin_mma_disassemble_acc (Result, &acc[7]); MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); } } else { @@ -263,60 +270,60 @@ MlasDgemmMMAProcessCount( // if (CountN >= 6) { - __builtin_mma_disassemble_acc ((void *)Result, &acc[0]); + __builtin_mma_disassemble_acc (Result, &acc[0]); MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc ((void *)Result, &acc[1]); + __builtin_mma_disassemble_acc (Result, &acc[1]); MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc ((void *)Result, &acc[2]); + __builtin_mma_disassemble_acc (Result, &acc[2]); MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); if (CountM == 8) { - __builtin_mma_disassemble_acc ((void *)Result, &acc[4]); + __builtin_mma_disassemble_acc (Result, &acc[4]); MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc ((void *)Result, &acc[5]); + __builtin_mma_disassemble_acc (Result, &acc[5]); MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc ((void *)Result, &acc[6]); + __builtin_mma_disassemble_acc (Result, &acc[6]); MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); if (CountN - 6 > 0) { - __builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[7]); + __builtin_mma_disassemble_acc (Accumulators[1], &acc[7]); } } if (CountN - 6 > 0) { - __builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[3]); + __builtin_mma_disassemble_acc (Accumulators[0], &acc[3]); } } else if (CountN >= 4) { - __builtin_mma_disassemble_acc ((void *)Result, &acc[0]); + __builtin_mma_disassemble_acc (Result, &acc[0]); MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc ((void *)Result, &acc[1]); + __builtin_mma_disassemble_acc (Result, &acc[1]); MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); if (CountM == 8) { - __builtin_mma_disassemble_acc ((void *)Result, &acc[4]); + __builtin_mma_disassemble_acc (Result, &acc[4]); MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc ((void *)Result, &acc[5]); + __builtin_mma_disassemble_acc (Result, &acc[5]); MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); if (CountN - 4 > 0) { - __builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[6]); + __builtin_mma_disassemble_acc (Accumulators[1], &acc[6]); } } if (CountN - 4 > 0) { - __builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[2]); + __builtin_mma_disassemble_acc (Accumulators[0], &acc[2]); } } else if (CountN >= 2) { - __builtin_mma_disassemble_acc ((void *)Result, &acc[0]); + __builtin_mma_disassemble_acc (Result, &acc[0]); MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); if (CountM == 8) { - __builtin_mma_disassemble_acc ((void *)Result, &acc[4]); + __builtin_mma_disassemble_acc (Result, &acc[4]); MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); if (CountN - 2 > 0) { - __builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[5]); + __builtin_mma_disassemble_acc (Accumulators[1], &acc[5]); } } if (CountN - 2 > 0) { - __builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[1]); + __builtin_mma_disassemble_acc (Accumulators[0], &acc[1]); } } else { - __builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[0]); + __builtin_mma_disassemble_acc (Accumulators[0], &acc[0]); if (CountM == 8) { - __builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[4]); + __builtin_mma_disassemble_acc (Accumulators[1], &acc[4]); } } diff --git a/onnxruntime/core/mlas/lib/power/SgemmKernelPOWER10.cpp b/onnxruntime/core/mlas/lib/power/SgemmKernelPOWER10.cpp index bc08af0cd7..3dfe061c72 100644 --- a/onnxruntime/core/mlas/lib/power/SgemmKernelPOWER10.cpp +++ b/onnxruntime/core/mlas/lib/power/SgemmKernelPOWER10.cpp @@ -68,15 +68,15 @@ MlasSgemmComputeBlockMMA( BElements[1] = MlasLoadFloat32x4(B + 4); BElements[2] = MlasLoadFloat32x4(B + 8); BElements[3] = MlasLoadFloat32x4(B + 12); - __builtin_mma_xvf32gerpp (&acc[0], (vec_t) ABroadcast, (vec_t )BElements[0]); - __builtin_mma_xvf32gerpp (&acc[1], (vec_t) ABroadcast, (vec_t )BElements[1]); - __builtin_mma_xvf32gerpp (&acc[2], (vec_t) ABroadcast, (vec_t )BElements[2]); - __builtin_mma_xvf32gerpp (&acc[3], (vec_t) ABroadcast, (vec_t )BElements[3]); + __builtin_mma_xvf32gerpp (&acc[0], reinterpret_cast(ABroadcast), reinterpret_cast(BElements[0])); + __builtin_mma_xvf32gerpp (&acc[1], reinterpret_cast(ABroadcast), reinterpret_cast(BElements[1])); + __builtin_mma_xvf32gerpp (&acc[2], reinterpret_cast(ABroadcast), reinterpret_cast(BElements[2])); + __builtin_mma_xvf32gerpp (&acc[3], reinterpret_cast(ABroadcast), reinterpret_cast(BElements[3])); if (CountM == 8) { - __builtin_mma_xvf32gerpp (&acc[4], (vec_t) A2Broadcast, (vec_t )BElements[0]); - __builtin_mma_xvf32gerpp (&acc[5], (vec_t) A2Broadcast, (vec_t )BElements[1]); - __builtin_mma_xvf32gerpp (&acc[6], (vec_t) A2Broadcast, (vec_t )BElements[2]); - __builtin_mma_xvf32gerpp (&acc[7], (vec_t) A2Broadcast, (vec_t )BElements[3]); + __builtin_mma_xvf32gerpp (&acc[4], reinterpret_cast(A2Broadcast), reinterpret_cast(BElements[0])); + __builtin_mma_xvf32gerpp (&acc[5], reinterpret_cast(A2Broadcast), reinterpret_cast(BElements[1])); + __builtin_mma_xvf32gerpp (&acc[6], reinterpret_cast(A2Broadcast), reinterpret_cast(BElements[2])); + __builtin_mma_xvf32gerpp (&acc[7], reinterpret_cast(A2Broadcast), reinterpret_cast(BElements[3])); } } template @@ -96,10 +96,10 @@ struct MlasSgemmStoreVectorMMA { MLAS_FLOAT32X4 *rowC; if (ZeroMode) { - rowC = (MLAS_FLOAT32X4 *) &C[Row * ldc + VectorCount]; + rowC = reinterpret_cast(&C[Row * ldc + VectorCount]); rowC[0] = Result[Row] * AlphaBroadcast; } else { - rowC = (MLAS_FLOAT32X4 *) &C[Row * ldc + VectorCount]; + rowC = reinterpret_cast(&C[Row * ldc + VectorCount]); rowC[0] += Result[Row] * AlphaBroadcast; } } @@ -218,22 +218,22 @@ MlasSgemmMMAProcessCount( // // Store the entire output block. // - __builtin_mma_disassemble_acc ((void *)Result, &acc[0]); + __builtin_mma_disassemble_acc (Result, &acc[0]); MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc ((void *)Result, &acc[1]); + __builtin_mma_disassemble_acc (Result, &acc[1]); MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc ((void *)Result, &acc[2]); + __builtin_mma_disassemble_acc (Result, &acc[2]); MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc ((void *)Result, &acc[3]); + __builtin_mma_disassemble_acc (Result, &acc[3]); MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); if (CountM == 8) { - __builtin_mma_disassemble_acc ((void *)Result, &acc[4]); + __builtin_mma_disassemble_acc (Result, &acc[4]); MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc ((void *)Result, &acc[5]); + __builtin_mma_disassemble_acc (Result, &acc[5]); MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc ((void *)Result, &acc[6]); + __builtin_mma_disassemble_acc (Result, &acc[6]); MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc ((void *)Result, &acc[7]); + __builtin_mma_disassemble_acc (Result, &acc[7]); MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); } } else { @@ -243,60 +243,60 @@ MlasSgemmMMAProcessCount( // if (CountN >= 12) { - __builtin_mma_disassemble_acc ((void *)Result, &acc[0]); + __builtin_mma_disassemble_acc (Result, &acc[0]); MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc ((void *)Result, &acc[1]); + __builtin_mma_disassemble_acc (Result, &acc[1]); MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc ((void *)Result, &acc[2]); + __builtin_mma_disassemble_acc (Result, &acc[2]); MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); if (CountM == 8) { - __builtin_mma_disassemble_acc ((void *)Result, &acc[4]); + __builtin_mma_disassemble_acc (Result, &acc[4]); MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc ((void *)Result, &acc[5]); + __builtin_mma_disassemble_acc (Result, &acc[5]); MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc ((void *)Result, &acc[6]); + __builtin_mma_disassemble_acc (Result, &acc[6]); MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); if (CountN - 12 > 0) { - __builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[7]); + __builtin_mma_disassemble_acc (Accumulators[1], &acc[7]); } } if (CountN - 12 > 0) { - __builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[3]); + __builtin_mma_disassemble_acc (Accumulators[0], &acc[3]); } } else if (CountN >= 8) { - __builtin_mma_disassemble_acc ((void *)Result, &acc[0]); + __builtin_mma_disassemble_acc (Result, &acc[0]); MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc ((void *)Result, &acc[1]); + __builtin_mma_disassemble_acc (Result, &acc[1]); MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); if (CountM == 8) { - __builtin_mma_disassemble_acc ((void *)Result, &acc[4]); + __builtin_mma_disassemble_acc (Result, &acc[4]); MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc ((void *)Result, &acc[5]); + __builtin_mma_disassemble_acc (Result, &acc[5]); MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); if (CountN - 8 > 0) { - __builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[6]); + __builtin_mma_disassemble_acc (Accumulators[1], &acc[6]); } } if (CountN - 8 > 0) { - __builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[2]); + __builtin_mma_disassemble_acc (Accumulators[0], &acc[2]); } } else if (CountN >= 4) { - __builtin_mma_disassemble_acc ((void *)Result, &acc[0]); + __builtin_mma_disassemble_acc (Result, &acc[0]); MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); if (CountM == 8) { - __builtin_mma_disassemble_acc ((void *)Result, &acc[4]); + __builtin_mma_disassemble_acc (Result, &acc[4]); MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); if (CountN - 4 > 0) { - __builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[5]); + __builtin_mma_disassemble_acc (Accumulators[1], &acc[5]); } } if (CountN - 4 > 0) { - __builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[1]); + __builtin_mma_disassemble_acc (Accumulators[0], &acc[1]); } } else { - __builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[0]); + __builtin_mma_disassemble_acc (Accumulators[0], &acc[0]); if (CountM == 8) { - __builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[4]); + __builtin_mma_disassemble_acc (Accumulators[1], &acc[4]); } }