mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
POWER: Optimize MlasTranspose functions (#11172)
This patch makes use of POWER vector intrinsics to improve performance of MlasTranspose functions. Co-authored-by: Rajalakshmi Srinivasaraghavan <rajis@linux.ibm.com>
This commit is contained in:
parent
833f5d5604
commit
e397d8e63e
1 changed files with 199 additions and 2 deletions
|
|
@ -168,6 +168,157 @@ MlasTranspose8x8Block(
|
|||
vst1_u8(&Output[OutputStride * 7], vreinterpret_u8_u32(d3.val[1]));
|
||||
}
|
||||
|
||||
#elif defined(MLAS_TARGET_POWER)
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
MlasTranspose4x4Block(
|
||||
const uint32_t* Input,
|
||||
size_t InputStride,
|
||||
uint32_t* Output,
|
||||
size_t OutputStride
|
||||
)
|
||||
{
|
||||
__vector unsigned int a0 = vec_vsx_ld(0, Input);
|
||||
__vector unsigned int a1 = vec_vsx_ld(0, &Input[InputStride]);
|
||||
__vector unsigned int a2 = vec_vsx_ld(0, &Input[InputStride * 2]);
|
||||
__vector unsigned int a3 = vec_vsx_ld(0, &Input[InputStride * 3]);
|
||||
|
||||
__vector unsigned int b0 = vec_mergeh(a0, a1);
|
||||
__vector unsigned int b1 = vec_mergeh(a2, a3);
|
||||
__vector unsigned int b2 = vec_mergel(a0, a1);
|
||||
__vector unsigned int b3 = vec_mergel(a2, a3);
|
||||
|
||||
__vector unsigned int c0 = vec_xxpermdi(b0, b1, 0);
|
||||
__vector unsigned int c1 = vec_xxpermdi(b0, b1, 3);
|
||||
__vector unsigned int c2 = vec_xxpermdi(b2, b3, 0);
|
||||
__vector unsigned int c3 = vec_xxpermdi(b2, b3, 3);
|
||||
|
||||
vec_vsx_st(c0, 0, Output);
|
||||
vec_vsx_st(c1, 0, &Output[OutputStride]);
|
||||
vec_vsx_st(c2, 0, &Output[OutputStride * 2]);
|
||||
vec_vsx_st(c3, 0, &Output[OutputStride * 3]);
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
MlasTranspose16x16Block(
|
||||
const uint8_t* Input,
|
||||
size_t InputStride,
|
||||
uint8_t* Output,
|
||||
size_t OutputStride
|
||||
)
|
||||
{
|
||||
__vector unsigned char a0 = vec_vsx_ld(0, Input);
|
||||
__vector unsigned char a1 = vec_vsx_ld(0, &Input[InputStride]);
|
||||
__vector unsigned char a2 = vec_vsx_ld(0, &Input[InputStride * 2]);
|
||||
__vector unsigned char a3 = vec_vsx_ld(0, &Input[InputStride * 3]);
|
||||
__vector unsigned char a4 = vec_vsx_ld(0, &Input[InputStride * 4]);
|
||||
__vector unsigned char a5 = vec_vsx_ld(0, &Input[InputStride * 5]);
|
||||
__vector unsigned char a6 = vec_vsx_ld(0, &Input[InputStride * 6]);
|
||||
__vector unsigned char a7 = vec_vsx_ld(0, &Input[InputStride * 7]);
|
||||
__vector unsigned char a8 = vec_vsx_ld(0, &Input[InputStride * 8]);
|
||||
__vector unsigned char a9 = vec_vsx_ld(0, &Input[InputStride * 9]);
|
||||
__vector unsigned char a10 = vec_vsx_ld(0, &Input[InputStride * 10]);
|
||||
__vector unsigned char a11 = vec_vsx_ld(0, &Input[InputStride * 11]);
|
||||
__vector unsigned char a12 = vec_vsx_ld(0, &Input[InputStride * 12]);
|
||||
__vector unsigned char a13 = vec_vsx_ld(0, &Input[InputStride * 13]);
|
||||
__vector unsigned char a14 = vec_vsx_ld(0, &Input[InputStride * 14]);
|
||||
__vector unsigned char a15 = vec_vsx_ld(0, &Input[InputStride * 15]);
|
||||
|
||||
__vector unsigned char b0 = vec_mergeh(a0, a1);
|
||||
__vector unsigned char b1 = vec_mergeh(a2, a3);
|
||||
__vector unsigned char b2 = vec_mergeh(a4, a5);
|
||||
__vector unsigned char b3 = vec_mergeh(a6, a7);
|
||||
__vector unsigned char b4 = vec_mergeh(a8, a9);
|
||||
__vector unsigned char b5 = vec_mergeh(a10, a11);
|
||||
__vector unsigned char b6 = vec_mergeh(a12, a13);
|
||||
__vector unsigned char b7 = vec_mergeh(a14, a15);
|
||||
__vector unsigned char c0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b0), reinterpret_cast<__vector unsigned short>(b1)));
|
||||
__vector unsigned char c1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b2), reinterpret_cast<__vector unsigned short>(b3)));
|
||||
__vector unsigned char c2 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b4), reinterpret_cast<__vector unsigned short>(b5)));
|
||||
__vector unsigned char c3 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b6), reinterpret_cast<__vector unsigned short>(b7)));
|
||||
|
||||
__vector unsigned char d0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1)));
|
||||
__vector unsigned char d1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3)));
|
||||
__vector unsigned char e0 = vec_xxpermdi(d0, d1, 0);
|
||||
__vector unsigned char e1 = vec_xxpermdi(d0, d1, 3);
|
||||
vec_vsx_st(e0, 0, &Output[0]);
|
||||
vec_vsx_st(e1, 0, &Output[OutputStride]);
|
||||
|
||||
d0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1)));
|
||||
d1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3)));
|
||||
e0 = vec_xxpermdi(d0, d1, 0);
|
||||
e1 = vec_xxpermdi(d0, d1, 3);
|
||||
vec_vsx_st(e0, 0, &Output[OutputStride * 2]);
|
||||
vec_vsx_st(e1, 0, &Output[OutputStride * 3]);
|
||||
|
||||
c0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b0), reinterpret_cast<__vector unsigned short>(b1)));
|
||||
c1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b2), reinterpret_cast<__vector unsigned short>(b3)));
|
||||
c2 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b4), reinterpret_cast<__vector unsigned short>(b5)));
|
||||
c3 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b6), reinterpret_cast<__vector unsigned short>(b7)));
|
||||
|
||||
d0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1)));
|
||||
d1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3)));
|
||||
e0 = vec_xxpermdi(d0, d1, 0);
|
||||
e1 = vec_xxpermdi(d0, d1, 3);
|
||||
vec_vsx_st(e0, 0, &Output[OutputStride * 4]);
|
||||
vec_vsx_st(e1, 0, &Output[OutputStride * 5]);
|
||||
|
||||
d0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1)));
|
||||
d1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3)));
|
||||
e0 = vec_xxpermdi(d0, d1, 0);
|
||||
e1 = vec_xxpermdi(d0, d1, 3);
|
||||
vec_vsx_st(e0, 0, &Output[OutputStride * 6]);
|
||||
vec_vsx_st(e1, 0, &Output[OutputStride * 7]);
|
||||
|
||||
b0 = vec_mergel(a0, a1);
|
||||
b1 = vec_mergel(a2, a3);
|
||||
b2 = vec_mergel(a4, a5);
|
||||
b3 = vec_mergel(a6, a7);
|
||||
b4 = vec_mergel(a8, a9);
|
||||
b5 = vec_mergel(a10, a11);
|
||||
b6 = vec_mergel(a12, a13);
|
||||
b7 = vec_mergel(a14, a15);
|
||||
|
||||
c0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b0), reinterpret_cast<__vector unsigned short>(b1)));
|
||||
c1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b2), reinterpret_cast<__vector unsigned short>(b3)));
|
||||
c2 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b4), reinterpret_cast<__vector unsigned short>(b5)));
|
||||
c3 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b6), reinterpret_cast<__vector unsigned short>(b7)));
|
||||
|
||||
d0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1)));
|
||||
d1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3)));
|
||||
e0 = vec_xxpermdi(d0, d1, 0);
|
||||
e1 = vec_xxpermdi(d0, d1, 3);
|
||||
vec_vsx_st(e0, 0, &Output[OutputStride * 8]);
|
||||
vec_vsx_st(e1, 0, &Output[OutputStride * 9]);
|
||||
|
||||
d0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1)));
|
||||
d1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3)));
|
||||
e0 = vec_xxpermdi(d0, d1, 0);
|
||||
e1 = vec_xxpermdi(d0, d1, 3);
|
||||
vec_vsx_st(e0, 0, &Output[OutputStride * 10]);
|
||||
vec_vsx_st(e1, 0, &Output[OutputStride * 11]);
|
||||
|
||||
c0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b0), reinterpret_cast<__vector unsigned short>(b1)));
|
||||
c1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b2), reinterpret_cast<__vector unsigned short>(b3)));
|
||||
c2 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b4), reinterpret_cast<__vector unsigned short>(b5)));
|
||||
c3 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b6), reinterpret_cast<__vector unsigned short>(b7)));
|
||||
|
||||
d0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1)));
|
||||
d1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3)));
|
||||
e0 = vec_xxpermdi(d0, d1, 0);
|
||||
e1 = vec_xxpermdi(d0, d1, 3);
|
||||
vec_vsx_st(e0, 0, &Output[OutputStride * 12]);
|
||||
vec_vsx_st(e1, 0, &Output[OutputStride * 13]);
|
||||
|
||||
d0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1)));
|
||||
d1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3)));
|
||||
e0 = vec_xxpermdi(d0, d1, 0);
|
||||
e1 = vec_xxpermdi(d0, d1, 3);
|
||||
vec_vsx_st(e0, 0, &Output[OutputStride * 14]);
|
||||
vec_vsx_st(e1, 0, &Output[OutputStride * 15]);
|
||||
}
|
||||
#endif
|
||||
|
||||
template<typename ElementType>
|
||||
|
|
@ -191,6 +342,24 @@ MlasTranspose4xNVector(
|
|||
Output[OutputStride * 3] = a3;
|
||||
}
|
||||
|
||||
#if defined(MLAS_TARGET_POWER)
|
||||
template<typename ElementType>
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
MlasTranspose16xNVector(
|
||||
const ElementType* Input,
|
||||
size_t InputStride,
|
||||
ElementType* Output,
|
||||
size_t OutputStride
|
||||
)
|
||||
{
|
||||
MlasTranspose4xNVector(&Input[InputStride * 0], InputStride, &Output[OutputStride * 0], OutputStride);
|
||||
MlasTranspose4xNVector(&Input[InputStride * 4], InputStride, &Output[OutputStride * 4], OutputStride);
|
||||
MlasTranspose4xNVector(&Input[InputStride * 8], InputStride, &Output[OutputStride * 8], OutputStride);
|
||||
MlasTranspose4xNVector(&Input[InputStride * 12], InputStride, &Output[OutputStride * 12], OutputStride);
|
||||
}
|
||||
#endif
|
||||
|
||||
template<typename ElementType>
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
|
|
@ -251,7 +420,7 @@ Return Value:
|
|||
uint32_t* d = Output;
|
||||
size_t m = M;
|
||||
|
||||
#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS)
|
||||
#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_TARGET_POWER)
|
||||
|
||||
while (m >= 4) {
|
||||
|
||||
|
|
@ -368,7 +537,35 @@ Return Value:
|
|||
// Transpose elements from the input matrix to the output matrix 8 columns
|
||||
// at a time.
|
||||
//
|
||||
#if defined(MLAS_TARGET_POWER)
|
||||
while (n >= 16) {
|
||||
|
||||
const uint8_t* s = Input;
|
||||
uint8_t* d = Output;
|
||||
size_t m = M;
|
||||
while (m >= 16) {
|
||||
|
||||
MlasTranspose16x16Block(s, N, d, M);
|
||||
|
||||
s += N * 16;
|
||||
d += 16;
|
||||
m -= 16;
|
||||
}
|
||||
|
||||
while (m > 0) {
|
||||
|
||||
MlasTranspose16xNVector(s, 1, d, M);
|
||||
|
||||
s += N;
|
||||
d += 1;
|
||||
m -= 1;
|
||||
}
|
||||
|
||||
Input += 16;
|
||||
Output += M * 16;
|
||||
n -= 16;
|
||||
}
|
||||
#endif
|
||||
while (n >= 8) {
|
||||
|
||||
const uint8_t* s = Input;
|
||||
|
|
@ -450,4 +647,4 @@ MlasTranspose(
|
|||
reinterpret_cast<uint8_t*>(Output),
|
||||
M,
|
||||
N);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue