diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 679a563008..9df48d93be 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -115,6 +115,8 @@ else() set(ARM TRUE) elseif(dumpmachine_output MATCHES "^aarch64.*") set(ARM64 TRUE) + elseif(dumpmachine_output MATCHES "^powerpc.*") + set(POWER TRUE) elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") set(X86 TRUE) elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") @@ -133,6 +135,10 @@ else() set(mlas_platform_srcs ${ONNXRUNTIME_ROOT}/core/mlas/lib/aarch64/SgemmKernelNeon.S ) + elseif(POWER) + set(mlas_platform_srcs + ${ONNXRUNTIME_ROOT}/core/mlas/lib/power/SgemmKernelPower.cpp + ) elseif(X86) enable_language(ASM) set(mlas_platform_srcs_sse2 diff --git a/onnxruntime/core/mlas/lib/activate.cpp b/onnxruntime/core/mlas/lib/activate.cpp index fe23c12d50..b0a2e331fc 100644 --- a/onnxruntime/core/mlas/lib/activate.cpp +++ b/onnxruntime/core/mlas/lib/activate.cpp @@ -142,6 +142,8 @@ struct MLAS_ACTIVATION_FUNCTION #elif defined(MLAS_SSE2_INTRINSICS) __m128 Selection = _mm_cmple_ps(ZeroFloat32x4, Value); return _mm_or_ps(_mm_and_ps(Value, Selection), _mm_andnot_ps(Selection, ValueTimesAlpha)); +#elif defined(MLAS_VSX_INTRINSICS) + return vec_sel(ValueTimesAlpha, Value, vec_cmple(ZeroFloat32x4, Value)); #else #error Unsupported architecture. #endif diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index e09f4f4531..93b3f12835 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -34,6 +34,13 @@ Abstract: #include #include #endif +#if defined(__VSX__) +#include +// Undefine unwanted aliases from altivec.h. +#undef vector +#undef pixel +#undef bool +#endif #endif // @@ -93,6 +100,9 @@ Abstract: #if defined(_M_ARM) || defined(__arm__) #define MLAS_TARGET_ARM #endif +#if defined(__VSX__) +#define MLAS_TARGET_POWER +#endif // // Select the threading model. @@ -142,7 +152,7 @@ Abstract: // Define the prototypes of the platform optimized routines. // -#if defined(MLAS_TARGET_AMD64_IX86) +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) typedef size_t @@ -469,6 +479,8 @@ extern "C" { MLAS_GEMM_DOUBLE_KERNEL MlasGemmDoubleKernelFma3; MLAS_GEMM_DOUBLE_KERNEL MlasGemmDoubleKernelAvx512F; #endif +#elif defined(MLAS_TARGET_POWER) + MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernel; #else MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelZero; MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelAdd; @@ -728,6 +740,8 @@ MlasPartitionWork( #elif defined(MLAS_TARGET_ARM64) #define MLAS_NEON_INTRINSICS #define MLAS_NEON64_INTRINSICS +#elif defined(MLAS_TARGET_POWER) +#define MLAS_VSX_INTRINSICS #elif defined(MLAS_TARGET_AMD64_IX86) #define MLAS_SSE2_INTRINSICS #if defined(__SSE4_1__) || (defined(_MSC_VER) && defined(__AVX__)) @@ -752,9 +766,52 @@ typedef int32x4_t MLAS_INT32X4; #elif defined(MLAS_SSE2_INTRINSICS) typedef __m128 MLAS_FLOAT32X4; typedef __m128i MLAS_INT32X4; +#elif defined(MLAS_VSX_INTRINSICS) +typedef __vector float MLAS_FLOAT32X4; +typedef __vector int MLAS_INT32X4; +typedef __vector unsigned MLAS_UINT32X4; #endif -inline +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasBroadcastInt32x4(int32_t Value) +{ +#if defined(MLAS_NEON_INTRINSICS) + return vdupq_n_s32(Value); +#elif defined(MLAS_SSE2_INTRINSICS) + return _mm_set1_epi32(Value); +#else + return MLAS_INT32X4{Value, Value, Value, Value}; +#endif +} + +MLAS_FORCEINLINE +MLAS_FLOAT32X4 +MlasBroadcastFloat32x4(float Value) +{ +#if defined(MLAS_NEON_INTRINSICS) + return vdupq_n_f32(Value); +#elif defined(MLAS_SSE2_INTRINSICS) + return _mm_set1_ps(Value); +#else + return MLAS_FLOAT32X4{Value, Value, Value, Value}; +#endif +} + +MLAS_FORCEINLINE +MLAS_FLOAT32X4 +MlasBroadcastFloat32x4(const float* Value) +{ +#if defined(MLAS_NEON_INTRINSICS) + return vld1q_dup_f32(Value); +#elif defined(MLAS_SSE2_INTRINSICS) + return _mm_load_ps1(Value); +#else + return MLAS_FLOAT32X4{*Value, *Value, *Value, *Value}; +#endif +} + +MLAS_FORCEINLINE MLAS_FLOAT32X4 MlasZeroFloat32x4(void) { @@ -762,10 +819,12 @@ MlasZeroFloat32x4(void) return vdupq_n_f32(0.0f); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_setzero_ps(); +#else + return MlasBroadcastFloat32x4(0.0f); #endif } -inline +MLAS_FORCEINLINE MLAS_FLOAT32X4 MlasLoadFloat32x4(const float* Buffer) { @@ -773,10 +832,12 @@ MlasLoadFloat32x4(const float* Buffer) return vld1q_f32(Buffer); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_loadu_ps(Buffer); +#elif defined(MLAS_VSX_INTRINSICS) + return vec_vsx_ld(0, Buffer); #endif } -inline +MLAS_FORCEINLINE void MlasStoreFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) { @@ -784,10 +845,12 @@ MlasStoreFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) vst1q_f32(Buffer, Vector); #elif defined(MLAS_SSE2_INTRINSICS) _mm_storeu_ps(Buffer, Vector); +#elif defined(MLAS_VSX_INTRINSICS) + vec_vsx_st(Vector, 0, Buffer); #endif } -inline +MLAS_FORCEINLINE void MlasStoreAlignedFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) { @@ -795,10 +858,15 @@ MlasStoreAlignedFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) vst1q_f32(Buffer, Vector); #elif defined(MLAS_SSE2_INTRINSICS) _mm_store_ps(Buffer, Vector); +#elif defined(MLAS_VSX_INTRINSICS) + // Workaround for bad GCC warning that these parameters are set but not used. + MLAS_UNREFERENCED_PARAMETER(Buffer); + MLAS_UNREFERENCED_PARAMETER(Vector); + vec_st(Vector, 0, Buffer); #endif } -inline +MLAS_FORCEINLINE void MlasStoreLowHalfFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) { @@ -806,11 +874,13 @@ MlasStoreLowHalfFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) vst1_f32(Buffer, vget_low_f32(Vector)); #elif defined(MLAS_SSE2_INTRINSICS) _mm_storel_pi((__m64*)Buffer, Vector); +#elif defined(MLAS_VSX_INTRINSICS) + *((int64_t*)Buffer) = ((__vector int64_t)Vector)[0]; #endif } template -inline +MLAS_FORCEINLINE void MlasStoreLaneFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) { @@ -820,11 +890,13 @@ MlasStoreLaneFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) // N.B. When building with AVX instructions, compilers optimize the following // to a single vextractps instruction. _mm_store_ss(Buffer, _mm_shuffle_ps(Vector, Vector, _MM_SHUFFLE(Lane, Lane, Lane, Lane))); +#else + *Buffer = Vector[Lane]; #endif } template -inline +MLAS_FORCEINLINE float MlasExtractLaneFloat32x4(MLAS_FLOAT32X4 Vector) { @@ -832,13 +904,15 @@ MlasExtractLaneFloat32x4(MLAS_FLOAT32X4 Vector) return vgetq_lane_f32(Vector, Lane); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_cvtss_f32(_mm_shuffle_ps(Vector, Vector, _MM_SHUFFLE(Lane, Lane, Lane, Lane))); +#else + return Vector[Lane]; #endif } #if defined(MLAS_SSE2_INTRINSICS) template<> -inline +MLAS_FORCEINLINE void MlasStoreLaneFloat32x4<0>(float* Buffer, MLAS_FLOAT32X4 Vector) { @@ -846,7 +920,7 @@ MlasStoreLaneFloat32x4<0>(float* Buffer, MLAS_FLOAT32X4 Vector) } template<> -inline +MLAS_FORCEINLINE float MlasExtractLaneFloat32x4<0>(MLAS_FLOAT32X4 Vector) { @@ -855,29 +929,7 @@ MlasExtractLaneFloat32x4<0>(MLAS_FLOAT32X4 Vector) #endif -inline -MLAS_FLOAT32X4 -MlasBroadcastFloat32x4(float Value) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vdupq_n_f32(Value); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_set1_ps(Value); -#endif -} - -inline -MLAS_FLOAT32X4 -MlasBroadcastFloat32x4(const float* Value) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vld1q_dup_f32(Value); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_load_ps1(Value); -#endif -} - -inline +MLAS_FORCEINLINE MLAS_FLOAT32X4 MlasAddFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) { @@ -885,10 +937,12 @@ MlasAddFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return vaddq_f32(Vector1, Vector2); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_add_ps(Vector1, Vector2); +#else + return Vector1 + Vector2; #endif } -inline +MLAS_FORCEINLINE MLAS_FLOAT32X4 MlasSubtractFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) { @@ -896,10 +950,12 @@ MlasSubtractFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return vsubq_f32(Vector1, Vector2); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_sub_ps(Vector1, Vector2); +#else + return Vector1 - Vector2; #endif } -inline +MLAS_FORCEINLINE MLAS_FLOAT32X4 MlasMultiplyFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) { @@ -907,10 +963,12 @@ MlasMultiplyFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return vmulq_f32(Vector1, Vector2); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_mul_ps(Vector1, Vector2); +#else + return Vector1 * Vector2; #endif } -inline +MLAS_FORCEINLINE MLAS_FLOAT32X4 MlasMultiplyAddFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2, MLAS_FLOAT32X4 Vector3) { @@ -920,10 +978,14 @@ MlasMultiplyAddFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2, MLAS_FL return _mm_fmadd_ps(Vector1, Vector2, Vector3); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_add_ps(_mm_mul_ps(Vector1, Vector2), Vector3); +#elif defined(MLAS_VSX_INTRINSICS) + return vec_madd(Vector1, Vector2, Vector3); +#else + return Vector1 * Vector2 + Vector3; #endif } -inline +MLAS_FORCEINLINE MLAS_FLOAT32X4 MlasDivideFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) { @@ -937,10 +999,12 @@ MlasDivideFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return Vector1; #elif defined(MLAS_SSE2_INTRINSICS) return _mm_div_ps(Vector1, Vector2); +#else + return Vector1 / Vector2; #endif } -inline +MLAS_FORCEINLINE MLAS_FLOAT32X4 MlasMaximumFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) { @@ -948,10 +1012,12 @@ MlasMaximumFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return vmaxq_f32(Vector1, Vector2); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_max_ps(Vector1, Vector2); +#else + return vec_sel(Vector2, Vector1, vec_cmpgt(Vector1, Vector2)); #endif } -inline +MLAS_FORCEINLINE MLAS_FLOAT32X4 MlasMinimumFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) { @@ -959,10 +1025,12 @@ MlasMinimumFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return vminq_f32(Vector1, Vector2); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_min_ps(Vector1, Vector2); +#else + return vec_sel(Vector2, Vector1, vec_cmpgt(Vector2, Vector1)); #endif } -inline +MLAS_FORCEINLINE MLAS_FLOAT32X4 MlasGreaterThanFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) { @@ -970,10 +1038,14 @@ MlasGreaterThanFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return vreinterpretq_f32_u32(vcgtq_f32(Vector1, Vector2)); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_cmpgt_ps(Vector1, Vector2); +#elif defined(MLAS_VSX_INTRINSICS) + return MLAS_FLOAT32X4(vec_cmpgt(Vector1, Vector2)); +#else +#error Unsupported architecture. #endif } -inline +MLAS_FORCEINLINE MLAS_FLOAT32X4 MlasAndFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) { @@ -981,10 +1053,12 @@ MlasAndFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(Vector1), vreinterpretq_u32_f32(Vector2))); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_and_ps(Vector1, Vector2); +#else + return MLAS_FLOAT32X4(MLAS_INT32X4(Vector1) & MLAS_INT32X4(Vector2)); #endif } -inline +MLAS_FORCEINLINE MLAS_FLOAT32X4 MlasOrFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) { @@ -992,10 +1066,12 @@ MlasOrFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(Vector1), vreinterpretq_u32_f32(Vector2))); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_or_ps(Vector1, Vector2); +#else + return MLAS_FLOAT32X4(MLAS_INT32X4(Vector1) | MLAS_INT32X4(Vector2)); #endif } -inline +MLAS_FORCEINLINE MLAS_FLOAT32X4 MlasAndNotFloat32x4(MLAS_FLOAT32X4 VectorNot, MLAS_FLOAT32X4 Vector) { @@ -1003,10 +1079,12 @@ MlasAndNotFloat32x4(MLAS_FLOAT32X4 VectorNot, MLAS_FLOAT32X4 Vector) return vreinterpretq_f32_u32(vandq_u32(vmvnq_u32(vreinterpretq_u32_f32(VectorNot)), vreinterpretq_u32_f32(Vector))); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_andnot_ps(VectorNot, Vector); +#else + return MLAS_FLOAT32X4(~MLAS_INT32X4(VectorNot) & MLAS_INT32X4(Vector)); #endif } -inline +MLAS_FORCEINLINE MLAS_FLOAT32X4 MlasXorFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) { @@ -1014,31 +1092,25 @@ MlasXorFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return vreinterpretq_f32_u32(veorq_u32(vreinterpretq_u32_f32(Vector1), vreinterpretq_u32_f32(Vector2))); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_xor_ps(Vector1, Vector2); +#else + return MLAS_FLOAT32X4(MLAS_INT32X4(Vector1) ^ MLAS_INT32X4(Vector2)); #endif } // calc 2^int(N) -inline +MLAS_FORCEINLINE MLAS_FLOAT32X4 MlasPowerOf2Float32x4(MLAS_FLOAT32X4 Vector) { #if defined(MLAS_NEON_INTRINSICS) - int32x4_t emm0 = vaddq_s32(vcvtq_s32_f32(Vector), vdupq_n_s32(0x7f)); + MLAS_INT32X4 emm0 = vaddq_s32(vcvtq_s32_f32(Vector), MlasBroadcastInt32x4(127)); return vreinterpretq_f32_s32(vshlq_n_s32(emm0, 23)); #elif defined(MLAS_SSE2_INTRINSICS) - __m128i emm0 = _mm_add_epi32(_mm_cvttps_epi32(Vector), _mm_set1_epi32(0x7f)); + MLAS_INT32X4 emm0 = _mm_add_epi32(_mm_cvttps_epi32(Vector), MlasBroadcastInt32x4(127)); return _mm_castsi128_ps(_mm_slli_epi32(emm0, 23)); -#endif -} - -inline -MLAS_INT32X4 -MlasBroadcastInt32x4(int32_t Value) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vdupq_n_s32(Value); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_set1_epi32(Value); +#elif defined(MLAS_VSX_INTRINSICS) + MLAS_INT32X4 emm0 = vec_cts(Vector, 0) + MlasBroadcastInt32x4(127); + return MLAS_FLOAT32X4(vec_sl(emm0, MLAS_UINT32X4(MlasBroadcastInt32x4(23)))); #endif } @@ -1054,43 +1126,7 @@ typedef __m128d MLAS_FLOAT64X2; #ifndef MLAS_FLOAT64X2_UNSUPPORTED -inline -MLAS_FLOAT64X2 -MlasZeroFloat64x2(void) -{ -#if defined(MLAS_SSE2_INTRINSICS) - return _mm_setzero_pd(); -#endif -} - -inline -MLAS_FLOAT64X2 -MlasLoadFloat64x2(const double* Buffer) -{ -#if defined(MLAS_SSE2_INTRINSICS) - return _mm_loadu_pd(Buffer); -#endif -} - -inline -void -MlasStoreFloat64x2(double* Buffer, MLAS_FLOAT64X2 Vector) -{ -#if defined(MLAS_SSE2_INTRINSICS) - _mm_storeu_pd(Buffer, Vector); -#endif -} - -inline -void -MlasStoreAlignedFloat64x2(double* Buffer, MLAS_FLOAT64X2 Vector) -{ -#if defined(MLAS_SSE2_INTRINSICS) - _mm_store_pd(Buffer, Vector); -#endif -} - -inline +MLAS_FORCEINLINE MLAS_FLOAT64X2 MlasBroadcastFloat64x2(double Value) { @@ -1099,7 +1135,43 @@ MlasBroadcastFloat64x2(double Value) #endif } -inline +MLAS_FORCEINLINE +MLAS_FLOAT64X2 +MlasZeroFloat64x2(void) +{ +#if defined(MLAS_SSE2_INTRINSICS) + return _mm_setzero_pd(); +#endif +} + +MLAS_FORCEINLINE +MLAS_FLOAT64X2 +MlasLoadFloat64x2(const double* Buffer) +{ +#if defined(MLAS_SSE2_INTRINSICS) + return _mm_loadu_pd(Buffer); +#endif +} + +MLAS_FORCEINLINE +void +MlasStoreFloat64x2(double* Buffer, MLAS_FLOAT64X2 Vector) +{ +#if defined(MLAS_SSE2_INTRINSICS) + _mm_storeu_pd(Buffer, Vector); +#endif +} + +MLAS_FORCEINLINE +void +MlasStoreAlignedFloat64x2(double* Buffer, MLAS_FLOAT64X2 Vector) +{ +#if defined(MLAS_SSE2_INTRINSICS) + _mm_store_pd(Buffer, Vector); +#endif +} + +MLAS_FORCEINLINE MLAS_FLOAT64X2 MlasMultiplyFloat64x2(MLAS_FLOAT64X2 Vector1, MLAS_FLOAT64X2 Vector2) { @@ -1114,7 +1186,7 @@ MlasMultiplyFloat64x2(MLAS_FLOAT64X2 Vector1, MLAS_FLOAT64X2 Vector2) // Reads a platform specific time stamp counter. // -inline +MLAS_FORCEINLINE uint64_t MlasReadTimeStampCounter(void) { diff --git a/onnxruntime/core/mlas/lib/pooling.cpp b/onnxruntime/core/mlas/lib/pooling.cpp index a0336ad527..06a89973a0 100644 --- a/onnxruntime/core/mlas/lib/pooling.cpp +++ b/onnxruntime/core/mlas/lib/pooling.cpp @@ -628,7 +628,7 @@ Return Value: break; } -#if defined(MLAS_NEON_INTRINSICS) +#if defined(MLAS_NEON_INTRINSICS) || defined(MLAS_VSX_INTRINSICS) MlasStoreLaneFloat32x4<0>(Output, Reduction); MlasStoreLaneFloat32x4<2>(Output + 1, Reduction); #elif defined(MLAS_SSE2_INTRINSICS) @@ -1000,7 +1000,7 @@ Return Value: break; } -#if defined(MLAS_NEON_INTRINSICS) +#if defined(MLAS_NEON_INTRINSICS) || defined(MLAS_VSX_INTRINSICS) MlasStoreLaneFloat32x4<0>(Output, Reduction); MlasStoreLaneFloat32x4<2>(Output + 1, Reduction); #elif defined(MLAS_SSE2_INTRINSICS) @@ -1102,6 +1102,13 @@ Return Value: float ReductionValue = _mm_cvtss_f32(Reduction); +#elif defined(MLAS_VSX_INTRINSICS) + + Reduction = PoolingType::Reduce(Reduction, MLAS_FLOAT32X4(vec_splat((__vector int64_t)Reduction, 1))); + Reduction = PoolingType::Reduce(Reduction, vec_splat(Reduction, 1)); + + float ReductionValue = Reduction[0]; + #else #error Unsupported architecture. #endif @@ -1232,7 +1239,7 @@ Return Value: //TODO: use a safeint here and make sure the result value can fit into int32_t size_t TotalChannelCount = size_t(InputShape[0]) * size_t(InputShape[1]); - + InputShape += 2; OutputShape += 2; @@ -1347,5 +1354,5 @@ Return Value: PoolKernelRoutine(&WorkBlock, 1, Input + c * InputSize, Output + c * OutputSize); }, 0); return; -#endif +#endif } diff --git a/onnxruntime/core/mlas/lib/power/SgemmKernelPower.cpp b/onnxruntime/core/mlas/lib/power/SgemmKernelPower.cpp new file mode 100644 index 0000000000..b56986e7a5 --- /dev/null +++ b/onnxruntime/core/mlas/lib/power/SgemmKernelPower.cpp @@ -0,0 +1,487 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SgemmKernelPower.cpp + +Abstract: + + This module implements the kernels for the single precision matrix/matrix + multiply operation (SGEMM). + +--*/ + +#include "mlasi.h" + +// +// Templates to ensure that a loop is unrolled. +// + +template +struct MlasLoopUnrollStep +{ + template + MLAS_FORCEINLINE + static + void + Step( + IterationArgs&&... Arguments + ) + { + IterationType::template Iteration(Arguments...); + MlasLoopUnrollStep::template Step(Arguments...); + } +}; + +template +struct MlasLoopUnrollStep +{ + template + MLAS_FORCEINLINE + static + void + Step( + IterationArgs&&... + ) + { + // Terminate the loop. + } +}; + +template +struct MlasLoopUnroll +{ + template + MLAS_FORCEINLINE + void + operator()( + IterationArgs&&... Arguments + ) + { + MlasLoopUnrollStep::template Step(Arguments...); + } +}; + +// +// Templates used with loop unrolling to perform an action on one row of the +// output. +// + +struct MlasSgemmZeroAccumulators +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOAT32X4 Accumulators[RowCount][4] + ) + { + Accumulators[Row][0] = MlasZeroFloat32x4(); + Accumulators[Row][1] = MlasZeroFloat32x4(); + Accumulators[Row][2] = MlasZeroFloat32x4(); + Accumulators[Row][3] = MlasZeroFloat32x4(); + } +}; + +struct MlasSgemmLoadAElements +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOAT32X4 AElements[RowCount], + const float* A, + size_t lda + ) + { + AElements[Row] = MlasLoadFloat32x4(A + Row * lda); + } +}; + +struct MlasSgemmBroadcastAElements +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOAT32X4 ABroadcast[RowCount], + const float* A, + size_t lda + ) + { + ABroadcast[Row] = MlasBroadcastFloat32x4(A + Row * lda); + } +}; + +template +struct MlasSgemmSplatAElements +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOAT32X4 AElements[RowCount], + MLAS_FLOAT32X4 ABroadcast[RowCount] + ) + { + ABroadcast[Row] = vec_splat(AElements[Row], Lane); + } +}; + +struct MlasSgemmMultiplyAddRow +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOAT32X4 Accumulators[RowCount][4], + MLAS_FLOAT32X4 ABroadcast[RowCount], + MLAS_FLOAT32X4 BElements[4] + ) + { + Accumulators[Row][0] = MlasMultiplyAddFloat32x4(ABroadcast[Row], BElements[0], Accumulators[Row][0]); + Accumulators[Row][1] = MlasMultiplyAddFloat32x4(ABroadcast[Row], BElements[1], Accumulators[Row][1]); + Accumulators[Row][2] = MlasMultiplyAddFloat32x4(ABroadcast[Row], BElements[2], Accumulators[Row][2]); + Accumulators[Row][3] = MlasMultiplyAddFloat32x4(ABroadcast[Row], BElements[3], Accumulators[Row][3]); + } +}; + +template +MLAS_FORCEINLINE +void +MlasSgemmComputeBlock( + MLAS_FLOAT32X4 Accumulators[RowCount][4], + MLAS_FLOAT32X4 ABroadcast[RowCount], + const float* B + ) +{ + MLAS_FLOAT32X4 BElements[4]; + + BElements[0] = MlasLoadFloat32x4(B); + BElements[1] = MlasLoadFloat32x4(B + 4); + BElements[2] = MlasLoadFloat32x4(B + 8); + BElements[3] = MlasLoadFloat32x4(B + 12); + + MlasLoopUnroll()(Accumulators, ABroadcast, BElements); +} + +struct MlasSgemmMultiplyAlphaRow +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOAT32X4 Accumulators[4], + MLAS_FLOAT32X4 AlphaBroadcast + ) + { + Accumulators[Index] = MlasMultiplyFloat32x4(Accumulators[Index], AlphaBroadcast); + } +}; + +struct MlasSgemmMultiplyAlphaAddRow +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOAT32X4 Accumulators[4], + MLAS_FLOAT32X4 AlphaBroadcast, + const float* C + ) + { + Accumulators[Index] = MlasMultiplyAddFloat32x4(Accumulators[Index], + AlphaBroadcast, MlasLoadFloat32x4(C + Index * 4)); + } +}; + +struct MlasSgemmStoreRow +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOAT32X4 Accumulators[4], + float* C + ) + { + MlasStoreFloat32x4(C + Index * 4, Accumulators[Index]); + } +}; + +template +struct MlasSgemmStoreVector +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOAT32X4 Accumulators[RowCount][4], + float* C, + size_t ldc, + MLAS_FLOAT32X4 AlphaBroadcast, + bool ZeroMode + ) + { + float* c = C + Row * ldc; + + if (ZeroMode) { + MlasLoopUnroll()(Accumulators[Row], AlphaBroadcast); + } else { + MlasLoopUnroll()(Accumulators[Row], AlphaBroadcast, c); + } + + MlasLoopUnroll()(Accumulators[Row], c); + + // + // Shift down any unaligned elements to the bottom for further processing. + // + + if (VectorCount < 4) { + Accumulators[Row][0] = Accumulators[Row][VectorCount]; + } + } +}; + +struct MlasSgemmMultiplyAlphaTrailing +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOAT32X4 Accumulators[RowCount][4], + MLAS_FLOAT32X4 AlphaBroadcast + ) + { + Accumulators[Row][0] = MlasMultiplyFloat32x4(Accumulators[Row][0], AlphaBroadcast); + } +}; + +template +struct MlasSgemmStoreScalar +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOAT32X4 Accumulators[RowCount][4], + float* C, + size_t ldc, + bool ZeroMode + ) + { + float* c = C + Row * ldc + Lane; + float Value = MlasExtractLaneFloat32x4(Accumulators[Row][0]); + + if (!ZeroMode) { + Value += *c; + } + + *c = Value; + } +}; + +template +MLAS_FORCEINLINE +size_t +MlasSgemmProcessCount( + const float* A, + const float* B, + float* C, + size_t CountK, + size_t CountN, + size_t lda, + size_t ldc, + MLAS_FLOAT32X4 AlphaBroadcast, + bool ZeroMode + ) +{ + do { + + const float* a = A; + size_t k = CountK; + + MLAS_FLOAT32X4 Accumulators[RowCount][4]; + MLAS_FLOAT32X4 AElements[RowCount]; + MLAS_FLOAT32X4 ABroadcast[RowCount]; + + // + // Clear the block accumulators. + // + + MlasLoopUnroll()(Accumulators); + + // + // Compute the output block. + // + + while (k >= 4) { + + MlasLoopUnroll()(AElements, a, lda); + + MlasLoopUnroll>()(AElements, ABroadcast); + MlasSgemmComputeBlock(Accumulators, ABroadcast, B); + + MlasLoopUnroll>()(AElements, ABroadcast); + MlasSgemmComputeBlock(Accumulators, ABroadcast, B + 16); + + MlasLoopUnroll>()(AElements, ABroadcast); + MlasSgemmComputeBlock(Accumulators, ABroadcast, B + 32); + + MlasLoopUnroll>()(AElements, ABroadcast); + MlasSgemmComputeBlock(Accumulators, ABroadcast, B + 48); + + a += 4; + B += 16 * 4; + k -= 4; + } + + while (k > 0) { + + MlasLoopUnroll()(ABroadcast, a, lda); + MlasSgemmComputeBlock(Accumulators, ABroadcast, B); + + a += 1; + B += 16; + k -= 1; + } + + if (CountN >= 16) { + + // + // Store the entire output block. + // + + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + + } else { + + // + // Store the partial output block. + // + + if (CountN >= 12) { + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + } else if (CountN >= 8) { + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + } else if (CountN >= 4) { + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + } + + // + // Store the remaining unaligned columns. + // + + C += (CountN & ~3); + CountN &= 3; + + if (CountN > 0) { + + MlasLoopUnroll()(Accumulators, AlphaBroadcast); + + MlasLoopUnroll>()(Accumulators, C, ldc, ZeroMode); + + if (CountN >= 2) { + MlasLoopUnroll>()(Accumulators, C, ldc, ZeroMode); + } + + if (CountN >= 3) { + MlasLoopUnroll>()(Accumulators, C, ldc, ZeroMode); + } + } + + break; + } + + C += 16; + CountN -= 16; + + } while (CountN > 0); + + return RowCount; +} + +size_t +MLASCALL +MlasSgemmKernel( + const float* A, + const float* B, + float* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + float alpha, + bool ZeroMode + ) +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A - Supplies the address of matrix A. + + B - Supplies the address of matrix B. The matrix data has been packed using + MlasSgemmCopyPackB or MlasSgemmTransposePackB. + + C - Supplies the address of matrix C. + + CountK - Supplies the number of columns from matrix A and the number of rows + from matrix B to iterate over. + + CountM - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN - Supplies the number of columns from matrix B and matrix C to + iterate over. + + lda - Supplies the first dimension of matrix A. + + ldc - Supplies the first dimension of matrix C. + + alpha - Supplies the scalar multiplier (see SGEMM definition). + + ZeroMode - Supplies true if the output matrix must be zero initialized, + else false if the output matrix is accumulated into. + +Return Value: + + Returns the number of rows handled. + +--*/ +{ + size_t RowsHandled; + + MLAS_FLOAT32X4 AlphaBroadcast = MlasBroadcastFloat32x4(alpha); + + if (CountM >= 4) { + RowsHandled = MlasSgemmProcessCount<4>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } else if (CountM >= 2) { + RowsHandled = MlasSgemmProcessCount<2>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } else { + RowsHandled = MlasSgemmProcessCount<1>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } + + return RowsHandled; +} diff --git a/onnxruntime/core/mlas/lib/sgemm.cpp b/onnxruntime/core/mlas/lib/sgemm.cpp index 4f6903f0b0..52816e22f7 100644 --- a/onnxruntime/core/mlas/lib/sgemm.cpp +++ b/onnxruntime/core/mlas/lib/sgemm.cpp @@ -434,6 +434,15 @@ Return Value: t1 = _mm_movehl_ps(z2, z0); t2 = _mm_movelh_ps(z1, z3); t3 = _mm_movehl_ps(z3, z1); +#elif defined(MLAS_VSX_INTRINSICS) + __vector float z0 = vec_mergeh(t0, t2); + __vector float z1 = vec_mergel(t0, t2); + __vector float z2 = vec_mergeh(t1, t3); + __vector float z3 = vec_mergel(t1, t3); + t0 = vec_mergeh(z0, z2); + t1 = vec_mergel(z0, z2); + t2 = vec_mergeh(z1, z3); + t3 = vec_mergel(z1, z3); #else #error Unsupported architecture. #endif @@ -627,10 +636,10 @@ Return Value: if ((CountY & 2) != 0) { -#if defined(MLAS_NEON_INTRINSICS) MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&b[0]); MLAS_FLOAT32X4 t1 = MlasLoadFloat32x4(&b[ldb]); +#if defined(MLAS_NEON_INTRINSICS) || defined(MLAS_VSX_INTRINSICS) MlasStoreLaneFloat32x4<0>(&d[0], t0); MlasStoreLaneFloat32x4<0>(&d[1], t1); MlasStoreLaneFloat32x4<1>(&d[16], t0); @@ -640,9 +649,6 @@ Return Value: MlasStoreLaneFloat32x4<3>(&d[48], t0); MlasStoreLaneFloat32x4<3>(&d[49], t1); #elif defined(MLAS_SSE2_INTRINSICS) - MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&b[0]); - MLAS_FLOAT32X4 t1 = MlasLoadFloat32x4(&b[ldb]); - __m128 v0 = _mm_unpacklo_ps(t0, t1); __m128 v1 = _mm_unpackhi_ps(t0, t1); _mm_storel_pi((__m64*)&d[0], v0); @@ -762,6 +768,85 @@ Return Value: } } +MLAS_FORCEINLINE +float* +MlasSgemmKernelLoop( + const float* A, + const float* B, + float* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + float alpha, + bool ZeroMode + ) +/*++ + +Routine Description: + + This routine steps through the rows of the input and output matrices calling + the kernel until all rows have been processed. + +Arguments: + + A - Supplies the address of matrix A. + + B - Supplies the address of matrix B. The matrix data has been packed using + MlasSgemmCopyPackB or MlasSgemmTransposePackB. + + C - Supplies the address of matrix C. + + CountK - Supplies the number of columns from matrix A and the number of rows + from matrix B to iterate over. + + CountM - Supplies the number of rows from matrix A and matrix C to iterate + over. + + CountN - Supplies the number of columns from matrix B and matrix C to + iterate over. + + lda - Supplies the first dimension of matrix A. + + ldc - Supplies the first dimension of matrix C. + + alpha - Supplies the scalar alpha multiplier (see SGEMM definition). + + ZeroMode - Supplies true if the output matrix must be zero initialized, + else false if the output matrix is accumulated into. + +Return Value: + + Returns the next address of matrix C. + +--*/ +{ + do { + + size_t RowsHandled; + +#if defined(MLAS_TARGET_AMD64_IX86) + RowsHandled = MlasPlatform.GemmFloatKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode); +#elif defined(MLAS_TARGET_POWER) + RowsHandled = MlasSgemmKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode); +#else + if (ZeroMode) { + RowsHandled = MlasSgemmKernelZero(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); + } else { + RowsHandled = MlasSgemmKernelAdd(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); + } +#endif + + C += ldc * RowsHandled; + A += lda * RowsHandled; + CountM -= RowsHandled; + + } while (CountM > 0); + + return C; +} + void MlasSgemmOperation( CBLAS_TRANSPOSE TransA, @@ -856,6 +941,7 @@ Return Value: // Transpose(A*B) = Transpose(B) * Transpose(A), we can apply the same 'small-M' // optimization as above, with A and B flipped. // + if (N == 1 && ldb == 1 && ldc == 1 && alpha == 1.0f && (beta == 0.0f || beta == 1.0f)) { #if defined(MLAS_TARGET_AMD64) @@ -877,7 +963,6 @@ Return Value: } - // // Compute the strides to step through slices of the input matrices. // @@ -931,9 +1016,9 @@ Return Value: // Step through each slice of matrix B along the K dimension. // - for (size_t k = 0; k < K; k += CountK) { + bool ZeroMode = (beta == 0.0f); - bool ZeroMode = (k == 0 && beta == 0.0f); + for (size_t k = 0; k < K; k += CountK) { CountK = StrideK; @@ -957,39 +1042,14 @@ Return Value: float* c = C + n; - size_t RowsRemaining = M; - size_t RowsHandled; - if (TransA == CblasNoTrans) { - const float* a = A + k; - - // - // Step through the rows of matrix A. - // - - do { - -#if defined(MLAS_TARGET_AMD64_IX86) - RowsHandled = MlasPlatform.GemmFloatKernel(a, PanelB, c, CountK, RowsRemaining, CountN, lda, ldc, alpha, ZeroMode); -#else - if (ZeroMode) { - RowsHandled = MlasSgemmKernelZero(a, PanelB, c, CountK, RowsRemaining, CountN, lda, ldc, alpha); - } else { - RowsHandled = MlasSgemmKernelAdd(a, PanelB, c, CountK, RowsRemaining, CountN, lda, ldc, alpha); - } -#endif - - c += ldc * RowsHandled; - a += lda * RowsHandled; - - RowsRemaining -= RowsHandled; - - } while (RowsRemaining > 0); + MlasSgemmKernelLoop(A + k, PanelB, c, CountK, M, CountN, lda, ldc, alpha, ZeroMode); } else { const float* a = A + k * lda; + size_t RowsRemaining = M; do { @@ -1003,39 +1063,21 @@ Return Value: RowsTransposed = MLAS_SGEMM_TRANSA_ROWS; } - RowsRemaining -= RowsTransposed; - MlasSgemmTransposeA(PanelA, a, lda, RowsTransposed, CountK); + RowsRemaining -= RowsTransposed; a += RowsTransposed; // // Step through the rows of the local buffer. // - const float* pa = PanelA; - - do { - -#if defined(MLAS_TARGET_AMD64_IX86) - RowsHandled = MlasPlatform.GemmFloatKernel(pa, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha, ZeroMode); -#else - if (ZeroMode) { - RowsHandled = MlasSgemmKernelZero(pa, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha); - } else { - RowsHandled = MlasSgemmKernelAdd(pa, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha); - } -#endif - - c += ldc * RowsHandled; - pa += CountK * RowsHandled; - - RowsTransposed -= RowsHandled; - - } while (RowsTransposed > 0); + c = MlasSgemmKernelLoop(PanelA, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha, ZeroMode); } while (RowsRemaining > 0); } + + ZeroMode = false; } } } diff --git a/onnxruntime/test/mlas/unittest.cpp b/onnxruntime/test/mlas/unittest.cpp index 6f9f6b74cf..c7be47d0f2 100644 --- a/onnxruntime/test/mlas/unittest.cpp +++ b/onnxruntime/test/mlas/unittest.cpp @@ -1939,7 +1939,7 @@ public: Buffer[i].u = TestData[i][0].u; } - MlasActivation(&Activation, &Buffer[0].f, nullptr, _countof(Buffer), 1, 1); + MlasActivation(&Activation, &Buffer[0].f, nullptr, 1, _countof(Buffer), _countof(Buffer)); for (unsigned i = 0; i < _countof(TestData); i++) { // Sensitive to comparing positive/negative zero and NaNs.