mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-07 00:13:17 +00:00
Transpose for 16b tensors (#14877)
### Description Matrix transpose for 16b tensors (shorts, and half precision floats) ### Motivation and Context Need it for fp16 operations
This commit is contained in:
parent
7cd4b334a9
commit
603026fb84
3 changed files with 175 additions and 2 deletions
|
|
@ -1051,6 +1051,15 @@ MlasTranspose(
|
|||
size_t N
|
||||
);
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasTranspose(
|
||||
const uint16_t* Input,
|
||||
uint16_t* Output,
|
||||
size_t M,
|
||||
size_t N
|
||||
);
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasTranspose(
|
||||
|
|
|
|||
|
|
@ -48,6 +48,32 @@ MlasTranspose4x4Block(
|
|||
_mm_storeu_si128((__m128i*)&Output[OutputStride * 3], c3);
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
MlasTranspose4x4Block(
|
||||
const uint16_t* Input,
|
||||
size_t InputStride,
|
||||
uint16_t* Output,
|
||||
size_t OutputStride
|
||||
)
|
||||
{
|
||||
__m128i a0 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 0]);
|
||||
__m128i a1 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 1]);
|
||||
__m128i a2 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 2]);
|
||||
__m128i a3 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 3]);
|
||||
|
||||
__m128i b0 = _mm_unpacklo_epi16(a0, a2);
|
||||
__m128i b1 = _mm_unpacklo_epi16(a1, a3);
|
||||
|
||||
__m128i c0 = _mm_unpacklo_epi16(b0, b1);
|
||||
__m128i c1 = _mm_unpackhi_epi16(b0, b1);
|
||||
|
||||
_mm_storel_pi((__m64*)&Output[OutputStride * 0], _mm_castsi128_ps(c0));
|
||||
_mm_storeh_pi((__m64*)&Output[OutputStride * 1], _mm_castsi128_ps(c0));
|
||||
_mm_storel_pi((__m64*)&Output[OutputStride * 2], _mm_castsi128_ps(c1));
|
||||
_mm_storeh_pi((__m64*)&Output[OutputStride * 3], _mm_castsi128_ps(c1));
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
MlasTranspose8x8Block(
|
||||
|
|
@ -123,6 +149,32 @@ MlasTranspose4x4Block(
|
|||
vst1q_u32(&Output[OutputStride * 3], c1.val[1]);
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
MlasTranspose4x4Block(
|
||||
const uint16_t* Input,
|
||||
size_t InputStride,
|
||||
uint16_t* Output,
|
||||
size_t OutputStride
|
||||
)
|
||||
{
|
||||
uint16x4_t a0 = vld1_u16(&Input[InputStride * 0]);
|
||||
uint16x4_t a1 = vld1_u16(&Input[InputStride * 1]);
|
||||
uint16x4_t a2 = vld1_u16(&Input[InputStride * 2]);
|
||||
uint16x4_t a3 = vld1_u16(&Input[InputStride * 3]);
|
||||
|
||||
uint16x4x2_t b0 = vzip_u16(a0, a2);
|
||||
uint16x4x2_t b1 = vzip_u16(a1, a3);
|
||||
|
||||
uint16x4x2_t c0 = vzip_u16(b0.val[0], b1.val[0]);
|
||||
uint16x4x2_t c1 = vzip_u16(b0.val[1], b1.val[1]);
|
||||
|
||||
vst1_u16(&Output[OutputStride * 0], c0.val[0]);
|
||||
vst1_u16(&Output[OutputStride * 1], c0.val[1]);
|
||||
vst1_u16(&Output[OutputStride * 2], c1.val[0]);
|
||||
vst1_u16(&Output[OutputStride * 3], c1.val[1]);
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
MlasTranspose8x8Block(
|
||||
|
|
@ -498,6 +550,116 @@ MlasTranspose(
|
|||
N);
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasTranspose(
|
||||
const uint16_t* Input,
|
||||
uint16_t* Output,
|
||||
size_t M,
|
||||
size_t N
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine transposes the input matrix (M rows by N columns) to the
|
||||
output matrix (N rows by M columns).
|
||||
|
||||
Arguments:
|
||||
|
||||
Input - Supplies the input buffer.
|
||||
|
||||
Output - Supplies the output buffer.
|
||||
|
||||
M - Supplies the number of rows for the input matrix and the number of
|
||||
columns for the output matrix.
|
||||
|
||||
N - Supplies the number of columns for the input matrix and the number of
|
||||
rows for the output matrix.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
{
|
||||
size_t n = N;
|
||||
|
||||
//
|
||||
// Transpose elements from the input matrix to the output matrix 4 columns
|
||||
// at a time.
|
||||
//
|
||||
|
||||
while (n >= 4) {
|
||||
|
||||
const uint16_t* s = Input;
|
||||
uint16_t* d = Output;
|
||||
size_t m = M;
|
||||
|
||||
#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS)
|
||||
|
||||
while (m >= 4) {
|
||||
|
||||
MlasTranspose4x4Block(s, N, d, M);
|
||||
|
||||
s += N * 4;
|
||||
d += 4;
|
||||
m -= 4;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
while (m > 0) {
|
||||
|
||||
MlasTranspose4xNVector(s, 1, d, M);
|
||||
|
||||
s += N;
|
||||
d += 1;
|
||||
m -= 1;
|
||||
}
|
||||
|
||||
Input += 4;
|
||||
Output += M * 4;
|
||||
n -= 4;
|
||||
}
|
||||
|
||||
//
|
||||
// Transpose elements from the input matrix to the output matrix for the
|
||||
// remaining columns.
|
||||
//
|
||||
|
||||
while (n > 0) {
|
||||
|
||||
const uint16_t* s = Input;
|
||||
uint16_t* d = Output;
|
||||
size_t m = M;
|
||||
|
||||
while (m >= 4) {
|
||||
|
||||
MlasTranspose4xNVector(s, N, d, 1);
|
||||
|
||||
s += N * 4;
|
||||
d += 4;
|
||||
m -= 4;
|
||||
}
|
||||
|
||||
while (m > 0) {
|
||||
|
||||
d[0] = s[0];
|
||||
|
||||
s += N;
|
||||
d += 1;
|
||||
m -= 1;
|
||||
}
|
||||
|
||||
Input += 1;
|
||||
Output += M;
|
||||
n -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasTranspose(
|
||||
|
|
|
|||
|
|
@ -46,13 +46,15 @@ class MlasTransposeTest : public MlasTestBase {
|
|||
};
|
||||
|
||||
template <> MlasTransposeTest<uint32_t>* MlasTestFixture<MlasTransposeTest<uint32_t>>::mlas_tester(nullptr);
|
||||
template <> MlasTransposeTest<uint16_t>* MlasTestFixture<MlasTransposeTest<uint16_t>>::mlas_tester(nullptr);
|
||||
template <> MlasTransposeTest<uint8_t>* MlasTestFixture<MlasTransposeTest<uint8_t>>::mlas_tester(nullptr);
|
||||
|
||||
static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) {
|
||||
size_t count = 0;
|
||||
if (is_short_execute) {
|
||||
count += MlasDirectShortExecuteTests<MlasTransposeTest<uint32_t>>::RegisterShortExecute();
|
||||
count += MlasDirectShortExecuteTests<MlasTransposeTest<uint8_t>>::RegisterShortExecute();
|
||||
count += MlasDirectShortExecuteTests<MlasTransposeTest<uint32_t>>::RegisterShortExecute();
|
||||
count += MlasDirectShortExecuteTests<MlasTransposeTest<uint16_t>>::RegisterShortExecute();
|
||||
count += MlasDirectShortExecuteTests<MlasTransposeTest<uint8_t>>::RegisterShortExecute();
|
||||
}
|
||||
return count;
|
||||
});
|
||||
|
|
|
|||
Loading…
Reference in a new issue