mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-24 02:47:54 +00:00
Zhalei/erff (#846)
Implement error function in mlas with avx2 optimization.
This commit is contained in:
parent
7e88ca19ee
commit
468de7c8af
9 changed files with 1484 additions and 1 deletions
|
|
@ -10,6 +10,7 @@ set(mlas_common_srcs
|
|||
${ONNXRUNTIME_ROOT}/core/mlas/lib/activate.cpp
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/logistic.cpp
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/tanh.cpp
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/erf.cpp
|
||||
)
|
||||
|
||||
if(MSVC)
|
||||
|
|
@ -65,6 +66,7 @@ if(MSVC)
|
|||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/cvtfp16a.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/LogisticKernelFma3.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/TanhKernelFma3.asm
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/ErfKernelFma3.asm
|
||||
)
|
||||
|
||||
endif()
|
||||
|
|
@ -157,6 +159,7 @@ else()
|
|||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/SgemmKernelFma3.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/LogisticKernelFma3.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/TanhKernelFma3.S
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/ErfKernelFma3.S
|
||||
)
|
||||
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma")
|
||||
|
||||
|
|
|
|||
|
|
@ -226,6 +226,14 @@ MlasComputeTanh(
|
|||
size_t N
|
||||
);
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasComputeErf(
|
||||
const float* Input,
|
||||
float* Output,
|
||||
size_t N
|
||||
);
|
||||
|
||||
//
|
||||
// Half-precision floating-point routines.
|
||||
//
|
||||
|
|
|
|||
569
onnxruntime/core/mlas/lib/amd64/ErfKernelFma3.asm
Normal file
569
onnxruntime/core/mlas/lib/amd64/ErfKernelFma3.asm
Normal file
|
|
@ -0,0 +1,569 @@
|
|||
;++
|
||||
;
|
||||
; Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
;
|
||||
; Licensed under the MIT License.
|
||||
;
|
||||
; Module Name:
|
||||
;
|
||||
; ErfKernelFma3.asm
|
||||
;
|
||||
; Abstract:
|
||||
;
|
||||
; This module implements a kernel for computing the error function for a
|
||||
; buffer of elements.
|
||||
;
|
||||
; This implementation uses AVX fused multiply/add instructions.
|
||||
;
|
||||
;--
|
||||
|
||||
.xlist
|
||||
INCLUDE mlasi.inc
|
||||
.list
|
||||
|
||||
EXTERN MlasMaskMoveAvx:NEAR
|
||||
EXTERN MlasErfConstants:NEAR
|
||||
|
||||
;
|
||||
; Structure layout for the erf constants block.
|
||||
;
|
||||
|
||||
ErfConstants STRUCT
|
||||
|
||||
ErfUpperAbsRange DWORD ?
|
||||
ErfSplitBoundary DWORD ?
|
||||
ErfSMALL_P0 DWORD ?
|
||||
ErfSMALL_P1 DWORD ?
|
||||
ErfSMALL_P2 DWORD ?
|
||||
ErfSMALL_P3 DWORD ?
|
||||
ErfSMALL_P4 DWORD ?
|
||||
ErfSMALL_P5_Minus_One DWORD ?
|
||||
ErfReserve0 DWORD ?
|
||||
ErfBIG_P0 DWORD ?
|
||||
ErfBIG_P1 DWORD ?
|
||||
ErfBIG_P2 DWORD ?
|
||||
ErfBIG_P3 DWORD ?
|
||||
ErfBIG_P4 DWORD ?
|
||||
ErfBIG_P5 DWORD ?
|
||||
ErfBIG_P6_Minus_One DWORD ?
|
||||
ErfNegZero DWORD ?
|
||||
ErfOne DWORD ?
|
||||
|
||||
Exp_UpperRange DWORD ?
|
||||
Exp_LowerRange DWORD ?
|
||||
Exp_Log2Reciprocal DWORD ?
|
||||
Exp_log2_hi DWORD ?
|
||||
Exp_log2_lo DWORD ?
|
||||
Exp_P0 DWORD ?
|
||||
Exp_P1 DWORD ?
|
||||
Exp_P2 DWORD ?
|
||||
Exp_P3 DWORD ?
|
||||
Exp_P4 DWORD ?
|
||||
Exp_P5 DWORD ?
|
||||
Exp_P6 DWORD ?
|
||||
Exp_C DWORD ?
|
||||
Exp_X7F DWORD ?
|
||||
|
||||
ErfConstants ENDS
|
||||
|
||||
;
|
||||
; Stack frame layout for the erf kernel.
|
||||
;
|
||||
|
||||
ErfKernelFrame STRUCT
|
||||
|
||||
ErfBuffer0 OWORD 8 DUP(?)
|
||||
ErfBuffer1 OWORD 8 DUP(?)
|
||||
SavedXmm6 OWORD ?
|
||||
SavedXmm7 OWORD ?
|
||||
SavedXmm8 OWORD ?
|
||||
SavedXmm9 OWORD ?
|
||||
SavedXmm10 OWORD ?
|
||||
SavedXmm11 OWORD ?
|
||||
SavedXmm12 OWORD ?
|
||||
SavedXmm13 OWORD ?
|
||||
SavedXmm14 OWORD ?
|
||||
SavedXmm15 OWORD ?
|
||||
Padding0 QWORD ?
|
||||
Padding1 QWORD ?
|
||||
CountN QWORD ?
|
||||
ReturnAddress QWORD ?
|
||||
PreviousP1Home QWORD ?
|
||||
PreviousP2Home QWORD ?
|
||||
PreviousP3Home QWORD ?
|
||||
PreviousP4Home QWORD ?
|
||||
|
||||
ErfKernelFrame ENDS
|
||||
|
||||
;++
|
||||
;
|
||||
; Routine Description:
|
||||
;
|
||||
; This routine implements a vectorized kernel for the error function.
|
||||
;
|
||||
; Arguments:
|
||||
;
|
||||
; Input (rcx) - Supplies the input buffer.
|
||||
;
|
||||
; Output (rdx) - Supplies the output buffer.
|
||||
;
|
||||
; N (r8) - Supplies the number of elements to process.
|
||||
;
|
||||
; Return Value:
|
||||
;
|
||||
; None.
|
||||
;
|
||||
;--
|
||||
|
||||
NESTED_ENTRY MlasErfKernelFma3, _TEXT
|
||||
|
||||
alloc_stack (ErfKernelFrame.ReturnAddress)
|
||||
|
||||
save_xmm128_avx xmm6,ErfKernelFrame.SavedXmm6
|
||||
save_xmm128_avx xmm7,ErfKernelFrame.SavedXmm7
|
||||
save_xmm128_avx xmm8,ErfKernelFrame.SavedXmm8
|
||||
save_xmm128_avx xmm9,ErfKernelFrame.SavedXmm9
|
||||
save_xmm128_avx xmm10,ErfKernelFrame.SavedXmm10
|
||||
save_xmm128_avx xmm11,ErfKernelFrame.SavedXmm11
|
||||
save_xmm128_avx xmm12,ErfKernelFrame.SavedXmm12
|
||||
save_xmm128_avx xmm13,ErfKernelFrame.SavedXmm13
|
||||
save_xmm128_avx xmm14,ErfKernelFrame.SavedXmm14
|
||||
save_xmm128_avx xmm15,ErfKernelFrame.SavedXmm15
|
||||
|
||||
END_PROLOGUE
|
||||
|
||||
lea rax,MlasErfConstants
|
||||
sub r8,8*4
|
||||
jb LErfProcessRemainingCount
|
||||
|
||||
LComputeErf4x8Loop:
|
||||
vbroadcastss ymm15,ErfConstants.ErfNegZero[rax]
|
||||
vmovups ymm0,YMMWORD PTR [rcx] ; original input vx0
|
||||
vmovups ymm1,YMMWORD PTR [rcx+32] ; original input vx1
|
||||
vmovups ymm2,YMMWORD PTR [rcx+64] ; original input vx2
|
||||
vmovups ymm3,YMMWORD PTR [rcx+96] ; original input vx3
|
||||
|
||||
vandps ymm4,ymm0,ymm15 ; vsign0
|
||||
vandps ymm5,ymm1,ymm15 ; vsign1
|
||||
vandps ymm6,ymm2,ymm15 ; vsign2
|
||||
vandps ymm7,ymm3,ymm15 ; vsign3
|
||||
vandnps ymm0,ymm15,ymm0 ; abs(vx0) va0
|
||||
vandnps ymm1,ymm15,ymm1 ; abs(vx1) va1
|
||||
vandnps ymm2,ymm15,ymm2 ; abs(vx2) va2
|
||||
vandnps ymm3,ymm15,ymm3 ; abs(vx3) va3
|
||||
|
||||
vbroadcastss ymm14,ErfConstants.ErfUpperAbsRange[rax]
|
||||
vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp],ymm4
|
||||
vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp+32],ymm5
|
||||
vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp+64],ymm6
|
||||
vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp+96],ymm7
|
||||
|
||||
vbroadcastss ymm8,ErfConstants.ErfSMALL_P0[rax]
|
||||
vminps ymm0,ymm0,ymm14 ; force abs value in range
|
||||
vminps ymm1,ymm1,ymm14
|
||||
vminps ymm2,ymm2,ymm14
|
||||
vminps ymm3,ymm3,ymm14
|
||||
vmovaps ymm9,ymm8
|
||||
vmovaps ymm10,ymm8
|
||||
vmovaps ymm11,ymm8
|
||||
|
||||
vbroadcastss ymm15,ErfConstants.ErfSMALL_P1[rax]
|
||||
vmulps ymm4,ymm0,ymm0 ; vs0 (square)
|
||||
vmulps ymm5,ymm1,ymm1 ; vs1
|
||||
vmulps ymm6,ymm2,ymm2 ; vs2
|
||||
vmulps ymm7,ymm3,ymm3 ; vs3
|
||||
|
||||
vbroadcastss ymm14,ErfConstants.ErfSMALL_P2[rax]
|
||||
vfmadd213ps ymm8,ymm4,ymm15
|
||||
vfmadd213ps ymm9,ymm5,ymm15
|
||||
vfmadd213ps ymm10,ymm6,ymm15
|
||||
vfmadd213ps ymm11,ymm7,ymm15
|
||||
|
||||
vbroadcastss ymm13,ErfConstants.ErfSMALL_P3[rax]
|
||||
vfmadd213ps ymm8,ymm4,ymm14
|
||||
vfmadd213ps ymm9,ymm5,ymm14
|
||||
vfmadd213ps ymm10,ymm6,ymm14
|
||||
vfmadd213ps ymm11,ymm7,ymm14
|
||||
|
||||
vbroadcastss ymm15,ErfConstants.ErfSMALL_P4[rax]
|
||||
vfmadd213ps ymm8,ymm4,ymm13
|
||||
vfmadd213ps ymm9,ymm5,ymm13
|
||||
vfmadd213ps ymm10,ymm6,ymm13
|
||||
vfmadd213ps ymm11,ymm7,ymm13
|
||||
|
||||
vbroadcastss ymm14,ErfConstants.ErfSMALL_P5_Minus_One[rax]
|
||||
vfmadd213ps ymm8,ymm4,ymm15
|
||||
vfmadd213ps ymm9,ymm5,ymm15
|
||||
vfmadd213ps ymm10,ymm6,ymm15
|
||||
vfmadd213ps ymm11,ymm7,ymm15
|
||||
|
||||
vfmadd213ps ymm8,ymm4,ymm14
|
||||
vfmadd213ps ymm9,ymm5,ymm14
|
||||
vfmadd213ps ymm10,ymm6,ymm14
|
||||
vfmadd213ps ymm11,ymm7,ymm14
|
||||
|
||||
vbroadcastss ymm12,ErfConstants.ErfSplitBoundary[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm0
|
||||
vfmadd213ps ymm9,ymm1,ymm1
|
||||
vfmadd213ps ymm10,ymm2,ymm2
|
||||
vfmadd213ps ymm11,ymm3,ymm3
|
||||
|
||||
vcmpgtps ymm4,ymm0,ymm12 ; vmask0
|
||||
vcmpgtps ymm5,ymm1,ymm12 ; vmask1
|
||||
vcmpgtps ymm6,ymm2,ymm12 ; vmask2
|
||||
vcmpgtps ymm7,ymm3,ymm12 ; vmask3
|
||||
|
||||
vandnps ymm8,ymm4,ymm8
|
||||
vandnps ymm9,ymm5,ymm9
|
||||
vandnps ymm10,ymm6,ymm10
|
||||
vandnps ymm11,ymm7,ymm11
|
||||
|
||||
vbroadcastss ymm15,ErfConstants.ErfBIG_P1[rax]
|
||||
vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp],ymm8
|
||||
vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp+32],ymm9
|
||||
vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp+64],ymm10
|
||||
vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp+96],ymm11
|
||||
|
||||
LBiggerNumbers:
|
||||
vbroadcastss ymm8,ErfConstants.ErfBIG_P0[rax]
|
||||
vandps ymm0,ymm4,ymm0
|
||||
vandps ymm1,ymm5,ymm1
|
||||
vandps ymm2,ymm6,ymm2
|
||||
vandps ymm3,ymm7,ymm3
|
||||
vmovaps ymm9,ymm8
|
||||
vmovaps ymm10,ymm8
|
||||
vmovaps ymm11,ymm8
|
||||
|
||||
vbroadcastss ymm14,ErfConstants.ErfBIG_P2[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm15
|
||||
vfmadd213ps ymm9,ymm1,ymm15
|
||||
vfmadd213ps ymm10,ymm2,ymm15
|
||||
vfmadd213ps ymm11,ymm3,ymm15
|
||||
|
||||
vbroadcastss ymm13,ErfConstants.ErfBIG_P3[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm14
|
||||
vfmadd213ps ymm9,ymm1,ymm14
|
||||
vfmadd213ps ymm10,ymm2,ymm14
|
||||
vfmadd213ps ymm11,ymm3,ymm14
|
||||
|
||||
vbroadcastss ymm15,ErfConstants.ErfBIG_P4[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm13
|
||||
vfmadd213ps ymm9,ymm1,ymm13
|
||||
vfmadd213ps ymm10,ymm2,ymm13
|
||||
vfmadd213ps ymm11,ymm3,ymm13
|
||||
|
||||
vbroadcastss ymm14,ErfConstants.ErfBIG_P5[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm15
|
||||
vfmadd213ps ymm9,ymm1,ymm15
|
||||
vfmadd213ps ymm10,ymm2,ymm15
|
||||
vfmadd213ps ymm11,ymm3,ymm15
|
||||
|
||||
vbroadcastss ymm13,ErfConstants.ErfBIG_P6_Minus_One[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm14
|
||||
vfmadd213ps ymm9,ymm1,ymm14
|
||||
vfmadd213ps ymm10,ymm2,ymm14
|
||||
vfmadd213ps ymm11,ymm3,ymm14
|
||||
|
||||
vbroadcastss ymm15,ErfConstants.ErfNegZero[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm13
|
||||
vfmadd213ps ymm9,ymm1,ymm13
|
||||
vfmadd213ps ymm10,ymm2,ymm13
|
||||
vfmadd213ps ymm11,ymm3,ymm13
|
||||
|
||||
vbroadcastss ymm14,ErfConstants.Exp_LowerRange[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm0
|
||||
vfmadd213ps ymm9,ymm1,ymm1
|
||||
vfmadd213ps ymm10,ymm2,ymm2
|
||||
vfmadd213ps ymm11,ymm3,ymm3
|
||||
|
||||
vbroadcastss ymm4,ErfConstants.Exp_Log2Reciprocal[rax]
|
||||
vxorps ymm8,ymm8,ymm15
|
||||
vxorps ymm9,ymm9,ymm15
|
||||
vxorps ymm10,ymm10,ymm15
|
||||
vxorps ymm11,ymm11,ymm15
|
||||
|
||||
vbroadcastss ymm13,ErfConstants.Exp_C[rax]
|
||||
vmovaps ymm5,ymm4
|
||||
vmovaps ymm6,ymm4
|
||||
vmovaps ymm7,ymm4
|
||||
|
||||
; expf(ymm8 -- ymm11)
|
||||
vmaxps ymm8,ymm8,ymm14
|
||||
vmaxps ymm9,ymm9,ymm14
|
||||
vmaxps ymm10,ymm10,ymm14
|
||||
vmaxps ymm11,ymm11,ymm14
|
||||
|
||||
vbroadcastss ymm0,ErfConstants.Exp_log2_hi[rax]
|
||||
vfmadd213ps ymm4,ymm8,ymm13
|
||||
vfmadd213ps ymm5,ymm9,ymm13
|
||||
vfmadd213ps ymm6,ymm10,ymm13
|
||||
vfmadd213ps ymm7,ymm11,ymm13
|
||||
|
||||
vbroadcastss ymm15,ErfConstants.Exp_log2_lo[rax]
|
||||
vmovaps ymm1,ymm0
|
||||
vmovaps ymm2,ymm0
|
||||
vmovaps ymm3,ymm0
|
||||
|
||||
vsubps ymm4,ymm4,ymm13 ; vr = round()
|
||||
vsubps ymm5,ymm5,ymm13
|
||||
vsubps ymm6,ymm6,ymm13
|
||||
vsubps ymm7,ymm7,ymm13
|
||||
|
||||
vfmadd213ps ymm0,ymm4,ymm8 ; vf = vr * log2_hi + ve
|
||||
vfmadd213ps ymm1,ymm5,ymm9
|
||||
vfmadd213ps ymm2,ymm6,ymm10
|
||||
vfmadd213ps ymm3,ymm7,ymm11
|
||||
|
||||
vbroadcastss ymm8,ErfConstants.Exp_P0[rax]
|
||||
vfmadd231ps ymm0,ymm4,ymm15 ; vf += vr * log_2_lo
|
||||
vfmadd231ps ymm1,ymm5,ymm15
|
||||
vfmadd231ps ymm2,ymm6,ymm15
|
||||
vfmadd231ps ymm3,ymm7,ymm15
|
||||
vmovaps ymm9,ymm8
|
||||
vmovaps ymm10,ymm8
|
||||
vmovaps ymm11,ymm8
|
||||
|
||||
vbroadcastss ymm14,ErfConstants.Exp_P1[rax]
|
||||
vbroadcastss ymm13,ErfConstants.Exp_P2[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm14 ; *+ exp_p1
|
||||
vfmadd213ps ymm9,ymm1,ymm14
|
||||
vfmadd213ps ymm10,ymm2,ymm14
|
||||
vfmadd213ps ymm11,ymm3,ymm14
|
||||
|
||||
vbroadcastss ymm12,ErfConstants.Exp_P3[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm13 ; *+ exp_p2
|
||||
vfmadd213ps ymm9,ymm1,ymm13
|
||||
vfmadd213ps ymm10,ymm2,ymm13
|
||||
vfmadd213ps ymm11,ymm3,ymm13
|
||||
|
||||
vbroadcastss ymm15,ErfConstants.Exp_P4[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm12 ; *+ exp_p3
|
||||
vfmadd213ps ymm9,ymm1,ymm12
|
||||
vfmadd213ps ymm10,ymm2,ymm12
|
||||
vfmadd213ps ymm11,ymm3,ymm12
|
||||
|
||||
vbroadcastss ymm14,ErfConstants.Exp_P5[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm15 ; *+ exp_p4
|
||||
vfmadd213ps ymm9,ymm1,ymm15
|
||||
vfmadd213ps ymm10,ymm2,ymm15
|
||||
vfmadd213ps ymm11,ymm3,ymm15
|
||||
|
||||
vbroadcastss ymm13,ErfConstants.Exp_P6[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm14 ; *+ exp_p5
|
||||
vfmadd213ps ymm9,ymm1,ymm14
|
||||
vfmadd213ps ymm10,ymm2,ymm14
|
||||
vfmadd213ps ymm11,ymm3,ymm14
|
||||
|
||||
vbroadcastss ymm12,ErfConstants.Exp_X7F[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm13 ; *+ exp_p6
|
||||
vfmadd213ps ymm9,ymm1,ymm13
|
||||
vfmadd213ps ymm10,ymm2,ymm13
|
||||
vfmadd213ps ymm11,ymm3,ymm13
|
||||
|
||||
vcvttps2dq ymm4,ymm4
|
||||
vcvttps2dq ymm5,ymm5
|
||||
vcvttps2dq ymm6,ymm6
|
||||
vcvttps2dq ymm7,ymm7
|
||||
|
||||
vbroadcastss ymm15,ErfConstants.ErfOne[rax]
|
||||
vpaddd ymm4,ymm4,ymm12 ; +127
|
||||
vpaddd ymm5,ymm5,ymm12
|
||||
vpaddd ymm6,ymm6,ymm12
|
||||
vpaddd ymm7,ymm7,ymm12
|
||||
|
||||
vpslld ymm4,ymm4,23
|
||||
vpslld ymm5,ymm5,23
|
||||
vpslld ymm6,ymm6,23
|
||||
vpslld ymm7,ymm7,23
|
||||
|
||||
vmulps ymm8,ymm8,ymm4 ; 2^i * exp(vf)
|
||||
vmulps ymm9,ymm9,ymm5
|
||||
vmulps ymm10,ymm10,ymm6
|
||||
vmulps ymm11,ymm11,ymm7
|
||||
|
||||
vsubps ymm8,ymm15,ymm8
|
||||
vsubps ymm9,ymm15,ymm9
|
||||
vsubps ymm10,ymm15,ymm10
|
||||
vsubps ymm11,ymm15,ymm11
|
||||
|
||||
; merge small numbers' result
|
||||
vorps ymm8,ymm8,YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp]
|
||||
vorps ymm9,ymm9,YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp+32]
|
||||
vorps ymm10,ymm10,YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp+64]
|
||||
vorps ymm11,ymm11,YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp+96]
|
||||
|
||||
; copy sign
|
||||
vorps ymm0,ymm8,YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp]
|
||||
vorps ymm1,ymm9,YMMWORD PTR 32+ErfKernelFrame.ErfBuffer0[rsp]
|
||||
vorps ymm2,ymm10,YMMWORD PTR 64+ErfKernelFrame.ErfBuffer0[rsp]
|
||||
vorps ymm3,ymm11,YMMWORD PTR 96+ErfKernelFrame.ErfBuffer0[rsp]
|
||||
|
||||
vmovups YMMWORD PTR [rdx],ymm0
|
||||
vmovups YMMWORD PTR [rdx+32],ymm1
|
||||
vmovups YMMWORD PTR [rdx+64],ymm2
|
||||
vmovups YMMWORD PTR [rdx+96],ymm3
|
||||
|
||||
add rcx,32*4 ; advance by 4*8 elements
|
||||
add rdx,32*4
|
||||
sub r8,32
|
||||
jae LComputeErf4x8Loop
|
||||
|
||||
LErfProcessRemainingCount:
|
||||
add r8,32 ; correct for over-subtract above
|
||||
jz LErfBatchExp
|
||||
|
||||
LErfProcess1x8:
|
||||
mov DWORD PTR ErfKernelFrame.CountN[rsp],r8d
|
||||
vbroadcastss ymm3,DWORD PTR ErfKernelFrame.CountN[rsp]
|
||||
|
||||
vpcmpgtd ymm3,ymm3,YMMWORD PTR [MlasMaskMoveAvx]
|
||||
vbroadcastss ymm15,ErfConstants.ErfNegZero[rax]
|
||||
vmaskmovps ymm0,ymm3,YMMWORD PTR [rcx] ; original input vx0
|
||||
|
||||
vandps ymm4,ymm0,ymm15 ; vsign0
|
||||
vandnps ymm0,ymm15,ymm0 ; abs(vx0) va0
|
||||
|
||||
vbroadcastss ymm14,ErfConstants.ErfUpperAbsRange[rax]
|
||||
vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp],ymm4
|
||||
|
||||
vbroadcastss ymm8,ErfConstants.ErfSMALL_P0[rax]
|
||||
vminps ymm0,ymm0,ymm14 ; force abs value in range
|
||||
|
||||
vbroadcastss ymm15,ErfConstants.ErfSMALL_P1[rax]
|
||||
vmulps ymm4,ymm0,ymm0 ; vs0 (square)
|
||||
|
||||
vbroadcastss ymm14,ErfConstants.ErfSMALL_P2[rax]
|
||||
vfmadd213ps ymm8,ymm4,ymm15
|
||||
|
||||
vbroadcastss ymm13,ErfConstants.ErfSMALL_P3[rax]
|
||||
vfmadd213ps ymm8,ymm4,ymm14
|
||||
|
||||
vbroadcastss ymm15,ErfConstants.ErfSMALL_P4[rax]
|
||||
vfmadd213ps ymm8,ymm4,ymm13
|
||||
|
||||
vbroadcastss ymm14,ErfConstants.ErfSMALL_P5_Minus_One[rax]
|
||||
vfmadd213ps ymm8,ymm4,ymm15
|
||||
|
||||
vfmadd213ps ymm8,ymm4,ymm14
|
||||
|
||||
vbroadcastss ymm12,ErfConstants.ErfSplitBoundary[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm0
|
||||
|
||||
vcmpgtps ymm4,ymm0,ymm12 ; vmask0
|
||||
|
||||
vandnps ymm8,ymm4,ymm8
|
||||
|
||||
vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp],ymm8
|
||||
|
||||
LBiggerNumbersRemaining:
|
||||
vbroadcastss ymm15,ErfConstants.ErfBIG_P1[rax]
|
||||
vbroadcastss ymm8,ErfConstants.ErfBIG_P0[rax]
|
||||
vandps ymm0,ymm4,ymm0
|
||||
|
||||
vbroadcastss ymm14,ErfConstants.ErfBIG_P2[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm15
|
||||
|
||||
vbroadcastss ymm13,ErfConstants.ErfBIG_P3[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm14
|
||||
|
||||
vbroadcastss ymm15,ErfConstants.ErfBIG_P4[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm13
|
||||
|
||||
vbroadcastss ymm14,ErfConstants.ErfBIG_P5[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm15
|
||||
|
||||
vbroadcastss ymm13,ErfConstants.ErfBIG_P6_Minus_One[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm14
|
||||
|
||||
vbroadcastss ymm15,ErfConstants.ErfNegZero[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm13
|
||||
|
||||
vbroadcastss ymm14,ErfConstants.Exp_LowerRange[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm0
|
||||
|
||||
vbroadcastss ymm4,ErfConstants.Exp_Log2Reciprocal[rax]
|
||||
vxorps ymm8,ymm8,ymm15
|
||||
|
||||
vbroadcastss ymm13,ErfConstants.Exp_C[rax]
|
||||
|
||||
; expf(ymm8 -- ymm11)
|
||||
vmaxps ymm8,ymm8,ymm14
|
||||
|
||||
vbroadcastss ymm0,ErfConstants.Exp_log2_hi[rax]
|
||||
vfmadd213ps ymm4,ymm8,ymm13
|
||||
|
||||
vbroadcastss ymm15,ErfConstants.Exp_log2_lo[rax]
|
||||
|
||||
vsubps ymm4,ymm4,ymm13 ; vr = round()
|
||||
|
||||
vfmadd213ps ymm0,ymm4,ymm8 ; vf = vr * log2_hi + ve
|
||||
|
||||
vbroadcastss ymm8,ErfConstants.Exp_P0[rax]
|
||||
|
||||
vfmadd231ps ymm0,ymm4,ymm15 ; vf += vr * log_2_lo
|
||||
|
||||
vbroadcastss ymm14,ErfConstants.Exp_P1[rax]
|
||||
|
||||
vbroadcastss ymm13,ErfConstants.Exp_P2[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm14 ; *+ exp_p1
|
||||
|
||||
vbroadcastss ymm12,ErfConstants.Exp_P3[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm13 ; *+ exp_p2
|
||||
|
||||
vbroadcastss ymm15,ErfConstants.Exp_P4[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm12 ; *+ exp_p3
|
||||
|
||||
vbroadcastss ymm14,ErfConstants.Exp_P5[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm15 ; *+ exp_p4
|
||||
|
||||
vbroadcastss ymm13,ErfConstants.Exp_P6[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm14 ; *+ exp_p5
|
||||
|
||||
vbroadcastss ymm12,ErfConstants.Exp_X7F[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm13 ; *+ exp_p6
|
||||
|
||||
vcvttps2dq ymm4,ymm4
|
||||
|
||||
vbroadcastss ymm15,ErfConstants.ErfOne[rax]
|
||||
vpaddd ymm4,ymm4,ymm12 ; +127
|
||||
|
||||
vpslld ymm4,ymm4,23
|
||||
|
||||
vmulps ymm8,ymm8,ymm4 ; 2^i * exp(vf)
|
||||
|
||||
vsubps ymm8,ymm15,ymm8
|
||||
|
||||
; merge small numbers' result
|
||||
vorps ymm8,ymm8,YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp]
|
||||
|
||||
; copy sign
|
||||
vorps ymm0,ymm8,YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp]
|
||||
|
||||
vmaskmovps YMMWORD PTR [rdx],ymm3,ymm0
|
||||
|
||||
add rcx,8*4
|
||||
add rdx,8*4
|
||||
sub r8,8
|
||||
jg LErfProcess1x8
|
||||
|
||||
LErfBatchExp:
|
||||
vzeroupper
|
||||
vmovaps xmm6,ErfKernelFrame.SavedXmm6[rsp]
|
||||
vmovaps xmm7,ErfKernelFrame.SavedXmm7[rsp]
|
||||
vmovaps xmm8,ErfKernelFrame.SavedXmm8[rsp]
|
||||
vmovaps xmm9,ErfKernelFrame.SavedXmm9[rsp]
|
||||
vmovaps xmm10,ErfKernelFrame.SavedXmm10[rsp]
|
||||
vmovaps xmm11,ErfKernelFrame.SavedXmm11[rsp]
|
||||
vmovaps xmm12,ErfKernelFrame.SavedXmm12[rsp]
|
||||
vmovaps xmm13,ErfKernelFrame.SavedXmm13[rsp]
|
||||
vmovaps xmm14,ErfKernelFrame.SavedXmm14[rsp]
|
||||
vmovaps xmm15,ErfKernelFrame.SavedXmm15[rsp]
|
||||
add rsp,(ErfKernelFrame.ReturnAddress)
|
||||
|
||||
BEGIN_EPILOGUE
|
||||
|
||||
ret
|
||||
|
||||
NESTED_END MlasErfKernelFma3, _TEXT
|
||||
|
||||
END
|
||||
271
onnxruntime/core/mlas/lib/erf.cpp
Normal file
271
onnxruntime/core/mlas/lib/erf.cpp
Normal file
|
|
@ -0,0 +1,271 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
erf.cpp
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements routines to compute the hyperbolic tangent function.
|
||||
|
||||
This implementation uses the same polynomial coefficients and algorithm as
|
||||
found in: https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff
|
||||
Our usage requires building platform specific versions of
|
||||
the algorithm to target different instruction sets. The implementation below
|
||||
targets the base instruction set (typically SSE2) while assembly
|
||||
implementations target newer instruction sets (such as FMA3).
|
||||
|
||||
--*/
|
||||
|
||||
#include "mlasi.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
//
|
||||
// Bundles the constants for use by kernels written in assembly.
|
||||
//
|
||||
|
||||
extern "C" const struct {
|
||||
float ErfUpperAbsRange;
|
||||
float ErfSplitBoundary;
|
||||
float ErfSMALL_P0;
|
||||
float ErfSMALL_P1;
|
||||
float ErfSMALL_P2;
|
||||
float ErfSMALL_P3;
|
||||
float ErfSMALL_P4;
|
||||
float ErfSMALL_P5_Minus_One;
|
||||
float ErfReserved0;
|
||||
float ErfBIG_P0;
|
||||
float ErfBIG_P1;
|
||||
float ErfBIG_P2;
|
||||
float ErfBIG_P3;
|
||||
float ErfBIG_P4;
|
||||
float ErfBIG_P5;
|
||||
float ErfBIG_P6_Minus_One;
|
||||
float ErfNegZero;
|
||||
float ErfOne;
|
||||
|
||||
float Exp_UpperRange;
|
||||
float Exp_LowerRange;
|
||||
float Exp_Log2Reciprocal;
|
||||
float Exp_log2_hi;
|
||||
float Exp_log2_lo;
|
||||
float Exp_P0;
|
||||
float Exp_P1;
|
||||
float Exp_P2;
|
||||
float Exp_P3;
|
||||
float Exp_P4;
|
||||
float Exp_P5;
|
||||
float Exp_P6;
|
||||
float Exp_C;
|
||||
int32_t Exp_X7F;
|
||||
} MlasErfConstants = {
|
||||
3.725f,
|
||||
0.921875f,
|
||||
-5.99104969e-4f,
|
||||
4.99339588e-3f,
|
||||
-2.67667342e-2f,
|
||||
1.12818025e-1f,
|
||||
-3.76124859e-1f,
|
||||
1.28379151e-1f,
|
||||
0.0f,
|
||||
1.72948930e-5f,
|
||||
-3.83208680e-4f,
|
||||
3.88393435e-3f,
|
||||
-2.42545605e-2f,
|
||||
1.06777847e-1f,
|
||||
6.34846687e-1f,
|
||||
1.28717512e-1f,
|
||||
-0.0f,
|
||||
1.0f,
|
||||
|
||||
// Independent parameters to calculate Exp for Erff()
|
||||
88.3762626647950f,
|
||||
-88.3762626647949f,
|
||||
1.44269504088896341f,
|
||||
-6.93145752e-1f,
|
||||
-1.42860677e-6f,
|
||||
1.38319808e-3f,
|
||||
8.37550033e-3f,
|
||||
4.16689515e-2f,
|
||||
1.66664466e-1f,
|
||||
4.99999851e-1f,
|
||||
1.00000000e+0f,
|
||||
1.00000000e+0f,
|
||||
1.25829120e+7f,
|
||||
127,
|
||||
};
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasErfKernel(
|
||||
const float* Input,
|
||||
float* Output,
|
||||
size_t N
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine implements the generic kernel for the error function.
|
||||
|
||||
Arguments:
|
||||
|
||||
Input - Supplies the input buffer.
|
||||
|
||||
Output - Supplies the output buffer.
|
||||
|
||||
N - Supplies the number of elements to process.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
{
|
||||
while (N >= 4) {
|
||||
MLAS_FLOAT32X4 Value = MlasLoadFloat32x4(Input);
|
||||
MLAS_FLOAT32X4 NegZero = MlasBroadcastFloat32x4(MlasErfConstants.ErfNegZero);
|
||||
MLAS_FLOAT32X4 SignMask = MlasAndFloat32x4(Value, NegZero);
|
||||
MLAS_FLOAT32X4 AbsValue = MlasAndNotFloat32x4(NegZero, Value);
|
||||
AbsValue = MlasMinimumFloat32x4(MlasBroadcastFloat32x4(MlasErfConstants.ErfUpperAbsRange), AbsValue);
|
||||
MLAS_FLOAT32X4 SquareValue = MlasMultiplyFloat32x4(AbsValue, AbsValue);
|
||||
|
||||
MLAS_FLOAT32X4 r_small = MlasBroadcastFloat32x4(MlasErfConstants.ErfSMALL_P0);
|
||||
r_small = MlasMultiplyAddFloat32x4(r_small, SquareValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfSMALL_P1));
|
||||
r_small = MlasMultiplyAddFloat32x4(r_small, SquareValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfSMALL_P2));
|
||||
r_small = MlasMultiplyAddFloat32x4(r_small, SquareValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfSMALL_P3));
|
||||
r_small = MlasMultiplyAddFloat32x4(r_small, SquareValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfSMALL_P4));
|
||||
r_small = MlasMultiplyAddFloat32x4(r_small, SquareValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfSMALL_P5_Minus_One));
|
||||
r_small = MlasMultiplyAddFloat32x4(r_small, AbsValue, AbsValue);
|
||||
MLAS_FLOAT32X4 split_mask = MlasGreaterThanFloat32x4(AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfSplitBoundary));
|
||||
r_small = MlasAndNotFloat32x4(split_mask, r_small);
|
||||
|
||||
AbsValue = MlasAndFloat32x4(split_mask, AbsValue); // clear smaller value into zero for bigger number calculation
|
||||
MLAS_FLOAT32X4 r_big = MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P0);
|
||||
r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P1));
|
||||
r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P2));
|
||||
r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P3));
|
||||
r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P4));
|
||||
r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P5));
|
||||
r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P6_Minus_One));
|
||||
r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, AbsValue);
|
||||
|
||||
// 1.0 - exp(-r_big), no need to do min()
|
||||
r_big = MlasXorFloat32x4(r_big, MlasBroadcastFloat32x4(MlasErfConstants.ErfNegZero)); // -r_big
|
||||
r_big = MlasMaximumFloat32x4(MlasBroadcastFloat32x4(MlasErfConstants.Exp_LowerRange), r_big);
|
||||
MLAS_FLOAT32X4 exp_c = MlasBroadcastFloat32x4(MlasErfConstants.Exp_C);
|
||||
MLAS_FLOAT32X4 r = MlasMultiplyAddFloat32x4(MlasBroadcastFloat32x4(MlasErfConstants.Exp_Log2Reciprocal), r_big, exp_c);
|
||||
r = MlasSubtractFloat32x4(r, exp_c);
|
||||
|
||||
MLAS_FLOAT32X4 fx = MlasMultiplyAddFloat32x4(r, MlasBroadcastFloat32x4(MlasErfConstants.Exp_log2_hi), r_big);
|
||||
fx = MlasMultiplyAddFloat32x4(r, MlasBroadcastFloat32x4(MlasErfConstants.Exp_log2_lo), fx);
|
||||
// y = exp(fx)
|
||||
MLAS_FLOAT32X4 y = MlasBroadcastFloat32x4(MlasErfConstants.Exp_P0);
|
||||
y = MlasMultiplyAddFloat32x4(y, fx, MlasBroadcastFloat32x4(MlasErfConstants.Exp_P1));
|
||||
y = MlasMultiplyAddFloat32x4(y, fx, MlasBroadcastFloat32x4(MlasErfConstants.Exp_P2));
|
||||
y = MlasMultiplyAddFloat32x4(y, fx, MlasBroadcastFloat32x4(MlasErfConstants.Exp_P3));
|
||||
y = MlasMultiplyAddFloat32x4(y, fx, MlasBroadcastFloat32x4(MlasErfConstants.Exp_P4));
|
||||
y = MlasMultiplyAddFloat32x4(y, fx, MlasBroadcastFloat32x4(MlasErfConstants.Exp_P5));
|
||||
y = MlasMultiplyAddFloat32x4(y, fx, MlasBroadcastFloat32x4(MlasErfConstants.Exp_P6));
|
||||
// 1.0 - exp(fx) * 2^INT(r)
|
||||
y = MlasMultiplyFloat32x4(y, MlasPowerOf2Float32x4(r));
|
||||
y = MlasSubtractFloat32x4(MlasBroadcastFloat32x4(MlasErfConstants.ErfOne), y);
|
||||
|
||||
// merge two splits results
|
||||
y = MlasOrFloat32x4(r_small, y);
|
||||
y = MlasOrFloat32x4(y, SignMask);
|
||||
|
||||
MlasStoreFloat32x4(Output, y);
|
||||
|
||||
Input += 4;
|
||||
Output += 4;
|
||||
N -= 4;
|
||||
}
|
||||
|
||||
while (N > 0) {
|
||||
float Value = *Input++;
|
||||
float AbsValue = fabsf(Value);
|
||||
|
||||
float r;
|
||||
if (AbsValue > MlasErfConstants.ErfSplitBoundary) {
|
||||
AbsValue = (std::min)(MlasErfConstants.ErfUpperAbsRange, AbsValue);
|
||||
float r_big = MlasErfConstants.ErfBIG_P0;
|
||||
r_big = r_big * AbsValue + MlasErfConstants.ErfBIG_P1;
|
||||
r_big = r_big * AbsValue + MlasErfConstants.ErfBIG_P2;
|
||||
r_big = r_big * AbsValue + MlasErfConstants.ErfBIG_P3;
|
||||
r_big = r_big * AbsValue + MlasErfConstants.ErfBIG_P4;
|
||||
r_big = r_big * AbsValue + MlasErfConstants.ErfBIG_P5;
|
||||
r_big = r_big * AbsValue + MlasErfConstants.ErfBIG_P6_Minus_One;
|
||||
r_big = r_big * AbsValue + AbsValue;
|
||||
|
||||
r_big = (std::max)(-r_big, MlasErfConstants.Exp_LowerRange);
|
||||
r = MlasErfConstants.Exp_Log2Reciprocal * r_big + MlasErfConstants.Exp_C;
|
||||
r -= MlasErfConstants.Exp_C;
|
||||
float fx = r * MlasErfConstants.Exp_log2_hi + r_big;
|
||||
fx = r * MlasErfConstants.Exp_log2_lo + fx;
|
||||
|
||||
float y = MlasErfConstants.Exp_P0;
|
||||
y = y * fx + MlasErfConstants.Exp_P1;
|
||||
y = y * fx + MlasErfConstants.Exp_P2;
|
||||
y = y * fx + MlasErfConstants.Exp_P3;
|
||||
y = y * fx + MlasErfConstants.Exp_P4;
|
||||
y = y * fx + MlasErfConstants.Exp_P5;
|
||||
y = y * fx + MlasErfConstants.Exp_P6;
|
||||
|
||||
r = 1.0f - ldexpf(y, (int)r);
|
||||
r = (Value <= -0.0f) ? -r : r;
|
||||
}
|
||||
else {
|
||||
float SquareValue = AbsValue * AbsValue;
|
||||
r = MlasErfConstants.ErfSMALL_P0;
|
||||
r = r * SquareValue + MlasErfConstants.ErfSMALL_P1;
|
||||
r = r * SquareValue + MlasErfConstants.ErfSMALL_P2;
|
||||
r = r * SquareValue + MlasErfConstants.ErfSMALL_P3;
|
||||
r = r * SquareValue + MlasErfConstants.ErfSMALL_P4;
|
||||
r = r * SquareValue + MlasErfConstants.ErfSMALL_P5_Minus_One;
|
||||
r = r * Value + Value;
|
||||
}
|
||||
|
||||
*Output++ = r;
|
||||
N -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasComputeErf(
|
||||
const float* Input,
|
||||
float* Output,
|
||||
size_t N
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine computes the error function.
|
||||
|
||||
Arguments:
|
||||
|
||||
Input - Supplies the input buffer.
|
||||
|
||||
Output - Supplies the output buffer.
|
||||
|
||||
N - Supplies the number of elements to process.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
{
|
||||
#if defined(MLAS_TARGET_AMD64)
|
||||
MlasPlatform.ErfKernelRoutine(Input, Output, N);
|
||||
#else
|
||||
MlasErfKernel(Input, Output, N);
|
||||
#endif
|
||||
}
|
||||
|
|
@ -174,6 +174,16 @@ void
|
|||
|
||||
typedef MLAS_TANH_KERNEL_ROUTINE* PMLAS_TANH_KERNEL_ROUTINE;
|
||||
|
||||
typedef
|
||||
void
|
||||
(MLASCALL MLAS_ERF_KERNEL_ROUTINE)(
|
||||
const float* Input,
|
||||
float* Output,
|
||||
size_t N
|
||||
);
|
||||
|
||||
typedef MLAS_ERF_KERNEL_ROUTINE* PMLAS_ERF_KERNEL_ROUTINE;
|
||||
|
||||
extern "C" {
|
||||
|
||||
MLAS_SGEMM_KERNEL_ROUTINE MlasSgemmKernelZero;
|
||||
|
|
@ -203,9 +213,11 @@ extern "C" {
|
|||
|
||||
MLAS_TANH_KERNEL_ROUTINE MlasLogisticKernel;
|
||||
MLAS_TANH_KERNEL_ROUTINE MlasTanhKernel;
|
||||
MLAS_ERF_KERNEL_ROUTINE MlasErfKernel;
|
||||
#if defined(MLAS_TARGET_AMD64)
|
||||
MLAS_TANH_KERNEL_ROUTINE MlasLogisticKernelFma3;
|
||||
MLAS_TANH_KERNEL_ROUTINE MlasTanhKernelFma3;
|
||||
MLAS_ERF_KERNEL_ROUTINE MlasErfKernelFma3;
|
||||
#endif
|
||||
|
||||
}
|
||||
|
|
@ -269,6 +281,7 @@ struct MLAS_PLATFORM {
|
|||
PMLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE TransposePackB16x4Routine;
|
||||
PMLAS_LOGISTIC_KERNEL_ROUTINE LogisticKernelRoutine;
|
||||
PMLAS_TANH_KERNEL_ROUTINE TanhKernelRoutine;
|
||||
PMLAS_ERF_KERNEL_ROUTINE ErfKernelRoutine;
|
||||
#endif
|
||||
|
||||
#if defined(MLAS_USE_WIN32_THREADPOOL)
|
||||
|
|
@ -574,6 +587,75 @@ MlasMinimumFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
|
|||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_FLOAT32X4
|
||||
MlasGreaterThanFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
|
||||
{
|
||||
#if defined(MLAS_NEON_INTRINSICS)
|
||||
return vreinterpretq_f32_u32(vcgtq_f32(Vector1, Vector2));
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_cmpgt_ps(Vector1, Vector2);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_FLOAT32X4
|
||||
MlasAndFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
|
||||
{
|
||||
#if defined(MLAS_NEON_INTRINSICS)
|
||||
return vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(Vector1), vreinterpretq_u32_f32(Vector2)));
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_and_ps(Vector1, Vector2);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_FLOAT32X4
|
||||
MlasOrFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
|
||||
{
|
||||
#if defined(MLAS_NEON_INTRINSICS)
|
||||
return vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(Vector1), vreinterpretq_u32_f32(Vector2)));
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_or_ps(Vector1, Vector2);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_FLOAT32X4
|
||||
MlasAndNotFloat32x4(MLAS_FLOAT32X4 VectorNot, MLAS_FLOAT32X4 Vector)
|
||||
{
|
||||
#if defined(MLAS_NEON_INTRINSICS)
|
||||
return vreinterpretq_f32_u32(vandq_u32(vmvnq_u32(vreinterpretq_u32_f32(VectorNot)), vreinterpretq_u32_f32(Vector)));
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_andnot_ps(VectorNot, Vector);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline
|
||||
MLAS_FLOAT32X4
|
||||
MlasXorFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
|
||||
{
|
||||
#if defined(MLAS_NEON_INTRINSICS)
|
||||
return vreinterpretq_f32_u32(veorq_u32(vreinterpretq_u32_f32(Vector1), vreinterpretq_u32_f32(Vector2)));
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
return _mm_xor_ps(Vector1, Vector2);
|
||||
#endif
|
||||
}
|
||||
|
||||
// calc 2^int(N)
|
||||
inline
|
||||
MLAS_FLOAT32X4
|
||||
MlasPowerOf2Float32x4(MLAS_FLOAT32X4 Vector)
|
||||
{
|
||||
#if defined(MLAS_NEON_INTRINSICS)
|
||||
int32x4_t emm0 = vaddq_s32(vcvtq_s32_f32(Vector), vdupq_n_s32(0x7f));
|
||||
return vreinterpretq_f32_s32(vshlq_n_s32(emm0, 23));
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
__m128i emm0 = _mm_add_epi32(_mm_cvttps_epi32(Vector), _mm_set1_epi32(0x7f));
|
||||
return _mm_castsi128_ps(_mm_slli_epi32(emm0, 23));
|
||||
#endif
|
||||
}
|
||||
|
||||
//
|
||||
// Reads a platform specific time stamp counter.
|
||||
//
|
||||
|
|
|
|||
|
|
@ -90,6 +90,7 @@ Return Value:
|
|||
this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4Sse;
|
||||
this->LogisticKernelRoutine = MlasLogisticKernel;
|
||||
this->TanhKernelRoutine = MlasTanhKernel;
|
||||
this->ErfKernelRoutine = MlasErfKernel;
|
||||
#endif
|
||||
|
||||
//
|
||||
|
|
@ -144,6 +145,7 @@ Return Value:
|
|||
|
||||
this->LogisticKernelRoutine = MlasLogisticKernelFma3;
|
||||
this->TanhKernelRoutine = MlasTanhKernelFma3;
|
||||
this->ErfKernelRoutine = MlasErfKernelFma3;
|
||||
|
||||
} else {
|
||||
|
||||
|
|
|
|||
517
onnxruntime/core/mlas/lib/x86_64/ErfKernelFma3.S
Normal file
517
onnxruntime/core/mlas/lib/x86_64/ErfKernelFma3.S
Normal file
|
|
@ -0,0 +1,517 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
ErfKernelFma3.s
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements a kernel for computing the error function for a
|
||||
buffer of elements.
|
||||
|
||||
This implementation uses AVX fused multiply/add instructions.
|
||||
|
||||
--*/
|
||||
|
||||
#include "asmmacro.h"
|
||||
|
||||
.intel_syntax noprefix
|
||||
|
||||
.text
|
||||
|
||||
//
|
||||
// Structure layout for the erf constants block.
|
||||
//
|
||||
.equ ErfUpperAbsRange, 0
|
||||
.equ ErfSplitBoundary, 4
|
||||
.equ ErfSMALL_P0, 8
|
||||
.equ ErfSMALL_P1, 12
|
||||
.equ ErfSMALL_P2, 16
|
||||
.equ ErfSMALL_P3, 20
|
||||
.equ ErfSMALL_P4, 24
|
||||
.equ ErfSMALL_P5_Minus_One, 28
|
||||
.equ ErfReserve0, 32
|
||||
.equ ErfBIG_P0, 36
|
||||
.equ ErfBIG_P1, 40
|
||||
.equ ErfBIG_P2, 44
|
||||
.equ ErfBIG_P3, 48
|
||||
.equ ErfBIG_P4, 52
|
||||
.equ ErfBIG_P5, 56
|
||||
.equ ErfBIG_P6_Minus_One, 60
|
||||
.equ ErfNegZero, 64
|
||||
.equ ErfOne, 68
|
||||
|
||||
.equ ExpConstOffset, 72
|
||||
.equ Exp_UpperRange, 0 + ExpConstOffset
|
||||
.equ Exp_LowerRange, 4 + ExpConstOffset
|
||||
.equ Exp_Log2Reciprocal, 8 + ExpConstOffset
|
||||
.equ Exp_log2_hi, 12 + ExpConstOffset
|
||||
.equ Exp_log2_lo, 16 + ExpConstOffset
|
||||
.equ Exp_P0, 20 + ExpConstOffset
|
||||
.equ Exp_P1, 24 + ExpConstOffset
|
||||
.equ Exp_P2, 28 + ExpConstOffset
|
||||
.equ Exp_P3, 32 + ExpConstOffset
|
||||
.equ Exp_P4, 36 + ExpConstOffset
|
||||
.equ Exp_P5, 40 + ExpConstOffset
|
||||
.equ Exp_P6, 44 + ExpConstOffset
|
||||
.equ Exp_C, 48 + ExpConstOffset
|
||||
.equ Exp_X7F, 52 + ExpConstOffset
|
||||
|
||||
//
|
||||
// Stack frame layout for the erf kernel.
|
||||
//
|
||||
.equ ErfBuffer0, 0
|
||||
.equ ErfBuffer1, 128
|
||||
.equ ErfKernelFrame_CountN, 256
|
||||
.equ ErfKernelFrame_ReturnAddress, 256+8
|
||||
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine implements a vectorized kernel for the error function.
|
||||
|
||||
Arguments:
|
||||
|
||||
Input (rdi) - Supplies the input buffer.
|
||||
|
||||
Output (rsi) - Supplies the output buffer.
|
||||
|
||||
N (rdx) - Supplies the number of elements to process.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
|
||||
.globl C_UNDERSCORE(MlasErfKernelFma3)
|
||||
C_UNDERSCORE(MlasErfKernelFma3):
|
||||
sub rsp,ErfKernelFrame_ReturnAddress
|
||||
mov rax,C_UNDERSCORE(MlasErfConstants)@GOTPCREL[rip]
|
||||
|
||||
sub rdx,8*4
|
||||
jb .LErfProcessRemainingCount
|
||||
|
||||
.LComputeErf4x8Loop:
|
||||
vbroadcastss ymm15,ErfNegZero[rax]
|
||||
vmovups ymm0,YMMWORD PTR [rdi] # original input vx0
|
||||
vmovups ymm1,YMMWORD PTR [rdi+32] # original input vx1
|
||||
vmovups ymm2,YMMWORD PTR [rdi+64] # original input vx2
|
||||
vmovups ymm3,YMMWORD PTR [rdi+96] # original input vx3
|
||||
|
||||
vandps ymm4,ymm0,ymm15 # vsign0
|
||||
vandps ymm5,ymm1,ymm15 # vsign1
|
||||
vandps ymm6,ymm2,ymm15 # vsign2
|
||||
vandps ymm7,ymm3,ymm15 # vsign3
|
||||
vandnps ymm0,ymm15,ymm0 # abs(vx0) va0
|
||||
vandnps ymm1,ymm15,ymm1 # abs(vx1) va1
|
||||
vandnps ymm2,ymm15,ymm2 # abs(vx2) va2
|
||||
vandnps ymm3,ymm15,ymm3 # abs(vx3) va3
|
||||
|
||||
vbroadcastss ymm14,ErfUpperAbsRange[rax]
|
||||
vmovups YMMWORD PTR ErfBuffer0[rsp],ymm4
|
||||
vmovups YMMWORD PTR ErfBuffer0[rsp+32],ymm5
|
||||
vmovups YMMWORD PTR ErfBuffer0[rsp+64],ymm6
|
||||
vmovups YMMWORD PTR ErfBuffer0[rsp+96],ymm7
|
||||
|
||||
vbroadcastss ymm8,ErfSMALL_P0[rax]
|
||||
vminps ymm0,ymm0,ymm14 # force abs value in range
|
||||
vminps ymm1,ymm1,ymm14
|
||||
vminps ymm2,ymm2,ymm14
|
||||
vminps ymm3,ymm3,ymm14
|
||||
vmovaps ymm9,ymm8
|
||||
vmovaps ymm10,ymm8
|
||||
vmovaps ymm11,ymm8
|
||||
|
||||
vbroadcastss ymm15,ErfSMALL_P1[rax]
|
||||
vmulps ymm4,ymm0,ymm0 # vs0 (square)
|
||||
vmulps ymm5,ymm1,ymm1 # vs1
|
||||
vmulps ymm6,ymm2,ymm2 # vs2
|
||||
vmulps ymm7,ymm3,ymm3 # vs3
|
||||
|
||||
vbroadcastss ymm14,ErfSMALL_P2[rax]
|
||||
vfmadd213ps ymm8,ymm4,ymm15
|
||||
vfmadd213ps ymm9,ymm5,ymm15
|
||||
vfmadd213ps ymm10,ymm6,ymm15
|
||||
vfmadd213ps ymm11,ymm7,ymm15
|
||||
|
||||
vbroadcastss ymm13,ErfSMALL_P3[rax]
|
||||
vfmadd213ps ymm8,ymm4,ymm14
|
||||
vfmadd213ps ymm9,ymm5,ymm14
|
||||
vfmadd213ps ymm10,ymm6,ymm14
|
||||
vfmadd213ps ymm11,ymm7,ymm14
|
||||
|
||||
vbroadcastss ymm15,ErfSMALL_P4[rax]
|
||||
vfmadd213ps ymm8,ymm4,ymm13
|
||||
vfmadd213ps ymm9,ymm5,ymm13
|
||||
vfmadd213ps ymm10,ymm6,ymm13
|
||||
vfmadd213ps ymm11,ymm7,ymm13
|
||||
|
||||
vbroadcastss ymm14,ErfSMALL_P5_Minus_One[rax]
|
||||
vfmadd213ps ymm8,ymm4,ymm15
|
||||
vfmadd213ps ymm9,ymm5,ymm15
|
||||
vfmadd213ps ymm10,ymm6,ymm15
|
||||
vfmadd213ps ymm11,ymm7,ymm15
|
||||
|
||||
vfmadd213ps ymm8,ymm4,ymm14
|
||||
vfmadd213ps ymm9,ymm5,ymm14
|
||||
vfmadd213ps ymm10,ymm6,ymm14
|
||||
vfmadd213ps ymm11,ymm7,ymm14
|
||||
|
||||
vbroadcastss ymm12,ErfSplitBoundary[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm0
|
||||
vfmadd213ps ymm9,ymm1,ymm1
|
||||
vfmadd213ps ymm10,ymm2,ymm2
|
||||
vfmadd213ps ymm11,ymm3,ymm3
|
||||
|
||||
vcmpgtps ymm4,ymm0,ymm12 # vmask0
|
||||
vcmpgtps ymm5,ymm1,ymm12 # vmask1
|
||||
vcmpgtps ymm6,ymm2,ymm12 # vmask2
|
||||
vcmpgtps ymm7,ymm3,ymm12 # vmask3
|
||||
|
||||
vandnps ymm8,ymm4,ymm8
|
||||
vandnps ymm9,ymm5,ymm9
|
||||
vandnps ymm10,ymm6,ymm10
|
||||
vandnps ymm11,ymm7,ymm11
|
||||
|
||||
vbroadcastss ymm15,ErfBIG_P1[rax]
|
||||
vmovups YMMWORD PTR ErfBuffer1[rsp],ymm8
|
||||
vmovups YMMWORD PTR ErfBuffer1[rsp+32],ymm9
|
||||
vmovups YMMWORD PTR ErfBuffer1[rsp+64],ymm10
|
||||
vmovups YMMWORD PTR ErfBuffer1[rsp+96],ymm11
|
||||
|
||||
.BiggerNumbers:
|
||||
vbroadcastss ymm8,ErfBIG_P0[rax]
|
||||
vandps ymm0,ymm4,ymm0
|
||||
vandps ymm1,ymm5,ymm1
|
||||
vandps ymm2,ymm6,ymm2
|
||||
vandps ymm3,ymm7,ymm3
|
||||
vmovaps ymm9,ymm8
|
||||
vmovaps ymm10,ymm8
|
||||
vmovaps ymm11,ymm8
|
||||
|
||||
vbroadcastss ymm14,ErfBIG_P2[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm15
|
||||
vfmadd213ps ymm9,ymm1,ymm15
|
||||
vfmadd213ps ymm10,ymm2,ymm15
|
||||
vfmadd213ps ymm11,ymm3,ymm15
|
||||
|
||||
vbroadcastss ymm13,ErfBIG_P3[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm14
|
||||
vfmadd213ps ymm9,ymm1,ymm14
|
||||
vfmadd213ps ymm10,ymm2,ymm14
|
||||
vfmadd213ps ymm11,ymm3,ymm14
|
||||
|
||||
vbroadcastss ymm15,ErfBIG_P4[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm13
|
||||
vfmadd213ps ymm9,ymm1,ymm13
|
||||
vfmadd213ps ymm10,ymm2,ymm13
|
||||
vfmadd213ps ymm11,ymm3,ymm13
|
||||
|
||||
vbroadcastss ymm14,ErfBIG_P5[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm15
|
||||
vfmadd213ps ymm9,ymm1,ymm15
|
||||
vfmadd213ps ymm10,ymm2,ymm15
|
||||
vfmadd213ps ymm11,ymm3,ymm15
|
||||
|
||||
vbroadcastss ymm13,ErfBIG_P6_Minus_One[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm14
|
||||
vfmadd213ps ymm9,ymm1,ymm14
|
||||
vfmadd213ps ymm10,ymm2,ymm14
|
||||
vfmadd213ps ymm11,ymm3,ymm14
|
||||
|
||||
vbroadcastss ymm15,ErfNegZero[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm13
|
||||
vfmadd213ps ymm9,ymm1,ymm13
|
||||
vfmadd213ps ymm10,ymm2,ymm13
|
||||
vfmadd213ps ymm11,ymm3,ymm13
|
||||
|
||||
vbroadcastss ymm14,Exp_LowerRange[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm0
|
||||
vfmadd213ps ymm9,ymm1,ymm1
|
||||
vfmadd213ps ymm10,ymm2,ymm2
|
||||
vfmadd213ps ymm11,ymm3,ymm3
|
||||
|
||||
vbroadcastss ymm4,Exp_Log2Reciprocal[rax]
|
||||
vxorps ymm8,ymm8,ymm15
|
||||
vxorps ymm9,ymm9,ymm15
|
||||
vxorps ymm10,ymm10,ymm15
|
||||
vxorps ymm11,ymm11,ymm15
|
||||
|
||||
vbroadcastss ymm13,Exp_C[rax]
|
||||
vmovaps ymm5,ymm4
|
||||
vmovaps ymm6,ymm4
|
||||
vmovaps ymm7,ymm4
|
||||
|
||||
# expf(ymm8 -- ymm11)
|
||||
vmaxps ymm8,ymm8,ymm14
|
||||
vmaxps ymm9,ymm9,ymm14
|
||||
vmaxps ymm10,ymm10,ymm14
|
||||
vmaxps ymm11,ymm11,ymm14
|
||||
|
||||
vbroadcastss ymm0,Exp_log2_hi[rax]
|
||||
vfmadd213ps ymm4,ymm8,ymm13
|
||||
vfmadd213ps ymm5,ymm9,ymm13
|
||||
vfmadd213ps ymm6,ymm10,ymm13
|
||||
vfmadd213ps ymm7,ymm11,ymm13
|
||||
|
||||
vbroadcastss ymm15,Exp_log2_lo[rax]
|
||||
vmovaps ymm1,ymm0
|
||||
vmovaps ymm2,ymm0
|
||||
vmovaps ymm3,ymm0
|
||||
|
||||
vsubps ymm4,ymm4,ymm13 # vr = round()
|
||||
vsubps ymm5,ymm5,ymm13
|
||||
vsubps ymm6,ymm6,ymm13
|
||||
vsubps ymm7,ymm7,ymm13
|
||||
|
||||
vfmadd213ps ymm0,ymm4,ymm8 # vf = vr * log2_hi + ve
|
||||
vfmadd213ps ymm1,ymm5,ymm9
|
||||
vfmadd213ps ymm2,ymm6,ymm10
|
||||
vfmadd213ps ymm3,ymm7,ymm11
|
||||
|
||||
vbroadcastss ymm8,Exp_P0[rax]
|
||||
vfmadd231ps ymm0,ymm4,ymm15 # vf += vr * log_2_lo
|
||||
vfmadd231ps ymm1,ymm5,ymm15
|
||||
vfmadd231ps ymm2,ymm6,ymm15
|
||||
vfmadd231ps ymm3,ymm7,ymm15
|
||||
vmovaps ymm9,ymm8
|
||||
vmovaps ymm10,ymm8
|
||||
vmovaps ymm11,ymm8
|
||||
|
||||
vbroadcastss ymm14,Exp_P1[rax]
|
||||
vbroadcastss ymm13,Exp_P2[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm14 # *+ exp_p1
|
||||
vfmadd213ps ymm9,ymm1,ymm14
|
||||
vfmadd213ps ymm10,ymm2,ymm14
|
||||
vfmadd213ps ymm11,ymm3,ymm14
|
||||
|
||||
vbroadcastss ymm12,Exp_P3[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm13 # *+ exp_p2
|
||||
vfmadd213ps ymm9,ymm1,ymm13
|
||||
vfmadd213ps ymm10,ymm2,ymm13
|
||||
vfmadd213ps ymm11,ymm3,ymm13
|
||||
|
||||
vbroadcastss ymm15,Exp_P4[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm12 # *+ exp_p3
|
||||
vfmadd213ps ymm9,ymm1,ymm12
|
||||
vfmadd213ps ymm10,ymm2,ymm12
|
||||
vfmadd213ps ymm11,ymm3,ymm12
|
||||
|
||||
vbroadcastss ymm14,Exp_P5[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm15 # *+ exp_p4
|
||||
vfmadd213ps ymm9,ymm1,ymm15
|
||||
vfmadd213ps ymm10,ymm2,ymm15
|
||||
vfmadd213ps ymm11,ymm3,ymm15
|
||||
|
||||
vbroadcastss ymm13,Exp_P6[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm14 # *+ exp_p5
|
||||
vfmadd213ps ymm9,ymm1,ymm14
|
||||
vfmadd213ps ymm10,ymm2,ymm14
|
||||
vfmadd213ps ymm11,ymm3,ymm14
|
||||
|
||||
vbroadcastss ymm12,Exp_X7F[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm13 # *+ exp_p6
|
||||
vfmadd213ps ymm9,ymm1,ymm13
|
||||
vfmadd213ps ymm10,ymm2,ymm13
|
||||
vfmadd213ps ymm11,ymm3,ymm13
|
||||
|
||||
vcvttps2dq ymm4,ymm4
|
||||
vcvttps2dq ymm5,ymm5
|
||||
vcvttps2dq ymm6,ymm6
|
||||
vcvttps2dq ymm7,ymm7
|
||||
|
||||
|
||||
vbroadcastss ymm15,ErfOne[rax]
|
||||
vpaddd ymm4,ymm4,ymm12 # +127
|
||||
vpaddd ymm5,ymm5,ymm12
|
||||
vpaddd ymm6,ymm6,ymm12
|
||||
vpaddd ymm7,ymm7,ymm12
|
||||
|
||||
vpslld ymm4,ymm4,23
|
||||
vpslld ymm5,ymm5,23
|
||||
vpslld ymm6,ymm6,23
|
||||
vpslld ymm7,ymm7,23
|
||||
|
||||
vmulps ymm8,ymm8,ymm4 # 2^i * exp(vf)
|
||||
vmulps ymm9,ymm9,ymm5
|
||||
vmulps ymm10,ymm10,ymm6
|
||||
vmulps ymm11,ymm11,ymm7
|
||||
|
||||
vsubps ymm8,ymm15,ymm8
|
||||
vsubps ymm9,ymm15,ymm9
|
||||
vsubps ymm10,ymm15,ymm10
|
||||
vsubps ymm11,ymm15,ymm11
|
||||
|
||||
# merge small numbers' result
|
||||
vorps ymm8,ymm8,YMMWORD PTR ErfBuffer1[rsp]
|
||||
vorps ymm9,ymm9,YMMWORD PTR ErfBuffer1[rsp+32]
|
||||
vorps ymm10,ymm10,YMMWORD PTR ErfBuffer1[rsp+64]
|
||||
vorps ymm11,ymm11,YMMWORD PTR ErfBuffer1[rsp+96]
|
||||
|
||||
# copy sign
|
||||
vorps ymm0,ymm8,YMMWORD PTR ErfBuffer0[rsp]
|
||||
vorps ymm1,ymm9,YMMWORD PTR ErfBuffer0[rsp+32]
|
||||
vorps ymm2,ymm10,YMMWORD PTR ErfBuffer0[rsp+64]
|
||||
vorps ymm3,ymm11,YMMWORD PTR ErfBuffer0[rsp+96]
|
||||
|
||||
vmovups YMMWORD PTR [rsi],ymm0
|
||||
vmovups YMMWORD PTR [rsi+32],ymm1
|
||||
vmovups YMMWORD PTR [rsi+64],ymm2
|
||||
vmovups YMMWORD PTR [rsi+96],ymm3
|
||||
|
||||
add rdi,32*4 # advance by 4*8 elements
|
||||
add rsi,32*4
|
||||
sub rdx,32
|
||||
jae .LComputeErf4x8Loop
|
||||
|
||||
.LErfProcessRemainingCount:
|
||||
add rdx,32 # correct for over-subtract above
|
||||
jz .LErfBatchExp
|
||||
|
||||
.LErfProcess1x8:
|
||||
mov DWORD PTR ErfKernelFrame_CountN[rsp],edx
|
||||
mov rcx,QWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)@GOTPCREL[rip]
|
||||
vbroadcastss ymm3,DWORD PTR ErfKernelFrame_CountN[rsp]
|
||||
|
||||
vpcmpgtd ymm3,ymm3,YMMWORD PTR [rcx]
|
||||
vbroadcastss ymm15,ErfNegZero[rax]
|
||||
vmaskmovps ymm0,ymm3,YMMWORD PTR [rdi] # original input vx0
|
||||
|
||||
vandps ymm4,ymm0,ymm15 # vsign0
|
||||
vandnps ymm0,ymm15,ymm0 # abs(vx0) va0
|
||||
|
||||
vbroadcastss ymm14,ErfUpperAbsRange[rax]
|
||||
vmovups YMMWORD PTR ErfBuffer0[rsp],ymm4
|
||||
|
||||
vbroadcastss ymm8,ErfSMALL_P0[rax]
|
||||
vminps ymm0,ymm0,ymm14 # force abs value in range
|
||||
|
||||
vbroadcastss ymm15,ErfSMALL_P1[rax]
|
||||
vmulps ymm4,ymm0,ymm0 # vs0 (square)
|
||||
|
||||
vbroadcastss ymm14,ErfSMALL_P2[rax]
|
||||
vfmadd213ps ymm8,ymm4,ymm15
|
||||
|
||||
vbroadcastss ymm13,ErfSMALL_P3[rax]
|
||||
vfmadd213ps ymm8,ymm4,ymm14
|
||||
|
||||
vbroadcastss ymm15,ErfSMALL_P4[rax]
|
||||
vfmadd213ps ymm8,ymm4,ymm13
|
||||
|
||||
vbroadcastss ymm14,ErfSMALL_P5_Minus_One[rax]
|
||||
vfmadd213ps ymm8,ymm4,ymm15
|
||||
|
||||
vfmadd213ps ymm8,ymm4,ymm14
|
||||
|
||||
vbroadcastss ymm12,ErfSplitBoundary[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm0
|
||||
|
||||
vcmpgtps ymm4,ymm0,ymm12 # vmask0
|
||||
|
||||
vandnps ymm8,ymm4,ymm8
|
||||
|
||||
vmovups YMMWORD PTR ErfBuffer1[rsp],ymm8
|
||||
|
||||
.BiggerNumbersRemaining:
|
||||
vbroadcastss ymm15,ErfBIG_P1[rax]
|
||||
vbroadcastss ymm8,ErfBIG_P0[rax]
|
||||
vandps ymm0,ymm4,ymm0
|
||||
|
||||
vbroadcastss ymm14,ErfBIG_P2[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm15
|
||||
|
||||
vbroadcastss ymm13,ErfBIG_P3[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm14
|
||||
|
||||
vbroadcastss ymm15,ErfBIG_P4[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm13
|
||||
|
||||
vbroadcastss ymm14,ErfBIG_P5[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm15
|
||||
|
||||
vbroadcastss ymm13,ErfBIG_P6_Minus_One[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm14
|
||||
|
||||
vbroadcastss ymm15,ErfNegZero[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm13
|
||||
|
||||
vbroadcastss ymm14,Exp_LowerRange[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm0
|
||||
|
||||
vbroadcastss ymm4,Exp_Log2Reciprocal[rax]
|
||||
vxorps ymm8,ymm8,ymm15
|
||||
|
||||
vbroadcastss ymm13,Exp_C[rax]
|
||||
|
||||
# expf(ymm8 -- ymm11)
|
||||
vmaxps ymm8,ymm8,ymm14
|
||||
|
||||
vbroadcastss ymm0,Exp_log2_hi[rax]
|
||||
vfmadd213ps ymm4,ymm8,ymm13
|
||||
|
||||
vbroadcastss ymm15,Exp_log2_lo[rax]
|
||||
|
||||
vsubps ymm4,ymm4,ymm13 # vr = round()
|
||||
|
||||
vfmadd213ps ymm0,ymm4,ymm8 # vf = vr * log2_hi + ve
|
||||
|
||||
vbroadcastss ymm8,Exp_P0[rax]
|
||||
|
||||
vfmadd231ps ymm0,ymm4,ymm15 # vf += vr * log_2_lo
|
||||
|
||||
vbroadcastss ymm14,Exp_P1[rax]
|
||||
|
||||
vbroadcastss ymm13,Exp_P2[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm14 # *+ exp_p1
|
||||
|
||||
vbroadcastss ymm12,Exp_P3[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm13 # *+ exp_p2
|
||||
|
||||
vbroadcastss ymm15,Exp_P4[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm12 # *+ exp_p3
|
||||
|
||||
vbroadcastss ymm14,Exp_P5[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm15 # *+ exp_p4
|
||||
|
||||
vbroadcastss ymm13,Exp_P6[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm14 # *+ exp_p5
|
||||
|
||||
vbroadcastss ymm12,Exp_X7F[rax]
|
||||
vfmadd213ps ymm8,ymm0,ymm13 # *+ exp_p6
|
||||
|
||||
vcvttps2dq ymm4,ymm4
|
||||
|
||||
vbroadcastss ymm15,ErfOne[rax]
|
||||
vpaddd ymm4,ymm4,ymm12 # +127
|
||||
|
||||
vpslld ymm4,ymm4,23
|
||||
|
||||
vmulps ymm8,ymm8,ymm4 # 2^i * exp(vf)
|
||||
|
||||
vsubps ymm8,ymm15,ymm8
|
||||
|
||||
# merge small numbers' result
|
||||
vorps ymm8,ymm8,YMMWORD PTR ErfBuffer1[rsp]
|
||||
|
||||
# copy sign
|
||||
vorps ymm0,ymm8,YMMWORD PTR ErfBuffer0[rsp]
|
||||
|
||||
vmaskmovps YMMWORD PTR [rsi],ymm3,ymm0
|
||||
|
||||
add rdi,8*4
|
||||
add rsi,8*4
|
||||
sub rdx,8
|
||||
jg .LErfProcess1x8
|
||||
|
||||
.LErfBatchExp:
|
||||
vzeroupper
|
||||
add rsp,ErfKernelFrame_ReturnAddress
|
||||
ret
|
||||
|
||||
.end
|
||||
|
|
@ -4,6 +4,7 @@
|
|||
#include "core/providers/cpu/math/element_wise_ops.h"
|
||||
#include <unsupported/Eigen/SpecialFunctions>
|
||||
#include "core/util/math.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
|
|
@ -1032,7 +1033,8 @@ Status Erf<float>::Compute(OpKernelContext* context) const {
|
|||
ORT_ENFORCE(X_ptr != nullptr);
|
||||
auto& X = *X_ptr;
|
||||
auto& Y = *context->Output(0, X.Shape());
|
||||
EigenMap<float>(Y) = EigenMap<float>(X).array().erf();
|
||||
|
||||
MlasComputeErf(X.template Data<float>(), Y.template MutableData<float>(), X.Shape().Size());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1025,6 +1025,35 @@ TEST(MathOpTest, Erf) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, ErfMoreData) {
|
||||
OpTester test("Erf", 9);
|
||||
std::vector<float> inputs{
|
||||
-3.625f, 3.375f, 0.0f, 0.00025f, 0.0005f, -0.00075f, -0.001f, 0.00125f,
|
||||
0.0015f, -3.125f, 0.00175f, 2.875f, 2.625f, 2.375f, 2.125f, 6.25e-05f,
|
||||
0.0003125f, 0.0005625f, -0.0008125f, 0.0010625f, 0.0013125f, 0.0015625f, 0.0018125f, 3.5625f,
|
||||
3.3125f, 3.0625f, 2.8125f, -2.5625f, 2.3125f, 2.0625f, 0.000125f, 0.000375f,
|
||||
-0.000625f, -0.000875f, -0.001125f, -0.001375f, -0.001625f, -0.001875f, -3.5f, -3.25f,
|
||||
3.0f, 2.75f, -2.5f, -2.25f, -2.0f, -0.0001875f, 0.0004375f, 0.0006875f,
|
||||
2.1875f, -1.9375f, 0.0014375f, -0.0016875f, -0.0019375f, 3.4375f, 3.1875f, -2.9375f,
|
||||
-2.4375f, -0.0009375f, 0.0011875f
|
||||
};
|
||||
std::vector<float> outputs{
|
||||
-1.0f, 0.999998f, 0.0f, 0.000282095f, 0.00056419f, -0.000846284f, -0.00112838f, 0.00141047f,
|
||||
0.00169257f, -0.99999f, 0.00197466f, 0.999952f, 0.999795f, 0.999217f, 0.997346f, 7.05237e-05f,
|
||||
0.000352618f, 0.000634713f, -0.000916808f, 0.0011989f, 0.001481f, 0.00176309f, 0.00204518f, 1.0f,
|
||||
0.999997f, 0.999985f, 0.99993f, -0.99971f, 0.998926f, 0.996464f, 0.000141047f, 0.000423142f,
|
||||
-0.000705237f, -0.000987331f, -0.00126943f, -0.00155152f, -0.00183361f, -0.00211571f, -0.999999f, -0.999996f,
|
||||
0.999978f, 0.999899f, -0.999593f, -0.998537f, -0.995322f, -0.000211571f, 0.000493666f, 0.000775761f,
|
||||
0.998022f, -0.993857f, 0.00162204f, -0.00190414f, -0.00218623f, 0.999999f, 0.999993f, -0.999967f,
|
||||
-0.999433f, -0.00105786f, 0.00133995f
|
||||
};
|
||||
std::vector<int64_t> dims{static_cast<int64_t>(inputs.size())};
|
||||
|
||||
test.AddInput<float>("A", dims, inputs);
|
||||
test.AddOutput<float>("B", dims, outputs);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
const int ModOp_ver = 10;
|
||||
|
||||
TEST(ModOpTest, Fmod_float_mixed_sign) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue