mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Optimize quantized LSTM (#8634)
* optimize some lstm gate computation. Remove no need string constructions. * change gcc optimization flags for computation bound logics in rnn_helpers * better qgemm for M=1 * Some improve on avx512 * add condition to limit GCC related marcros * Correct QGemm assembly for M=1 AVX2 optimization to pass mlas_test. * Fix rnn_helper build issue for wasm. * better asm code here according to feedbacks. * Remove customized vectorize and unroll option for GCC. Using restrict on some function to help GCC to correctly vectorize it. Rewrite clip_add_bias() to let GCC correctly vectorize it. * Better restrict semantic for merge_lstm_gates_to_memory() by adding in_place(). Add MSC __restrict for the clip_add_bias() mthod to vectorize correctly. * Force CI restart as it stucked by the onnxruntime-python-checks-ci-pipeline which can not restart.
This commit is contained in:
parent
caacf249c5
commit
76dfe8108b
6 changed files with 96 additions and 83 deletions
|
|
@ -277,10 +277,32 @@ ComputeBlockU8S8AvxVnni MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffs
|
|||
|
||||
ComputeBlockLoopU8S8 MACRO Isa, ColumnCount, RowCount
|
||||
|
||||
LOCAL ComputeBlockBy4Loop
|
||||
LOCAL ProcessRemainingBlocks
|
||||
LOCAL ComputeBlockBy1Loop
|
||||
LOCAL ComputeBlockLoopExit
|
||||
|
||||
mov rsi,r9 ; reload row length remaining
|
||||
|
||||
IF (ColumnCount EQ 16) AND (RowCount EQ 1)
|
||||
sub rsi,4*4
|
||||
jb ProcessRemainingBlocks
|
||||
|
||||
ComputeBlockBy4Loop:
|
||||
ComputeBlockU8S8&Isa& ColumnCount, RowCount, 0*64, 0
|
||||
ComputeBlockU8S8&Isa& ColumnCount, RowCount, 1*64, 4
|
||||
ComputeBlockU8S8&Isa& ColumnCount, RowCount, 2*64, 8
|
||||
ComputeBlockU8S8&Isa& ColumnCount, RowCount, 3*64, 12
|
||||
add rcx,4*4 ; advance matrix A by 4 quads
|
||||
add rdx,4*64 ; advance matrix B
|
||||
sub rsi,4*4
|
||||
jae ComputeBlockBy4Loop
|
||||
|
||||
ProcessRemainingBlocks:
|
||||
add rsi,4*4 ; correct for over-subtract above
|
||||
jz ComputeBlockLoopExit
|
||||
ENDIF
|
||||
|
||||
ComputeBlockBy1Loop:
|
||||
ComputeBlockU8S8&Isa& ColumnCount, RowCount, 0, 0
|
||||
add rcx,4 ; advance matrix A by 1 quad
|
||||
|
|
@ -291,6 +313,8 @@ ENDIF
|
|||
sub rsi,4
|
||||
jnz ComputeBlockBy1Loop
|
||||
|
||||
ComputeBlockLoopExit:
|
||||
|
||||
ENDM
|
||||
|
||||
;
|
||||
|
|
|
|||
|
|
@ -233,7 +233,7 @@ ComputeBlockLoopU8S8 MACRO Isa, ColumnCount, RowCount
|
|||
|
||||
mov rsi,r9 ; reload row length remaining
|
||||
|
||||
IF ((RowCount AND 1) EQ 0)
|
||||
IF (RowCount EQ 1) OR ((RowCount AND 1) EQ 0)
|
||||
sub rsi,4*4
|
||||
jb ProcessRemainingBlocks
|
||||
|
||||
|
|
|
|||
|
|
@ -264,6 +264,25 @@ Implicit Arguments:
|
|||
|
||||
mov rbp,rcx # reload row length remaining
|
||||
|
||||
.if (\ColumnCount\() == 16) && (\RowCount\() == 1)
|
||||
sub rbp,4*4
|
||||
jb .LProcessRemainingBlocks\@
|
||||
|
||||
.LComputeBlockBy4Loop\@:
|
||||
ComputeBlockU8S8\Isa\() \ColumnCount\(), \RowCount\(), 0*64, 0
|
||||
ComputeBlockU8S8\Isa\() \ColumnCount\(), \RowCount\(), 1*64, 4
|
||||
ComputeBlockU8S8\Isa\() \ColumnCount\(), \RowCount\(), 2*64, 8
|
||||
ComputeBlockU8S8\Isa\() \ColumnCount\(), \RowCount\(), 3*64, 12
|
||||
add rdi,4*4 # advance matrix A by 4 quads
|
||||
add rsi,4*64 # advance matrix B
|
||||
sub rbp,4*4
|
||||
jae .LComputeBlockBy4Loop\@
|
||||
|
||||
.LProcessRemainingBlocks\@:
|
||||
add rbp,4*4 # correct for over-subtract above
|
||||
jz .LComputeBlockLoopExit\@
|
||||
.endif
|
||||
|
||||
.LComputeBlockBy1Loop\@:
|
||||
ComputeBlockU8S8\Isa\() \ColumnCount\(), \RowCount\(), 0, 0
|
||||
add rdi,4 # advance matrix A by 1 quad
|
||||
|
|
@ -274,6 +293,8 @@ Implicit Arguments:
|
|||
sub rbp,4
|
||||
jnz .LComputeBlockBy1Loop\@
|
||||
|
||||
.LComputeBlockLoopExit\@:
|
||||
|
||||
.endm
|
||||
|
||||
/*++
|
||||
|
|
|
|||
|
|
@ -218,7 +218,7 @@ Implicit Arguments:
|
|||
|
||||
mov rbp,rcx # reload row length remaining
|
||||
|
||||
.if ((\RowCount\() & 1) == 0)
|
||||
.if (\RowCount\() == 1) || ((\RowCount\() & 1) == 0)
|
||||
sub rbp,4*4
|
||||
jb .LProcessRemainingBlocks\@
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@
|
|||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
#include "core/providers/cpu/rnn/rnn_activation_functors.h"
|
||||
#include "core/util/math.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
|
|
@ -292,7 +293,7 @@ void ComputeGemm(const int M,
|
|||
gemm_shape.N = static_cast<size_t>(N);
|
||||
gemm_shape.K = static_cast<size_t>(K);
|
||||
gemm_shape.BIsSigned = b_is_signed;
|
||||
|
||||
|
||||
MLAS_GEMM_U8X8_DATA_PARAMS gemm_params;
|
||||
gemm_params.A = quantized_A_buffer;
|
||||
gemm_params.lda = static_cast<size_t>(K);
|
||||
|
|
@ -326,6 +327,14 @@ const float beta_6 = 1.19825839466702e-06f;
|
|||
const float sigmoid_bound = 20.0f;
|
||||
const float tanh_bound = 10.0f;
|
||||
|
||||
#if defined(__GNUC__) && !defined(__wasm__)
|
||||
#define restrict __restrict__
|
||||
#elif defined(_MSC_VER)
|
||||
#define restrict __restrict
|
||||
#else
|
||||
#define restrict
|
||||
#endif
|
||||
|
||||
inline void clip_for_sigmoid_in_place(float* ps, int c) {
|
||||
for (int i = 0; i < c; i++) {
|
||||
if (ps[i] < -sigmoid_bound)
|
||||
|
|
@ -402,67 +411,41 @@ void clip_ignore_bias(const float b, const float* pb, float* pd, int c) {
|
|||
}
|
||||
}
|
||||
|
||||
void clip_add_bias(const float b, const float* pb, float* pd, int c) {
|
||||
void clip_add_bias(const float b, const float* restrict pb, float* restrict pd, int c) {
|
||||
for (int i = 0; i < c; i++) {
|
||||
float x = pd[i] + pb[i];
|
||||
if (x > b)
|
||||
pd[i] = b;
|
||||
else if (x < -b)
|
||||
pd[i] = -b;
|
||||
else
|
||||
pd[i] = x;
|
||||
x = std::min(b, x);
|
||||
x = std::max(-b, x);
|
||||
pd[i] = x;
|
||||
}
|
||||
}
|
||||
|
||||
void sigmoid_m(const float* ps1, float* ps1_c, const float* ps2, float* pd, int c,
|
||||
void sigmoid_m(const float* restrict ps1, float* restrict ps1_c, const float* restrict ps2, float* restrict pd, int c,
|
||||
const float alpha, const float beta) {
|
||||
ORT_UNUSED_PARAMETER(alpha);
|
||||
ORT_UNUSED_PARAMETER(beta);
|
||||
ORT_UNUSED_PARAMETER(ps1_c);
|
||||
|
||||
clip_for_sigmoid(ps1, ps1_c, c);
|
||||
|
||||
MlasComputeLogistic(ps1, pd, c);
|
||||
for (int i = 0; i < c; i++) {
|
||||
float x = 0.5f * ps1_c[i];
|
||||
float x2 = x * x;
|
||||
float p = x2 * alpha_13 + alpha_11;
|
||||
p = x2 * p + alpha_9;
|
||||
p = x2 * p + alpha_7;
|
||||
p = x2 * p + alpha_5;
|
||||
p = x2 * p + alpha_3;
|
||||
p = x2 * p + alpha_1;
|
||||
p = x * p;
|
||||
float q = x2 * beta_6 + beta_4;
|
||||
q = x2 * q + beta_2;
|
||||
q = x2 * q + beta_0;
|
||||
pd[i] = ps2[i] * 0.5f * (1 + (p / q));
|
||||
pd[i] *= ps2[i];
|
||||
}
|
||||
}
|
||||
|
||||
void tanh_m(const float* ps1, float* ps1_c, const float* ps2, float* pd, int c,
|
||||
void tanh_m(const float* restrict ps1, float* restrict ps1_c, const float* restrict ps2, float* restrict pd, int c,
|
||||
const float alpha, const float beta) {
|
||||
ORT_UNUSED_PARAMETER(alpha);
|
||||
ORT_UNUSED_PARAMETER(beta);
|
||||
ORT_UNUSED_PARAMETER(ps1_c);
|
||||
|
||||
clip_for_tanh(ps1, ps1_c, c);
|
||||
|
||||
MlasComputeTanh(ps1, pd, c);
|
||||
for (int i = 0; i < c; i++) {
|
||||
float x = ps1_c[i];
|
||||
float x2 = x * x;
|
||||
float p = x2 * alpha_13 + alpha_11;
|
||||
p = x2 * p + alpha_9;
|
||||
p = x2 * p + alpha_7;
|
||||
p = x2 * p + alpha_5;
|
||||
p = x2 * p + alpha_3;
|
||||
p = x2 * p + alpha_1;
|
||||
p = x * p;
|
||||
float q = x2 * beta_6 + beta_4;
|
||||
q = x2 * q + beta_2;
|
||||
q = x2 * q + beta_0;
|
||||
pd[i] = ps2[i] * p / q;
|
||||
pd[i] *= ps2[i];
|
||||
}
|
||||
}
|
||||
|
||||
void relu_m(const float* ps1, float* ps1_c, const float* ps2, float* pd, int c, float alpha, float beta) {
|
||||
void relu_m(const float* restrict ps1, float* restrict ps1_c, const float* restrict ps2, float* restrict pd, int c,
|
||||
const float alpha, const float beta) {
|
||||
ORT_UNUSED_PARAMETER(ps1_c);
|
||||
ORT_UNUSED_PARAMETER(alpha);
|
||||
ORT_UNUSED_PARAMETER(beta);
|
||||
|
|
@ -507,56 +490,23 @@ void sigmoid(float* pd, int c, float alpha, float beta) {
|
|||
ORT_UNUSED_PARAMETER(alpha);
|
||||
ORT_UNUSED_PARAMETER(beta);
|
||||
|
||||
clip_for_sigmoid_in_place(pd, c);
|
||||
|
||||
for (int i = 0; i < c; i++) {
|
||||
float x = 0.5f * pd[i];
|
||||
float x2 = x * x;
|
||||
float p = x2 * alpha_13 + alpha_11;
|
||||
p = x2 * p + alpha_9;
|
||||
p = x2 * p + alpha_7;
|
||||
p = x2 * p + alpha_5;
|
||||
p = x2 * p + alpha_3;
|
||||
p = x2 * p + alpha_1;
|
||||
p = x * p;
|
||||
float q = x2 * beta_6 + beta_4;
|
||||
q = x2 * q + beta_2;
|
||||
q = x2 * q + beta_0;
|
||||
pd[i] = 0.5f * (1 + (p / q));
|
||||
}
|
||||
MlasComputeLogistic(pd, pd, c);
|
||||
}
|
||||
|
||||
void tanh(float* pd, int c, float alpha, float beta) {
|
||||
ORT_UNUSED_PARAMETER(alpha);
|
||||
ORT_UNUSED_PARAMETER(beta);
|
||||
|
||||
clip_for_tanh_in_place(pd, c);
|
||||
|
||||
for (int i = 0; i < c; i++) {
|
||||
float x = pd[i];
|
||||
float x2 = x * x;
|
||||
float p = x2 * alpha_13 + alpha_11;
|
||||
p = x2 * p + alpha_9;
|
||||
p = x2 * p + alpha_7;
|
||||
p = x2 * p + alpha_5;
|
||||
p = x2 * p + alpha_3;
|
||||
p = x2 * p + alpha_1;
|
||||
p = x * p;
|
||||
float q = x2 * beta_6 + beta_4;
|
||||
q = x2 * q + beta_2;
|
||||
q = x2 * q + beta_0;
|
||||
pd[i] = p / q;
|
||||
}
|
||||
MlasComputeTanh(pd, pd, c);
|
||||
}
|
||||
|
||||
void relu(float* pd, int c, float alpha, float beta) {
|
||||
ORT_UNUSED_PARAMETER(alpha);
|
||||
ORT_UNUSED_PARAMETER(beta);
|
||||
|
||||
for (int i = 0; i < c; i++) {
|
||||
if (pd[i] < 0)
|
||||
pd[i] = 0.0f;
|
||||
}
|
||||
MLAS_ACTIVATION activation;
|
||||
activation.ActivationKind = MlasReluActivation;
|
||||
MlasActivation(&activation, pd, nullptr, 1, c, c);
|
||||
}
|
||||
|
||||
void sigmoid_exact(float* pd, int c, float alpha, float beta) {
|
||||
|
|
@ -579,8 +529,23 @@ void tanh_exact(float* pd, int c, float alpha, float beta) {
|
|||
}
|
||||
}
|
||||
|
||||
void merge_lstm_gates_to_memory(const float* pprev, const float* pi, const float* pf, const float* pg, float* pcurr,
|
||||
int c) {
|
||||
// Help compiler simply and correctly optimize for pcurr == pprev case.
|
||||
// Although without this in_place(), if restrict pprev and pcur, compiler could also work.
|
||||
// Yet this in_place() follow the restrict semantic better.
|
||||
static void merge_lstm_gates_to_memory_in_place(const float* restrict pi, const float* restrict pf,
|
||||
const float* restrict pg, float* restrict pcurr, int c) {
|
||||
for (int i = 0; i < c; i++) {
|
||||
pcurr[i] = pcurr[i] * pf[i] + pi[i] * pg[i];
|
||||
}
|
||||
}
|
||||
|
||||
void merge_lstm_gates_to_memory(const float* pprev, const float* restrict pi, const float* restrict pf,
|
||||
const float* restrict pg, float* pcurr, int c) {
|
||||
if (pprev == pcurr) {
|
||||
merge_lstm_gates_to_memory_in_place(pi, pf, pg, pcurr, c);
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < c; i++) {
|
||||
pcurr[i] = pprev[i] * pf[i] + pi[i] * pg[i];
|
||||
}
|
||||
|
|
@ -953,3 +918,4 @@ GruOutputGateFuncPtr GruOutputGateFuncByName(const std::string& func) {
|
|||
} // namespace detail
|
||||
} // namespace rnn
|
||||
} // namespace onnxruntime
|
||||
|
||||
|
|
|
|||
|
|
@ -536,8 +536,10 @@ void UniDirectionalLstm<T>::GateComputations(
|
|||
// DumpMatrix("H" + row_str, pH, 1, hidden_size_);
|
||||
}
|
||||
|
||||
#if defined(DUMP_MATRIXES)
|
||||
auto num_rows = local_fused_hidden_rows - row;
|
||||
std::string rows_str = " rows[" + std::to_string(row) + ".." + std::to_string(num_rows) + "]";
|
||||
#endif
|
||||
|
||||
DumpMatrix("i" + rows_str, &*out, num_rows, hidden_size_, 0, hidden_size_x4);
|
||||
DumpMatrix("o" + rows_str, &*out, num_rows, hidden_size_, 1 * hidden_size_, hidden_size_x4);
|
||||
|
|
|
|||
Loading…
Reference in a new issue