Zhalei/erff (#846)

Implement error function in mlas with avx2 optimization.
This commit is contained in:
Zhang Lei 2019-05-06 14:05:04 -07:00 committed by GitHub
parent 7e88ca19ee
commit 468de7c8af
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 1484 additions and 1 deletions

View file

@ -10,6 +10,7 @@ set(mlas_common_srcs
${ONNXRUNTIME_ROOT}/core/mlas/lib/activate.cpp
${ONNXRUNTIME_ROOT}/core/mlas/lib/logistic.cpp
${ONNXRUNTIME_ROOT}/core/mlas/lib/tanh.cpp
${ONNXRUNTIME_ROOT}/core/mlas/lib/erf.cpp
)
if(MSVC)
@ -65,6 +66,7 @@ if(MSVC)
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/cvtfp16a.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/LogisticKernelFma3.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/TanhKernelFma3.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/ErfKernelFma3.asm
)
endif()
@ -157,6 +159,7 @@ else()
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/SgemmKernelFma3.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/LogisticKernelFma3.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/TanhKernelFma3.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/ErfKernelFma3.S
)
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma")

View file

@ -226,6 +226,14 @@ MlasComputeTanh(
size_t N
);
void
MLASCALL
MlasComputeErf(
const float* Input,
float* Output,
size_t N
);
//
// Half-precision floating-point routines.
//

View file

