Transpose for 16b tensors (#14877)

### Description
Matrix transpose for 16b tensors (shorts, and half precision floats)


### Motivation and Context

Need it for fp16 operations
This commit is contained in:
Chen Fu 2023-03-02 11:32:05 -08:00 committed by GitHub
parent 7cd4b334a9
commit 603026fb84
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 175 additions and 2 deletions

View file

@ -1051,6 +1051,15 @@ MlasTranspose(
size_t N
);
void
MLASCALL
MlasTranspose(
const uint16_t* Input,
uint16_t* Output,
size_t M,
size_t N
);
void
MLASCALL
MlasTranspose(

View file

@ -48,6 +48,32 @@ MlasTranspose4x4Block(
_mm_storeu_si128((__m128i*)&Output[OutputStride * 3], c3);
}
MLAS_FORCEINLINE
void
MlasTranspose4x4Block(
const uint16_t* Input,
size_t InputStride,
uint16_t* Output,
size_t OutputStride
)
{
__m128i a0 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 0]);
__m128i a1 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 1]);
__m128i a2 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 2]);
__m128i a3 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 3]);
__m128i b0 = _mm_unpacklo_epi16(a0, a2);
__m128i b1 = _mm_unpacklo_epi16(a1, a3);
__m128i c0 = _mm_unpacklo_epi16(b0, b1);
__m128i c1 = _mm_unpackhi_epi16(b0, b1);
_mm_storel_pi((__m64*)&Output[OutputStride * 0], _mm_castsi128_ps(c0));
_mm_storeh_pi((__m64*)&Output[OutputStride * 1], _mm_castsi128_ps(c0));
_mm_storel_pi((__m64*)&Output[OutputStride * 2], _mm_castsi128_ps(c1));
_mm_storeh_pi((__m64*)&Output[OutputStride * 3], _mm_castsi128_ps(c1));
}
MLAS_FORCEINLINE
void
MlasTranspose8x8Block(
@ -123,6 +149,32 @@ MlasTranspose4x4Block(
vst1q_u32(&Output[OutputStride * 3], c1.val[1]);
}
MLAS_FORCEINLINE
void
MlasTranspose4x4Block(
const uint16_t* Input,
size_t InputStride,
uint16_t* Output,
size_t OutputStride
)
{
uint16x4_t a0 = vld1_u16(&Input[InputStride * 0]);
uint16x4_t a1 = vld1_u16(&Input[InputStride * 1]);
uint16x4_t a2 = vld1_u16(&Input[InputStride * 2]);
uint16x4_t a3 = vld1_u16(&Input[InputStride * 3]);
uint16x4x2_t b0 = vzip_u16(a0, a2);
uint16x4x2_t b1 = vzip_u16(a1, a3);
uint16x4x2_t c0 = vzip_u16(b0.val[0], b1.val[0]);
uint16x4x2_t c1 = vzip_u16(b0.val[1], b1.val[1]);
vst1_u16(&Output[OutputStride * 0], c0.val[0]);
vst1_u16(&Output[OutputStride * 1], c0.val[1]);
vst1_u16(&Output[OutputStride * 2], c1.val[0]);
vst1_u16(&Output[OutputStride * 3], c1.val[1]);
}
MLAS_FORCEINLINE
void
MlasTranspose8x8Block(
@ -498,6 +550,116 @@ MlasTranspose(
N);
}
void
MLASCALL
MlasTranspose(
const uint16_t* Input,
uint16_t* Output,
size_t M,
size_t N
)
/*++
Routine Description:
This routine transposes the input matrix (M rows by N columns) to the
output matrix (N rows by M columns).
Arguments:
Input - Supplies the input buffer.
Output - Supplies the output buffer.
M - Supplies the number of rows for the input matrix and the number of
columns for the output matrix.
N - Supplies the number of columns for the input matrix and the number of
rows for the output matrix.
Return Value:
None.
--*/
{
size_t n = N;
//
// Transpose elements from the input matrix to the output matrix 4 columns
// at a time.
//
while (n >= 4) {
const uint16_t* s = Input;
uint16_t* d = Output;
size_t m = M;
#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS)
while (m >= 4) {
MlasTranspose4x4Block(s, N, d, M);
s += N * 4;
d += 4;
m -= 4;
}
#endif
while (m > 0) {
MlasTranspose4xNVector(s, 1, d, M);
s += N;
d += 1;
m -= 1;
}
Input += 4;
Output += M * 4;
n -= 4;
}
//
// Transpose elements from the input matrix to the output matrix for the
// remaining columns.
//
while (n > 0) {
const uint16_t* s = Input;
uint16_t* d = Output;
size_t m = M;
while (m >= 4) {
MlasTranspose4xNVector(s, N, d, 1);
s += N * 4;
d += 4;
m -= 4;
}
while (m > 0) {
d[0] = s[0];
s += N;
d += 1;
m -= 1;
}
Input += 1;
Output += M;
n -= 1;
}
}
void
MLASCALL
MlasTranspose(

View file

@ -46,13 +46,15 @@ class MlasTransposeTest : public MlasTestBase {
};
template <> MlasTransposeTest<uint32_t>* MlasTestFixture<MlasTransposeTest<uint32_t>>::mlas_tester(nullptr);
template <> MlasTransposeTest<uint16_t>* MlasTestFixture<MlasTransposeTest<uint16_t>>::mlas_tester(nullptr);
template <> MlasTransposeTest<uint8_t>* MlasTestFixture<MlasTransposeTest<uint8_t>>::mlas_tester(nullptr);
static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) {
size_t count = 0;
if (is_short_execute) {
count += MlasDirectShortExecuteTests<MlasTransposeTest<uint32_t>>::RegisterShortExecute();
count += MlasDirectShortExecuteTests<MlasTransposeTest<uint8_t>>::RegisterShortExecute();
count += MlasDirectShortExecuteTests<MlasTransposeTest<uint32_t>>::RegisterShortExecute();
count += MlasDirectShortExecuteTests<MlasTransposeTest<uint16_t>>::RegisterShortExecute();
count += MlasDirectShortExecuteTests<MlasTransposeTest<uint8_t>>::RegisterShortExecute();
}
return count;
});