POWER: Optimize MlasTranspose functions (#11172)

This patch makes use of POWER vector intrinsics to improve performance
of MlasTranspose functions.

Co-authored-by: Rajalakshmi Srinivasaraghavan <rajis@linux.ibm.com>
This commit is contained in:
RajalakshmiSR 2022-04-12 11:51:20 -05:00 committed by GitHub
parent 833f5d5604
commit e397d8e63e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -168,6 +168,157 @@ MlasTranspose8x8Block(
vst1_u8(&Output[OutputStride * 7], vreinterpret_u8_u32(d3.val[1]));
}
#elif defined(MLAS_TARGET_POWER)
MLAS_FORCEINLINE
void
MlasTranspose4x4Block(
const uint32_t* Input,
size_t InputStride,
uint32_t* Output,
size_t OutputStride
)
{
__vector unsigned int a0 = vec_vsx_ld(0, Input);
__vector unsigned int a1 = vec_vsx_ld(0, &Input[InputStride]);
__vector unsigned int a2 = vec_vsx_ld(0, &Input[InputStride * 2]);
__vector unsigned int a3 = vec_vsx_ld(0, &Input[InputStride * 3]);
__vector unsigned int b0 = vec_mergeh(a0, a1);
__vector unsigned int b1 = vec_mergeh(a2, a3);
__vector unsigned int b2 = vec_mergel(a0, a1);
__vector unsigned int b3 = vec_mergel(a2, a3);
__vector unsigned int c0 = vec_xxpermdi(b0, b1, 0);
__vector unsigned int c1 = vec_xxpermdi(b0, b1, 3);
__vector unsigned int c2 = vec_xxpermdi(b2, b3, 0);
__vector unsigned int c3 = vec_xxpermdi(b2, b3, 3);
vec_vsx_st(c0, 0, Output);
vec_vsx_st(c1, 0, &Output[OutputStride]);
vec_vsx_st(c2, 0, &Output[OutputStride * 2]);
vec_vsx_st(c3, 0, &Output[OutputStride * 3]);
}
MLAS_FORCEINLINE
void
MlasTranspose16x16Block(
const uint8_t* Input,
size_t InputStride,
uint8_t* Output,
size_t OutputStride
)
{
__vector unsigned char a0 = vec_vsx_ld(0, Input);
__vector unsigned char a1 = vec_vsx_ld(0, &Input[InputStride]);
__vector unsigned char a2 = vec_vsx_ld(0, &Input[InputStride * 2]);
__vector unsigned char a3 = vec_vsx_ld(0, &Input[InputStride * 3]);
__vector unsigned char a4 = vec_vsx_ld(0, &Input[InputStride * 4]);
__vector unsigned char a5 = vec_vsx_ld(0, &Input[InputStride * 5]);
__vector unsigned char a6 = vec_vsx_ld(0, &Input[InputStride * 6]);
__vector unsigned char a7 = vec_vsx_ld(0, &Input[InputStride * 7]);
__vector unsigned char a8 = vec_vsx_ld(0, &Input[InputStride * 8]);
__vector unsigned char a9 = vec_vsx_ld(0, &Input[InputStride * 9]);
__vector unsigned char a10 = vec_vsx_ld(0, &Input[InputStride * 10]);
__vector unsigned char a11 = vec_vsx_ld(0, &Input[InputStride * 11]);
__vector unsigned char a12 = vec_vsx_ld(0, &Input[InputStride * 12]);
__vector unsigned char a13 = vec_vsx_ld(0, &Input[InputStride * 13]);
__vector unsigned char a14 = vec_vsx_ld(0, &Input[InputStride * 14]);
__vector unsigned char a15 = vec_vsx_ld(0, &Input[InputStride * 15]);
__vector unsigned char b0 = vec_mergeh(a0, a1);
__vector unsigned char b1 = vec_mergeh(a2, a3);
__vector unsigned char b2 = vec_mergeh(a4, a5);
__vector unsigned char b3 = vec_mergeh(a6, a7);
__vector unsigned char b4 = vec_mergeh(a8, a9);
__vector unsigned char b5 = vec_mergeh(a10, a11);
__vector unsigned char b6 = vec_mergeh(a12, a13);
__vector unsigned char b7 = vec_mergeh(a14, a15);
__vector unsigned char c0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b0), reinterpret_cast<__vector unsigned short>(b1)));
__vector unsigned char c1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b2), reinterpret_cast<__vector unsigned short>(b3)));
__vector unsigned char c2 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b4), reinterpret_cast<__vector unsigned short>(b5)));
__vector unsigned char c3 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b6), reinterpret_cast<__vector unsigned short>(b7)));
__vector unsigned char d0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1)));
__vector unsigned char d1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3)));
__vector unsigned char e0 = vec_xxpermdi(d0, d1, 0);
__vector unsigned char e1 = vec_xxpermdi(d0, d1, 3);
vec_vsx_st(e0, 0, &Output[0]);
vec_vsx_st(e1, 0, &Output[OutputStride]);
d0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1)));
d1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3)));
e0 = vec_xxpermdi(d0, d1, 0);
e1 = vec_xxpermdi(d0, d1, 3);
vec_vsx_st(e0, 0, &Output[OutputStride * 2]);
vec_vsx_st(e1, 0, &Output[OutputStride * 3]);
c0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b0), reinterpret_cast<__vector unsigned short>(b1)));
c1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b2), reinterpret_cast<__vector unsigned short>(b3)));
c2 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b4), reinterpret_cast<__vector unsigned short>(b5)));
c3 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b6), reinterpret_cast<__vector unsigned short>(b7)));
d0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1)));
d1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3)));
e0 = vec_xxpermdi(d0, d1, 0);
e1 = vec_xxpermdi(d0, d1, 3);
vec_vsx_st(e0, 0, &Output[OutputStride * 4]);
vec_vsx_st(e1, 0, &Output[OutputStride * 5]);
d0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1)));
d1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3)));
e0 = vec_xxpermdi(d0, d1, 0);
e1 = vec_xxpermdi(d0, d1, 3);
vec_vsx_st(e0, 0, &Output[OutputStride * 6]);
vec_vsx_st(e1, 0, &Output[OutputStride * 7]);
b0 = vec_mergel(a0, a1);
b1 = vec_mergel(a2, a3);
b2 = vec_mergel(a4, a5);
b3 = vec_mergel(a6, a7);
b4 = vec_mergel(a8, a9);
b5 = vec_mergel(a10, a11);
b6 = vec_mergel(a12, a13);
b7 = vec_mergel(a14, a15);
c0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b0), reinterpret_cast<__vector unsigned short>(b1)));
c1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b2), reinterpret_cast<__vector unsigned short>(b3)));
c2 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b4), reinterpret_cast<__vector unsigned short>(b5)));
c3 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b6), reinterpret_cast<__vector unsigned short>(b7)));
d0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1)));
d1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3)));
e0 = vec_xxpermdi(d0, d1, 0);
e1 = vec_xxpermdi(d0, d1, 3);
vec_vsx_st(e0, 0, &Output[OutputStride * 8]);
vec_vsx_st(e1, 0, &Output[OutputStride * 9]);
d0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1)));
d1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3)));
e0 = vec_xxpermdi(d0, d1, 0);
e1 = vec_xxpermdi(d0, d1, 3);
vec_vsx_st(e0, 0, &Output[OutputStride * 10]);
vec_vsx_st(e1, 0, &Output[OutputStride * 11]);
c0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b0), reinterpret_cast<__vector unsigned short>(b1)));
c1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b2), reinterpret_cast<__vector unsigned short>(b3)));
c2 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b4), reinterpret_cast<__vector unsigned short>(b5)));
c3 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b6), reinterpret_cast<__vector unsigned short>(b7)));
d0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1)));
d1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3)));
e0 = vec_xxpermdi(d0, d1, 0);
e1 = vec_xxpermdi(d0, d1, 3);
vec_vsx_st(e0, 0, &Output[OutputStride * 12]);
vec_vsx_st(e1, 0, &Output[OutputStride * 13]);
d0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1)));
d1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3)));
e0 = vec_xxpermdi(d0, d1, 0);
e1 = vec_xxpermdi(d0, d1, 3);
vec_vsx_st(e0, 0, &Output[OutputStride * 14]);
vec_vsx_st(e1, 0, &Output[OutputStride * 15]);
}
#endif
template<typename ElementType>
@ -191,6 +342,24 @@ MlasTranspose4xNVector(
Output[OutputStride * 3] = a3;
}
#if defined(MLAS_TARGET_POWER)
template<typename ElementType>
MLAS_FORCEINLINE
void
MlasTranspose16xNVector(
const ElementType* Input,
size_t InputStride,
ElementType* Output,
size_t OutputStride
)
{
MlasTranspose4xNVector(&Input[InputStride * 0], InputStride, &Output[OutputStride * 0], OutputStride);
MlasTranspose4xNVector(&Input[InputStride * 4], InputStride, &Output[OutputStride * 4], OutputStride);
MlasTranspose4xNVector(&Input[InputStride * 8], InputStride, &Output[OutputStride * 8], OutputStride);
MlasTranspose4xNVector(&Input[InputStride * 12], InputStride, &Output[OutputStride * 12], OutputStride);
}
#endif
template<typename ElementType>
MLAS_FORCEINLINE
void
@ -251,7 +420,7 @@ Return Value:
uint32_t* d = Output;
size_t m = M;
#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS)
#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_TARGET_POWER)
while (m >= 4) {
@ -368,7 +537,35 @@ Return Value:
// Transpose elements from the input matrix to the output matrix 8 columns
// at a time.
//
#if defined(MLAS_TARGET_POWER)
while (n >= 16) {
const uint8_t* s = Input;
uint8_t* d = Output;
size_t m = M;
while (m >= 16) {
MlasTranspose16x16Block(s, N, d, M);
s += N * 16;
d += 16;
m -= 16;
}
while (m > 0) {
MlasTranspose16xNVector(s, 1, d, M);
s += N;
d += 1;
m -= 1;
}
Input += 16;
Output += M * 16;
n -= 16;
}
#endif
while (n >= 8) {
const uint8_t* s = Input;
@ -450,4 +647,4 @@ MlasTranspose(
reinterpret_cast<uint8_t*>(Output),
M,
N);
}
}