@ -0,0 +1,569 @@
;++
;
; Copyright (c) Microsoft Corporation. All rights reserved.
;
; Licensed under the MIT License.
;
; Module Name:
;
; ErfKernelFma3.asm
;
; Abstract:
;
; This module implements a kernel for computing the error function for a
; buffer of elements.
;
; This implementation uses AVX fused multiply/add instructions.
;
;--
.xlist
INCLUDE mlasi.inc
.list
EXTERN MlasMaskMoveAvx:NEAR
EXTERN MlasErfConstants:NEAR
;
; Structure layout for the erf constants block.
;
ErfConstants STRUCT
ErfUpperAbsRange DWORD ?
ErfSplitBoundary DWORD ?
ErfSMALL_P0 DWORD ?
ErfSMALL_P1 DWORD ?
ErfSMALL_P2 DWORD ?
ErfSMALL_P3 DWORD ?
ErfSMALL_P4 DWORD ?
ErfSMALL_P5_Minus_One DWORD ?
ErfReserve0 DWORD ?
ErfBIG_P0 DWORD ?
ErfBIG_P1 DWORD ?
ErfBIG_P2 DWORD ?
ErfBIG_P3 DWORD ?
ErfBIG_P4 DWORD ?
ErfBIG_P5 DWORD ?
ErfBIG_P6_Minus_One DWORD ?
ErfNegZero DWORD ?
ErfOne DWORD ?
Exp_UpperRange DWORD ?
Exp_LowerRange DWORD ?
Exp_Log2Reciprocal DWORD ?
Exp_log2_hi DWORD ?
Exp_log2_lo DWORD ?
Exp_P0 DWORD ?
Exp_P1 DWORD ?
Exp_P2 DWORD ?
Exp_P3 DWORD ?
Exp_P4 DWORD ?
Exp_P5 DWORD ?
Exp_P6 DWORD ?
Exp_C DWORD ?
Exp_X7F DWORD ?
ErfConstants ENDS
;
; Stack frame layout for the erf kernel.
;
ErfKernelFrame STRUCT
ErfBuffer0 OWORD 8 DUP(?)
ErfBuffer1 OWORD 8 DUP(?)
SavedXmm6 OWORD ?
SavedXmm7 OWORD ?
SavedXmm8 OWORD ?
SavedXmm9 OWORD ?
SavedXmm10 OWORD ?
SavedXmm11 OWORD ?
SavedXmm12 OWORD ?
SavedXmm13 OWORD ?
SavedXmm14 OWORD ?
SavedXmm15 OWORD ?
Padding0 QWORD ?
Padding1 QWORD ?
CountN QWORD ?
ReturnAddress QWORD ?
PreviousP1Home QWORD ?
PreviousP2Home QWORD ?
PreviousP3Home QWORD ?
PreviousP4Home QWORD ?
ErfKernelFrame ENDS
;++
;
; Routine Description:
;
; This routine implements a vectorized kernel for the error function.
;
; Arguments:
;
; Input (rcx) - Supplies the input buffer.
;
; Output (rdx) - Supplies the output buffer.
;
; N (r8) - Supplies the number of elements to process.
;
; Return Value:
;
; None.
;
;--
NESTED_ENTRY MlasErfKernelFma3, _TEXT
alloc_stack (ErfKernelFrame.ReturnAddress)
save_xmm128_avx xmm6,ErfKernelFrame.SavedXmm6
save_xmm128_avx xmm7,ErfKernelFrame.SavedXmm7
save_xmm128_avx xmm8,ErfKernelFrame.SavedXmm8
save_xmm128_avx xmm9,ErfKernelFrame.SavedXmm9
save_xmm128_avx xmm10,ErfKernelFrame.SavedXmm10
save_xmm128_avx xmm11,ErfKernelFrame.SavedXmm11
save_xmm128_avx xmm12,ErfKernelFrame.SavedXmm12
save_xmm128_avx xmm13,ErfKernelFrame.SavedXmm13
save_xmm128_avx xmm14,ErfKernelFrame.SavedXmm14
save_xmm128_avx xmm15,ErfKernelFrame.SavedXmm15
END_PROLOGUE
lea rax,MlasErfConstants
sub r8,8*4
jb LErfProcessRemainingCount
LComputeErf4x8Loop:
vbroadcastss ymm15,ErfConstants.ErfNegZero[rax]
vmovups ymm0,YMMWORD PTR [rcx] ; original input vx0
vmovups ymm1,YMMWORD PTR [rcx+32] ; original input vx1
vmovups ymm2,YMMWORD PTR [rcx+64] ; original input vx2
vmovups ymm3,YMMWORD PTR [rcx+96] ; original input vx3
vandps ymm4,ymm0,ymm15 ; vsign0
vandps ymm5,ymm1,ymm15 ; vsign1
vandps ymm6,ymm2,ymm15 ; vsign2
vandps ymm7,ymm3,ymm15 ; vsign3
vandnps ymm0,ymm15,ymm0 ; abs(vx0) va0
vandnps ymm1,ymm15,ymm1 ; abs(vx1) va1
vandnps ymm2,ymm15,ymm2 ; abs(vx2) va2
vandnps ymm3,ymm15,ymm3 ; abs(vx3) va3
vbroadcastss ymm14,ErfConstants.ErfUpperAbsRange[rax]
vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp],ymm4
vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp+32],ymm5
vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp+64],ymm6
vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp+96],ymm7
vbroadcastss ymm8,ErfConstants.ErfSMALL_P0[rax]
vminps ymm0,ymm0,ymm14 ; force abs value in range
vminps ymm1,ymm1,ymm14
vminps ymm2,ymm2,ymm14
vminps ymm3,ymm3,ymm14
vmovaps ymm9,ymm8
vmovaps ymm10,ymm8
vmovaps ymm11,ymm8
vbroadcastss ymm15,ErfConstants.ErfSMALL_P1[rax]
vmulps ymm4,ymm0,ymm0 ; vs0 (square)
vmulps ymm5,ymm1,ymm1 ; vs1
vmulps ymm6,ymm2,ymm2 ; vs2
vmulps ymm7,ymm3,ymm3 ; vs3
vbroadcastss ymm14,ErfConstants.ErfSMALL_P2[rax]
vfmadd213ps ymm8,ymm4,ymm15
vfmadd213ps ymm9,ymm5,ymm15
vfmadd213ps ymm10,ymm6,ymm15
vfmadd213ps ymm11,ymm7,ymm15
vbroadcastss ymm13,ErfConstants.ErfSMALL_P3[rax]
vfmadd213ps ymm8,ymm4,ymm14
vfmadd213ps ymm9,ymm5,ymm14
vfmadd213ps ymm10,ymm6,ymm14
vfmadd213ps ymm11,ymm7,ymm14
vbroadcastss ymm15,ErfConstants.ErfSMALL_P4[rax]
vfmadd213ps ymm8,ymm4,ymm13
vfmadd213ps ymm9,ymm5,ymm13
vfmadd213ps ymm10,ymm6,ymm13
vfmadd213ps ymm11,ymm7,ymm13
vbroadcastss ymm14,ErfConstants.ErfSMALL_P5_Minus_One[rax]
vfmadd213ps ymm8,ymm4,ymm15
vfmadd213ps ymm9,ymm5,ymm15
vfmadd213ps ymm10,ymm6,ymm15
vfmadd213ps ymm11,ymm7,ymm15
vfmadd213ps ymm8,ymm4,ymm14
vfmadd213ps ymm9,ymm5,ymm14
vfmadd213ps ymm10,ymm6,ymm14
vfmadd213ps ymm11,ymm7,ymm14
vbroadcastss ymm12,ErfConstants.ErfSplitBoundary[rax]
vfmadd213ps ymm8,ymm0,ymm0
vfmadd213ps ymm9,ymm1,ymm1
vfmadd213ps ymm10,ymm2,ymm2
vfmadd213ps ymm11,ymm3,ymm3
vcmpgtps ymm4,ymm0,ymm12 ; vmask0
vcmpgtps ymm5,ymm1,ymm12 ; vmask1
vcmpgtps ymm6,ymm2,ymm12 ; vmask2
vcmpgtps ymm7,ymm3,ymm12 ; vmask3
vandnps ymm8,ymm4,ymm8
vandnps ymm9,ymm5,ymm9
vandnps ymm10,ymm6,ymm10
vandnps ymm11,ymm7,ymm11
vbroadcastss ymm15,ErfConstants.ErfBIG_P1[rax]
vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp],ymm8
vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp+32],ymm9
vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp+64],ymm10
vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp+96],ymm11
LBiggerNumbers:
vbroadcastss ymm8,ErfConstants.ErfBIG_P0[rax]
vandps ymm0,ymm4,ymm0
vandps ymm1,ymm5,ymm1
vandps ymm2,ymm6,ymm2
vandps ymm3,ymm7,ymm3
vmovaps ymm9,ymm8
vmovaps ymm10,ymm8
vmovaps ymm11,ymm8
vbroadcastss ymm14,ErfConstants.ErfBIG_P2[rax]
vfmadd213ps ymm8,ymm0,ymm15
vfmadd213ps ymm9,ymm1,ymm15
vfmadd213ps ymm10,ymm2,ymm15
vfmadd213ps ymm11,ymm3,ymm15
vbroadcastss ymm13,ErfConstants.ErfBIG_P3[rax]
vfmadd213ps ymm8,ymm0,ymm14
vfmadd213ps ymm9,ymm1,ymm14
vfmadd213ps ymm10,ymm2,ymm14
vfmadd213ps ymm11,ymm3,ymm14
vbroadcastss ymm15,ErfConstants.ErfBIG_P4[rax]
vfmadd213ps ymm8,ymm0,ymm13
vfmadd213ps ymm9,ymm1,ymm13
vfmadd213ps ymm10,ymm2,ymm13
vfmadd213ps ymm11,ymm3,ymm13
vbroadcastss ymm14,ErfConstants.ErfBIG_P5[rax]
vfmadd213ps ymm8,ymm0,ymm15
vfmadd213ps ymm9,ymm1,ymm15
vfmadd213ps ymm10,ymm2,ymm15
vfmadd213ps ymm11,ymm3,ymm15
vbroadcastss ymm13,ErfConstants.ErfBIG_P6_Minus_One[rax]
vfmadd213ps ymm8,ymm0,ymm14
vfmadd213ps ymm9,ymm1,ymm14
vfmadd213ps ymm10,ymm2,ymm14
vfmadd213ps ymm11,ymm3,ymm14
vbroadcastss ymm15,ErfConstants.ErfNegZero[rax]
vfmadd213ps ymm8,ymm0,ymm13
vfmadd213ps ymm9,ymm1,ymm13
vfmadd213ps ymm10,ymm2,ymm13
vfmadd213ps ymm11,ymm3,ymm13
vbroadcastss ymm14,ErfConstants.Exp_LowerRange[rax]
vfmadd213ps ymm8,ymm0,ymm0
vfmadd213ps ymm9,ymm1,ymm1
vfmadd213ps ymm10,ymm2,ymm2
vfmadd213ps ymm11,ymm3,ymm3
vbroadcastss ymm4,ErfConstants.Exp_Log2Reciprocal[rax]
vxorps ymm8,ymm8,ymm15
vxorps ymm9,ymm9,ymm15
vxorps ymm10,ymm10,ymm15
vxorps ymm11,ymm11,ymm15
vbroadcastss ymm13,ErfConstants.Exp_C[rax]
vmovaps ymm5,ymm4
vmovaps ymm6,ymm4
vmovaps ymm7,ymm4
; expf(ymm8 -- ymm11)
vmaxps ymm8,ymm8,ymm14
vmaxps ymm9,ymm9,ymm14
vmaxps ymm10,ymm10,ymm14
vmaxps ymm11,ymm11,ymm14
vbroadcastss ymm0,ErfConstants.Exp_log2_hi[rax]
vfmadd213ps ymm4,ymm8,ymm13
vfmadd213ps ymm5,ymm9,ymm13
vfmadd213ps ymm6,ymm10,ymm13
vfmadd213ps ymm7,ymm11,ymm13
vbroadcastss ymm15,ErfConstants.Exp_log2_lo[rax]
vmovaps ymm1,ymm0
vmovaps ymm2,ymm0
vmovaps ymm3,ymm0
vsubps ymm4,ymm4,ymm13 ; vr = round()
vsubps ymm5,ymm5,ymm13
vsubps ymm6,ymm6,ymm13
vsubps ymm7,ymm7,ymm13
vfmadd213ps ymm0,ymm4,ymm8 ; vf = vr * log2_hi + ve
vfmadd213ps ymm1,ymm5,ymm9
vfmadd213ps ymm2,ymm6,ymm10
vfmadd213ps ymm3,ymm7,ymm11
vbroadcastss ymm8,ErfConstants.Exp_P0[rax]
vfmadd231ps ymm0,ymm4,ymm15 ; vf += vr * log_2_lo
vfmadd231ps ymm1,ymm5,ymm15
vfmadd231ps ymm2,ymm6,ymm15
vfmadd231ps ymm3,ymm7,ymm15
vmovaps ymm9,ymm8
vmovaps ymm10,ymm8
vmovaps ymm11,ymm8
vbroadcastss ymm14,ErfConstants.Exp_P1[rax]
vbroadcastss ymm13,ErfConstants.Exp_P2[rax]
vfmadd213ps ymm8,ymm0,ymm14 ; *+ exp_p1
vfmadd213ps ymm9,ymm1,ymm14
vfmadd213ps ymm10,ymm2,ymm14
vfmadd213ps ymm11,ymm3,ymm14
vbroadcastss ymm12,ErfConstants.Exp_P3[rax]
vfmadd213ps ymm8,ymm0,ymm13 ; *+ exp_p2
vfmadd213ps ymm9,ymm1,ymm13
vfmadd213ps ymm10,ymm2,ymm13
vfmadd213ps ymm11,ymm3,ymm13
vbroadcastss ymm15,ErfConstants.Exp_P4[rax]
vfmadd213ps ymm8,ymm0,ymm12 ; *+ exp_p3
vfmadd213ps ymm9,ymm1,ymm12
vfmadd213ps ymm10,ymm2,ymm12
vfmadd213ps ymm11,ymm3,ymm12
vbroadcastss ymm14,ErfConstants.Exp_P5[rax]
vfmadd213ps ymm8,ymm0,ymm15 ; *+ exp_p4
vfmadd213ps ymm9,ymm1,ymm15
vfmadd213ps ymm10,ymm2,ymm15
vfmadd213ps ymm11,ymm3,ymm15
vbroadcastss ymm13,ErfConstants.Exp_P6[rax]
vfmadd213ps ymm8,ymm0,ymm14 ; *+ exp_p5
vfmadd213ps ymm9,ymm1,ymm14
vfmadd213ps ymm10,ymm2,ymm14
vfmadd213ps ymm11,ymm3,ymm14
vbroadcastss ymm12,ErfConstants.Exp_X7F[rax]
vfmadd213ps ymm8,ymm0,ymm13 ; *+ exp_p6
vfmadd213ps ymm9,ymm1,ymm13
vfmadd213ps ymm10,ymm2,ymm13
vfmadd213ps ymm11,ymm3,ymm13
vcvttps2dq ymm4,ymm4
vcvttps2dq ymm5,ymm5
vcvttps2dq ymm6,ymm6
vcvttps2dq ymm7,ymm7
vbroadcastss ymm15,ErfConstants.ErfOne[rax]
vpaddd ymm4,ymm4,ymm12 ; +127
vpaddd ymm5,ymm5,ymm12
vpaddd ymm6,ymm6,ymm12
vpaddd ymm7,ymm7,ymm12
vpslld ymm4,ymm4,23
vpslld ymm5,ymm5,23
vpslld ymm6,ymm6,23
vpslld ymm7,ymm7,23
vmulps ymm8,ymm8,ymm4 ; 2^i * exp(vf)
vmulps ymm9,ymm9,ymm5
vmulps ymm10,ymm10,ymm6
vmulps ymm11,ymm11,ymm7
vsubps ymm8,ymm15,ymm8
vsubps ymm9,ymm15,ymm9
vsubps ymm10,ymm15,ymm10
vsubps ymm11,ymm15,ymm11
; merge small numbers' result
vorps ymm8,ymm8,YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp]
vorps ymm9,ymm9,YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp+32]
vorps ymm10,ymm10,YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp+64]
vorps ymm11,ymm11,YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp+96]
; copy sign
vorps ymm0,ymm8,YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp]
vorps ymm1,ymm9,YMMWORD PTR 32+ErfKernelFrame.ErfBuffer0[rsp]
vorps ymm2,ymm10,YMMWORD PTR 64+ErfKernelFrame.ErfBuffer0[rsp]
vorps ymm3,ymm11,YMMWORD PTR 96+ErfKernelFrame.ErfBuffer0[rsp]
vmovups YMMWORD PTR [rdx],ymm0
vmovups YMMWORD PTR [rdx+32],ymm1
vmovups YMMWORD PTR [rdx+64],ymm2
vmovups YMMWORD PTR [rdx+96],ymm3
add rcx,32*4 ; advance by 4*8 elements
add rdx,32*4
sub r8,32
jae LComputeErf4x8Loop
LErfProcessRemainingCount:
add r8,32 ; correct for over-subtract above
jz LErfBatchExp
LErfProcess1x8:
mov DWORD PTR ErfKernelFrame.CountN[rsp],r8d
vbroadcastss ymm3,DWORD PTR ErfKernelFrame.CountN[rsp]
vpcmpgtd ymm3,ymm3,YMMWORD PTR [MlasMaskMoveAvx]
vbroadcastss ymm15,ErfConstants.ErfNegZero[rax]
vmaskmovps ymm0,ymm3,YMMWORD PTR [rcx] ; original input vx0
vandps ymm4,ymm0,ymm15 ; vsign0
vandnps ymm0,ymm15,ymm0 ; abs(vx0) va0
vbroadcastss ymm14,ErfConstants.ErfUpperAbsRange[rax]
vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp],ymm4
vbroadcastss ymm8,ErfConstants.ErfSMALL_P0[rax]
vminps ymm0,ymm0,ymm14 ; force abs value in range
vbroadcastss ymm15,ErfConstants.ErfSMALL_P1[rax]
vmulps ymm4,ymm0,ymm0 ; vs0 (square)
vbroadcastss ymm14,ErfConstants.ErfSMALL_P2[rax]
vfmadd213ps ymm8,ymm4,ymm15
vbroadcastss ymm13,ErfConstants.ErfSMALL_P3[rax]
vfmadd213ps ymm8,ymm4,ymm14
vbroadcastss ymm15,ErfConstants.ErfSMALL_P4[rax]
vfmadd213ps ymm8,ymm4,ymm13
vbroadcastss ymm14,ErfConstants.ErfSMALL_P5_Minus_One[rax]
vfmadd213ps ymm8,ymm4,ymm15
vfmadd213ps ymm8,ymm4,ymm14
vbroadcastss ymm12,ErfConstants.ErfSplitBoundary[rax]
vfmadd213ps ymm8,ymm0,ymm0
vcmpgtps ymm4,ymm0,ymm12 ; vmask0
vandnps ymm8,ymm4,ymm8
vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp],ymm8
LBiggerNumbersRemaining:
vbroadcastss ymm15,ErfConstants.ErfBIG_P1[rax]
vbroadcastss ymm8,ErfConstants.ErfBIG_P0[rax]
vandps ymm0,ymm4,ymm0
vbroadcastss ymm14,ErfConstants.ErfBIG_P2[rax]
vfmadd213ps ymm8,ymm0,ymm15
vbroadcastss ymm13,ErfConstants.ErfBIG_P3[rax]
vfmadd213ps ymm8,ymm0,ymm14
vbroadcastss ymm15,ErfConstants.ErfBIG_P4[rax]
vfmadd213ps ymm8,ymm0,ymm13
vbroadcastss ymm14,ErfConstants.ErfBIG_P5[rax]
vfmadd213ps ymm8,ymm0,ymm15
vbroadcastss ymm13,ErfConstants.ErfBIG_P6_Minus_One[rax]
vfmadd213ps ymm8,ymm0,ymm14
vbroadcastss ymm15,ErfConstants.ErfNegZero[rax]
vfmadd213ps ymm8,ymm0,ymm13
vbroadcastss ymm14,ErfConstants.Exp_LowerRange[rax]
vfmadd213ps ymm8,ymm0,ymm0
vbroadcastss ymm4,ErfConstants.Exp_Log2Reciprocal[rax]
vxorps ymm8,ymm8,ymm15
vbroadcastss ymm13,ErfConstants.Exp_C[rax]
; expf(ymm8 -- ymm11)
vmaxps ymm8,ymm8,ymm14
vbroadcastss ymm0,ErfConstants.Exp_log2_hi[rax]
vfmadd213ps ymm4,ymm8,ymm13
vbroadcastss ymm15,ErfConstants.Exp_log2_lo[rax]
vsubps ymm4,ymm4,ymm13 ; vr = round()
vfmadd213ps ymm0,ymm4,ymm8 ; vf = vr * log2_hi + ve
vbroadcastss ymm8,ErfConstants.Exp_P0[rax]
vfmadd231ps ymm0,ymm4,ymm15 ; vf += vr * log_2_lo
vbroadcastss ymm14,ErfConstants.Exp_P1[rax]
vbroadcastss ymm13,ErfConstants.Exp_P2[rax]
vfmadd213ps ymm8,ymm0,ymm14 ; *+ exp_p1
vbroadcastss ymm12,ErfConstants.Exp_P3[rax]
vfmadd213ps ymm8,ymm0,ymm13 ; *+ exp_p2
vbroadcastss ymm15,ErfConstants.Exp_P4[rax]
vfmadd213ps ymm8,ymm0,ymm12 ; *+ exp_p3
vbroadcastss ymm14,ErfConstants.Exp_P5[rax]
vfmadd213ps ymm8,ymm0,ymm15 ; *+ exp_p4
vbroadcastss ymm13,ErfConstants.Exp_P6[rax]
vfmadd213ps ymm8,ymm0,ymm14 ; *+ exp_p5
vbroadcastss ymm12,ErfConstants.Exp_X7F[rax]
vfmadd213ps ymm8,ymm0,ymm13 ; *+ exp_p6
vcvttps2dq ymm4,ymm4
vbroadcastss ymm15,ErfConstants.ErfOne[rax]
vpaddd ymm4,ymm4,ymm12 ; +127
vpslld ymm4,ymm4,23
vmulps ymm8,ymm8,ymm4 ; 2^i * exp(vf)
vsubps ymm8,ymm15,ymm8
; merge small numbers' result
vorps ymm8,ymm8,YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp]
; copy sign
vorps ymm0,ymm8,YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp]
vmaskmovps YMMWORD PTR [rdx],ymm3,ymm0
add rcx,8*4
add rdx,8*4
sub r8,8
jg LErfProcess1x8
LErfBatchExp:
vzeroupper
vmovaps xmm6,ErfKernelFrame.SavedXmm6[rsp]
vmovaps xmm7,ErfKernelFrame.SavedXmm7[rsp]
vmovaps xmm8,ErfKernelFrame.SavedXmm8[rsp]
vmovaps xmm9,ErfKernelFrame.SavedXmm9[rsp]
vmovaps xmm10,ErfKernelFrame.SavedXmm10[rsp]
vmovaps xmm11,ErfKernelFrame.SavedXmm11[rsp]
vmovaps xmm12,ErfKernelFrame.SavedXmm12[rsp]
vmovaps xmm13,ErfKernelFrame.SavedXmm13[rsp]
vmovaps xmm14,ErfKernelFrame.SavedXmm14[rsp]
vmovaps xmm15,ErfKernelFrame.SavedXmm15[rsp]
add rsp,(ErfKernelFrame.ReturnAddress)
BEGIN_EPILOGUE
ret
NESTED_END MlasErfKernelFma3, _TEXT
END

