From 603026fb84df38a76ac04fd9741ef6d1fa960676 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Thu, 2 Mar 2023 11:32:05 -0800 Subject: [PATCH] Transpose for 16b tensors (#14877) ### Description Matrix transpose for 16b tensors (shorts, and half precision floats) ### Motivation and Context Need it for fp16 operations --- onnxruntime/core/mlas/inc/mlas.h | 9 + onnxruntime/core/mlas/lib/transpose.cpp | 162 ++++++++++++++++++ .../test/mlas/unittest/test_transpose.cpp | 6 +- 3 files changed, 175 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index b87cd0a77b..b141be78e3 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1051,6 +1051,15 @@ MlasTranspose( size_t N ); +void +MLASCALL +MlasTranspose( + const uint16_t* Input, + uint16_t* Output, + size_t M, + size_t N + ); + void MLASCALL MlasTranspose( diff --git a/onnxruntime/core/mlas/lib/transpose.cpp b/onnxruntime/core/mlas/lib/transpose.cpp index c9e35ec1c3..86b0897bb9 100644 --- a/onnxruntime/core/mlas/lib/transpose.cpp +++ b/onnxruntime/core/mlas/lib/transpose.cpp @@ -48,6 +48,32 @@ MlasTranspose4x4Block( _mm_storeu_si128((__m128i*)&Output[OutputStride * 3], c3); } +MLAS_FORCEINLINE +void +MlasTranspose4x4Block( + const uint16_t* Input, + size_t InputStride, + uint16_t* Output, + size_t OutputStride + ) +{ + __m128i a0 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 0]); + __m128i a1 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 1]); + __m128i a2 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 2]); + __m128i a3 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 3]); + + __m128i b0 = _mm_unpacklo_epi16(a0, a2); + __m128i b1 = _mm_unpacklo_epi16(a1, a3); + + __m128i c0 = _mm_unpacklo_epi16(b0, b1); + __m128i c1 = _mm_unpackhi_epi16(b0, b1); + + _mm_storel_pi((__m64*)&Output[OutputStride * 0], _mm_castsi128_ps(c0)); + _mm_storeh_pi((__m64*)&Output[OutputStride * 1], _mm_castsi128_ps(c0)); + _mm_storel_pi((__m64*)&Output[OutputStride * 2], _mm_castsi128_ps(c1)); + _mm_storeh_pi((__m64*)&Output[OutputStride * 3], _mm_castsi128_ps(c1)); +} + MLAS_FORCEINLINE void MlasTranspose8x8Block( @@ -123,6 +149,32 @@ MlasTranspose4x4Block( vst1q_u32(&Output[OutputStride * 3], c1.val[1]); } +MLAS_FORCEINLINE +void +MlasTranspose4x4Block( + const uint16_t* Input, + size_t InputStride, + uint16_t* Output, + size_t OutputStride + ) +{ + uint16x4_t a0 = vld1_u16(&Input[InputStride * 0]); + uint16x4_t a1 = vld1_u16(&Input[InputStride * 1]); + uint16x4_t a2 = vld1_u16(&Input[InputStride * 2]); + uint16x4_t a3 = vld1_u16(&Input[InputStride * 3]); + + uint16x4x2_t b0 = vzip_u16(a0, a2); + uint16x4x2_t b1 = vzip_u16(a1, a3); + + uint16x4x2_t c0 = vzip_u16(b0.val[0], b1.val[0]); + uint16x4x2_t c1 = vzip_u16(b0.val[1], b1.val[1]); + + vst1_u16(&Output[OutputStride * 0], c0.val[0]); + vst1_u16(&Output[OutputStride * 1], c0.val[1]); + vst1_u16(&Output[OutputStride * 2], c1.val[0]); + vst1_u16(&Output[OutputStride * 3], c1.val[1]); +} + MLAS_FORCEINLINE void MlasTranspose8x8Block( @@ -498,6 +550,116 @@ MlasTranspose( N); } + +void +MLASCALL +MlasTranspose( + const uint16_t* Input, + uint16_t* Output, + size_t M, + size_t N + ) +/*++ + +Routine Description: + + This routine transposes the input matrix (M rows by N columns) to the + output matrix (N rows by M columns). + +Arguments: + + Input - Supplies the input buffer. + + Output - Supplies the output buffer. + + M - Supplies the number of rows for the input matrix and the number of + columns for the output matrix. + + N - Supplies the number of columns for the input matrix and the number of + rows for the output matrix. + +Return Value: + + None. + +--*/ +{ + size_t n = N; + + // + // Transpose elements from the input matrix to the output matrix 4 columns + // at a time. + // + + while (n >= 4) { + + const uint16_t* s = Input; + uint16_t* d = Output; + size_t m = M; + +#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) + + while (m >= 4) { + + MlasTranspose4x4Block(s, N, d, M); + + s += N * 4; + d += 4; + m -= 4; + } + +#endif + + while (m > 0) { + + MlasTranspose4xNVector(s, 1, d, M); + + s += N; + d += 1; + m -= 1; + } + + Input += 4; + Output += M * 4; + n -= 4; + } + + // + // Transpose elements from the input matrix to the output matrix for the + // remaining columns. + // + + while (n > 0) { + + const uint16_t* s = Input; + uint16_t* d = Output; + size_t m = M; + + while (m >= 4) { + + MlasTranspose4xNVector(s, N, d, 1); + + s += N * 4; + d += 4; + m -= 4; + } + + while (m > 0) { + + d[0] = s[0]; + + s += N; + d += 1; + m -= 1; + } + + Input += 1; + Output += M; + n -= 1; + } +} + + void MLASCALL MlasTranspose( diff --git a/onnxruntime/test/mlas/unittest/test_transpose.cpp b/onnxruntime/test/mlas/unittest/test_transpose.cpp index 6aa25e35b7..801c88556a 100644 --- a/onnxruntime/test/mlas/unittest/test_transpose.cpp +++ b/onnxruntime/test/mlas/unittest/test_transpose.cpp @@ -46,13 +46,15 @@ class MlasTransposeTest : public MlasTestBase { }; template <> MlasTransposeTest* MlasTestFixture>::mlas_tester(nullptr); +template <> MlasTransposeTest* MlasTestFixture>::mlas_tester(nullptr); template <> MlasTransposeTest* MlasTestFixture>::mlas_tester(nullptr); static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { size_t count = 0; if (is_short_execute) { - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); } return count; });