diff --git a/onnxruntime/core/mlas/lib/transpose.cpp b/onnxruntime/core/mlas/lib/transpose.cpp index 37181ec2f3..c9e35ec1c3 100644 --- a/onnxruntime/core/mlas/lib/transpose.cpp +++ b/onnxruntime/core/mlas/lib/transpose.cpp @@ -168,6 +168,157 @@ MlasTranspose8x8Block( vst1_u8(&Output[OutputStride * 7], vreinterpret_u8_u32(d3.val[1])); } +#elif defined(MLAS_TARGET_POWER) + +MLAS_FORCEINLINE +void +MlasTranspose4x4Block( + const uint32_t* Input, + size_t InputStride, + uint32_t* Output, + size_t OutputStride + ) +{ + __vector unsigned int a0 = vec_vsx_ld(0, Input); + __vector unsigned int a1 = vec_vsx_ld(0, &Input[InputStride]); + __vector unsigned int a2 = vec_vsx_ld(0, &Input[InputStride * 2]); + __vector unsigned int a3 = vec_vsx_ld(0, &Input[InputStride * 3]); + + __vector unsigned int b0 = vec_mergeh(a0, a1); + __vector unsigned int b1 = vec_mergeh(a2, a3); + __vector unsigned int b2 = vec_mergel(a0, a1); + __vector unsigned int b3 = vec_mergel(a2, a3); + + __vector unsigned int c0 = vec_xxpermdi(b0, b1, 0); + __vector unsigned int c1 = vec_xxpermdi(b0, b1, 3); + __vector unsigned int c2 = vec_xxpermdi(b2, b3, 0); + __vector unsigned int c3 = vec_xxpermdi(b2, b3, 3); + + vec_vsx_st(c0, 0, Output); + vec_vsx_st(c1, 0, &Output[OutputStride]); + vec_vsx_st(c2, 0, &Output[OutputStride * 2]); + vec_vsx_st(c3, 0, &Output[OutputStride * 3]); +} + +MLAS_FORCEINLINE +void +MlasTranspose16x16Block( + const uint8_t* Input, + size_t InputStride, + uint8_t* Output, + size_t OutputStride + ) +{ + __vector unsigned char a0 = vec_vsx_ld(0, Input); + __vector unsigned char a1 = vec_vsx_ld(0, &Input[InputStride]); + __vector unsigned char a2 = vec_vsx_ld(0, &Input[InputStride * 2]); + __vector unsigned char a3 = vec_vsx_ld(0, &Input[InputStride * 3]); + __vector unsigned char a4 = vec_vsx_ld(0, &Input[InputStride * 4]); + __vector unsigned char a5 = vec_vsx_ld(0, &Input[InputStride * 5]); + __vector unsigned char a6 = vec_vsx_ld(0, &Input[InputStride * 6]); + __vector unsigned char a7 = vec_vsx_ld(0, &Input[InputStride * 7]); + __vector unsigned char a8 = vec_vsx_ld(0, &Input[InputStride * 8]); + __vector unsigned char a9 = vec_vsx_ld(0, &Input[InputStride * 9]); + __vector unsigned char a10 = vec_vsx_ld(0, &Input[InputStride * 10]); + __vector unsigned char a11 = vec_vsx_ld(0, &Input[InputStride * 11]); + __vector unsigned char a12 = vec_vsx_ld(0, &Input[InputStride * 12]); + __vector unsigned char a13 = vec_vsx_ld(0, &Input[InputStride * 13]); + __vector unsigned char a14 = vec_vsx_ld(0, &Input[InputStride * 14]); + __vector unsigned char a15 = vec_vsx_ld(0, &Input[InputStride * 15]); + + __vector unsigned char b0 = vec_mergeh(a0, a1); + __vector unsigned char b1 = vec_mergeh(a2, a3); + __vector unsigned char b2 = vec_mergeh(a4, a5); + __vector unsigned char b3 = vec_mergeh(a6, a7); + __vector unsigned char b4 = vec_mergeh(a8, a9); + __vector unsigned char b5 = vec_mergeh(a10, a11); + __vector unsigned char b6 = vec_mergeh(a12, a13); + __vector unsigned char b7 = vec_mergeh(a14, a15); + __vector unsigned char c0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b0), reinterpret_cast<__vector unsigned short>(b1))); + __vector unsigned char c1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b2), reinterpret_cast<__vector unsigned short>(b3))); + __vector unsigned char c2 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b4), reinterpret_cast<__vector unsigned short>(b5))); + __vector unsigned char c3 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b6), reinterpret_cast<__vector unsigned short>(b7))); + + __vector unsigned char d0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); + __vector unsigned char d1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); + __vector unsigned char e0 = vec_xxpermdi(d0, d1, 0); + __vector unsigned char e1 = vec_xxpermdi(d0, d1, 3); + vec_vsx_st(e0, 0, &Output[0]); + vec_vsx_st(e1, 0, &Output[OutputStride]); + + d0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); + d1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); + e0 = vec_xxpermdi(d0, d1, 0); + e1 = vec_xxpermdi(d0, d1, 3); + vec_vsx_st(e0, 0, &Output[OutputStride * 2]); + vec_vsx_st(e1, 0, &Output[OutputStride * 3]); + + c0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b0), reinterpret_cast<__vector unsigned short>(b1))); + c1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b2), reinterpret_cast<__vector unsigned short>(b3))); + c2 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b4), reinterpret_cast<__vector unsigned short>(b5))); + c3 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b6), reinterpret_cast<__vector unsigned short>(b7))); + + d0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); + d1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); + e0 = vec_xxpermdi(d0, d1, 0); + e1 = vec_xxpermdi(d0, d1, 3); + vec_vsx_st(e0, 0, &Output[OutputStride * 4]); + vec_vsx_st(e1, 0, &Output[OutputStride * 5]); + + d0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); + d1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); + e0 = vec_xxpermdi(d0, d1, 0); + e1 = vec_xxpermdi(d0, d1, 3); + vec_vsx_st(e0, 0, &Output[OutputStride * 6]); + vec_vsx_st(e1, 0, &Output[OutputStride * 7]); + + b0 = vec_mergel(a0, a1); + b1 = vec_mergel(a2, a3); + b2 = vec_mergel(a4, a5); + b3 = vec_mergel(a6, a7); + b4 = vec_mergel(a8, a9); + b5 = vec_mergel(a10, a11); + b6 = vec_mergel(a12, a13); + b7 = vec_mergel(a14, a15); + + c0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b0), reinterpret_cast<__vector unsigned short>(b1))); + c1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b2), reinterpret_cast<__vector unsigned short>(b3))); + c2 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b4), reinterpret_cast<__vector unsigned short>(b5))); + c3 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b6), reinterpret_cast<__vector unsigned short>(b7))); + + d0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); + d1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); + e0 = vec_xxpermdi(d0, d1, 0); + e1 = vec_xxpermdi(d0, d1, 3); + vec_vsx_st(e0, 0, &Output[OutputStride * 8]); + vec_vsx_st(e1, 0, &Output[OutputStride * 9]); + + d0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); + d1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); + e0 = vec_xxpermdi(d0, d1, 0); + e1 = vec_xxpermdi(d0, d1, 3); + vec_vsx_st(e0, 0, &Output[OutputStride * 10]); + vec_vsx_st(e1, 0, &Output[OutputStride * 11]); + + c0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b0), reinterpret_cast<__vector unsigned short>(b1))); + c1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b2), reinterpret_cast<__vector unsigned short>(b3))); + c2 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b4), reinterpret_cast<__vector unsigned short>(b5))); + c3 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b6), reinterpret_cast<__vector unsigned short>(b7))); + + d0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); + d1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); + e0 = vec_xxpermdi(d0, d1, 0); + e1 = vec_xxpermdi(d0, d1, 3); + vec_vsx_st(e0, 0, &Output[OutputStride * 12]); + vec_vsx_st(e1, 0, &Output[OutputStride * 13]); + + d0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); + d1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); + e0 = vec_xxpermdi(d0, d1, 0); + e1 = vec_xxpermdi(d0, d1, 3); + vec_vsx_st(e0, 0, &Output[OutputStride * 14]); + vec_vsx_st(e1, 0, &Output[OutputStride * 15]); +} #endif template @@ -191,6 +342,24 @@ MlasTranspose4xNVector( Output[OutputStride * 3] = a3; } +#if defined(MLAS_TARGET_POWER) +template +MLAS_FORCEINLINE +void +MlasTranspose16xNVector( + const ElementType* Input, + size_t InputStride, + ElementType* Output, + size_t OutputStride + ) +{ + MlasTranspose4xNVector(&Input[InputStride * 0], InputStride, &Output[OutputStride * 0], OutputStride); + MlasTranspose4xNVector(&Input[InputStride * 4], InputStride, &Output[OutputStride * 4], OutputStride); + MlasTranspose4xNVector(&Input[InputStride * 8], InputStride, &Output[OutputStride * 8], OutputStride); + MlasTranspose4xNVector(&Input[InputStride * 12], InputStride, &Output[OutputStride * 12], OutputStride); +} +#endif + template MLAS_FORCEINLINE void @@ -251,7 +420,7 @@ Return Value: uint32_t* d = Output; size_t m = M; -#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) +#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_TARGET_POWER) while (m >= 4) { @@ -368,7 +537,35 @@ Return Value: // Transpose elements from the input matrix to the output matrix 8 columns // at a time. // +#if defined(MLAS_TARGET_POWER) + while (n >= 16) { + const uint8_t* s = Input; + uint8_t* d = Output; + size_t m = M; + while (m >= 16) { + + MlasTranspose16x16Block(s, N, d, M); + + s += N * 16; + d += 16; + m -= 16; + } + + while (m > 0) { + + MlasTranspose16xNVector(s, 1, d, M); + + s += N; + d += 1; + m -= 1; + } + + Input += 16; + Output += M * 16; + n -= 16; + } +#endif while (n >= 8) { const uint8_t* s = Input; @@ -450,4 +647,4 @@ MlasTranspose( reinterpret_cast(Output), M, N); -} \ No newline at end of file +}