View file

@ -0,0 +1,271 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
erf.cpp
Abstract:
This module implements routines to compute the hyperbolic tangent function.
This implementation uses the same polynomial coefficients and algorithm as
found in: https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff
Our usage requires building platform specific versions of
the algorithm to target different instruction sets. The implementation below
targets the base instruction set (typically SSE2) while assembly
implementations target newer instruction sets (such as FMA3).
--*/
#include "mlasi.h"
#include <cmath>
//
// Bundles the constants for use by kernels written in assembly.
//
extern "C" const struct {
float ErfUpperAbsRange;
float ErfSplitBoundary;
float ErfSMALL_P0;
float ErfSMALL_P1;
float ErfSMALL_P2;
float ErfSMALL_P3;
float ErfSMALL_P4;
float ErfSMALL_P5_Minus_One;
float ErfReserved0;
float ErfBIG_P0;
float ErfBIG_P1;
float ErfBIG_P2;
float ErfBIG_P3;
float ErfBIG_P4;
float ErfBIG_P5;
float ErfBIG_P6_Minus_One;
float ErfNegZero;
float ErfOne;
float Exp_UpperRange;
float Exp_LowerRange;
float Exp_Log2Reciprocal;
float Exp_log2_hi;
float Exp_log2_lo;
float Exp_P0;
float Exp_P1;
float Exp_P2;
float Exp_P3;
float Exp_P4;
float Exp_P5;
float Exp_P6;
float Exp_C;
int32_t Exp_X7F;
} MlasErfConstants = {
3.725f,
0.921875f,
-5.99104969e-4f,
4.99339588e-3f,
-2.67667342e-2f,
1.12818025e-1f,
-3.76124859e-1f,
1.28379151e-1f,
0.0f,
1.72948930e-5f,
-3.83208680e-4f,
3.88393435e-3f,
-2.42545605e-2f,
1.06777847e-1f,
6.34846687e-1f,
1.28717512e-1f,
-0.0f,
1.0f,
// Independent parameters to calculate Exp for Erff()
88.3762626647950f,
-88.3762626647949f,
1.44269504088896341f,
-6.93145752e-1f,
-1.42860677e-6f,
1.38319808e-3f,
8.37550033e-3f,
4.16689515e-2f,
1.66664466e-1f,
4.99999851e-1f,
1.00000000e+0f,
1.00000000e+0f,
1.25829120e+7f,
127,
};
void
MLASCALL
MlasErfKernel(
const float* Input,
float* Output,
size_t N
)
/*++
Routine Description:
This routine implements the generic kernel for the error function.
Arguments:
Input - Supplies the input buffer.
Output - Supplies the output buffer.
N - Supplies the number of elements to process.
Return Value:
None.
--*/
{
while (N >= 4) {
MLAS_FLOAT32X4 Value = MlasLoadFloat32x4(Input);
MLAS_FLOAT32X4 NegZero = MlasBroadcastFloat32x4(MlasErfConstants.ErfNegZero);
MLAS_FLOAT32X4 SignMask = MlasAndFloat32x4(Value, NegZero);
MLAS_FLOAT32X4 AbsValue = MlasAndNotFloat32x4(NegZero, Value);
AbsValue = MlasMinimumFloat32x4(MlasBroadcastFloat32x4(MlasErfConstants.ErfUpperAbsRange), AbsValue);
MLAS_FLOAT32X4 SquareValue = MlasMultiplyFloat32x4(AbsValue, AbsValue);
MLAS_FLOAT32X4 r_small = MlasBroadcastFloat32x4(MlasErfConstants.ErfSMALL_P0);
r_small = MlasMultiplyAddFloat32x4(r_small, SquareValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfSMALL_P1));
r_small = MlasMultiplyAddFloat32x4(r_small, SquareValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfSMALL_P2));
r_small = MlasMultiplyAddFloat32x4(r_small, SquareValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfSMALL_P3));
r_small = MlasMultiplyAddFloat32x4(r_small, SquareValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfSMALL_P4));
r_small = MlasMultiplyAddFloat32x4(r_small, SquareValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfSMALL_P5_Minus_One));
r_small = MlasMultiplyAddFloat32x4(r_small, AbsValue, AbsValue);
MLAS_FLOAT32X4 split_mask = MlasGreaterThanFloat32x4(AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfSplitBoundary));
r_small = MlasAndNotFloat32x4(split_mask, r_small);
AbsValue = MlasAndFloat32x4(split_mask, AbsValue); // clear smaller value into zero for bigger number calculation
MLAS_FLOAT32X4 r_big = MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P0);
r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P1));
r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P2));
r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P3));
r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P4));
r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P5));
r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P6_Minus_One));
r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, AbsValue);
// 1.0 - exp(-r_big), no need to do min()
r_big = MlasXorFloat32x4(r_big, MlasBroadcastFloat32x4(MlasErfConstants.ErfNegZero)); // -r_big
r_big = MlasMaximumFloat32x4(MlasBroadcastFloat32x4(MlasErfConstants.Exp_LowerRange), r_big);
MLAS_FLOAT32X4 exp_c = MlasBroadcastFloat32x4(MlasErfConstants.Exp_C);
MLAS_FLOAT32X4 r = MlasMultiplyAddFloat32x4(MlasBroadcastFloat32x4(MlasErfConstants.Exp_Log2Reciprocal), r_big, exp_c);
r = MlasSubtractFloat32x4(r, exp_c);
MLAS_FLOAT32X4 fx = MlasMultiplyAddFloat32x4(r, MlasBroadcastFloat32x4(MlasErfConstants.Exp_log2_hi), r_big);
fx = MlasMultiplyAddFloat32x4(r, MlasBroadcastFloat32x4(MlasErfConstants.Exp_log2_lo), fx);
// y = exp(fx)
MLAS_FLOAT32X4 y = MlasBroadcastFloat32x4(MlasErfConstants.Exp_P0);
y = MlasMultiplyAddFloat32x4(y, fx, MlasBroadcastFloat32x4(MlasErfConstants.Exp_P1));
y = MlasMultiplyAddFloat32x4(y, fx, MlasBroadcastFloat32x4(MlasErfConstants.Exp_P2));
y = MlasMultiplyAddFloat32x4(y, fx, MlasBroadcastFloat32x4(MlasErfConstants.Exp_P3));
y = MlasMultiplyAddFloat32x4(y, fx, MlasBroadcastFloat32x4(MlasErfConstants.Exp_P4));
y = MlasMultiplyAddFloat32x4(y, fx, MlasBroadcastFloat32x4(MlasErfConstants.Exp_P5));
y = MlasMultiplyAddFloat32x4(y, fx, MlasBroadcastFloat32x4(MlasErfConstants.Exp_P6));
// 1.0 - exp(fx) * 2^INT(r)
y = MlasMultiplyFloat32x4(y, MlasPowerOf2Float32x4(r));
y = MlasSubtractFloat32x4(MlasBroadcastFloat32x4(MlasErfConstants.ErfOne), y);
// merge two splits results
y = MlasOrFloat32x4(r_small, y);
y = MlasOrFloat32x4(y, SignMask);
MlasStoreFloat32x4(Output, y);
Input += 4;
Output += 4;
N -= 4;
}
while (N > 0) {
float Value = *Input++;
float AbsValue = fabsf(Value);
float r;
if (AbsValue > MlasErfConstants.ErfSplitBoundary) {
AbsValue = (std::min)(MlasErfConstants.ErfUpperAbsRange, AbsValue);
float r_big = MlasErfConstants.ErfBIG_P0;
r_big = r_big * AbsValue + MlasErfConstants.ErfBIG_P1;
r_big = r_big * AbsValue + MlasErfConstants.ErfBIG_P2;
r_big = r_big * AbsValue + MlasErfConstants.ErfBIG_P3;
r_big = r_big * AbsValue + MlasErfConstants.ErfBIG_P4;
r_big = r_big * AbsValue + MlasErfConstants.ErfBIG_P5;
r_big = r_big * AbsValue + MlasErfConstants.ErfBIG_P6_Minus_One;
r_big = r_big * AbsValue + AbsValue;
r_big = (std::max)(-r_big, MlasErfConstants.Exp_LowerRange);
r = MlasErfConstants.Exp_Log2Reciprocal * r_big + MlasErfConstants.Exp_C;
r -= MlasErfConstants.Exp_C;
float fx = r * MlasErfConstants.Exp_log2_hi + r_big;
fx = r * MlasErfConstants.Exp_log2_lo + fx;
float y = MlasErfConstants.Exp_P0;
y = y * fx + MlasErfConstants.Exp_P1;
y = y * fx + MlasErfConstants.Exp_P2;
y = y * fx + MlasErfConstants.Exp_P3;
y = y * fx + MlasErfConstants.Exp_P4;
y = y * fx + MlasErfConstants.Exp_P5;
y = y * fx + MlasErfConstants.Exp_P6;
r = 1.0f - ldexpf(y, (int)r);
r = (Value <= -0.0f) ? -r : r;
}
else {
float SquareValue = AbsValue * AbsValue;
r = MlasErfConstants.ErfSMALL_P0;
r = r * SquareValue + MlasErfConstants.ErfSMALL_P1;
r = r * SquareValue + MlasErfConstants.ErfSMALL_P2;
r = r * SquareValue + MlasErfConstants.ErfSMALL_P3;
r = r * SquareValue + MlasErfConstants.ErfSMALL_P4;
r = r * SquareValue + MlasErfConstants.ErfSMALL_P5_Minus_One;
r = r * Value + Value;
}
*Output++ = r;
N -= 1;
}
}
void
MLASCALL
MlasComputeErf(
const float* Input,
float* Output,
size_t N
)
/*++
Routine Description:
This routine computes the error function.
Arguments:
Input - Supplies the input buffer.
Output - Supplies the output buffer.
N - Supplies the number of elements to process.
Return Value:
None.
--*/
{
#if defined(MLAS_TARGET_AMD64)
MlasPlatform.ErfKernelRoutine(Input, Output, N);
#else
MlasErfKernel(Input, Output, N);
#endif
}

