mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
Port MLAS to Power architecture (#3703)
Updates to MLAS to support building for the Power architecture.
This commit is contained in:
parent
e22d97ba56
commit
bf1caba2b2
7 changed files with 774 additions and 158 deletions
|
|
@ -115,6 +115,8 @@ else()
|
|||
set(ARM TRUE)
|
||||
elseif(dumpmachine_output MATCHES "^aarch64.*")
|
||||
set(ARM64 TRUE)
|
||||
elseif(dumpmachine_output MATCHES "^powerpc.*")
|
||||
set(POWER TRUE)
|
||||
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$")
|
||||
set(X86 TRUE)
|
||||
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$")
|
||||
|
|
@ -133,6 +135,10 @@ else()
|
|||
set(mlas_platform_srcs
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/aarch64/SgemmKernelNeon.S
|
||||
)
|
||||
elseif(POWER)
|
||||
set(mlas_platform_srcs
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/power/SgemmKernelPower.cpp
|
||||
)
|
||||
elseif(X86)
|
||||
enable_language(ASM)
|
||||
set(mlas_platform_srcs_sse2
|
||||
|
|
|
|||
|
|
@ -142,6 +142,8 @@ struct MLAS_ACTIVATION_FUNCTION<MlasLeakyReluActivation>
|
|||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
__m128 Selection = _mm_cmple_ps(ZeroFloat32x4, Value);
|
||||
return _mm_or_ps(_mm_and_ps(Value, Selection), _mm_andnot_ps(Selection, ValueTimesAlpha));
|
||||
#elif defined(MLAS_VSX_INTRINSICS)
|
||||
return vec_sel(ValueTimesAlpha, Value, vec_cmple(ZeroFloat32x4, Value));
|
||||
#else
|
||||
#error Unsupported architecture.
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -34,6 +34,13 @@ Abstract:
|
|||
#include <cpuid.h>
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
#if defined(__VSX__)
|
||||
#include <altivec.h>
|
||||
// Undefine unwanted aliases from altivec.h.
|
||||
#undef vector
|
||||
#undef pixel
|
||||
#undef bool
|
||||
#endif
|
||||
#endif
|
||||
|
||||
//
|
||||
|
|
@ -93,6 +100,9 @@ Abstract:
|
|||
#if defined(_M_ARM) || defined(__arm__)
|
||||
#define MLAS_TARGET_ARM
|
||||
#endif
|
||||
#if defined(__VSX__)
|
||||
#define MLAS_TARGET_POWER
|
||||
#endif
|
||||
|
||||
//
|
||||
// Select the threading model.
|
||||
|
|
@ -142,7 +152,7 @@ Abstract:
|
|||
// Define the prototypes of the platform optimized routines.
|
||||
//
|
||||
|
||||
#if defined(MLAS_TARGET_AMD64_IX86)
|
||||
#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER)
|
||||
|
||||
typedef
|
||||
size_t
|
||||
|
|
@ -469,6 +479,8 @@ extern "C" {
|
|||
MLAS_GEMM_DOUBLE_KERNEL MlasGemmDoubleKernelFma3;
|
||||
MLAS_GEMM_DOUBLE_KERNEL MlasGemmDoubleKernelAvx512F;
|
||||
#endif
|
||||
#elif defined(MLAS_TARGET_POWER)
|
||||
MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernel;
|
||||
#else
|
||||
MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelZero;
|
||||
MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelAdd;
|
||||
|
|
@ -728,6 +740,8 @@ MlasPartitionWork(
|
|||
#elif defined(MLAS_TARGET_ARM64)
|
||||
#define MLAS_NEON_INTRINSICS
|
||||
#define MLAS_NEON64_INTRINSICS
|
||||
#elif defined(MLAS_TARGET_POWER)
|
||||
#define MLAS_VSX_INTRINSICS
|
||||
#elif defined(MLAS_TARGET_AMD64_IX86)
|
||||
#define MLAS_SSE2_INTRINSICS
|
||||
#if defined(__SSE4_1__) || (defined(_MSC_VER) && defined(__AVX__))
|
||||
|
|
@ -752,9 +766,52 @@ typedef int32x4_t MLAS_INT32X4;
|
|||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
typedef __m128 MLAS_FLOAT32X4;
|
||||
typedef __m128i MLAS_INT32X4;
|
||||
#elif defined(MLAS_VSX_INTRINSICS)
|
||||
typedef __vector float MLAS_FLOAT32X4;
|
||||
typedef __vector int MLAS_INT32X4;
|
||||
typedef __vector unsigned MLAS_UINT32X4;
|
||||
#endif
|
||||
|
||||
inline
|
||||
MLAS_FORCEINLINE
|
||||
MLAS_INT32X4
|
||||
MlasBroadcastInt32x4(int32_t Value)
|
||||
{
|
||||
#if defined(MLAS_NEON_INTRINSICS)
|
||||
return vdupq_n_s32(Value);
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_set1_epi32(Value);
|
||||
#else
|
||||
return MLAS_INT32X4{Value, Value, Value, Value};
|
||||
#endif
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
MLAS_FLOAT32X4
|
||||
MlasBroadcastFloat32x4(float Value)
|
||||
{
|
||||
#if defined(MLAS_NEON_INTRINSICS)
|
||||
return vdupq_n_f32(Value);
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_set1_ps(Value);
|
||||
#else
|
||||
return MLAS_FLOAT32X4{Value, Value, Value, Value};
|
||||
#endif
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
MLAS_FLOAT32X4
|
||||
MlasBroadcastFloat32x4(const float* Value)
|
||||
{
|
||||
#if defined(MLAS_NEON_INTRINSICS)
|
||||
return vld1q_dup_f32(Value);
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_load_ps1(Value);
|
||||
#else
|
||||
return MLAS_FLOAT32X4{*Value, *Value, *Value, *Value};
|
||||
#endif
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
MLAS_FLOAT32X4
|
||||
MlasZeroFloat32x4(void)
|
||||
{
|
||||
|
|
@ -762,10 +819,12 @@ MlasZeroFloat32x4(void)
|
|||
return vdupq_n_f32(0.0f);
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_setzero_ps();
|
||||
#else
|
||||
return MlasBroadcastFloat32x4(0.0f);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_FORCEINLINE
|
||||
MLAS_FLOAT32X4
|
||||
MlasLoadFloat32x4(const float* Buffer)
|
||||
{
|
||||
|
|
@ -773,10 +832,12 @@ MlasLoadFloat32x4(const float* Buffer)
|
|||
return vld1q_f32(Buffer);
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_loadu_ps(Buffer);
|
||||
#elif defined(MLAS_VSX_INTRINSICS)
|
||||
return vec_vsx_ld(0, Buffer);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
MlasStoreFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector)
|
||||
{
|
||||
|
|
@ -784,10 +845,12 @@ MlasStoreFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector)
|
|||
vst1q_f32(Buffer, Vector);
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
_mm_storeu_ps(Buffer, Vector);
|
||||
#elif defined(MLAS_VSX_INTRINSICS)
|
||||
vec_vsx_st(Vector, 0, Buffer);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
MlasStoreAlignedFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector)
|
||||
{
|
||||
|
|
@ -795,10 +858,15 @@ MlasStoreAlignedFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector)
|
|||
vst1q_f32(Buffer, Vector);
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
_mm_store_ps(Buffer, Vector);
|
||||
#elif defined(MLAS_VSX_INTRINSICS)
|
||||
// Workaround for bad GCC warning that these parameters are set but not used.
|
||||
MLAS_UNREFERENCED_PARAMETER(Buffer);
|
||||
MLAS_UNREFERENCED_PARAMETER(Vector);
|
||||
vec_st(Vector, 0, Buffer);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
MlasStoreLowHalfFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector)
|
||||
{
|
||||
|
|
@ -806,11 +874,13 @@ MlasStoreLowHalfFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector)
|
|||
vst1_f32(Buffer, vget_low_f32(Vector));
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
_mm_storel_pi((__m64*)Buffer, Vector);
|
||||
#elif defined(MLAS_VSX_INTRINSICS)
|
||||
*((int64_t*)Buffer) = ((__vector int64_t)Vector)[0];
|
||||
#endif
|
||||
}
|
||||
|
||||
template<unsigned Lane>
|
||||
inline
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
MlasStoreLaneFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector)
|
||||
{
|
||||
|
|
@ -820,11 +890,13 @@ MlasStoreLaneFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector)
|
|||
// N.B. When building with AVX instructions, compilers optimize the following
|
||||
// to a single vextractps instruction.
|
||||
_mm_store_ss(Buffer, _mm_shuffle_ps(Vector, Vector, _MM_SHUFFLE(Lane, Lane, Lane, Lane)));
|
||||
#else
|
||||
*Buffer = Vector[Lane];
|
||||
#endif
|
||||
}
|
||||
|
||||
template<unsigned Lane>
|
||||
inline
|
||||
MLAS_FORCEINLINE
|
||||
float
|
||||
MlasExtractLaneFloat32x4(MLAS_FLOAT32X4 Vector)
|
||||
{
|
||||
|
|
@ -832,13 +904,15 @@ MlasExtractLaneFloat32x4(MLAS_FLOAT32X4 Vector)
|
|||
return vgetq_lane_f32(Vector, Lane);
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_cvtss_f32(_mm_shuffle_ps(Vector, Vector, _MM_SHUFFLE(Lane, Lane, Lane, Lane)));
|
||||
#else
|
||||
return Vector[Lane];
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(MLAS_SSE2_INTRINSICS)
|
||||
|
||||
template<>
|
||||
inline
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
MlasStoreLaneFloat32x4<0>(float* Buffer, MLAS_FLOAT32X4 Vector)
|
||||
{
|
||||
|
|
@ -846,7 +920,7 @@ MlasStoreLaneFloat32x4<0>(float* Buffer, MLAS_FLOAT32X4 Vector)
|
|||
}
|
||||
|
||||
template<>
|
||||
inline
|
||||
MLAS_FORCEINLINE
|
||||
float
|
||||
MlasExtractLaneFloat32x4<0>(MLAS_FLOAT32X4 Vector)
|
||||
{
|
||||
|
|
@ -855,29 +929,7 @@ MlasExtractLaneFloat32x4<0>(MLAS_FLOAT32X4 Vector)
|
|||
|
||||
#endif
|
||||
|
||||
inline
|
||||
MLAS_FLOAT32X4
|
||||
MlasBroadcastFloat32x4(float Value)
|
||||
{
|
||||
#if defined(MLAS_NEON_INTRINSICS)
|
||||
return vdupq_n_f32(Value);
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_set1_ps(Value);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_FLOAT32X4
|
||||
MlasBroadcastFloat32x4(const float* Value)
|
||||
{
|
||||
#if defined(MLAS_NEON_INTRINSICS)
|
||||
return vld1q_dup_f32(Value);
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_load_ps1(Value);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_FORCEINLINE
|
||||
MLAS_FLOAT32X4
|
||||
MlasAddFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
|
||||
{
|
||||
|
|
@ -885,10 +937,12 @@ MlasAddFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
|
|||
return vaddq_f32(Vector1, Vector2);
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_add_ps(Vector1, Vector2);
|
||||
#else
|
||||
return Vector1 + Vector2;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_FORCEINLINE
|
||||
MLAS_FLOAT32X4
|
||||
MlasSubtractFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
|
||||
{
|
||||
|
|
@ -896,10 +950,12 @@ MlasSubtractFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
|
|||
return vsubq_f32(Vector1, Vector2);
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_sub_ps(Vector1, Vector2);
|
||||
#else
|
||||
return Vector1 - Vector2;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_FORCEINLINE
|
||||
MLAS_FLOAT32X4
|
||||
MlasMultiplyFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
|
||||
{
|
||||
|
|
@ -907,10 +963,12 @@ MlasMultiplyFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
|
|||
return vmulq_f32(Vector1, Vector2);
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_mul_ps(Vector1, Vector2);
|
||||
#else
|
||||
return Vector1 * Vector2;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_FORCEINLINE
|
||||
MLAS_FLOAT32X4
|
||||
MlasMultiplyAddFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2, MLAS_FLOAT32X4 Vector3)
|
||||
{
|
||||
|
|
@ -920,10 +978,14 @@ MlasMultiplyAddFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2, MLAS_FL
|
|||
return _mm_fmadd_ps(Vector1, Vector2, Vector3);
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_add_ps(_mm_mul_ps(Vector1, Vector2), Vector3);
|
||||
#elif defined(MLAS_VSX_INTRINSICS)
|
||||
return vec_madd(Vector1, Vector2, Vector3);
|
||||
#else
|
||||
return Vector1 * Vector2 + Vector3;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_FORCEINLINE
|
||||
MLAS_FLOAT32X4
|
||||
MlasDivideFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
|
||||
{
|
||||
|
|
@ -937,10 +999,12 @@ MlasDivideFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
|
|||
return Vector1;
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_div_ps(Vector1, Vector2);
|
||||
#else
|
||||
return Vector1 / Vector2;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_FORCEINLINE
|
||||
MLAS_FLOAT32X4
|
||||
MlasMaximumFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
|
||||
{
|
||||
|
|
@ -948,10 +1012,12 @@ MlasMaximumFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
|
|||
return vmaxq_f32(Vector1, Vector2);
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_max_ps(Vector1, Vector2);
|
||||
#else
|
||||
return vec_sel(Vector2, Vector1, vec_cmpgt(Vector1, Vector2));
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_FORCEINLINE
|
||||
MLAS_FLOAT32X4
|
||||
MlasMinimumFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
|
||||
{
|
||||
|
|
@ -959,10 +1025,12 @@ MlasMinimumFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
|
|||
return vminq_f32(Vector1, Vector2);
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_min_ps(Vector1, Vector2);
|
||||
#else
|
||||
return vec_sel(Vector2, Vector1, vec_cmpgt(Vector2, Vector1));
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_FORCEINLINE
|
||||
MLAS_FLOAT32X4
|
||||
MlasGreaterThanFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
|
||||
{
|
||||
|
|
@ -970,10 +1038,14 @@ MlasGreaterThanFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
|
|||
return vreinterpretq_f32_u32(vcgtq_f32(Vector1, Vector2));
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_cmpgt_ps(Vector1, Vector2);
|
||||
#elif defined(MLAS_VSX_INTRINSICS)
|
||||
return MLAS_FLOAT32X4(vec_cmpgt(Vector1, Vector2));
|
||||
#else
|
||||
#error Unsupported architecture.
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_FORCEINLINE
|
||||
MLAS_FLOAT32X4
|
||||
MlasAndFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
|
||||
{
|
||||
|
|
@ -981,10 +1053,12 @@ MlasAndFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
|
|||
return vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(Vector1), vreinterpretq_u32_f32(Vector2)));
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_and_ps(Vector1, Vector2);
|
||||
#else
|
||||
return MLAS_FLOAT32X4(MLAS_INT32X4(Vector1) & MLAS_INT32X4(Vector2));
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_FORCEINLINE
|
||||
MLAS_FLOAT32X4
|
||||
MlasOrFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
|
||||
{
|
||||
|
|
@ -992,10 +1066,12 @@ MlasOrFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
|
|||
return vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(Vector1), vreinterpretq_u32_f32(Vector2)));
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_or_ps(Vector1, Vector2);
|
||||
#else
|
||||
return MLAS_FLOAT32X4(MLAS_INT32X4(Vector1) | MLAS_INT32X4(Vector2));
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_FORCEINLINE
|
||||
MLAS_FLOAT32X4
|
||||
MlasAndNotFloat32x4(MLAS_FLOAT32X4 VectorNot, MLAS_FLOAT32X4 Vector)
|
||||
{
|
||||
|
|
@ -1003,10 +1079,12 @@ MlasAndNotFloat32x4(MLAS_FLOAT32X4 VectorNot, MLAS_FLOAT32X4 Vector)
|
|||
return vreinterpretq_f32_u32(vandq_u32(vmvnq_u32(vreinterpretq_u32_f32(VectorNot)), vreinterpretq_u32_f32(Vector)));
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_andnot_ps(VectorNot, Vector);
|
||||
#else
|
||||
return MLAS_FLOAT32X4(~MLAS_INT32X4(VectorNot) & MLAS_INT32X4(Vector));
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_FORCEINLINE
|
||||
MLAS_FLOAT32X4
|
||||
MlasXorFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
|
||||
{
|
||||
|
|
@ -1014,31 +1092,25 @@ MlasXorFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
|
|||
return vreinterpretq_f32_u32(veorq_u32(vreinterpretq_u32_f32(Vector1), vreinterpretq_u32_f32(Vector2)));
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_xor_ps(Vector1, Vector2);
|
||||
#else
|
||||
return MLAS_FLOAT32X4(MLAS_INT32X4(Vector1) ^ MLAS_INT32X4(Vector2));
|
||||
#endif
|
||||
}
|
||||
|
||||
// calc 2^int(N)
|
||||
inline
|
||||
MLAS_FORCEINLINE
|
||||
MLAS_FLOAT32X4
|
||||
MlasPowerOf2Float32x4(MLAS_FLOAT32X4 Vector)
|
||||
{
|
||||
#if defined(MLAS_NEON_INTRINSICS)
|
||||
int32x4_t emm0 = vaddq_s32(vcvtq_s32_f32(Vector), vdupq_n_s32(0x7f));
|
||||
MLAS_INT32X4 emm0 = vaddq_s32(vcvtq_s32_f32(Vector), MlasBroadcastInt32x4(127));
|
||||
return vreinterpretq_f32_s32(vshlq_n_s32(emm0, 23));
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
__m128i emm0 = _mm_add_epi32(_mm_cvttps_epi32(Vector), _mm_set1_epi32(0x7f));
|
||||
MLAS_INT32X4 emm0 = _mm_add_epi32(_mm_cvttps_epi32(Vector), MlasBroadcastInt32x4(127));
|
||||
return _mm_castsi128_ps(_mm_slli_epi32(emm0, 23));
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_INT32X4
|
||||
MlasBroadcastInt32x4(int32_t Value)
|
||||
{
|
||||
#if defined(MLAS_NEON_INTRINSICS)
|
||||
return vdupq_n_s32(Value);
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_set1_epi32(Value);
|
||||
#elif defined(MLAS_VSX_INTRINSICS)
|
||||
MLAS_INT32X4 emm0 = vec_cts(Vector, 0) + MlasBroadcastInt32x4(127);
|
||||
return MLAS_FLOAT32X4(vec_sl(emm0, MLAS_UINT32X4(MlasBroadcastInt32x4(23))));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
@ -1054,43 +1126,7 @@ typedef __m128d MLAS_FLOAT64X2;
|
|||
|
||||
#ifndef MLAS_FLOAT64X2_UNSUPPORTED
|
||||
|
||||
inline
|
||||
MLAS_FLOAT64X2
|
||||
MlasZeroFloat64x2(void)
|
||||
{
|
||||
#if defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_setzero_pd();
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_FLOAT64X2
|
||||
MlasLoadFloat64x2(const double* Buffer)
|
||||
{
|
||||
#if defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_loadu_pd(Buffer);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
void
|
||||
MlasStoreFloat64x2(double* Buffer, MLAS_FLOAT64X2 Vector)
|
||||
{
|
||||
#if defined(MLAS_SSE2_INTRINSICS)
|
||||
_mm_storeu_pd(Buffer, Vector);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
void
|
||||
MlasStoreAlignedFloat64x2(double* Buffer, MLAS_FLOAT64X2 Vector)
|
||||
{
|
||||
#if defined(MLAS_SSE2_INTRINSICS)
|
||||
_mm_store_pd(Buffer, Vector);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_FORCEINLINE
|
||||
MLAS_FLOAT64X2
|
||||
MlasBroadcastFloat64x2(double Value)
|
||||
{
|
||||
|
|
@ -1099,7 +1135,43 @@ MlasBroadcastFloat64x2(double Value)
|
|||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_FORCEINLINE
|
||||
MLAS_FLOAT64X2
|
||||
MlasZeroFloat64x2(void)
|
||||
{
|
||||
#if defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_setzero_pd();
|
||||
#endif
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
MLAS_FLOAT64X2
|
||||
MlasLoadFloat64x2(const double* Buffer)
|
||||
{
|
||||
#if defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_loadu_pd(Buffer);
|
||||
#endif
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
MlasStoreFloat64x2(double* Buffer, MLAS_FLOAT64X2 Vector)
|
||||
{
|
||||
#if defined(MLAS_SSE2_INTRINSICS)
|
||||
_mm_storeu_pd(Buffer, Vector);
|
||||
#endif
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
MlasStoreAlignedFloat64x2(double* Buffer, MLAS_FLOAT64X2 Vector)
|
||||
{
|
||||
#if defined(MLAS_SSE2_INTRINSICS)
|
||||
_mm_store_pd(Buffer, Vector);
|
||||
#endif
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
MLAS_FLOAT64X2
|
||||
MlasMultiplyFloat64x2(MLAS_FLOAT64X2 Vector1, MLAS_FLOAT64X2 Vector2)
|
||||
{
|
||||
|
|
@ -1114,7 +1186,7 @@ MlasMultiplyFloat64x2(MLAS_FLOAT64X2 Vector1, MLAS_FLOAT64X2 Vector2)
|
|||
// Reads a platform specific time stamp counter.
|
||||
//
|
||||
|
||||
inline
|
||||
MLAS_FORCEINLINE
|
||||
uint64_t
|
||||
MlasReadTimeStampCounter(void)
|
||||
{
|
||||
|
|
|
|||
|
|
@ -628,7 +628,7 @@ Return Value:
|
|||
break;
|
||||
}
|
||||
|
||||
#if defined(MLAS_NEON_INTRINSICS)
|
||||
#if defined(MLAS_NEON_INTRINSICS) || defined(MLAS_VSX_INTRINSICS)
|
||||
MlasStoreLaneFloat32x4<0>(Output, Reduction);
|
||||
MlasStoreLaneFloat32x4<2>(Output + 1, Reduction);
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
|
|
@ -1000,7 +1000,7 @@ Return Value:
|
|||
break;
|
||||
}
|
||||
|
||||
#if defined(MLAS_NEON_INTRINSICS)
|
||||
#if defined(MLAS_NEON_INTRINSICS) || defined(MLAS_VSX_INTRINSICS)
|
||||
MlasStoreLaneFloat32x4<0>(Output, Reduction);
|
||||
MlasStoreLaneFloat32x4<2>(Output + 1, Reduction);
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
|
|
@ -1102,6 +1102,13 @@ Return Value:
|
|||
|
||||
float ReductionValue = _mm_cvtss_f32(Reduction);
|
||||
|
||||
#elif defined(MLAS_VSX_INTRINSICS)
|
||||
|
||||
Reduction = PoolingType::Reduce(Reduction, MLAS_FLOAT32X4(vec_splat((__vector int64_t)Reduction, 1)));
|
||||
Reduction = PoolingType::Reduce(Reduction, vec_splat(Reduction, 1));
|
||||
|
||||
float ReductionValue = Reduction[0];
|
||||
|
||||
#else
|
||||
#error Unsupported architecture.
|
||||
#endif
|
||||
|
|
@ -1232,7 +1239,7 @@ Return Value:
|
|||
|
||||
//TODO: use a safeint here and make sure the result value can fit into int32_t
|
||||
size_t TotalChannelCount = size_t(InputShape[0]) * size_t(InputShape[1]);
|
||||
|
||||
|
||||
|
||||
InputShape += 2;
|
||||
OutputShape += 2;
|
||||
|
|
@ -1347,5 +1354,5 @@ Return Value:
|
|||
PoolKernelRoutine(&WorkBlock, 1, Input + c * InputSize, Output + c * OutputSize);
|
||||
}, 0);
|
||||
return;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
|
|
|||
487
onnxruntime/core/mlas/lib/power/SgemmKernelPower.cpp
Normal file
487
onnxruntime/core/mlas/lib/power/SgemmKernelPower.cpp
Normal file
|
|
@ -0,0 +1,487 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
SgemmKernelPower.cpp
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements the kernels for the single precision matrix/matrix
|
||||
multiply operation (SGEMM).
|
||||
|
||||
--*/
|
||||
|
||||
#include "mlasi.h"
|
||||
|
||||
//
|
||||
// Templates to ensure that a loop is unrolled.
|
||||
//
|
||||
|
||||
template<size_t Count, size_t Index>
|
||||
struct MlasLoopUnrollStep
|
||||
{
|
||||
template<typename IterationType, typename... IterationArgs>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Step(
|
||||
IterationArgs&&... Arguments
|
||||
)
|
||||
{
|
||||
IterationType::template Iteration<Count, Index>(Arguments...);
|
||||
MlasLoopUnrollStep<Count, Index + 1>::template Step<IterationType>(Arguments...);
|
||||
}
|
||||
};
|
||||
|
||||
template<size_t Count>
|
||||
struct MlasLoopUnrollStep<Count, Count>
|
||||
{
|
||||
template<typename IterationType, typename... IterationArgs>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Step(
|
||||
IterationArgs&&...
|
||||
)
|
||||
{
|
||||
// Terminate the loop.
|
||||
}
|
||||
};
|
||||
|
||||
template<size_t Count, typename IteratorType>
|
||||
struct MlasLoopUnroll
|
||||
{
|
||||
template<typename... IterationArgs>
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
operator()(
|
||||
IterationArgs&&... Arguments
|
||||
)
|
||||
{
|
||||
MlasLoopUnrollStep<Count, 0>::template Step<IteratorType>(Arguments...);
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Templates used with loop unrolling to perform an action on one row of the
|
||||
// output.
|
||||
//
|
||||
|
||||
struct MlasSgemmZeroAccumulators
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT32X4 Accumulators[RowCount][4]
|
||||
)
|
||||
{
|
||||
Accumulators[Row][0] = MlasZeroFloat32x4();
|
||||
Accumulators[Row][1] = MlasZeroFloat32x4();
|
||||
Accumulators[Row][2] = MlasZeroFloat32x4();
|
||||
Accumulators[Row][3] = MlasZeroFloat32x4();
|
||||
}
|
||||
};
|
||||
|
||||
struct MlasSgemmLoadAElements
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT32X4 AElements[RowCount],
|
||||
const float* A,
|
||||
size_t lda
|
||||
)
|
||||
{
|
||||
AElements[Row] = MlasLoadFloat32x4(A + Row * lda);
|
||||
}
|
||||
};
|
||||
|
||||
struct MlasSgemmBroadcastAElements
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT32X4 ABroadcast[RowCount],
|
||||
const float* A,
|
||||
size_t lda
|
||||
)
|
||||
{
|
||||
ABroadcast[Row] = MlasBroadcastFloat32x4(A + Row * lda);
|
||||
}
|
||||
};
|
||||
|
||||
template<unsigned Lane>
|
||||
struct MlasSgemmSplatAElements
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT32X4 AElements[RowCount],
|
||||
MLAS_FLOAT32X4 ABroadcast[RowCount]
|
||||
)
|
||||
{
|
||||
ABroadcast[Row] = vec_splat(AElements[Row], Lane);
|
||||
}
|
||||
};
|
||||
|
||||
struct MlasSgemmMultiplyAddRow
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT32X4 Accumulators[RowCount][4],
|
||||
MLAS_FLOAT32X4 ABroadcast[RowCount],
|
||||
MLAS_FLOAT32X4 BElements[4]
|
||||
)
|
||||
{
|
||||
Accumulators[Row][0] = MlasMultiplyAddFloat32x4(ABroadcast[Row], BElements[0], Accumulators[Row][0]);
|
||||
Accumulators[Row][1] = MlasMultiplyAddFloat32x4(ABroadcast[Row], BElements[1], Accumulators[Row][1]);
|
||||
Accumulators[Row][2] = MlasMultiplyAddFloat32x4(ABroadcast[Row], BElements[2], Accumulators[Row][2]);
|
||||
Accumulators[Row][3] = MlasMultiplyAddFloat32x4(ABroadcast[Row], BElements[3], Accumulators[Row][3]);
|
||||
}
|
||||
};
|
||||
|
||||
template<size_t RowCount>
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
MlasSgemmComputeBlock(
|
||||
MLAS_FLOAT32X4 Accumulators[RowCount][4],
|
||||
MLAS_FLOAT32X4 ABroadcast[RowCount],
|
||||
const float* B
|
||||
)
|
||||
{
|
||||
MLAS_FLOAT32X4 BElements[4];
|
||||
|
||||
BElements[0] = MlasLoadFloat32x4(B);
|
||||
BElements[1] = MlasLoadFloat32x4(B + 4);
|
||||
BElements[2] = MlasLoadFloat32x4(B + 8);
|
||||
BElements[3] = MlasLoadFloat32x4(B + 12);
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasSgemmMultiplyAddRow>()(Accumulators, ABroadcast, BElements);
|
||||
}
|
||||
|
||||
struct MlasSgemmMultiplyAlphaRow
|
||||
{
|
||||
template<size_t Count, size_t Index>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT32X4 Accumulators[4],
|
||||
MLAS_FLOAT32X4 AlphaBroadcast
|
||||
)
|
||||
{
|
||||
Accumulators[Index] = MlasMultiplyFloat32x4(Accumulators[Index], AlphaBroadcast);
|
||||
}
|
||||
};
|
||||
|
||||
struct MlasSgemmMultiplyAlphaAddRow
|
||||
{
|
||||
template<size_t Count, size_t Index>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT32X4 Accumulators[4],
|
||||
MLAS_FLOAT32X4 AlphaBroadcast,
|
||||
const float* C
|
||||
)
|
||||
{
|
||||
Accumulators[Index] = MlasMultiplyAddFloat32x4(Accumulators[Index],
|
||||
AlphaBroadcast, MlasLoadFloat32x4(C + Index * 4));
|
||||
}
|
||||
};
|
||||
|
||||
struct MlasSgemmStoreRow
|
||||
{
|
||||
template<size_t Count, size_t Index>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT32X4 Accumulators[4],
|
||||
float* C
|
||||
)
|
||||
{
|
||||
MlasStoreFloat32x4(C + Index * 4, Accumulators[Index]);
|
||||
}
|
||||
};
|
||||
|
||||
template<size_t VectorCount>
|
||||
struct MlasSgemmStoreVector
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT32X4 Accumulators[RowCount][4],
|
||||
float* C,
|
||||
size_t ldc,
|
||||
MLAS_FLOAT32X4 AlphaBroadcast,
|
||||
bool ZeroMode
|
||||
)
|
||||
{
|
||||
float* c = C + Row * ldc;
|
||||
|
||||
if (ZeroMode) {
|
||||
MlasLoopUnroll<VectorCount, MlasSgemmMultiplyAlphaRow>()(Accumulators[Row], AlphaBroadcast);
|
||||
} else {
|
||||
MlasLoopUnroll<VectorCount, MlasSgemmMultiplyAlphaAddRow>()(Accumulators[Row], AlphaBroadcast, c);
|
||||
}
|
||||
|
||||
MlasLoopUnroll<VectorCount, MlasSgemmStoreRow>()(Accumulators[Row], c);
|
||||
|
||||
//
|
||||
// Shift down any unaligned elements to the bottom for further processing.
|
||||
//
|
||||
|
||||
if (VectorCount < 4) {
|
||||
Accumulators[Row][0] = Accumulators[Row][VectorCount];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct MlasSgemmMultiplyAlphaTrailing
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT32X4 Accumulators[RowCount][4],
|
||||
MLAS_FLOAT32X4 AlphaBroadcast
|
||||
)
|
||||
{
|
||||
Accumulators[Row][0] = MlasMultiplyFloat32x4(Accumulators[Row][0], AlphaBroadcast);
|
||||
}
|
||||
};
|
||||
|
||||
template<unsigned Lane>
|
||||
struct MlasSgemmStoreScalar
|
||||
{
|
||||
template<size_t RowCount, size_t Row>
|
||||
MLAS_FORCEINLINE
|
||||
static
|
||||
void
|
||||
Iteration(
|
||||
MLAS_FLOAT32X4 Accumulators[RowCount][4],
|
||||
float* C,
|
||||
size_t ldc,
|
||||
bool ZeroMode
|
||||
)
|
||||
{
|
||||
float* c = C + Row * ldc + Lane;
|
||||
float Value = MlasExtractLaneFloat32x4<Lane>(Accumulators[Row][0]);
|
||||
|
||||
if (!ZeroMode) {
|
||||
Value += *c;
|
||||
}
|
||||
|
||||
*c = Value;
|
||||
}
|
||||
};
|
||||
|
||||
template<size_t RowCount>
|
||||
MLAS_FORCEINLINE
|
||||
size_t
|
||||
MlasSgemmProcessCount(
|
||||
const float* A,
|
||||
const float* B,
|
||||
float* C,
|
||||
size_t CountK,
|
||||
size_t CountN,
|
||||
size_t lda,
|
||||
size_t ldc,
|
||||
MLAS_FLOAT32X4 AlphaBroadcast,
|
||||
bool ZeroMode
|
||||
)
|
||||
{
|
||||
do {
|
||||
|
||||
const float* a = A;
|
||||
size_t k = CountK;
|
||||
|
||||
MLAS_FLOAT32X4 Accumulators[RowCount][4];
|
||||
MLAS_FLOAT32X4 AElements[RowCount];
|
||||
MLAS_FLOAT32X4 ABroadcast[RowCount];
|
||||
|
||||
//
|
||||
// Clear the block accumulators.
|
||||
//
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasSgemmZeroAccumulators>()(Accumulators);
|
||||
|
||||
//
|
||||
// Compute the output block.
|
||||
//
|
||||
|
||||
while (k >= 4) {
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasSgemmLoadAElements>()(AElements, a, lda);
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasSgemmSplatAElements<0>>()(AElements, ABroadcast);
|
||||
MlasSgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B);
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasSgemmSplatAElements<1>>()(AElements, ABroadcast);
|
||||
MlasSgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B + 16);
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasSgemmSplatAElements<2>>()(AElements, ABroadcast);
|
||||
MlasSgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B + 32);
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasSgemmSplatAElements<3>>()(AElements, ABroadcast);
|
||||
MlasSgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B + 48);
|
||||
|
||||
a += 4;
|
||||
B += 16 * 4;
|
||||
k -= 4;
|
||||
}
|
||||
|
||||
while (k > 0) {
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasSgemmBroadcastAElements>()(ABroadcast, a, lda);
|
||||
MlasSgemmComputeBlock<RowCount>(Accumulators, ABroadcast, B);
|
||||
|
||||
a += 1;
|
||||
B += 16;
|
||||
k -= 1;
|
||||
}
|
||||
|
||||
if (CountN >= 16) {
|
||||
|
||||
//
|
||||
// Store the entire output block.
|
||||
//
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVector<4>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
|
||||
} else {
|
||||
|
||||
//
|
||||
// Store the partial output block.
|
||||
//
|
||||
|
||||
if (CountN >= 12) {
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVector<3>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
} else if (CountN >= 8) {
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVector<2>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
} else if (CountN >= 4) {
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreVector<1>>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode);
|
||||
}
|
||||
|
||||
//
|
||||
// Store the remaining unaligned columns.
|
||||
//
|
||||
|
||||
C += (CountN & ~3);
|
||||
CountN &= 3;
|
||||
|
||||
if (CountN > 0) {
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasSgemmMultiplyAlphaTrailing>()(Accumulators, AlphaBroadcast);
|
||||
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreScalar<0>>()(Accumulators, C, ldc, ZeroMode);
|
||||
|
||||
if (CountN >= 2) {
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreScalar<1>>()(Accumulators, C, ldc, ZeroMode);
|
||||
}
|
||||
|
||||
if (CountN >= 3) {
|
||||
MlasLoopUnroll<RowCount, MlasSgemmStoreScalar<2>>()(Accumulators, C, ldc, ZeroMode);
|
||||
}
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
C += 16;
|
||||
CountN -= 16;
|
||||
|
||||
} while (CountN > 0);
|
||||
|
||||
return RowCount;
|
||||
}
|
||||
|
||||
size_t
|
||||
MLASCALL
|
||||
MlasSgemmKernel(
|
||||
const float* A,
|
||||
const float* B,
|
||||
float* C,
|
||||
size_t CountK,
|
||||
size_t CountM,
|
||||
size_t CountN,
|
||||
size_t lda,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
bool ZeroMode
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine is an inner kernel to compute matrix multiplication for a
|
||||
set of rows.
|
||||
|
||||
Arguments:
|
||||
|
||||
A - Supplies the address of matrix A.
|
||||
|
||||
B - Supplies the address of matrix B. The matrix data has been packed using
|
||||
MlasSgemmCopyPackB or MlasSgemmTransposePackB.
|
||||
|
||||
C - Supplies the address of matrix C.
|
||||
|
||||
CountK - Supplies the number of columns from matrix A and the number of rows
|
||||
from matrix B to iterate over.
|
||||
|
||||
CountM - Supplies the maximum number of rows that can be processed for
|
||||
matrix A and matrix C. The actual number of rows handled for this
|
||||
invocation depends on the kernel implementation.
|
||||
|
||||
CountN - Supplies the number of columns from matrix B and matrix C to
|
||||
iterate over.
|
||||
|
||||
lda - Supplies the first dimension of matrix A.
|
||||
|
||||
ldc - Supplies the first dimension of matrix C.
|
||||
|
||||
alpha - Supplies the scalar multiplier (see SGEMM definition).
|
||||
|
||||
ZeroMode - Supplies true if the output matrix must be zero initialized,
|
||||
else false if the output matrix is accumulated into.
|
||||
|
||||
Return Value:
|
||||
|
||||
Returns the number of rows handled.
|
||||
|
||||
--*/
|
||||
{
|
||||
size_t RowsHandled;
|
||||
|
||||
MLAS_FLOAT32X4 AlphaBroadcast = MlasBroadcastFloat32x4(alpha);
|
||||
|
||||
if (CountM >= 4) {
|
||||
RowsHandled = MlasSgemmProcessCount<4>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode);
|
||||
} else if (CountM >= 2) {
|
||||
RowsHandled = MlasSgemmProcessCount<2>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode);
|
||||
} else {
|
||||
RowsHandled = MlasSgemmProcessCount<1>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode);
|
||||
}
|
||||
|
||||
return RowsHandled;
|
||||
}
|
||||
|
|
@ -434,6 +434,15 @@ Return Value:
|
|||
t1 = _mm_movehl_ps(z2, z0);
|
||||
t2 = _mm_movelh_ps(z1, z3);
|
||||
t3 = _mm_movehl_ps(z3, z1);
|
||||
#elif defined(MLAS_VSX_INTRINSICS)
|
||||
__vector float z0 = vec_mergeh(t0, t2);
|
||||
__vector float z1 = vec_mergel(t0, t2);
|
||||
__vector float z2 = vec_mergeh(t1, t3);
|
||||
__vector float z3 = vec_mergel(t1, t3);
|
||||
t0 = vec_mergeh(z0, z2);
|
||||
t1 = vec_mergel(z0, z2);
|
||||
t2 = vec_mergeh(z1, z3);
|
||||
t3 = vec_mergel(z1, z3);
|
||||
#else
|
||||
#error Unsupported architecture.
|
||||
#endif
|
||||
|
|
@ -627,10 +636,10 @@ Return Value:
|
|||
|
||||
if ((CountY & 2) != 0) {
|
||||
|
||||
#if defined(MLAS_NEON_INTRINSICS)
|
||||
MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&b[0]);
|
||||
MLAS_FLOAT32X4 t1 = MlasLoadFloat32x4(&b[ldb]);
|
||||
|
||||
#if defined(MLAS_NEON_INTRINSICS) || defined(MLAS_VSX_INTRINSICS)
|
||||
MlasStoreLaneFloat32x4<0>(&d[0], t0);
|
||||
MlasStoreLaneFloat32x4<0>(&d[1], t1);
|
||||
MlasStoreLaneFloat32x4<1>(&d[16], t0);
|
||||
|
|
@ -640,9 +649,6 @@ Return Value:
|
|||
MlasStoreLaneFloat32x4<3>(&d[48], t0);
|
||||
MlasStoreLaneFloat32x4<3>(&d[49], t1);
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&b[0]);
|
||||
MLAS_FLOAT32X4 t1 = MlasLoadFloat32x4(&b[ldb]);
|
||||
|
||||
__m128 v0 = _mm_unpacklo_ps(t0, t1);
|
||||
__m128 v1 = _mm_unpackhi_ps(t0, t1);
|
||||
_mm_storel_pi((__m64*)&d[0], v0);
|
||||
|
|
@ -762,6 +768,85 @@ Return Value:
|
|||
}
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
float*
|
||||
MlasSgemmKernelLoop(
|
||||
const float* A,
|
||||
const float* B,
|
||||
float* C,
|
||||
size_t CountK,
|
||||
size_t CountM,
|
||||
size_t CountN,
|
||||
size_t lda,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
bool ZeroMode
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine steps through the rows of the input and output matrices calling
|
||||
the kernel until all rows have been processed.
|
||||
|
||||
Arguments:
|
||||
|
||||
A - Supplies the address of matrix A.
|
||||
|
||||
B - Supplies the address of matrix B. The matrix data has been packed using
|
||||
MlasSgemmCopyPackB or MlasSgemmTransposePackB.
|
||||
|
||||
C - Supplies the address of matrix C.
|
||||
|
||||
CountK - Supplies the number of columns from matrix A and the number of rows
|
||||
from matrix B to iterate over.
|
||||
|
||||
CountM - Supplies the number of rows from matrix A and matrix C to iterate
|
||||
over.
|
||||
|
||||
CountN - Supplies the number of columns from matrix B and matrix C to
|
||||
iterate over.
|
||||
|
||||
lda - Supplies the first dimension of matrix A.
|
||||
|
||||
ldc - Supplies the first dimension of matrix C.
|
||||
|
||||
alpha - Supplies the scalar alpha multiplier (see SGEMM definition).
|
||||
|
||||
ZeroMode - Supplies true if the output matrix must be zero initialized,
|
||||
else false if the output matrix is accumulated into.
|
||||
|
||||
Return Value:
|
||||
|
||||
Returns the next address of matrix C.
|
||||
|
||||
--*/
|
||||
{
|
||||
do {
|
||||
|
||||
size_t RowsHandled;
|
||||
|
||||
#if defined(MLAS_TARGET_AMD64_IX86)
|
||||
RowsHandled = MlasPlatform.GemmFloatKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode);
|
||||
#elif defined(MLAS_TARGET_POWER)
|
||||
RowsHandled = MlasSgemmKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode);
|
||||
#else
|
||||
if (ZeroMode) {
|
||||
RowsHandled = MlasSgemmKernelZero(A, B, C, CountK, CountM, CountN, lda, ldc, alpha);
|
||||
} else {
|
||||
RowsHandled = MlasSgemmKernelAdd(A, B, C, CountK, CountM, CountN, lda, ldc, alpha);
|
||||
}
|
||||
#endif
|
||||
|
||||
C += ldc * RowsHandled;
|
||||
A += lda * RowsHandled;
|
||||
CountM -= RowsHandled;
|
||||
|
||||
} while (CountM > 0);
|
||||
|
||||
return C;
|
||||
}
|
||||
|
||||
void
|
||||
MlasSgemmOperation(
|
||||
CBLAS_TRANSPOSE TransA,
|
||||
|
|
@ -856,6 +941,7 @@ Return Value:
|
|||
// Transpose(A*B) = Transpose(B) * Transpose(A), we can apply the same 'small-M'
|
||||
// optimization as above, with A and B flipped.
|
||||
//
|
||||
|
||||
if (N == 1 && ldb == 1 && ldc == 1 && alpha == 1.0f && (beta == 0.0f || beta == 1.0f)) {
|
||||
|
||||
#if defined(MLAS_TARGET_AMD64)
|
||||
|
|
@ -877,7 +963,6 @@ Return Value:
|
|||
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Compute the strides to step through slices of the input matrices.
|
||||
//
|
||||
|
|
@ -931,9 +1016,9 @@ Return Value:
|
|||
// Step through each slice of matrix B along the K dimension.
|
||||
//
|
||||
|
||||
for (size_t k = 0; k < K; k += CountK) {
|
||||
bool ZeroMode = (beta == 0.0f);
|
||||
|
||||
bool ZeroMode = (k == 0 && beta == 0.0f);
|
||||
for (size_t k = 0; k < K; k += CountK) {
|
||||
|
||||
CountK = StrideK;
|
||||
|
||||
|
|
@ -957,39 +1042,14 @@ Return Value:
|
|||
|
||||
float* c = C + n;
|
||||
|
||||
size_t RowsRemaining = M;
|
||||
size_t RowsHandled;
|
||||
|
||||
if (TransA == CblasNoTrans) {
|
||||
|
||||
const float* a = A + k;
|
||||
|
||||
//
|
||||
// Step through the rows of matrix A.
|
||||
//
|
||||
|
||||
do {
|
||||
|
||||
#if defined(MLAS_TARGET_AMD64_IX86)
|
||||
RowsHandled = MlasPlatform.GemmFloatKernel(a, PanelB, c, CountK, RowsRemaining, CountN, lda, ldc, alpha, ZeroMode);
|
||||
#else
|
||||
if (ZeroMode) {
|
||||
RowsHandled = MlasSgemmKernelZero(a, PanelB, c, CountK, RowsRemaining, CountN, lda, ldc, alpha);
|
||||
} else {
|
||||
RowsHandled = MlasSgemmKernelAdd(a, PanelB, c, CountK, RowsRemaining, CountN, lda, ldc, alpha);
|
||||
}
|
||||
#endif
|
||||
|
||||
c += ldc * RowsHandled;
|
||||
a += lda * RowsHandled;
|
||||
|
||||
RowsRemaining -= RowsHandled;
|
||||
|
||||
} while (RowsRemaining > 0);
|
||||
MlasSgemmKernelLoop(A + k, PanelB, c, CountK, M, CountN, lda, ldc, alpha, ZeroMode);
|
||||
|
||||
} else {
|
||||
|
||||
const float* a = A + k * lda;
|
||||
size_t RowsRemaining = M;
|
||||
|
||||
do {
|
||||
|
||||
|
|
@ -1003,39 +1063,21 @@ Return Value:
|
|||
RowsTransposed = MLAS_SGEMM_TRANSA_ROWS;
|
||||
}
|
||||
|
||||
RowsRemaining -= RowsTransposed;
|
||||
|
||||
MlasSgemmTransposeA(PanelA, a, lda, RowsTransposed, CountK);
|
||||
|
||||
RowsRemaining -= RowsTransposed;
|
||||
a += RowsTransposed;
|
||||
|
||||
//
|
||||
// Step through the rows of the local buffer.
|
||||
//
|
||||
|
||||
const float* pa = PanelA;
|
||||
|
||||
do {
|
||||
|
||||
#if defined(MLAS_TARGET_AMD64_IX86)
|
||||
RowsHandled = MlasPlatform.GemmFloatKernel(pa, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha, ZeroMode);
|
||||
#else
|
||||
if (ZeroMode) {
|
||||
RowsHandled = MlasSgemmKernelZero(pa, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha);
|
||||
} else {
|
||||
RowsHandled = MlasSgemmKernelAdd(pa, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha);
|
||||
}
|
||||
#endif
|
||||
|
||||
c += ldc * RowsHandled;
|
||||
pa += CountK * RowsHandled;
|
||||
|
||||
RowsTransposed -= RowsHandled;
|
||||
|
||||
} while (RowsTransposed > 0);
|
||||
c = MlasSgemmKernelLoop(PanelA, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha, ZeroMode);
|
||||
|
||||
} while (RowsRemaining > 0);
|
||||
}
|
||||
|
||||
ZeroMode = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1939,7 +1939,7 @@ public:
|
|||
Buffer[i].u = TestData[i][0].u;
|
||||
}
|
||||
|
||||
MlasActivation(&Activation, &Buffer[0].f, nullptr, _countof(Buffer), 1, 1);
|
||||
MlasActivation(&Activation, &Buffer[0].f, nullptr, 1, _countof(Buffer), _countof(Buffer));
|
||||
|
||||
for (unsigned i = 0; i < _countof(TestData); i++) {
|
||||
// Sensitive to comparing positive/negative zero and NaNs.
|
||||
|
|
|
|||
Loading…
Reference in a new issue