diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2.asm index 4b345d07fe..210f28bdd2 100644 --- a/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2.asm +++ b/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2.asm @@ -277,10 +277,32 @@ ComputeBlockU8S8AvxVnni MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffs ComputeBlockLoopU8S8 MACRO Isa, ColumnCount, RowCount + LOCAL ComputeBlockBy4Loop + LOCAL ProcessRemainingBlocks LOCAL ComputeBlockBy1Loop + LOCAL ComputeBlockLoopExit mov rsi,r9 ; reload row length remaining +IF (ColumnCount EQ 16) AND (RowCount EQ 1) + sub rsi,4*4 + jb ProcessRemainingBlocks + +ComputeBlockBy4Loop: + ComputeBlockU8S8&Isa& ColumnCount, RowCount, 0*64, 0 + ComputeBlockU8S8&Isa& ColumnCount, RowCount, 1*64, 4 + ComputeBlockU8S8&Isa& ColumnCount, RowCount, 2*64, 8 + ComputeBlockU8S8&Isa& ColumnCount, RowCount, 3*64, 12 + add rcx,4*4 ; advance matrix A by 4 quads + add rdx,4*64 ; advance matrix B + sub rsi,4*4 + jae ComputeBlockBy4Loop + +ProcessRemainingBlocks: + add rsi,4*4 ; correct for over-subtract above + jz ComputeBlockLoopExit +ENDIF + ComputeBlockBy1Loop: ComputeBlockU8S8&Isa& ColumnCount, RowCount, 0, 0 add rcx,4 ; advance matrix A by 1 quad @@ -291,6 +313,8 @@ ENDIF sub rsi,4 jnz ComputeBlockBy1Loop +ComputeBlockLoopExit: + ENDM ; diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx512Core.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx512Core.asm index bcd3f52b7e..9606c3ddb0 100644 --- a/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx512Core.asm +++ b/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx512Core.asm @@ -233,7 +233,7 @@ ComputeBlockLoopU8S8 MACRO Isa, ColumnCount, RowCount mov rsi,r9 ; reload row length remaining -IF ((RowCount AND 1) EQ 0) +IF (RowCount EQ 1) OR ((RowCount AND 1) EQ 0) sub rsi,4*4 jb ProcessRemainingBlocks diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S index b9caba221c..b0f7be63c4 100644 --- a/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S @@ -264,6 +264,25 @@ Implicit Arguments: mov rbp,rcx # reload row length remaining +.if (\ColumnCount\() == 16) && (\RowCount\() == 1) + sub rbp,4*4 + jb .LProcessRemainingBlocks\@ + +.LComputeBlockBy4Loop\@: + ComputeBlockU8S8\Isa\() \ColumnCount\(), \RowCount\(), 0*64, 0 + ComputeBlockU8S8\Isa\() \ColumnCount\(), \RowCount\(), 1*64, 4 + ComputeBlockU8S8\Isa\() \ColumnCount\(), \RowCount\(), 2*64, 8 + ComputeBlockU8S8\Isa\() \ColumnCount\(), \RowCount\(), 3*64, 12 + add rdi,4*4 # advance matrix A by 4 quads + add rsi,4*64 # advance matrix B + sub rbp,4*4 + jae .LComputeBlockBy4Loop\@ + +.LProcessRemainingBlocks\@: + add rbp,4*4 # correct for over-subtract above + jz .LComputeBlockLoopExit\@ +.endif + .LComputeBlockBy1Loop\@: ComputeBlockU8S8\Isa\() \ColumnCount\(), \RowCount\(), 0, 0 add rdi,4 # advance matrix A by 1 quad @@ -274,6 +293,8 @@ Implicit Arguments: sub rbp,4 jnz .LComputeBlockBy1Loop\@ +.LComputeBlockLoopExit\@: + .endm /*++ diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Core.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Core.S index 156091221e..279f406b88 100644 --- a/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Core.S +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Core.S @@ -218,7 +218,7 @@ Implicit Arguments: mov rbp,rcx # reload row length remaining -.if ((\RowCount\() & 1) == 0) +.if (\RowCount\() == 1) || ((\RowCount\() & 1) == 0) sub rbp,4*4 jb .LProcessRemainingBlocks\@ diff --git a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.cc b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.cc index 0311a20e26..f52f9fdda9 100644 --- a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.cc +++ b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.cc @@ -13,6 +13,7 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" +#include "core/mlas/inc/mlas.h" #include "core/providers/cpu/rnn/rnn_activation_functors.h" #include "core/util/math.h" #include "core/util/math_cpuonly.h" @@ -292,7 +293,7 @@ void ComputeGemm(const int M, gemm_shape.N = static_cast(N); gemm_shape.K = static_cast(K); gemm_shape.BIsSigned = b_is_signed; - + MLAS_GEMM_U8X8_DATA_PARAMS gemm_params; gemm_params.A = quantized_A_buffer; gemm_params.lda = static_cast(K); @@ -326,6 +327,14 @@ const float beta_6 = 1.19825839466702e-06f; const float sigmoid_bound = 20.0f; const float tanh_bound = 10.0f; +#if defined(__GNUC__) && !defined(__wasm__) +#define restrict __restrict__ +#elif defined(_MSC_VER) +#define restrict __restrict +#else +#define restrict +#endif + inline void clip_for_sigmoid_in_place(float* ps, int c) { for (int i = 0; i < c; i++) { if (ps[i] < -sigmoid_bound) @@ -402,67 +411,41 @@ void clip_ignore_bias(const float b, const float* pb, float* pd, int c) { } } -void clip_add_bias(const float b, const float* pb, float* pd, int c) { +void clip_add_bias(const float b, const float* restrict pb, float* restrict pd, int c) { for (int i = 0; i < c; i++) { float x = pd[i] + pb[i]; - if (x > b) - pd[i] = b; - else if (x < -b) - pd[i] = -b; - else - pd[i] = x; + x = std::min(b, x); + x = std::max(-b, x); + pd[i] = x; } } -void sigmoid_m(const float* ps1, float* ps1_c, const float* ps2, float* pd, int c, +void sigmoid_m(const float* restrict ps1, float* restrict ps1_c, const float* restrict ps2, float* restrict pd, int c, const float alpha, const float beta) { ORT_UNUSED_PARAMETER(alpha); ORT_UNUSED_PARAMETER(beta); + ORT_UNUSED_PARAMETER(ps1_c); - clip_for_sigmoid(ps1, ps1_c, c); - + MlasComputeLogistic(ps1, pd, c); for (int i = 0; i < c; i++) { - float x = 0.5f * ps1_c[i]; - float x2 = x * x; - float p = x2 * alpha_13 + alpha_11; - p = x2 * p + alpha_9; - p = x2 * p + alpha_7; - p = x2 * p + alpha_5; - p = x2 * p + alpha_3; - p = x2 * p + alpha_1; - p = x * p; - float q = x2 * beta_6 + beta_4; - q = x2 * q + beta_2; - q = x2 * q + beta_0; - pd[i] = ps2[i] * 0.5f * (1 + (p / q)); + pd[i] *= ps2[i]; } } -void tanh_m(const float* ps1, float* ps1_c, const float* ps2, float* pd, int c, +void tanh_m(const float* restrict ps1, float* restrict ps1_c, const float* restrict ps2, float* restrict pd, int c, const float alpha, const float beta) { ORT_UNUSED_PARAMETER(alpha); ORT_UNUSED_PARAMETER(beta); + ORT_UNUSED_PARAMETER(ps1_c); - clip_for_tanh(ps1, ps1_c, c); - + MlasComputeTanh(ps1, pd, c); for (int i = 0; i < c; i++) { - float x = ps1_c[i]; - float x2 = x * x; - float p = x2 * alpha_13 + alpha_11; - p = x2 * p + alpha_9; - p = x2 * p + alpha_7; - p = x2 * p + alpha_5; - p = x2 * p + alpha_3; - p = x2 * p + alpha_1; - p = x * p; - float q = x2 * beta_6 + beta_4; - q = x2 * q + beta_2; - q = x2 * q + beta_0; - pd[i] = ps2[i] * p / q; + pd[i] *= ps2[i]; } } -void relu_m(const float* ps1, float* ps1_c, const float* ps2, float* pd, int c, float alpha, float beta) { +void relu_m(const float* restrict ps1, float* restrict ps1_c, const float* restrict ps2, float* restrict pd, int c, + const float alpha, const float beta) { ORT_UNUSED_PARAMETER(ps1_c); ORT_UNUSED_PARAMETER(alpha); ORT_UNUSED_PARAMETER(beta); @@ -507,56 +490,23 @@ void sigmoid(float* pd, int c, float alpha, float beta) { ORT_UNUSED_PARAMETER(alpha); ORT_UNUSED_PARAMETER(beta); - clip_for_sigmoid_in_place(pd, c); - - for (int i = 0; i < c; i++) { - float x = 0.5f * pd[i]; - float x2 = x * x; - float p = x2 * alpha_13 + alpha_11; - p = x2 * p + alpha_9; - p = x2 * p + alpha_7; - p = x2 * p + alpha_5; - p = x2 * p + alpha_3; - p = x2 * p + alpha_1; - p = x * p; - float q = x2 * beta_6 + beta_4; - q = x2 * q + beta_2; - q = x2 * q + beta_0; - pd[i] = 0.5f * (1 + (p / q)); - } + MlasComputeLogistic(pd, pd, c); } void tanh(float* pd, int c, float alpha, float beta) { ORT_UNUSED_PARAMETER(alpha); ORT_UNUSED_PARAMETER(beta); - clip_for_tanh_in_place(pd, c); - - for (int i = 0; i < c; i++) { - float x = pd[i]; - float x2 = x * x; - float p = x2 * alpha_13 + alpha_11; - p = x2 * p + alpha_9; - p = x2 * p + alpha_7; - p = x2 * p + alpha_5; - p = x2 * p + alpha_3; - p = x2 * p + alpha_1; - p = x * p; - float q = x2 * beta_6 + beta_4; - q = x2 * q + beta_2; - q = x2 * q + beta_0; - pd[i] = p / q; - } + MlasComputeTanh(pd, pd, c); } void relu(float* pd, int c, float alpha, float beta) { ORT_UNUSED_PARAMETER(alpha); ORT_UNUSED_PARAMETER(beta); - for (int i = 0; i < c; i++) { - if (pd[i] < 0) - pd[i] = 0.0f; - } + MLAS_ACTIVATION activation; + activation.ActivationKind = MlasReluActivation; + MlasActivation(&activation, pd, nullptr, 1, c, c); } void sigmoid_exact(float* pd, int c, float alpha, float beta) { @@ -579,8 +529,23 @@ void tanh_exact(float* pd, int c, float alpha, float beta) { } } -void merge_lstm_gates_to_memory(const float* pprev, const float* pi, const float* pf, const float* pg, float* pcurr, - int c) { +// Help compiler simply and correctly optimize for pcurr == pprev case. +// Although without this in_place(), if restrict pprev and pcur, compiler could also work. +// Yet this in_place() follow the restrict semantic better. +static void merge_lstm_gates_to_memory_in_place(const float* restrict pi, const float* restrict pf, + const float* restrict pg, float* restrict pcurr, int c) { + for (int i = 0; i < c; i++) { + pcurr[i] = pcurr[i] * pf[i] + pi[i] * pg[i]; + } +} + +void merge_lstm_gates_to_memory(const float* pprev, const float* restrict pi, const float* restrict pf, + const float* restrict pg, float* pcurr, int c) { + if (pprev == pcurr) { + merge_lstm_gates_to_memory_in_place(pi, pf, pg, pcurr, c); + return; + } + for (int i = 0; i < c; i++) { pcurr[i] = pprev[i] * pf[i] + pi[i] * pg[i]; } @@ -953,3 +918,4 @@ GruOutputGateFuncPtr GruOutputGateFuncByName(const std::string& func) { } // namespace detail } // namespace rnn } // namespace onnxruntime + diff --git a/onnxruntime/core/providers/cpu/rnn/uni_directional_lstm.cc b/onnxruntime/core/providers/cpu/rnn/uni_directional_lstm.cc index 7aa8a243fd..715aa931a3 100644 --- a/onnxruntime/core/providers/cpu/rnn/uni_directional_lstm.cc +++ b/onnxruntime/core/providers/cpu/rnn/uni_directional_lstm.cc @@ -536,8 +536,10 @@ void UniDirectionalLstm::GateComputations( // DumpMatrix("H" + row_str, pH, 1, hidden_size_); } +#if defined(DUMP_MATRIXES) auto num_rows = local_fused_hidden_rows - row; std::string rows_str = " rows[" + std::to_string(row) + ".." + std::to_string(num_rows) + "]"; +#endif DumpMatrix("i" + rows_str, &*out, num_rows, hidden_size_, 0, hidden_size_x4); DumpMatrix("o" + rows_str, &*out, num_rows, hidden_size_, 1 * hidden_size_, hidden_size_x4);