View file

@ -174,6 +174,16 @@ void
typedef MLAS_TANH_KERNEL_ROUTINE* PMLAS_TANH_KERNEL_ROUTINE;
typedef
void
(MLASCALL MLAS_ERF_KERNEL_ROUTINE)(
const float* Input,
float* Output,
size_t N
);
typedef MLAS_ERF_KERNEL_ROUTINE* PMLAS_ERF_KERNEL_ROUTINE;
extern "C" {
MLAS_SGEMM_KERNEL_ROUTINE MlasSgemmKernelZero;
@ -203,9 +213,11 @@ extern "C" {
MLAS_TANH_KERNEL_ROUTINE MlasLogisticKernel;
MLAS_TANH_KERNEL_ROUTINE MlasTanhKernel;
MLAS_ERF_KERNEL_ROUTINE MlasErfKernel;
#if defined(MLAS_TARGET_AMD64)
MLAS_TANH_KERNEL_ROUTINE MlasLogisticKernelFma3;
MLAS_TANH_KERNEL_ROUTINE MlasTanhKernelFma3;
MLAS_ERF_KERNEL_ROUTINE MlasErfKernelFma3;
#endif
}
@ -269,6 +281,7 @@ struct MLAS_PLATFORM {
PMLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE TransposePackB16x4Routine;
PMLAS_LOGISTIC_KERNEL_ROUTINE LogisticKernelRoutine;
PMLAS_TANH_KERNEL_ROUTINE TanhKernelRoutine;
PMLAS_ERF_KERNEL_ROUTINE ErfKernelRoutine;
#endif
#if defined(MLAS_USE_WIN32_THREADPOOL)
@ -574,6 +587,75 @@ MlasMinimumFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
#endif
}
inline
MLAS_FLOAT32X4
MlasGreaterThanFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
{
#if defined(MLAS_NEON_INTRINSICS)
return vreinterpretq_f32_u32(vcgtq_f32(Vector1, Vector2));
#elif defined(MLAS_SSE2_INTRINSICS)
return _mm_cmpgt_ps(Vector1, Vector2);
#endif
}
inline
MLAS_FLOAT32X4
MlasAndFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
{
#if defined(MLAS_NEON_INTRINSICS)
return vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(Vector1), vreinterpretq_u32_f32(Vector2)));
#elif defined(MLAS_SSE2_INTRINSICS)
return _mm_and_ps(Vector1, Vector2);
#endif
}
inline
MLAS_FLOAT32X4
MlasOrFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
{
#if defined(MLAS_NEON_INTRINSICS)
return vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(Vector1), vreinterpretq_u32_f32(Vector2)));
#elif defined(MLAS_SSE2_INTRINSICS)
return _mm_or_ps(Vector1, Vector2);
#endif
}
inline
MLAS_FLOAT32X4
MlasAndNotFloat32x4(MLAS_FLOAT32X4 VectorNot, MLAS_FLOAT32X4 Vector)
{
#if defined(MLAS_NEON_INTRINSICS)
return vreinterpretq_f32_u32(vandq_u32(vmvnq_u32(vreinterpretq_u32_f32(VectorNot)), vreinterpretq_u32_f32(Vector)));
#elif defined(MLAS_SSE2_INTRINSICS)
return _mm_andnot_ps(VectorNot, Vector);
#endif
}
inline
MLAS_FLOAT32X4
MlasXorFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
{
#if defined(MLAS_NEON_INTRINSICS)
return vreinterpretq_f32_u32(veorq_u32(vreinterpretq_u32_f32(Vector1), vreinterpretq_u32_f32(Vector2)));
#elif defined(MLAS_SSE2_INTRINSICS)
return _mm_xor_ps(Vector1, Vector2);
#endif
}
// calc 2^int(N)
inline
MLAS_FLOAT32X4
MlasPowerOf2Float32x4(MLAS_FLOAT32X4 Vector)
{
#if defined(MLAS_NEON_INTRINSICS)
int32x4_t emm0 = vaddq_s32(vcvtq_s32_f32(Vector), vdupq_n_s32(0x7f));
return vreinterpretq_f32_s32(vshlq_n_s32(emm0, 23));
#elif defined(MLAS_SSE2_INTRINSICS)
__m128i emm0 = _mm_add_epi32(_mm_cvttps_epi32(Vector), _mm_set1_epi32(0x7f));
return _mm_castsi128_ps(_mm_slli_epi32(emm0, 23));
#endif
}
//
// Reads a platform specific time stamp counter.
//

