diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 1e19c11b30..a9c5ffec37 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -10,6 +10,7 @@ set(mlas_common_srcs ${ONNXRUNTIME_ROOT}/core/mlas/lib/activate.cpp ${ONNXRUNTIME_ROOT}/core/mlas/lib/logistic.cpp ${ONNXRUNTIME_ROOT}/core/mlas/lib/tanh.cpp + ${ONNXRUNTIME_ROOT}/core/mlas/lib/erf.cpp ) if(MSVC) @@ -65,6 +66,7 @@ if(MSVC) ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/cvtfp16a.asm ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/LogisticKernelFma3.asm ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/TanhKernelFma3.asm + ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/ErfKernelFma3.asm ) endif() @@ -157,6 +159,7 @@ else() ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/SgemmKernelFma3.S ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/LogisticKernelFma3.S ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/TanhKernelFma3.S + ${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/ErfKernelFma3.S ) set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma") diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 849c258c6b..dfa452b444 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -226,6 +226,14 @@ MlasComputeTanh( size_t N ); +void +MLASCALL +MlasComputeErf( + const float* Input, + float* Output, + size_t N + ); + // // Half-precision floating-point routines. // diff --git a/onnxruntime/core/mlas/lib/amd64/ErfKernelFma3.asm b/onnxruntime/core/mlas/lib/amd64/ErfKernelFma3.asm new file mode 100644 index 0000000000..96b3df9513 --- /dev/null +++ b/onnxruntime/core/mlas/lib/amd64/ErfKernelFma3.asm @@ -0,0 +1,569 @@ +;++ +; +; Copyright (c) Microsoft Corporation. All rights reserved. +; +; Licensed under the MIT License. +; +; Module Name: +; +; ErfKernelFma3.asm +; +; Abstract: +; +; This module implements a kernel for computing the error function for a +; buffer of elements. +; +; This implementation uses AVX fused multiply/add instructions. +; +;-- + + .xlist +INCLUDE mlasi.inc + .list + + EXTERN MlasMaskMoveAvx:NEAR + EXTERN MlasErfConstants:NEAR + +; +; Structure layout for the erf constants block. +; + +ErfConstants STRUCT + + ErfUpperAbsRange DWORD ? + ErfSplitBoundary DWORD ? + ErfSMALL_P0 DWORD ? + ErfSMALL_P1 DWORD ? + ErfSMALL_P2 DWORD ? + ErfSMALL_P3 DWORD ? + ErfSMALL_P4 DWORD ? + ErfSMALL_P5_Minus_One DWORD ? + ErfReserve0 DWORD ? + ErfBIG_P0 DWORD ? + ErfBIG_P1 DWORD ? + ErfBIG_P2 DWORD ? + ErfBIG_P3 DWORD ? + ErfBIG_P4 DWORD ? + ErfBIG_P5 DWORD ? + ErfBIG_P6_Minus_One DWORD ? + ErfNegZero DWORD ? + ErfOne DWORD ? + + Exp_UpperRange DWORD ? + Exp_LowerRange DWORD ? + Exp_Log2Reciprocal DWORD ? + Exp_log2_hi DWORD ? + Exp_log2_lo DWORD ? + Exp_P0 DWORD ? + Exp_P1 DWORD ? + Exp_P2 DWORD ? + Exp_P3 DWORD ? + Exp_P4 DWORD ? + Exp_P5 DWORD ? + Exp_P6 DWORD ? + Exp_C DWORD ? + Exp_X7F DWORD ? + +ErfConstants ENDS + +; +; Stack frame layout for the erf kernel. +; + +ErfKernelFrame STRUCT + + ErfBuffer0 OWORD 8 DUP(?) + ErfBuffer1 OWORD 8 DUP(?) + SavedXmm6 OWORD ? + SavedXmm7 OWORD ? + SavedXmm8 OWORD ? + SavedXmm9 OWORD ? + SavedXmm10 OWORD ? + SavedXmm11 OWORD ? + SavedXmm12 OWORD ? + SavedXmm13 OWORD ? + SavedXmm14 OWORD ? + SavedXmm15 OWORD ? + Padding0 QWORD ? + Padding1 QWORD ? + CountN QWORD ? + ReturnAddress QWORD ? + PreviousP1Home QWORD ? + PreviousP2Home QWORD ? + PreviousP3Home QWORD ? + PreviousP4Home QWORD ? + +ErfKernelFrame ENDS + +;++ +; +; Routine Description: +; +; This routine implements a vectorized kernel for the error function. +; +; Arguments: +; +; Input (rcx) - Supplies the input buffer. +; +; Output (rdx) - Supplies the output buffer. +; +; N (r8) - Supplies the number of elements to process. +; +; Return Value: +; +; None. +; +;-- + + NESTED_ENTRY MlasErfKernelFma3, _TEXT + + alloc_stack (ErfKernelFrame.ReturnAddress) + + save_xmm128_avx xmm6,ErfKernelFrame.SavedXmm6 + save_xmm128_avx xmm7,ErfKernelFrame.SavedXmm7 + save_xmm128_avx xmm8,ErfKernelFrame.SavedXmm8 + save_xmm128_avx xmm9,ErfKernelFrame.SavedXmm9 + save_xmm128_avx xmm10,ErfKernelFrame.SavedXmm10 + save_xmm128_avx xmm11,ErfKernelFrame.SavedXmm11 + save_xmm128_avx xmm12,ErfKernelFrame.SavedXmm12 + save_xmm128_avx xmm13,ErfKernelFrame.SavedXmm13 + save_xmm128_avx xmm14,ErfKernelFrame.SavedXmm14 + save_xmm128_avx xmm15,ErfKernelFrame.SavedXmm15 + + END_PROLOGUE + + lea rax,MlasErfConstants + sub r8,8*4 + jb LErfProcessRemainingCount + +LComputeErf4x8Loop: + vbroadcastss ymm15,ErfConstants.ErfNegZero[rax] + vmovups ymm0,YMMWORD PTR [rcx] ; original input vx0 + vmovups ymm1,YMMWORD PTR [rcx+32] ; original input vx1 + vmovups ymm2,YMMWORD PTR [rcx+64] ; original input vx2 + vmovups ymm3,YMMWORD PTR [rcx+96] ; original input vx3 + + vandps ymm4,ymm0,ymm15 ; vsign0 + vandps ymm5,ymm1,ymm15 ; vsign1 + vandps ymm6,ymm2,ymm15 ; vsign2 + vandps ymm7,ymm3,ymm15 ; vsign3 + vandnps ymm0,ymm15,ymm0 ; abs(vx0) va0 + vandnps ymm1,ymm15,ymm1 ; abs(vx1) va1 + vandnps ymm2,ymm15,ymm2 ; abs(vx2) va2 + vandnps ymm3,ymm15,ymm3 ; abs(vx3) va3 + + vbroadcastss ymm14,ErfConstants.ErfUpperAbsRange[rax] + vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp],ymm4 + vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp+32],ymm5 + vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp+64],ymm6 + vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp+96],ymm7 + + vbroadcastss ymm8,ErfConstants.ErfSMALL_P0[rax] + vminps ymm0,ymm0,ymm14 ; force abs value in range + vminps ymm1,ymm1,ymm14 + vminps ymm2,ymm2,ymm14 + vminps ymm3,ymm3,ymm14 + vmovaps ymm9,ymm8 + vmovaps ymm10,ymm8 + vmovaps ymm11,ymm8 + + vbroadcastss ymm15,ErfConstants.ErfSMALL_P1[rax] + vmulps ymm4,ymm0,ymm0 ; vs0 (square) + vmulps ymm5,ymm1,ymm1 ; vs1 + vmulps ymm6,ymm2,ymm2 ; vs2 + vmulps ymm7,ymm3,ymm3 ; vs3 + + vbroadcastss ymm14,ErfConstants.ErfSMALL_P2[rax] + vfmadd213ps ymm8,ymm4,ymm15 + vfmadd213ps ymm9,ymm5,ymm15 + vfmadd213ps ymm10,ymm6,ymm15 + vfmadd213ps ymm11,ymm7,ymm15 + + vbroadcastss ymm13,ErfConstants.ErfSMALL_P3[rax] + vfmadd213ps ymm8,ymm4,ymm14 + vfmadd213ps ymm9,ymm5,ymm14 + vfmadd213ps ymm10,ymm6,ymm14 + vfmadd213ps ymm11,ymm7,ymm14 + + vbroadcastss ymm15,ErfConstants.ErfSMALL_P4[rax] + vfmadd213ps ymm8,ymm4,ymm13 + vfmadd213ps ymm9,ymm5,ymm13 + vfmadd213ps ymm10,ymm6,ymm13 + vfmadd213ps ymm11,ymm7,ymm13 + + vbroadcastss ymm14,ErfConstants.ErfSMALL_P5_Minus_One[rax] + vfmadd213ps ymm8,ymm4,ymm15 + vfmadd213ps ymm9,ymm5,ymm15 + vfmadd213ps ymm10,ymm6,ymm15 + vfmadd213ps ymm11,ymm7,ymm15 + + vfmadd213ps ymm8,ymm4,ymm14 + vfmadd213ps ymm9,ymm5,ymm14 + vfmadd213ps ymm10,ymm6,ymm14 + vfmadd213ps ymm11,ymm7,ymm14 + + vbroadcastss ymm12,ErfConstants.ErfSplitBoundary[rax] + vfmadd213ps ymm8,ymm0,ymm0 + vfmadd213ps ymm9,ymm1,ymm1 + vfmadd213ps ymm10,ymm2,ymm2 + vfmadd213ps ymm11,ymm3,ymm3 + + vcmpgtps ymm4,ymm0,ymm12 ; vmask0 + vcmpgtps ymm5,ymm1,ymm12 ; vmask1 + vcmpgtps ymm6,ymm2,ymm12 ; vmask2 + vcmpgtps ymm7,ymm3,ymm12 ; vmask3 + + vandnps ymm8,ymm4,ymm8 + vandnps ymm9,ymm5,ymm9 + vandnps ymm10,ymm6,ymm10 + vandnps ymm11,ymm7,ymm11 + + vbroadcastss ymm15,ErfConstants.ErfBIG_P1[rax] + vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp],ymm8 + vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp+32],ymm9 + vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp+64],ymm10 + vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp+96],ymm11 + +LBiggerNumbers: + vbroadcastss ymm8,ErfConstants.ErfBIG_P0[rax] + vandps ymm0,ymm4,ymm0 + vandps ymm1,ymm5,ymm1 + vandps ymm2,ymm6,ymm2 + vandps ymm3,ymm7,ymm3 + vmovaps ymm9,ymm8 + vmovaps ymm10,ymm8 + vmovaps ymm11,ymm8 + + vbroadcastss ymm14,ErfConstants.ErfBIG_P2[rax] + vfmadd213ps ymm8,ymm0,ymm15 + vfmadd213ps ymm9,ymm1,ymm15 + vfmadd213ps ymm10,ymm2,ymm15 + vfmadd213ps ymm11,ymm3,ymm15 + + vbroadcastss ymm13,ErfConstants.ErfBIG_P3[rax] + vfmadd213ps ymm8,ymm0,ymm14 + vfmadd213ps ymm9,ymm1,ymm14 + vfmadd213ps ymm10,ymm2,ymm14 + vfmadd213ps ymm11,ymm3,ymm14 + + vbroadcastss ymm15,ErfConstants.ErfBIG_P4[rax] + vfmadd213ps ymm8,ymm0,ymm13 + vfmadd213ps ymm9,ymm1,ymm13 + vfmadd213ps ymm10,ymm2,ymm13 + vfmadd213ps ymm11,ymm3,ymm13 + + vbroadcastss ymm14,ErfConstants.ErfBIG_P5[rax] + vfmadd213ps ymm8,ymm0,ymm15 + vfmadd213ps ymm9,ymm1,ymm15 + vfmadd213ps ymm10,ymm2,ymm15 + vfmadd213ps ymm11,ymm3,ymm15 + + vbroadcastss ymm13,ErfConstants.ErfBIG_P6_Minus_One[rax] + vfmadd213ps ymm8,ymm0,ymm14 + vfmadd213ps ymm9,ymm1,ymm14 + vfmadd213ps ymm10,ymm2,ymm14 + vfmadd213ps ymm11,ymm3,ymm14 + + vbroadcastss ymm15,ErfConstants.ErfNegZero[rax] + vfmadd213ps ymm8,ymm0,ymm13 + vfmadd213ps ymm9,ymm1,ymm13 + vfmadd213ps ymm10,ymm2,ymm13 + vfmadd213ps ymm11,ymm3,ymm13 + + vbroadcastss ymm14,ErfConstants.Exp_LowerRange[rax] + vfmadd213ps ymm8,ymm0,ymm0 + vfmadd213ps ymm9,ymm1,ymm1 + vfmadd213ps ymm10,ymm2,ymm2 + vfmadd213ps ymm11,ymm3,ymm3 + + vbroadcastss ymm4,ErfConstants.Exp_Log2Reciprocal[rax] + vxorps ymm8,ymm8,ymm15 + vxorps ymm9,ymm9,ymm15 + vxorps ymm10,ymm10,ymm15 + vxorps ymm11,ymm11,ymm15 + + vbroadcastss ymm13,ErfConstants.Exp_C[rax] + vmovaps ymm5,ymm4 + vmovaps ymm6,ymm4 + vmovaps ymm7,ymm4 + + ; expf(ymm8 -- ymm11) + vmaxps ymm8,ymm8,ymm14 + vmaxps ymm9,ymm9,ymm14 + vmaxps ymm10,ymm10,ymm14 + vmaxps ymm11,ymm11,ymm14 + + vbroadcastss ymm0,ErfConstants.Exp_log2_hi[rax] + vfmadd213ps ymm4,ymm8,ymm13 + vfmadd213ps ymm5,ymm9,ymm13 + vfmadd213ps ymm6,ymm10,ymm13 + vfmadd213ps ymm7,ymm11,ymm13 + + vbroadcastss ymm15,ErfConstants.Exp_log2_lo[rax] + vmovaps ymm1,ymm0 + vmovaps ymm2,ymm0 + vmovaps ymm3,ymm0 + + vsubps ymm4,ymm4,ymm13 ; vr = round() + vsubps ymm5,ymm5,ymm13 + vsubps ymm6,ymm6,ymm13 + vsubps ymm7,ymm7,ymm13 + + vfmadd213ps ymm0,ymm4,ymm8 ; vf = vr * log2_hi + ve + vfmadd213ps ymm1,ymm5,ymm9 + vfmadd213ps ymm2,ymm6,ymm10 + vfmadd213ps ymm3,ymm7,ymm11 + + vbroadcastss ymm8,ErfConstants.Exp_P0[rax] + vfmadd231ps ymm0,ymm4,ymm15 ; vf += vr * log_2_lo + vfmadd231ps ymm1,ymm5,ymm15 + vfmadd231ps ymm2,ymm6,ymm15 + vfmadd231ps ymm3,ymm7,ymm15 + vmovaps ymm9,ymm8 + vmovaps ymm10,ymm8 + vmovaps ymm11,ymm8 + + vbroadcastss ymm14,ErfConstants.Exp_P1[rax] + vbroadcastss ymm13,ErfConstants.Exp_P2[rax] + vfmadd213ps ymm8,ymm0,ymm14 ; *+ exp_p1 + vfmadd213ps ymm9,ymm1,ymm14 + vfmadd213ps ymm10,ymm2,ymm14 + vfmadd213ps ymm11,ymm3,ymm14 + + vbroadcastss ymm12,ErfConstants.Exp_P3[rax] + vfmadd213ps ymm8,ymm0,ymm13 ; *+ exp_p2 + vfmadd213ps ymm9,ymm1,ymm13 + vfmadd213ps ymm10,ymm2,ymm13 + vfmadd213ps ymm11,ymm3,ymm13 + + vbroadcastss ymm15,ErfConstants.Exp_P4[rax] + vfmadd213ps ymm8,ymm0,ymm12 ; *+ exp_p3 + vfmadd213ps ymm9,ymm1,ymm12 + vfmadd213ps ymm10,ymm2,ymm12 + vfmadd213ps ymm11,ymm3,ymm12 + + vbroadcastss ymm14,ErfConstants.Exp_P5[rax] + vfmadd213ps ymm8,ymm0,ymm15 ; *+ exp_p4 + vfmadd213ps ymm9,ymm1,ymm15 + vfmadd213ps ymm10,ymm2,ymm15 + vfmadd213ps ymm11,ymm3,ymm15 + + vbroadcastss ymm13,ErfConstants.Exp_P6[rax] + vfmadd213ps ymm8,ymm0,ymm14 ; *+ exp_p5 + vfmadd213ps ymm9,ymm1,ymm14 + vfmadd213ps ymm10,ymm2,ymm14 + vfmadd213ps ymm11,ymm3,ymm14 + + vbroadcastss ymm12,ErfConstants.Exp_X7F[rax] + vfmadd213ps ymm8,ymm0,ymm13 ; *+ exp_p6 + vfmadd213ps ymm9,ymm1,ymm13 + vfmadd213ps ymm10,ymm2,ymm13 + vfmadd213ps ymm11,ymm3,ymm13 + + vcvttps2dq ymm4,ymm4 + vcvttps2dq ymm5,ymm5 + vcvttps2dq ymm6,ymm6 + vcvttps2dq ymm7,ymm7 + + vbroadcastss ymm15,ErfConstants.ErfOne[rax] + vpaddd ymm4,ymm4,ymm12 ; +127 + vpaddd ymm5,ymm5,ymm12 + vpaddd ymm6,ymm6,ymm12 + vpaddd ymm7,ymm7,ymm12 + + vpslld ymm4,ymm4,23 + vpslld ymm5,ymm5,23 + vpslld ymm6,ymm6,23 + vpslld ymm7,ymm7,23 + + vmulps ymm8,ymm8,ymm4 ; 2^i * exp(vf) + vmulps ymm9,ymm9,ymm5 + vmulps ymm10,ymm10,ymm6 + vmulps ymm11,ymm11,ymm7 + + vsubps ymm8,ymm15,ymm8 + vsubps ymm9,ymm15,ymm9 + vsubps ymm10,ymm15,ymm10 + vsubps ymm11,ymm15,ymm11 + + ; merge small numbers' result + vorps ymm8,ymm8,YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp] + vorps ymm9,ymm9,YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp+32] + vorps ymm10,ymm10,YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp+64] + vorps ymm11,ymm11,YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp+96] + + ; copy sign + vorps ymm0,ymm8,YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp] + vorps ymm1,ymm9,YMMWORD PTR 32+ErfKernelFrame.ErfBuffer0[rsp] + vorps ymm2,ymm10,YMMWORD PTR 64+ErfKernelFrame.ErfBuffer0[rsp] + vorps ymm3,ymm11,YMMWORD PTR 96+ErfKernelFrame.ErfBuffer0[rsp] + + vmovups YMMWORD PTR [rdx],ymm0 + vmovups YMMWORD PTR [rdx+32],ymm1 + vmovups YMMWORD PTR [rdx+64],ymm2 + vmovups YMMWORD PTR [rdx+96],ymm3 + + add rcx,32*4 ; advance by 4*8 elements + add rdx,32*4 + sub r8,32 + jae LComputeErf4x8Loop + +LErfProcessRemainingCount: + add r8,32 ; correct for over-subtract above + jz LErfBatchExp + +LErfProcess1x8: + mov DWORD PTR ErfKernelFrame.CountN[rsp],r8d + vbroadcastss ymm3,DWORD PTR ErfKernelFrame.CountN[rsp] + + vpcmpgtd ymm3,ymm3,YMMWORD PTR [MlasMaskMoveAvx] + vbroadcastss ymm15,ErfConstants.ErfNegZero[rax] + vmaskmovps ymm0,ymm3,YMMWORD PTR [rcx] ; original input vx0 + + vandps ymm4,ymm0,ymm15 ; vsign0 + vandnps ymm0,ymm15,ymm0 ; abs(vx0) va0 + + vbroadcastss ymm14,ErfConstants.ErfUpperAbsRange[rax] + vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp],ymm4 + + vbroadcastss ymm8,ErfConstants.ErfSMALL_P0[rax] + vminps ymm0,ymm0,ymm14 ; force abs value in range + + vbroadcastss ymm15,ErfConstants.ErfSMALL_P1[rax] + vmulps ymm4,ymm0,ymm0 ; vs0 (square) + + vbroadcastss ymm14,ErfConstants.ErfSMALL_P2[rax] + vfmadd213ps ymm8,ymm4,ymm15 + + vbroadcastss ymm13,ErfConstants.ErfSMALL_P3[rax] + vfmadd213ps ymm8,ymm4,ymm14 + + vbroadcastss ymm15,ErfConstants.ErfSMALL_P4[rax] + vfmadd213ps ymm8,ymm4,ymm13 + + vbroadcastss ymm14,ErfConstants.ErfSMALL_P5_Minus_One[rax] + vfmadd213ps ymm8,ymm4,ymm15 + + vfmadd213ps ymm8,ymm4,ymm14 + + vbroadcastss ymm12,ErfConstants.ErfSplitBoundary[rax] + vfmadd213ps ymm8,ymm0,ymm0 + + vcmpgtps ymm4,ymm0,ymm12 ; vmask0 + + vandnps ymm8,ymm4,ymm8 + + vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp],ymm8 + +LBiggerNumbersRemaining: + vbroadcastss ymm15,ErfConstants.ErfBIG_P1[rax] + vbroadcastss ymm8,ErfConstants.ErfBIG_P0[rax] + vandps ymm0,ymm4,ymm0 + + vbroadcastss ymm14,ErfConstants.ErfBIG_P2[rax] + vfmadd213ps ymm8,ymm0,ymm15 + + vbroadcastss ymm13,ErfConstants.ErfBIG_P3[rax] + vfmadd213ps ymm8,ymm0,ymm14 + + vbroadcastss ymm15,ErfConstants.ErfBIG_P4[rax] + vfmadd213ps ymm8,ymm0,ymm13 + + vbroadcastss ymm14,ErfConstants.ErfBIG_P5[rax] + vfmadd213ps ymm8,ymm0,ymm15 + + vbroadcastss ymm13,ErfConstants.ErfBIG_P6_Minus_One[rax] + vfmadd213ps ymm8,ymm0,ymm14 + + vbroadcastss ymm15,ErfConstants.ErfNegZero[rax] + vfmadd213ps ymm8,ymm0,ymm13 + + vbroadcastss ymm14,ErfConstants.Exp_LowerRange[rax] + vfmadd213ps ymm8,ymm0,ymm0 + + vbroadcastss ymm4,ErfConstants.Exp_Log2Reciprocal[rax] + vxorps ymm8,ymm8,ymm15 + + vbroadcastss ymm13,ErfConstants.Exp_C[rax] + + ; expf(ymm8 -- ymm11) + vmaxps ymm8,ymm8,ymm14 + + vbroadcastss ymm0,ErfConstants.Exp_log2_hi[rax] + vfmadd213ps ymm4,ymm8,ymm13 + + vbroadcastss ymm15,ErfConstants.Exp_log2_lo[rax] + + vsubps ymm4,ymm4,ymm13 ; vr = round() + + vfmadd213ps ymm0,ymm4,ymm8 ; vf = vr * log2_hi + ve + + vbroadcastss ymm8,ErfConstants.Exp_P0[rax] + + vfmadd231ps ymm0,ymm4,ymm15 ; vf += vr * log_2_lo + + vbroadcastss ymm14,ErfConstants.Exp_P1[rax] + + vbroadcastss ymm13,ErfConstants.Exp_P2[rax] + vfmadd213ps ymm8,ymm0,ymm14 ; *+ exp_p1 + + vbroadcastss ymm12,ErfConstants.Exp_P3[rax] + vfmadd213ps ymm8,ymm0,ymm13 ; *+ exp_p2 + + vbroadcastss ymm15,ErfConstants.Exp_P4[rax] + vfmadd213ps ymm8,ymm0,ymm12 ; *+ exp_p3 + + vbroadcastss ymm14,ErfConstants.Exp_P5[rax] + vfmadd213ps ymm8,ymm0,ymm15 ; *+ exp_p4 + + vbroadcastss ymm13,ErfConstants.Exp_P6[rax] + vfmadd213ps ymm8,ymm0,ymm14 ; *+ exp_p5 + + vbroadcastss ymm12,ErfConstants.Exp_X7F[rax] + vfmadd213ps ymm8,ymm0,ymm13 ; *+ exp_p6 + + vcvttps2dq ymm4,ymm4 + + vbroadcastss ymm15,ErfConstants.ErfOne[rax] + vpaddd ymm4,ymm4,ymm12 ; +127 + + vpslld ymm4,ymm4,23 + + vmulps ymm8,ymm8,ymm4 ; 2^i * exp(vf) + + vsubps ymm8,ymm15,ymm8 + + ; merge small numbers' result + vorps ymm8,ymm8,YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp] + + ; copy sign + vorps ymm0,ymm8,YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp] + + vmaskmovps YMMWORD PTR [rdx],ymm3,ymm0 + + add rcx,8*4 + add rdx,8*4 + sub r8,8 + jg LErfProcess1x8 + +LErfBatchExp: + vzeroupper + vmovaps xmm6,ErfKernelFrame.SavedXmm6[rsp] + vmovaps xmm7,ErfKernelFrame.SavedXmm7[rsp] + vmovaps xmm8,ErfKernelFrame.SavedXmm8[rsp] + vmovaps xmm9,ErfKernelFrame.SavedXmm9[rsp] + vmovaps xmm10,ErfKernelFrame.SavedXmm10[rsp] + vmovaps xmm11,ErfKernelFrame.SavedXmm11[rsp] + vmovaps xmm12,ErfKernelFrame.SavedXmm12[rsp] + vmovaps xmm13,ErfKernelFrame.SavedXmm13[rsp] + vmovaps xmm14,ErfKernelFrame.SavedXmm14[rsp] + vmovaps xmm15,ErfKernelFrame.SavedXmm15[rsp] + add rsp,(ErfKernelFrame.ReturnAddress) + + BEGIN_EPILOGUE + + ret + + NESTED_END MlasErfKernelFma3, _TEXT + + END diff --git a/onnxruntime/core/mlas/lib/erf.cpp b/onnxruntime/core/mlas/lib/erf.cpp new file mode 100644 index 0000000000..12fd0a368d --- /dev/null +++ b/onnxruntime/core/mlas/lib/erf.cpp @@ -0,0 +1,271 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + erf.cpp + +Abstract: + + This module implements routines to compute the hyperbolic tangent function. + + This implementation uses the same polynomial coefficients and algorithm as + found in: https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff + Our usage requires building platform specific versions of + the algorithm to target different instruction sets. The implementation below + targets the base instruction set (typically SSE2) while assembly + implementations target newer instruction sets (such as FMA3). + +--*/ + +#include "mlasi.h" + +#include + +// +// Bundles the constants for use by kernels written in assembly. +// + +extern "C" const struct { + float ErfUpperAbsRange; + float ErfSplitBoundary; + float ErfSMALL_P0; + float ErfSMALL_P1; + float ErfSMALL_P2; + float ErfSMALL_P3; + float ErfSMALL_P4; + float ErfSMALL_P5_Minus_One; + float ErfReserved0; + float ErfBIG_P0; + float ErfBIG_P1; + float ErfBIG_P2; + float ErfBIG_P3; + float ErfBIG_P4; + float ErfBIG_P5; + float ErfBIG_P6_Minus_One; + float ErfNegZero; + float ErfOne; + + float Exp_UpperRange; + float Exp_LowerRange; + float Exp_Log2Reciprocal; + float Exp_log2_hi; + float Exp_log2_lo; + float Exp_P0; + float Exp_P1; + float Exp_P2; + float Exp_P3; + float Exp_P4; + float Exp_P5; + float Exp_P6; + float Exp_C; + int32_t Exp_X7F; +} MlasErfConstants = { + 3.725f, + 0.921875f, + -5.99104969e-4f, + 4.99339588e-3f, + -2.67667342e-2f, + 1.12818025e-1f, + -3.76124859e-1f, + 1.28379151e-1f, + 0.0f, + 1.72948930e-5f, + -3.83208680e-4f, + 3.88393435e-3f, + -2.42545605e-2f, + 1.06777847e-1f, + 6.34846687e-1f, + 1.28717512e-1f, + -0.0f, + 1.0f, + + // Independent parameters to calculate Exp for Erff() + 88.3762626647950f, + -88.3762626647949f, + 1.44269504088896341f, + -6.93145752e-1f, + -1.42860677e-6f, + 1.38319808e-3f, + 8.37550033e-3f, + 4.16689515e-2f, + 1.66664466e-1f, + 4.99999851e-1f, + 1.00000000e+0f, + 1.00000000e+0f, + 1.25829120e+7f, + 127, +}; + +void +MLASCALL +MlasErfKernel( + const float* Input, + float* Output, + size_t N + ) +/*++ + +Routine Description: + + This routine implements the generic kernel for the error function. + +Arguments: + + Input - Supplies the input buffer. + + Output - Supplies the output buffer. + + N - Supplies the number of elements to process. + +Return Value: + + None. + +--*/ +{ + while (N >= 4) { + MLAS_FLOAT32X4 Value = MlasLoadFloat32x4(Input); + MLAS_FLOAT32X4 NegZero = MlasBroadcastFloat32x4(MlasErfConstants.ErfNegZero); + MLAS_FLOAT32X4 SignMask = MlasAndFloat32x4(Value, NegZero); + MLAS_FLOAT32X4 AbsValue = MlasAndNotFloat32x4(NegZero, Value); + AbsValue = MlasMinimumFloat32x4(MlasBroadcastFloat32x4(MlasErfConstants.ErfUpperAbsRange), AbsValue); + MLAS_FLOAT32X4 SquareValue = MlasMultiplyFloat32x4(AbsValue, AbsValue); + + MLAS_FLOAT32X4 r_small = MlasBroadcastFloat32x4(MlasErfConstants.ErfSMALL_P0); + r_small = MlasMultiplyAddFloat32x4(r_small, SquareValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfSMALL_P1)); + r_small = MlasMultiplyAddFloat32x4(r_small, SquareValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfSMALL_P2)); + r_small = MlasMultiplyAddFloat32x4(r_small, SquareValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfSMALL_P3)); + r_small = MlasMultiplyAddFloat32x4(r_small, SquareValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfSMALL_P4)); + r_small = MlasMultiplyAddFloat32x4(r_small, SquareValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfSMALL_P5_Minus_One)); + r_small = MlasMultiplyAddFloat32x4(r_small, AbsValue, AbsValue); + MLAS_FLOAT32X4 split_mask = MlasGreaterThanFloat32x4(AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfSplitBoundary)); + r_small = MlasAndNotFloat32x4(split_mask, r_small); + + AbsValue = MlasAndFloat32x4(split_mask, AbsValue); // clear smaller value into zero for bigger number calculation + MLAS_FLOAT32X4 r_big = MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P0); + r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P1)); + r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P2)); + r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P3)); + r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P4)); + r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P5)); + r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P6_Minus_One)); + r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, AbsValue); + + // 1.0 - exp(-r_big), no need to do min() + r_big = MlasXorFloat32x4(r_big, MlasBroadcastFloat32x4(MlasErfConstants.ErfNegZero)); // -r_big + r_big = MlasMaximumFloat32x4(MlasBroadcastFloat32x4(MlasErfConstants.Exp_LowerRange), r_big); + MLAS_FLOAT32X4 exp_c = MlasBroadcastFloat32x4(MlasErfConstants.Exp_C); + MLAS_FLOAT32X4 r = MlasMultiplyAddFloat32x4(MlasBroadcastFloat32x4(MlasErfConstants.Exp_Log2Reciprocal), r_big, exp_c); + r = MlasSubtractFloat32x4(r, exp_c); + + MLAS_FLOAT32X4 fx = MlasMultiplyAddFloat32x4(r, MlasBroadcastFloat32x4(MlasErfConstants.Exp_log2_hi), r_big); + fx = MlasMultiplyAddFloat32x4(r, MlasBroadcastFloat32x4(MlasErfConstants.Exp_log2_lo), fx); + // y = exp(fx) + MLAS_FLOAT32X4 y = MlasBroadcastFloat32x4(MlasErfConstants.Exp_P0); + y = MlasMultiplyAddFloat32x4(y, fx, MlasBroadcastFloat32x4(MlasErfConstants.Exp_P1)); + y = MlasMultiplyAddFloat32x4(y, fx, MlasBroadcastFloat32x4(MlasErfConstants.Exp_P2)); + y = MlasMultiplyAddFloat32x4(y, fx, MlasBroadcastFloat32x4(MlasErfConstants.Exp_P3)); + y = MlasMultiplyAddFloat32x4(y, fx, MlasBroadcastFloat32x4(MlasErfConstants.Exp_P4)); + y = MlasMultiplyAddFloat32x4(y, fx, MlasBroadcastFloat32x4(MlasErfConstants.Exp_P5)); + y = MlasMultiplyAddFloat32x4(y, fx, MlasBroadcastFloat32x4(MlasErfConstants.Exp_P6)); + // 1.0 - exp(fx) * 2^INT(r) + y = MlasMultiplyFloat32x4(y, MlasPowerOf2Float32x4(r)); + y = MlasSubtractFloat32x4(MlasBroadcastFloat32x4(MlasErfConstants.ErfOne), y); + + // merge two splits results + y = MlasOrFloat32x4(r_small, y); + y = MlasOrFloat32x4(y, SignMask); + + MlasStoreFloat32x4(Output, y); + + Input += 4; + Output += 4; + N -= 4; + } + + while (N > 0) { + float Value = *Input++; + float AbsValue = fabsf(Value); + + float r; + if (AbsValue > MlasErfConstants.ErfSplitBoundary) { + AbsValue = (std::min)(MlasErfConstants.ErfUpperAbsRange, AbsValue); + float r_big = MlasErfConstants.ErfBIG_P0; + r_big = r_big * AbsValue + MlasErfConstants.ErfBIG_P1; + r_big = r_big * AbsValue + MlasErfConstants.ErfBIG_P2; + r_big = r_big * AbsValue + MlasErfConstants.ErfBIG_P3; + r_big = r_big * AbsValue + MlasErfConstants.ErfBIG_P4; + r_big = r_big * AbsValue + MlasErfConstants.ErfBIG_P5; + r_big = r_big * AbsValue + MlasErfConstants.ErfBIG_P6_Minus_One; + r_big = r_big * AbsValue + AbsValue; + + r_big = (std::max)(-r_big, MlasErfConstants.Exp_LowerRange); + r = MlasErfConstants.Exp_Log2Reciprocal * r_big + MlasErfConstants.Exp_C; + r -= MlasErfConstants.Exp_C; + float fx = r * MlasErfConstants.Exp_log2_hi + r_big; + fx = r * MlasErfConstants.Exp_log2_lo + fx; + + float y = MlasErfConstants.Exp_P0; + y = y * fx + MlasErfConstants.Exp_P1; + y = y * fx + MlasErfConstants.Exp_P2; + y = y * fx + MlasErfConstants.Exp_P3; + y = y * fx + MlasErfConstants.Exp_P4; + y = y * fx + MlasErfConstants.Exp_P5; + y = y * fx + MlasErfConstants.Exp_P6; + + r = 1.0f - ldexpf(y, (int)r); + r = (Value <= -0.0f) ? -r : r; + } + else { + float SquareValue = AbsValue * AbsValue; + r = MlasErfConstants.ErfSMALL_P0; + r = r * SquareValue + MlasErfConstants.ErfSMALL_P1; + r = r * SquareValue + MlasErfConstants.ErfSMALL_P2; + r = r * SquareValue + MlasErfConstants.ErfSMALL_P3; + r = r * SquareValue + MlasErfConstants.ErfSMALL_P4; + r = r * SquareValue + MlasErfConstants.ErfSMALL_P5_Minus_One; + r = r * Value + Value; + } + + *Output++ = r; + N -= 1; + } +} + +void +MLASCALL +MlasComputeErf( + const float* Input, + float* Output, + size_t N + ) +/*++ + +Routine Description: + + This routine computes the error function. + +Arguments: + + Input - Supplies the input buffer. + + Output - Supplies the output buffer. + + N - Supplies the number of elements to process. + +Return Value: + + None. + +--*/ +{ +#if defined(MLAS_TARGET_AMD64) + MlasPlatform.ErfKernelRoutine(Input, Output, N); +#else + MlasErfKernel(Input, Output, N); +#endif +} diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 084a27b0c3..65ce61bd38 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -174,6 +174,16 @@ void typedef MLAS_TANH_KERNEL_ROUTINE* PMLAS_TANH_KERNEL_ROUTINE; +typedef +void +(MLASCALL MLAS_ERF_KERNEL_ROUTINE)( + const float* Input, + float* Output, + size_t N + ); + +typedef MLAS_ERF_KERNEL_ROUTINE* PMLAS_ERF_KERNEL_ROUTINE; + extern "C" { MLAS_SGEMM_KERNEL_ROUTINE MlasSgemmKernelZero; @@ -203,9 +213,11 @@ extern "C" { MLAS_TANH_KERNEL_ROUTINE MlasLogisticKernel; MLAS_TANH_KERNEL_ROUTINE MlasTanhKernel; + MLAS_ERF_KERNEL_ROUTINE MlasErfKernel; #if defined(MLAS_TARGET_AMD64) MLAS_TANH_KERNEL_ROUTINE MlasLogisticKernelFma3; MLAS_TANH_KERNEL_ROUTINE MlasTanhKernelFma3; + MLAS_ERF_KERNEL_ROUTINE MlasErfKernelFma3; #endif } @@ -269,6 +281,7 @@ struct MLAS_PLATFORM { PMLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE TransposePackB16x4Routine; PMLAS_LOGISTIC_KERNEL_ROUTINE LogisticKernelRoutine; PMLAS_TANH_KERNEL_ROUTINE TanhKernelRoutine; + PMLAS_ERF_KERNEL_ROUTINE ErfKernelRoutine; #endif #if defined(MLAS_USE_WIN32_THREADPOOL) @@ -574,6 +587,75 @@ MlasMinimumFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) #endif } +inline +MLAS_FLOAT32X4 +MlasGreaterThanFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) +{ +#if defined(MLAS_NEON_INTRINSICS) + return vreinterpretq_f32_u32(vcgtq_f32(Vector1, Vector2)); +#elif defined(MLAS_SSE2_INTRINSICS) + return _mm_cmpgt_ps(Vector1, Vector2); +#endif +} + +inline +MLAS_FLOAT32X4 +MlasAndFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) +{ +#if defined(MLAS_NEON_INTRINSICS) + return vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(Vector1), vreinterpretq_u32_f32(Vector2))); +#elif defined(MLAS_SSE2_INTRINSICS) + return _mm_and_ps(Vector1, Vector2); +#endif +} + +inline +MLAS_FLOAT32X4 +MlasOrFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) +{ +#if defined(MLAS_NEON_INTRINSICS) + return vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(Vector1), vreinterpretq_u32_f32(Vector2))); +#elif defined(MLAS_SSE2_INTRINSICS) + return _mm_or_ps(Vector1, Vector2); +#endif +} + +inline +MLAS_FLOAT32X4 +MlasAndNotFloat32x4(MLAS_FLOAT32X4 VectorNot, MLAS_FLOAT32X4 Vector) +{ +#if defined(MLAS_NEON_INTRINSICS) + return vreinterpretq_f32_u32(vandq_u32(vmvnq_u32(vreinterpretq_u32_f32(VectorNot)), vreinterpretq_u32_f32(Vector))); +#elif defined(MLAS_SSE2_INTRINSICS) + return _mm_andnot_ps(VectorNot, Vector); +#endif +} + +inline +MLAS_FLOAT32X4 +MlasXorFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) +{ +#if defined(MLAS_NEON_INTRINSICS) + return vreinterpretq_f32_u32(veorq_u32(vreinterpretq_u32_f32(Vector1), vreinterpretq_u32_f32(Vector2))); +#elif defined(MLAS_SSE2_INTRINSICS) + return _mm_xor_ps(Vector1, Vector2); +#endif +} + +// calc 2^int(N) +inline +MLAS_FLOAT32X4 +MlasPowerOf2Float32x4(MLAS_FLOAT32X4 Vector) +{ +#if defined(MLAS_NEON_INTRINSICS) + int32x4_t emm0 = vaddq_s32(vcvtq_s32_f32(Vector), vdupq_n_s32(0x7f)); + return vreinterpretq_f32_s32(vshlq_n_s32(emm0, 23)); +#elif defined(MLAS_SSE2_INTRINSICS) + __m128i emm0 = _mm_add_epi32(_mm_cvttps_epi32(Vector), _mm_set1_epi32(0x7f)); + return _mm_castsi128_ps(_mm_slli_epi32(emm0, 23)); +#endif +} + // // Reads a platform specific time stamp counter. // diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 88c3dd4579..904d16275d 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -90,6 +90,7 @@ Return Value: this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4Sse; this->LogisticKernelRoutine = MlasLogisticKernel; this->TanhKernelRoutine = MlasTanhKernel; + this->ErfKernelRoutine = MlasErfKernel; #endif // @@ -144,6 +145,7 @@ Return Value: this->LogisticKernelRoutine = MlasLogisticKernelFma3; this->TanhKernelRoutine = MlasTanhKernelFma3; + this->ErfKernelRoutine = MlasErfKernelFma3; } else { diff --git a/onnxruntime/core/mlas/lib/x86_64/ErfKernelFma3.S b/onnxruntime/core/mlas/lib/x86_64/ErfKernelFma3.S new file mode 100644 index 0000000000..29518fb911 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/ErfKernelFma3.S @@ -0,0 +1,517 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + ErfKernelFma3.s + +Abstract: + + This module implements a kernel for computing the error function for a + buffer of elements. + + This implementation uses AVX fused multiply/add instructions. + +--*/ + +#include "asmmacro.h" + + .intel_syntax noprefix + + .text + +// +// Structure layout for the erf constants block. +// + .equ ErfUpperAbsRange, 0 + .equ ErfSplitBoundary, 4 + .equ ErfSMALL_P0, 8 + .equ ErfSMALL_P1, 12 + .equ ErfSMALL_P2, 16 + .equ ErfSMALL_P3, 20 + .equ ErfSMALL_P4, 24 + .equ ErfSMALL_P5_Minus_One, 28 + .equ ErfReserve0, 32 + .equ ErfBIG_P0, 36 + .equ ErfBIG_P1, 40 + .equ ErfBIG_P2, 44 + .equ ErfBIG_P3, 48 + .equ ErfBIG_P4, 52 + .equ ErfBIG_P5, 56 + .equ ErfBIG_P6_Minus_One, 60 + .equ ErfNegZero, 64 + .equ ErfOne, 68 + + .equ ExpConstOffset, 72 + .equ Exp_UpperRange, 0 + ExpConstOffset + .equ Exp_LowerRange, 4 + ExpConstOffset + .equ Exp_Log2Reciprocal, 8 + ExpConstOffset + .equ Exp_log2_hi, 12 + ExpConstOffset + .equ Exp_log2_lo, 16 + ExpConstOffset + .equ Exp_P0, 20 + ExpConstOffset + .equ Exp_P1, 24 + ExpConstOffset + .equ Exp_P2, 28 + ExpConstOffset + .equ Exp_P3, 32 + ExpConstOffset + .equ Exp_P4, 36 + ExpConstOffset + .equ Exp_P5, 40 + ExpConstOffset + .equ Exp_P6, 44 + ExpConstOffset + .equ Exp_C, 48 + ExpConstOffset + .equ Exp_X7F, 52 + ExpConstOffset + +// +// Stack frame layout for the erf kernel. +// + .equ ErfBuffer0, 0 + .equ ErfBuffer1, 128 + .equ ErfKernelFrame_CountN, 256 + .equ ErfKernelFrame_ReturnAddress, 256+8 + +/*++ + +Routine Description: + + This routine implements a vectorized kernel for the error function. + +Arguments: + + Input (rdi) - Supplies the input buffer. + + Output (rsi) - Supplies the output buffer. + + N (rdx) - Supplies the number of elements to process. + +Return Value: + + None. + +--*/ + + .globl C_UNDERSCORE(MlasErfKernelFma3) +C_UNDERSCORE(MlasErfKernelFma3): + sub rsp,ErfKernelFrame_ReturnAddress + mov rax,C_UNDERSCORE(MlasErfConstants)@GOTPCREL[rip] + + sub rdx,8*4 + jb .LErfProcessRemainingCount + +.LComputeErf4x8Loop: + vbroadcastss ymm15,ErfNegZero[rax] + vmovups ymm0,YMMWORD PTR [rdi] # original input vx0 + vmovups ymm1,YMMWORD PTR [rdi+32] # original input vx1 + vmovups ymm2,YMMWORD PTR [rdi+64] # original input vx2 + vmovups ymm3,YMMWORD PTR [rdi+96] # original input vx3 + + vandps ymm4,ymm0,ymm15 # vsign0 + vandps ymm5,ymm1,ymm15 # vsign1 + vandps ymm6,ymm2,ymm15 # vsign2 + vandps ymm7,ymm3,ymm15 # vsign3 + vandnps ymm0,ymm15,ymm0 # abs(vx0) va0 + vandnps ymm1,ymm15,ymm1 # abs(vx1) va1 + vandnps ymm2,ymm15,ymm2 # abs(vx2) va2 + vandnps ymm3,ymm15,ymm3 # abs(vx3) va3 + + vbroadcastss ymm14,ErfUpperAbsRange[rax] + vmovups YMMWORD PTR ErfBuffer0[rsp],ymm4 + vmovups YMMWORD PTR ErfBuffer0[rsp+32],ymm5 + vmovups YMMWORD PTR ErfBuffer0[rsp+64],ymm6 + vmovups YMMWORD PTR ErfBuffer0[rsp+96],ymm7 + + vbroadcastss ymm8,ErfSMALL_P0[rax] + vminps ymm0,ymm0,ymm14 # force abs value in range + vminps ymm1,ymm1,ymm14 + vminps ymm2,ymm2,ymm14 + vminps ymm3,ymm3,ymm14 + vmovaps ymm9,ymm8 + vmovaps ymm10,ymm8 + vmovaps ymm11,ymm8 + + vbroadcastss ymm15,ErfSMALL_P1[rax] + vmulps ymm4,ymm0,ymm0 # vs0 (square) + vmulps ymm5,ymm1,ymm1 # vs1 + vmulps ymm6,ymm2,ymm2 # vs2 + vmulps ymm7,ymm3,ymm3 # vs3 + + vbroadcastss ymm14,ErfSMALL_P2[rax] + vfmadd213ps ymm8,ymm4,ymm15 + vfmadd213ps ymm9,ymm5,ymm15 + vfmadd213ps ymm10,ymm6,ymm15 + vfmadd213ps ymm11,ymm7,ymm15 + + vbroadcastss ymm13,ErfSMALL_P3[rax] + vfmadd213ps ymm8,ymm4,ymm14 + vfmadd213ps ymm9,ymm5,ymm14 + vfmadd213ps ymm10,ymm6,ymm14 + vfmadd213ps ymm11,ymm7,ymm14 + + vbroadcastss ymm15,ErfSMALL_P4[rax] + vfmadd213ps ymm8,ymm4,ymm13 + vfmadd213ps ymm9,ymm5,ymm13 + vfmadd213ps ymm10,ymm6,ymm13 + vfmadd213ps ymm11,ymm7,ymm13 + + vbroadcastss ymm14,ErfSMALL_P5_Minus_One[rax] + vfmadd213ps ymm8,ymm4,ymm15 + vfmadd213ps ymm9,ymm5,ymm15 + vfmadd213ps ymm10,ymm6,ymm15 + vfmadd213ps ymm11,ymm7,ymm15 + + vfmadd213ps ymm8,ymm4,ymm14 + vfmadd213ps ymm9,ymm5,ymm14 + vfmadd213ps ymm10,ymm6,ymm14 + vfmadd213ps ymm11,ymm7,ymm14 + + vbroadcastss ymm12,ErfSplitBoundary[rax] + vfmadd213ps ymm8,ymm0,ymm0 + vfmadd213ps ymm9,ymm1,ymm1 + vfmadd213ps ymm10,ymm2,ymm2 + vfmadd213ps ymm11,ymm3,ymm3 + + vcmpgtps ymm4,ymm0,ymm12 # vmask0 + vcmpgtps ymm5,ymm1,ymm12 # vmask1 + vcmpgtps ymm6,ymm2,ymm12 # vmask2 + vcmpgtps ymm7,ymm3,ymm12 # vmask3 + + vandnps ymm8,ymm4,ymm8 + vandnps ymm9,ymm5,ymm9 + vandnps ymm10,ymm6,ymm10 + vandnps ymm11,ymm7,ymm11 + + vbroadcastss ymm15,ErfBIG_P1[rax] + vmovups YMMWORD PTR ErfBuffer1[rsp],ymm8 + vmovups YMMWORD PTR ErfBuffer1[rsp+32],ymm9 + vmovups YMMWORD PTR ErfBuffer1[rsp+64],ymm10 + vmovups YMMWORD PTR ErfBuffer1[rsp+96],ymm11 + +.BiggerNumbers: + vbroadcastss ymm8,ErfBIG_P0[rax] + vandps ymm0,ymm4,ymm0 + vandps ymm1,ymm5,ymm1 + vandps ymm2,ymm6,ymm2 + vandps ymm3,ymm7,ymm3 + vmovaps ymm9,ymm8 + vmovaps ymm10,ymm8 + vmovaps ymm11,ymm8 + + vbroadcastss ymm14,ErfBIG_P2[rax] + vfmadd213ps ymm8,ymm0,ymm15 + vfmadd213ps ymm9,ymm1,ymm15 + vfmadd213ps ymm10,ymm2,ymm15 + vfmadd213ps ymm11,ymm3,ymm15 + + vbroadcastss ymm13,ErfBIG_P3[rax] + vfmadd213ps ymm8,ymm0,ymm14 + vfmadd213ps ymm9,ymm1,ymm14 + vfmadd213ps ymm10,ymm2,ymm14 + vfmadd213ps ymm11,ymm3,ymm14 + + vbroadcastss ymm15,ErfBIG_P4[rax] + vfmadd213ps ymm8,ymm0,ymm13 + vfmadd213ps ymm9,ymm1,ymm13 + vfmadd213ps ymm10,ymm2,ymm13 + vfmadd213ps ymm11,ymm3,ymm13 + + vbroadcastss ymm14,ErfBIG_P5[rax] + vfmadd213ps ymm8,ymm0,ymm15 + vfmadd213ps ymm9,ymm1,ymm15 + vfmadd213ps ymm10,ymm2,ymm15 + vfmadd213ps ymm11,ymm3,ymm15 + + vbroadcastss ymm13,ErfBIG_P6_Minus_One[rax] + vfmadd213ps ymm8,ymm0,ymm14 + vfmadd213ps ymm9,ymm1,ymm14 + vfmadd213ps ymm10,ymm2,ymm14 + vfmadd213ps ymm11,ymm3,ymm14 + + vbroadcastss ymm15,ErfNegZero[rax] + vfmadd213ps ymm8,ymm0,ymm13 + vfmadd213ps ymm9,ymm1,ymm13 + vfmadd213ps ymm10,ymm2,ymm13 + vfmadd213ps ymm11,ymm3,ymm13 + + vbroadcastss ymm14,Exp_LowerRange[rax] + vfmadd213ps ymm8,ymm0,ymm0 + vfmadd213ps ymm9,ymm1,ymm1 + vfmadd213ps ymm10,ymm2,ymm2 + vfmadd213ps ymm11,ymm3,ymm3 + + vbroadcastss ymm4,Exp_Log2Reciprocal[rax] + vxorps ymm8,ymm8,ymm15 + vxorps ymm9,ymm9,ymm15 + vxorps ymm10,ymm10,ymm15 + vxorps ymm11,ymm11,ymm15 + + vbroadcastss ymm13,Exp_C[rax] + vmovaps ymm5,ymm4 + vmovaps ymm6,ymm4 + vmovaps ymm7,ymm4 + + # expf(ymm8 -- ymm11) + vmaxps ymm8,ymm8,ymm14 + vmaxps ymm9,ymm9,ymm14 + vmaxps ymm10,ymm10,ymm14 + vmaxps ymm11,ymm11,ymm14 + + vbroadcastss ymm0,Exp_log2_hi[rax] + vfmadd213ps ymm4,ymm8,ymm13 + vfmadd213ps ymm5,ymm9,ymm13 + vfmadd213ps ymm6,ymm10,ymm13 + vfmadd213ps ymm7,ymm11,ymm13 + + vbroadcastss ymm15,Exp_log2_lo[rax] + vmovaps ymm1,ymm0 + vmovaps ymm2,ymm0 + vmovaps ymm3,ymm0 + + vsubps ymm4,ymm4,ymm13 # vr = round() + vsubps ymm5,ymm5,ymm13 + vsubps ymm6,ymm6,ymm13 + vsubps ymm7,ymm7,ymm13 + + vfmadd213ps ymm0,ymm4,ymm8 # vf = vr * log2_hi + ve + vfmadd213ps ymm1,ymm5,ymm9 + vfmadd213ps ymm2,ymm6,ymm10 + vfmadd213ps ymm3,ymm7,ymm11 + + vbroadcastss ymm8,Exp_P0[rax] + vfmadd231ps ymm0,ymm4,ymm15 # vf += vr * log_2_lo + vfmadd231ps ymm1,ymm5,ymm15 + vfmadd231ps ymm2,ymm6,ymm15 + vfmadd231ps ymm3,ymm7,ymm15 + vmovaps ymm9,ymm8 + vmovaps ymm10,ymm8 + vmovaps ymm11,ymm8 + + vbroadcastss ymm14,Exp_P1[rax] + vbroadcastss ymm13,Exp_P2[rax] + vfmadd213ps ymm8,ymm0,ymm14 # *+ exp_p1 + vfmadd213ps ymm9,ymm1,ymm14 + vfmadd213ps ymm10,ymm2,ymm14 + vfmadd213ps ymm11,ymm3,ymm14 + + vbroadcastss ymm12,Exp_P3[rax] + vfmadd213ps ymm8,ymm0,ymm13 # *+ exp_p2 + vfmadd213ps ymm9,ymm1,ymm13 + vfmadd213ps ymm10,ymm2,ymm13 + vfmadd213ps ymm11,ymm3,ymm13 + + vbroadcastss ymm15,Exp_P4[rax] + vfmadd213ps ymm8,ymm0,ymm12 # *+ exp_p3 + vfmadd213ps ymm9,ymm1,ymm12 + vfmadd213ps ymm10,ymm2,ymm12 + vfmadd213ps ymm11,ymm3,ymm12 + + vbroadcastss ymm14,Exp_P5[rax] + vfmadd213ps ymm8,ymm0,ymm15 # *+ exp_p4 + vfmadd213ps ymm9,ymm1,ymm15 + vfmadd213ps ymm10,ymm2,ymm15 + vfmadd213ps ymm11,ymm3,ymm15 + + vbroadcastss ymm13,Exp_P6[rax] + vfmadd213ps ymm8,ymm0,ymm14 # *+ exp_p5 + vfmadd213ps ymm9,ymm1,ymm14 + vfmadd213ps ymm10,ymm2,ymm14 + vfmadd213ps ymm11,ymm3,ymm14 + + vbroadcastss ymm12,Exp_X7F[rax] + vfmadd213ps ymm8,ymm0,ymm13 # *+ exp_p6 + vfmadd213ps ymm9,ymm1,ymm13 + vfmadd213ps ymm10,ymm2,ymm13 + vfmadd213ps ymm11,ymm3,ymm13 + + vcvttps2dq ymm4,ymm4 + vcvttps2dq ymm5,ymm5 + vcvttps2dq ymm6,ymm6 + vcvttps2dq ymm7,ymm7 + + + vbroadcastss ymm15,ErfOne[rax] + vpaddd ymm4,ymm4,ymm12 # +127 + vpaddd ymm5,ymm5,ymm12 + vpaddd ymm6,ymm6,ymm12 + vpaddd ymm7,ymm7,ymm12 + + vpslld ymm4,ymm4,23 + vpslld ymm5,ymm5,23 + vpslld ymm6,ymm6,23 + vpslld ymm7,ymm7,23 + + vmulps ymm8,ymm8,ymm4 # 2^i * exp(vf) + vmulps ymm9,ymm9,ymm5 + vmulps ymm10,ymm10,ymm6 + vmulps ymm11,ymm11,ymm7 + + vsubps ymm8,ymm15,ymm8 + vsubps ymm9,ymm15,ymm9 + vsubps ymm10,ymm15,ymm10 + vsubps ymm11,ymm15,ymm11 + + # merge small numbers' result + vorps ymm8,ymm8,YMMWORD PTR ErfBuffer1[rsp] + vorps ymm9,ymm9,YMMWORD PTR ErfBuffer1[rsp+32] + vorps ymm10,ymm10,YMMWORD PTR ErfBuffer1[rsp+64] + vorps ymm11,ymm11,YMMWORD PTR ErfBuffer1[rsp+96] + + # copy sign + vorps ymm0,ymm8,YMMWORD PTR ErfBuffer0[rsp] + vorps ymm1,ymm9,YMMWORD PTR ErfBuffer0[rsp+32] + vorps ymm2,ymm10,YMMWORD PTR ErfBuffer0[rsp+64] + vorps ymm3,ymm11,YMMWORD PTR ErfBuffer0[rsp+96] + + vmovups YMMWORD PTR [rsi],ymm0 + vmovups YMMWORD PTR [rsi+32],ymm1 + vmovups YMMWORD PTR [rsi+64],ymm2 + vmovups YMMWORD PTR [rsi+96],ymm3 + + add rdi,32*4 # advance by 4*8 elements + add rsi,32*4 + sub rdx,32 + jae .LComputeErf4x8Loop + +.LErfProcessRemainingCount: + add rdx,32 # correct for over-subtract above + jz .LErfBatchExp + +.LErfProcess1x8: + mov DWORD PTR ErfKernelFrame_CountN[rsp],edx + mov rcx,QWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)@GOTPCREL[rip] + vbroadcastss ymm3,DWORD PTR ErfKernelFrame_CountN[rsp] + + vpcmpgtd ymm3,ymm3,YMMWORD PTR [rcx] + vbroadcastss ymm15,ErfNegZero[rax] + vmaskmovps ymm0,ymm3,YMMWORD PTR [rdi] # original input vx0 + + vandps ymm4,ymm0,ymm15 # vsign0 + vandnps ymm0,ymm15,ymm0 # abs(vx0) va0 + + vbroadcastss ymm14,ErfUpperAbsRange[rax] + vmovups YMMWORD PTR ErfBuffer0[rsp],ymm4 + + vbroadcastss ymm8,ErfSMALL_P0[rax] + vminps ymm0,ymm0,ymm14 # force abs value in range + + vbroadcastss ymm15,ErfSMALL_P1[rax] + vmulps ymm4,ymm0,ymm0 # vs0 (square) + + vbroadcastss ymm14,ErfSMALL_P2[rax] + vfmadd213ps ymm8,ymm4,ymm15 + + vbroadcastss ymm13,ErfSMALL_P3[rax] + vfmadd213ps ymm8,ymm4,ymm14 + + vbroadcastss ymm15,ErfSMALL_P4[rax] + vfmadd213ps ymm8,ymm4,ymm13 + + vbroadcastss ymm14,ErfSMALL_P5_Minus_One[rax] + vfmadd213ps ymm8,ymm4,ymm15 + + vfmadd213ps ymm8,ymm4,ymm14 + + vbroadcastss ymm12,ErfSplitBoundary[rax] + vfmadd213ps ymm8,ymm0,ymm0 + + vcmpgtps ymm4,ymm0,ymm12 # vmask0 + + vandnps ymm8,ymm4,ymm8 + + vmovups YMMWORD PTR ErfBuffer1[rsp],ymm8 + +.BiggerNumbersRemaining: + vbroadcastss ymm15,ErfBIG_P1[rax] + vbroadcastss ymm8,ErfBIG_P0[rax] + vandps ymm0,ymm4,ymm0 + + vbroadcastss ymm14,ErfBIG_P2[rax] + vfmadd213ps ymm8,ymm0,ymm15 + + vbroadcastss ymm13,ErfBIG_P3[rax] + vfmadd213ps ymm8,ymm0,ymm14 + + vbroadcastss ymm15,ErfBIG_P4[rax] + vfmadd213ps ymm8,ymm0,ymm13 + + vbroadcastss ymm14,ErfBIG_P5[rax] + vfmadd213ps ymm8,ymm0,ymm15 + + vbroadcastss ymm13,ErfBIG_P6_Minus_One[rax] + vfmadd213ps ymm8,ymm0,ymm14 + + vbroadcastss ymm15,ErfNegZero[rax] + vfmadd213ps ymm8,ymm0,ymm13 + + vbroadcastss ymm14,Exp_LowerRange[rax] + vfmadd213ps ymm8,ymm0,ymm0 + + vbroadcastss ymm4,Exp_Log2Reciprocal[rax] + vxorps ymm8,ymm8,ymm15 + + vbroadcastss ymm13,Exp_C[rax] + + # expf(ymm8 -- ymm11) + vmaxps ymm8,ymm8,ymm14 + + vbroadcastss ymm0,Exp_log2_hi[rax] + vfmadd213ps ymm4,ymm8,ymm13 + + vbroadcastss ymm15,Exp_log2_lo[rax] + + vsubps ymm4,ymm4,ymm13 # vr = round() + + vfmadd213ps ymm0,ymm4,ymm8 # vf = vr * log2_hi + ve + + vbroadcastss ymm8,Exp_P0[rax] + + vfmadd231ps ymm0,ymm4,ymm15 # vf += vr * log_2_lo + + vbroadcastss ymm14,Exp_P1[rax] + + vbroadcastss ymm13,Exp_P2[rax] + vfmadd213ps ymm8,ymm0,ymm14 # *+ exp_p1 + + vbroadcastss ymm12,Exp_P3[rax] + vfmadd213ps ymm8,ymm0,ymm13 # *+ exp_p2 + + vbroadcastss ymm15,Exp_P4[rax] + vfmadd213ps ymm8,ymm0,ymm12 # *+ exp_p3 + + vbroadcastss ymm14,Exp_P5[rax] + vfmadd213ps ymm8,ymm0,ymm15 # *+ exp_p4 + + vbroadcastss ymm13,Exp_P6[rax] + vfmadd213ps ymm8,ymm0,ymm14 # *+ exp_p5 + + vbroadcastss ymm12,Exp_X7F[rax] + vfmadd213ps ymm8,ymm0,ymm13 # *+ exp_p6 + + vcvttps2dq ymm4,ymm4 + + vbroadcastss ymm15,ErfOne[rax] + vpaddd ymm4,ymm4,ymm12 # +127 + + vpslld ymm4,ymm4,23 + + vmulps ymm8,ymm8,ymm4 # 2^i * exp(vf) + + vsubps ymm8,ymm15,ymm8 + + # merge small numbers' result + vorps ymm8,ymm8,YMMWORD PTR ErfBuffer1[rsp] + + # copy sign + vorps ymm0,ymm8,YMMWORD PTR ErfBuffer0[rsp] + + vmaskmovps YMMWORD PTR [rsi],ymm3,ymm0 + + add rdi,8*4 + add rsi,8*4 + sub rdx,8 + jg .LErfProcess1x8 + +.LErfBatchExp: + vzeroupper + add rsp,ErfKernelFrame_ReturnAddress + ret + + .end diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index a19038ec7b..b56ff8f31a 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -4,6 +4,7 @@ #include "core/providers/cpu/math/element_wise_ops.h" #include #include "core/util/math.h" +#include "core/mlas/inc/mlas.h" #include @@ -1032,7 +1033,8 @@ Status Erf::Compute(OpKernelContext* context) const { ORT_ENFORCE(X_ptr != nullptr); auto& X = *X_ptr; auto& Y = *context->Output(0, X.Shape()); - EigenMap(Y) = EigenMap(X).array().erf(); + + MlasComputeErf(X.template Data(), Y.template MutableData(), X.Shape().Size()); return Status::OK(); } diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index e8bc5db12f..030f357b96 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -1025,6 +1025,35 @@ TEST(MathOpTest, Erf) { test.Run(); } +TEST(MathOpTest, ErfMoreData) { + OpTester test("Erf", 9); + std::vector inputs{ + -3.625f, 3.375f, 0.0f, 0.00025f, 0.0005f, -0.00075f, -0.001f, 0.00125f, + 0.0015f, -3.125f, 0.00175f, 2.875f, 2.625f, 2.375f, 2.125f, 6.25e-05f, + 0.0003125f, 0.0005625f, -0.0008125f, 0.0010625f, 0.0013125f, 0.0015625f, 0.0018125f, 3.5625f, + 3.3125f, 3.0625f, 2.8125f, -2.5625f, 2.3125f, 2.0625f, 0.000125f, 0.000375f, + -0.000625f, -0.000875f, -0.001125f, -0.001375f, -0.001625f, -0.001875f, -3.5f, -3.25f, + 3.0f, 2.75f, -2.5f, -2.25f, -2.0f, -0.0001875f, 0.0004375f, 0.0006875f, + 2.1875f, -1.9375f, 0.0014375f, -0.0016875f, -0.0019375f, 3.4375f, 3.1875f, -2.9375f, + -2.4375f, -0.0009375f, 0.0011875f + }; + std::vector outputs{ + -1.0f, 0.999998f, 0.0f, 0.000282095f, 0.00056419f, -0.000846284f, -0.00112838f, 0.00141047f, + 0.00169257f, -0.99999f, 0.00197466f, 0.999952f, 0.999795f, 0.999217f, 0.997346f, 7.05237e-05f, + 0.000352618f, 0.000634713f, -0.000916808f, 0.0011989f, 0.001481f, 0.00176309f, 0.00204518f, 1.0f, + 0.999997f, 0.999985f, 0.99993f, -0.99971f, 0.998926f, 0.996464f, 0.000141047f, 0.000423142f, + -0.000705237f, -0.000987331f, -0.00126943f, -0.00155152f, -0.00183361f, -0.00211571f, -0.999999f, -0.999996f, + 0.999978f, 0.999899f, -0.999593f, -0.998537f, -0.995322f, -0.000211571f, 0.000493666f, 0.000775761f, + 0.998022f, -0.993857f, 0.00162204f, -0.00190414f, -0.00218623f, 0.999999f, 0.999993f, -0.999967f, + -0.999433f, -0.00105786f, 0.00133995f + }; + std::vector dims{static_cast(inputs.size())}; + + test.AddInput("A", dims, inputs); + test.AddOutput("B", dims, outputs); + test.Run(); +} + const int ModOp_ver = 10; TEST(ModOpTest, Fmod_float_mixed_sign) {