View file

@ -90,6 +90,7 @@ Return Value:
this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4Sse;
this->LogisticKernelRoutine = MlasLogisticKernel;
this->TanhKernelRoutine = MlasTanhKernel;
this->ErfKernelRoutine = MlasErfKernel;
#endif
//
@ -144,6 +145,7 @@ Return Value:
this->LogisticKernelRoutine = MlasLogisticKernelFma3;
this->TanhKernelRoutine = MlasTanhKernelFma3;
this->ErfKernelRoutine = MlasErfKernelFma3;
} else {

View file

@ -0,0 +1,517 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
ErfKernelFma3.s
Abstract:
This module implements a kernel for computing the error function for a
buffer of elements.
This implementation uses AVX fused multiply/add instructions.
--*/
#include "asmmacro.h"
.intel_syntax noprefix
.text
//
// Structure layout for the erf constants block.
//
.equ ErfUpperAbsRange, 0
.equ ErfSplitBoundary, 4
.equ ErfSMALL_P0, 8
.equ ErfSMALL_P1, 12
.equ ErfSMALL_P2, 16
.equ ErfSMALL_P3, 20
.equ ErfSMALL_P4, 24
.equ ErfSMALL_P5_Minus_One, 28
.equ ErfReserve0, 32
.equ ErfBIG_P0, 36
.equ ErfBIG_P1, 40
.equ ErfBIG_P2, 44
.equ ErfBIG_P3, 48
.equ ErfBIG_P4, 52
.equ ErfBIG_P5, 56
.equ ErfBIG_P6_Minus_One, 60
.equ ErfNegZero, 64
.equ ErfOne, 68
.equ ExpConstOffset, 72
.equ Exp_UpperRange, 0 + ExpConstOffset
.equ Exp_LowerRange, 4 + ExpConstOffset
.equ Exp_Log2Reciprocal, 8 + ExpConstOffset
.equ Exp_log2_hi, 12 + ExpConstOffset
.equ Exp_log2_lo, 16 + ExpConstOffset
.equ Exp_P0, 20 + ExpConstOffset
.equ Exp_P1, 24 + ExpConstOffset
.equ Exp_P2, 28 + ExpConstOffset
.equ Exp_P3, 32 + ExpConstOffset
.equ Exp_P4, 36 + ExpConstOffset
.equ Exp_P5, 40 + ExpConstOffset
.equ Exp_P6, 44 + ExpConstOffset
.equ Exp_C, 48 + ExpConstOffset
.equ Exp_X7F, 52 + ExpConstOffset
//
// Stack frame layout for the erf kernel.
//
.equ ErfBuffer0, 0
.equ ErfBuffer1, 128
.equ ErfKernelFrame_CountN, 256
.equ ErfKernelFrame_ReturnAddress, 256+8
/*++
Routine Description:
This routine implements a vectorized kernel for the error function.
Arguments:
Input (rdi) - Supplies the input buffer.
Output (rsi) - Supplies the output buffer.
N (rdx) - Supplies the number of elements to process.
Return Value:
None.
--*/
.globl C_UNDERSCORE(MlasErfKernelFma3)
C_UNDERSCORE(MlasErfKernelFma3):
sub rsp,ErfKernelFrame_ReturnAddress
mov rax,C_UNDERSCORE(MlasErfConstants)@GOTPCREL[rip]
sub rdx,8*4
jb .LErfProcessRemainingCount
.LComputeErf4x8Loop:
vbroadcastss ymm15,ErfNegZero[rax]
vmovups ymm0,YMMWORD PTR [rdi] # original input vx0
vmovups ymm1,YMMWORD PTR [rdi+32] # original input vx1
vmovups ymm2,YMMWORD PTR [rdi+64] # original input vx2
vmovups ymm3,YMMWORD PTR [rdi+96] # original input vx3
vandps ymm4,ymm0,ymm15 # vsign0
vandps ymm5,ymm1,ymm15 # vsign1
vandps ymm6,ymm2,ymm15 # vsign2
vandps ymm7,ymm3,ymm15 # vsign3
vandnps ymm0,ymm15,ymm0 # abs(vx0) va0
vandnps ymm1,ymm15,ymm1 # abs(vx1) va1
vandnps ymm2,ymm15,ymm2 # abs(vx2) va2
vandnps ymm3,ymm15,ymm3 # abs(vx3) va3
vbroadcastss ymm14,ErfUpperAbsRange[rax]
vmovups YMMWORD PTR ErfBuffer0[rsp],ymm4
vmovups YMMWORD PTR ErfBuffer0[rsp+32],ymm5
vmovups YMMWORD PTR ErfBuffer0[rsp+64],ymm6
vmovups YMMWORD PTR ErfBuffer0[rsp+96],ymm7
vbroadcastss ymm8,ErfSMALL_P0[rax]
vminps ymm0,ymm0,ymm14 # force abs value in range
vminps ymm1,ymm1,ymm14
vminps ymm2,ymm2,ymm14
vminps ymm3,ymm3,ymm14
vmovaps ymm9,ymm8
vmovaps ymm10,ymm8
vmovaps ymm11,ymm8
vbroadcastss ymm15,ErfSMALL_P1[rax]
vmulps ymm4,ymm0,ymm0 # vs0 (square)
vmulps ymm5,ymm1,ymm1 # vs1
vmulps ymm6,ymm2,ymm2 # vs2
vmulps ymm7,ymm3,ymm3 # vs3
vbroadcastss ymm14,ErfSMALL_P2[rax]
vfmadd213ps ymm8,ymm4,ymm15
vfmadd213ps ymm9,ymm5,ymm15
vfmadd213ps ymm10,ymm6,ymm15
vfmadd213ps ymm11,ymm7,ymm15
vbroadcastss ymm13,ErfSMALL_P3[rax]
vfmadd213ps ymm8,ymm4,ymm14
vfmadd213ps ymm9,ymm5,ymm14
vfmadd213ps ymm10,ymm6,ymm14
vfmadd213ps ymm11,ymm7,ymm14
vbroadcastss ymm15,ErfSMALL_P4[rax]
vfmadd213ps ymm8,ymm4,ymm13
vfmadd213ps ymm9,ymm5,ymm13
vfmadd213ps ymm10,ymm6,ymm13
vfmadd213ps ymm11,ymm7,ymm13
vbroadcastss ymm14,ErfSMALL_P5_Minus_One[rax]
vfmadd213ps ymm8,ymm4,ymm15
vfmadd213ps ymm9,ymm5,ymm15
vfmadd213ps ymm10,ymm6,ymm15
vfmadd213ps ymm11,ymm7,ymm15
vfmadd213ps ymm8,ymm4,ymm14
vfmadd213ps ymm9,ymm5,ymm14
vfmadd213ps ymm10,ymm6,ymm14
vfmadd213ps ymm11,ymm7,ymm14
vbroadcastss ymm12,ErfSplitBoundary[rax]
vfmadd213ps ymm8,ymm0,ymm0
vfmadd213ps ymm9,ymm1,ymm1
vfmadd213ps ymm10,ymm2,ymm2
vfmadd213ps ymm11,ymm3,ymm3
vcmpgtps ymm4,ymm0,ymm12 # vmask0
vcmpgtps ymm5,ymm1,ymm12 # vmask1
vcmpgtps ymm6,ymm2,ymm12 # vmask2
vcmpgtps ymm7,ymm3,ymm12 # vmask3
vandnps ymm8,ymm4,ymm8
vandnps ymm9,ymm5,ymm9
vandnps ymm10,ymm6,ymm10
vandnps ymm11,ymm7,ymm11
vbroadcastss ymm15,ErfBIG_P1[rax]
vmovups YMMWORD PTR ErfBuffer1[rsp],ymm8
vmovups YMMWORD PTR ErfBuffer1[rsp+32],ymm9
vmovups YMMWORD PTR ErfBuffer1[rsp+64],ymm10
vmovups YMMWORD PTR ErfBuffer1[rsp+96],ymm11
.BiggerNumbers:
vbroadcastss ymm8,ErfBIG_P0[rax]
vandps ymm0,ymm4,ymm0
vandps ymm1,ymm5,ymm1
vandps ymm2,ymm6,ymm2
vandps ymm3,ymm7,ymm3
vmovaps ymm9,ymm8
vmovaps ymm10,ymm8
vmovaps ymm11,ymm8
vbroadcastss ymm14,ErfBIG_P2[rax]
vfmadd213ps ymm8,ymm0,ymm15
vfmadd213ps ymm9,ymm1,ymm15
vfmadd213ps ymm10,ymm2,ymm15
vfmadd213ps ymm11,ymm3,ymm15
vbroadcastss ymm13,ErfBIG_P3[rax]
vfmadd213ps ymm8,ymm0,ymm14
vfmadd213ps ymm9,ymm1,ymm14
vfmadd213ps ymm10,ymm2,ymm14
vfmadd213ps ymm11,ymm3,ymm14
vbroadcastss ymm15,ErfBIG_P4[rax]
vfmadd213ps ymm8,ymm0,ymm13
vfmadd213ps ymm9,ymm1,ymm13
vfmadd213ps ymm10,ymm2,ymm13
vfmadd213ps ymm11,ymm3,ymm13
vbroadcastss ymm14,ErfBIG_P5[rax]
vfmadd213ps ymm8,ymm0,ymm15
vfmadd213ps ymm9,ymm1,ymm15
vfmadd213ps ymm10,ymm2,ymm15
vfmadd213ps ymm11,ymm3,ymm15
vbroadcastss ymm13,ErfBIG_P6_Minus_One[rax]
vfmadd213ps ymm8,ymm0,ymm14
vfmadd213ps ymm9,ymm1,ymm14
vfmadd213ps ymm10,ymm2,ymm14
vfmadd213ps ymm11,ymm3,ymm14
vbroadcastss ymm15,ErfNegZero[rax]
vfmadd213ps ymm8,ymm0,ymm13
vfmadd213ps ymm9,ymm1,ymm13
vfmadd213ps ymm10,ymm2,ymm13
vfmadd213ps ymm11,ymm3,ymm13
vbroadcastss ymm14,Exp_LowerRange[rax]
vfmadd213ps ymm8,ymm0,ymm0
vfmadd213ps ymm9,ymm1,ymm1
vfmadd213ps ymm10,ymm2,ymm2
vfmadd213ps ymm11,ymm3,ymm3
vbroadcastss ymm4,Exp_Log2Reciprocal[rax]
vxorps ymm8,ymm8,ymm15
vxorps ymm9,ymm9,ymm15
vxorps ymm10,ymm10,ymm15
vxorps ymm11,ymm11,ymm15
vbroadcastss ymm13,Exp_C[rax]
vmovaps ymm5,ymm4
vmovaps ymm6,ymm4
vmovaps ymm7,ymm4
# expf(ymm8 -- ymm11)
vmaxps ymm8,ymm8,ymm14
vmaxps ymm9,ymm9,ymm14
vmaxps ymm10,ymm10,ymm14
vmaxps ymm11,ymm11,ymm14
vbroadcastss ymm0,Exp_log2_hi[rax]
vfmadd213ps ymm4,ymm8,ymm13
vfmadd213ps ymm5,ymm9,ymm13
vfmadd213ps ymm6,ymm10,ymm13
vfmadd213ps ymm7,ymm11,ymm13
vbroadcastss ymm15,Exp_log2_lo[rax]
vmovaps ymm1,ymm0
vmovaps ymm2,ymm0
vmovaps ymm3,ymm0
vsubps ymm4,ymm4,ymm13 # vr = round()
vsubps ymm5,ymm5,ymm13
vsubps ymm6,ymm6,ymm13
vsubps ymm7,ymm7,ymm13
vfmadd213ps ymm0,ymm4,ymm8 # vf = vr * log2_hi + ve
vfmadd213ps ymm1,ymm5,ymm9
vfmadd213ps ymm2,ymm6,ymm10
vfmadd213ps ymm3,ymm7,ymm11
vbroadcastss ymm8,Exp_P0[rax]
vfmadd231ps ymm0,ymm4,ymm15 # vf += vr * log_2_lo
vfmadd231ps ymm1,ymm5,ymm15
vfmadd231ps ymm2,ymm6,ymm15
vfmadd231ps ymm3,ymm7,ymm15
vmovaps ymm9,ymm8
vmovaps ymm10,ymm8
vmovaps ymm11,ymm8
vbroadcastss ymm14,Exp_P1[rax]
vbroadcastss ymm13,Exp_P2[rax]
vfmadd213ps ymm8,ymm0,ymm14 # *+ exp_p1
vfmadd213ps ymm9,ymm1,ymm14
vfmadd213ps ymm10,ymm2,ymm14
vfmadd213ps ymm11,ymm3,ymm14
vbroadcastss ymm12,Exp_P3[rax]
vfmadd213ps ymm8,ymm0,ymm13 # *+ exp_p2
vfmadd213ps ymm9,ymm1,ymm13
vfmadd213ps ymm10,ymm2,ymm13
vfmadd213ps ymm11,ymm3,ymm13
vbroadcastss ymm15,Exp_P4[rax]
vfmadd213ps ymm8,ymm0,ymm12 # *+ exp_p3
vfmadd213ps ymm9,ymm1,ymm12
vfmadd213ps ymm10,ymm2,ymm12
vfmadd213ps ymm11,ymm3,ymm12
vbroadcastss ymm14,Exp_P5[rax]
vfmadd213ps ymm8,ymm0,ymm15 # *+ exp_p4
vfmadd213ps ymm9,ymm1,ymm15
vfmadd213ps ymm10,ymm2,ymm15
vfmadd213ps ymm11,ymm3,ymm15
vbroadcastss ymm13,Exp_P6[rax]
vfmadd213ps ymm8,ymm0,ymm14 # *+ exp_p5
vfmadd213ps ymm9,ymm1,ymm14
vfmadd213ps ymm10,ymm2,ymm14
vfmadd213ps ymm11,ymm3,ymm14
vbroadcastss ymm12,Exp_X7F[rax]
vfmadd213ps ymm8,ymm0,ymm13 # *+ exp_p6
vfmadd213ps ymm9,ymm1,ymm13
vfmadd213ps ymm10,ymm2,ymm13
vfmadd213ps ymm11,ymm3,ymm13
vcvttps2dq ymm4,ymm4
vcvttps2dq ymm5,ymm5
vcvttps2dq ymm6,ymm6
vcvttps2dq ymm7,ymm7
vbroadcastss ymm15,ErfOne[rax]
vpaddd ymm4,ymm4,ymm12 # +127
vpaddd ymm5,ymm5,ymm12
vpaddd ymm6,ymm6,ymm12
vpaddd ymm7,ymm7,ymm12
vpslld ymm4,ymm4,23
vpslld ymm5,ymm5,23
vpslld ymm6,ymm6,23
vpslld ymm7,ymm7,23
vmulps ymm8,ymm8,ymm4 # 2^i * exp(vf)
vmulps ymm9,ymm9,ymm5
vmulps ymm10,ymm10,ymm6
vmulps ymm11,ymm11,ymm7
vsubps ymm8,ymm15,ymm8
vsubps ymm9,ymm15,ymm9
vsubps ymm10,ymm15,ymm10
vsubps ymm11,ymm15,ymm11
# merge small numbers' result
vorps ymm8,ymm8,YMMWORD PTR ErfBuffer1[rsp]
vorps ymm9,ymm9,YMMWORD PTR ErfBuffer1[rsp+32]
vorps ymm10,ymm10,YMMWORD PTR ErfBuffer1[rsp+64]
vorps ymm11,ymm11,YMMWORD PTR ErfBuffer1[rsp+96]
# copy sign
vorps ymm0,ymm8,YMMWORD PTR ErfBuffer0[rsp]
vorps ymm1,ymm9,YMMWORD PTR ErfBuffer0[rsp+32]
vorps ymm2,ymm10,YMMWORD PTR ErfBuffer0[rsp+64]
vorps ymm3,ymm11,YMMWORD PTR ErfBuffer0[rsp+96]
vmovups YMMWORD PTR [rsi],ymm0
vmovups YMMWORD PTR [rsi+32],ymm1
vmovups YMMWORD PTR [rsi+64],ymm2
vmovups YMMWORD PTR [rsi+96],ymm3
add rdi,32*4 # advance by 4*8 elements
add rsi,32*4
sub rdx,32
jae .LComputeErf4x8Loop
.LErfProcessRemainingCount:
add rdx,32 # correct for over-subtract above
jz .LErfBatchExp
.LErfProcess1x8:
mov DWORD PTR ErfKernelFrame_CountN[rsp],edx
mov rcx,QWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)@GOTPCREL[rip]
vbroadcastss ymm3,DWORD PTR ErfKernelFrame_CountN[rsp]
vpcmpgtd ymm3,ymm3,YMMWORD PTR [rcx]
vbroadcastss ymm15,ErfNegZero[rax]
vmaskmovps ymm0,ymm3,YMMWORD PTR [rdi] # original input vx0
vandps ymm4,ymm0,ymm15 # vsign0
vandnps ymm0,ymm15,ymm0 # abs(vx0) va0
vbroadcastss ymm14,ErfUpperAbsRange[rax]
vmovups YMMWORD PTR ErfBuffer0[rsp],ymm4
vbroadcastss ymm8,ErfSMALL_P0[rax]
vminps ymm0,ymm0,ymm14 # force abs value in range
vbroadcastss ymm15,ErfSMALL_P1[rax]
vmulps ymm4,ymm0,ymm0 # vs0 (square)
vbroadcastss ymm14,ErfSMALL_P2[rax]
vfmadd213ps ymm8,ymm4,ymm15
vbroadcastss ymm13,ErfSMALL_P3[rax]
vfmadd213ps ymm8,ymm4,ymm14
vbroadcastss ymm15,ErfSMALL_P4[rax]
vfmadd213ps ymm8,ymm4,ymm13
vbroadcastss ymm14,ErfSMALL_P5_Minus_One[rax]
vfmadd213ps ymm8,ymm4,ymm15
vfmadd213ps ymm8,ymm4,ymm14
vbroadcastss ymm12,ErfSplitBoundary[rax]
vfmadd213ps ymm8,ymm0,ymm0
vcmpgtps ymm4,ymm0,ymm12 # vmask0
vandnps ymm8,ymm4,ymm8
vmovups YMMWORD PTR ErfBuffer1[rsp],ymm8
.BiggerNumbersRemaining:
vbroadcastss ymm15,ErfBIG_P1[rax]
vbroadcastss ymm8,ErfBIG_P0[rax]
vandps ymm0,ymm4,ymm0
vbroadcastss ymm14,ErfBIG_P2[rax]
vfmadd213ps ymm8,ymm0,ymm15
vbroadcastss ymm13,ErfBIG_P3[rax]
vfmadd213ps ymm8,ymm0,ymm14
vbroadcastss ymm15,ErfBIG_P4[rax]
vfmadd213ps ymm8,ymm0,ymm13
vbroadcastss ymm14,ErfBIG_P5[rax]
vfmadd213ps ymm8,ymm0,ymm15
vbroadcastss ymm13,ErfBIG_P6_Minus_One[rax]
vfmadd213ps ymm8,ymm0,ymm14
vbroadcastss ymm15,ErfNegZero[rax]
vfmadd213ps ymm8,ymm0,ymm13
vbroadcastss ymm14,Exp_LowerRange[rax]
vfmadd213ps ymm8,ymm0,ymm0
vbroadcastss ymm4,Exp_Log2Reciprocal[rax]
vxorps ymm8,ymm8,ymm15
vbroadcastss ymm13,Exp_C[rax]
# expf(ymm8 -- ymm11)
vmaxps ymm8,ymm8,ymm14
vbroadcastss ymm0,Exp_log2_hi[rax]
vfmadd213ps ymm4,ymm8,ymm13
vbroadcastss ymm15,Exp_log2_lo[rax]
vsubps ymm4,ymm4,ymm13 # vr = round()
vfmadd213ps ymm0,ymm4,ymm8 # vf = vr * log2_hi + ve
vbroadcastss ymm8,Exp_P0[rax]
vfmadd231ps ymm0,ymm4,ymm15 # vf += vr * log_2_lo
vbroadcastss ymm14,Exp_P1[rax]
vbroadcastss ymm13,Exp_P2[rax]
vfmadd213ps ymm8,ymm0,ymm14 # *+ exp_p1
vbroadcastss ymm12,Exp_P3[rax]
vfmadd213ps ymm8,ymm0,ymm13 # *+ exp_p2
vbroadcastss ymm15,Exp_P4[rax]
vfmadd213ps ymm8,ymm0,ymm12 # *+ exp_p3
vbroadcastss ymm14,Exp_P5[rax]
vfmadd213ps ymm8,ymm0,ymm15 # *+ exp_p4
vbroadcastss ymm13,Exp_P6[rax]
vfmadd213ps ymm8,ymm0,ymm14 # *+ exp_p5
vbroadcastss ymm12,Exp_X7F[rax]
vfmadd213ps ymm8,ymm0,ymm13 # *+ exp_p6
vcvttps2dq ymm4,ymm4
vbroadcastss ymm15,ErfOne[rax]
vpaddd ymm4,ymm4,ymm12 # +127
vpslld ymm4,ymm4,23
vmulps ymm8,ymm8,ymm4 # 2^i * exp(vf)
vsubps ymm8,ymm15,ymm8
# merge small numbers' result
vorps ymm8,ymm8,YMMWORD PTR ErfBuffer1[rsp]
# copy sign
vorps ymm0,ymm8,YMMWORD PTR ErfBuffer0[rsp]
vmaskmovps YMMWORD PTR [rsi],ymm3,ymm0
add rdi,8*4
add rsi,8*4
sub rdx,8
jg .LErfProcess1x8
.LErfBatchExp:
vzeroupper
add rsp,ErfKernelFrame_ReturnAddress
ret
.end

View file

@ -4,6 +4,7 @@
#include "core/providers/cpu/math/element_wise_ops.h"
#include <unsupported/Eigen/SpecialFunctions>
#include "core/util/math.h"
#include "core/mlas/inc/mlas.h"
#include <cmath>
@ -1032,7 +1033,8 @@ Status Erf<float>::Compute(OpKernelContext* context) const {
ORT_ENFORCE(X_ptr != nullptr);
auto& X = *X_ptr;
auto& Y = *context->Output(0, X.Shape());
EigenMap<float>(Y) = EigenMap<float>(X).array().erf();
MlasComputeErf(X.template Data<float>(), Y.template MutableData<float>(), X.Shape().Size());
return Status::OK();
}

View file

@ -1025,6 +1025,35 @@ TEST(MathOpTest, Erf) {
test.Run();
}
TEST(MathOpTest, ErfMoreData) {
OpTester test("Erf", 9);
std::vector<float> inputs{
-3.625f, 3.375f, 0.0f, 0.00025f, 0.0005f, -0.00075f, -0.001f, 0.00125f,
0.0015f, -3.125f, 0.00175f, 2.875f, 2.625f, 2.375f, 2.125f, 6.25e-05f,
0.0003125f, 0.0005625f, -0.0008125f, 0.0010625f, 0.0013125f, 0.0015625f, 0.0018125f, 3.5625f,
3.3125f, 3.0625f, 2.8125f, -2.5625f, 2.3125f, 2.0625f, 0.000125f, 0.000375f,
-0.000625f, -0.000875f, -0.001125f, -0.001375f, -0.001625f, -0.001875f, -3.5f, -3.25f,
3.0f, 2.75f, -2.5f, -2.25f, -2.0f, -0.0001875f, 0.0004375f, 0.0006875f,
2.1875f, -1.9375f, 0.0014375f, -0.0016875f, -0.0019375f, 3.4375f, 3.1875f, -2.9375f,
-2.4375f, -0.0009375f, 0.0011875f
};
std::vector<float> outputs{
-1.0f, 0.999998f, 0.0f, 0.000282095f, 0.00056419f, -0.000846284f, -0.00112838f, 0.00141047f,
0.00169257f, -0.99999f, 0.00197466f, 0.999952f, 0.999795f, 0.999217f, 0.997346f, 7.05237e-05f,
0.000352618f, 0.000634713f, -0.000916808f, 0.0011989f, 0.001481f, 0.00176309f, 0.00204518f, 1.0f,
0.999997f, 0.999985f, 0.99993f, -0.99971f, 0.998926f, 0.996464f, 0.000141047f, 0.000423142f,
-0.000705237f, -0.000987331f, -0.00126943f, -0.00155152f, -0.00183361f, -0.00211571f, -0.999999f, -0.999996f,
0.999978f, 0.999899f, -0.999593f, -0.998537f, -0.995322f, -0.000211571f, 0.000493666f, 0.000775761f,
0.998022f, -0.993857f, 0.00162204f, -0.00190414f, -0.00218623f, 0.999999f, 0.999993f, -0.999967f,
-0.999433f, -0.00105786f, 0.00133995f
};
std::vector<int64_t> dims{static_cast<int64_t>(inputs.size())};
test.AddInput<float>("A", dims, inputs);
test.AddOutput<float>("B", dims, outputs);
test.Run();
}
const int ModOp_ver = 10;
TEST(ModOpTest, Fmod_float_mixed_sign) {