Optimize quantized LSTM (#8634)

* optimize some lstm gate computation. Remove no need string constructions.

* change gcc optimization flags for computation bound logics in rnn_helpers

* better qgemm for M=1

* Some improve on avx512

* add condition to limit GCC related marcros

* Correct QGemm assembly for M=1 AVX2 optimization to pass mlas_test.

* Fix rnn_helper build issue for wasm.

* better asm code here according to feedbacks.

* Remove customized vectorize and unroll option for GCC.
Using restrict on some function to help GCC to correctly vectorize it.
Rewrite clip_add_bias() to let GCC correctly vectorize it.

* Better restrict semantic for merge_lstm_gates_to_memory() by adding in_place().
Add MSC __restrict for the clip_add_bias() mthod to vectorize correctly.

* Force CI restart as it stucked by the onnxruntime-python-checks-ci-pipeline which can not restart.
This commit is contained in:
Zhang Lei 2021-08-11 22:02:18 -07:00 committed by GitHub
parent caacf249c5
commit 76dfe8108b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 96 additions and 83 deletions

View file

@ -277,10 +277,32 @@ ComputeBlockU8S8AvxVnni MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffs
ComputeBlockLoopU8S8 MACRO Isa, ColumnCount, RowCount
LOCAL ComputeBlockBy4Loop
LOCAL ProcessRemainingBlocks
LOCAL ComputeBlockBy1Loop
LOCAL ComputeBlockLoopExit
mov rsi,r9 ; reload row length remaining
IF (ColumnCount EQ 16) AND (RowCount EQ 1)
sub rsi,4*4
jb ProcessRemainingBlocks
ComputeBlockBy4Loop:
ComputeBlockU8S8&Isa& ColumnCount, RowCount, 0*64, 0
ComputeBlockU8S8&Isa& ColumnCount, RowCount, 1*64, 4
ComputeBlockU8S8&Isa& ColumnCount, RowCount, 2*64, 8
ComputeBlockU8S8&Isa& ColumnCount, RowCount, 3*64, 12
add rcx,4*4 ; advance matrix A by 4 quads
add rdx,4*64 ; advance matrix B
sub rsi,4*4
jae ComputeBlockBy4Loop
ProcessRemainingBlocks:
add rsi,4*4 ; correct for over-subtract above
jz ComputeBlockLoopExit
ENDIF
ComputeBlockBy1Loop:
ComputeBlockU8S8&Isa& ColumnCount, RowCount, 0, 0
add rcx,4 ; advance matrix A by 1 quad
@ -291,6 +313,8 @@ ENDIF
sub rsi,4
jnz ComputeBlockBy1Loop
ComputeBlockLoopExit:
ENDM
;

View file

@ -233,7 +233,7 @@ ComputeBlockLoopU8S8 MACRO Isa, ColumnCount, RowCount
mov rsi,r9 ; reload row length remaining
IF ((RowCount AND 1) EQ 0)
IF (RowCount EQ 1) OR ((RowCount AND 1) EQ 0)
sub rsi,4*4
jb ProcessRemainingBlocks

View file

@ -264,6 +264,25 @@ Implicit Arguments:
mov rbp,rcx # reload row length remaining
.if (\ColumnCount\() == 16) && (\RowCount\() == 1)
sub rbp,4*4
jb .LProcessRemainingBlocks\@
.LComputeBlockBy4Loop\@:
ComputeBlockU8S8\Isa\() \ColumnCount\(), \RowCount\(), 0*64, 0
ComputeBlockU8S8\Isa\() \ColumnCount\(), \RowCount\(), 1*64, 4
ComputeBlockU8S8\Isa\() \ColumnCount\(), \RowCount\(), 2*64, 8
ComputeBlockU8S8\Isa\() \ColumnCount\(), \RowCount\(), 3*64, 12
add rdi,4*4 # advance matrix A by 4 quads
add rsi,4*64 # advance matrix B
sub rbp,4*4
jae .LComputeBlockBy4Loop\@
.LProcessRemainingBlocks\@:
add rbp,4*4 # correct for over-subtract above
jz .LComputeBlockLoopExit\@
.endif
.LComputeBlockBy1Loop\@:
ComputeBlockU8S8\Isa\() \ColumnCount\(), \RowCount\(), 0, 0
add rdi,4 # advance matrix A by 1 quad
@ -274,6 +293,8 @@ Implicit Arguments:
sub rbp,4
jnz .LComputeBlockBy1Loop\@
.LComputeBlockLoopExit\@:
.endm
/*++

View file

@ -218,7 +218,7 @@ Implicit Arguments:
mov rbp,rcx # reload row length remaining
.if ((\RowCount\() & 1) == 0)
.if (\RowCount\() == 1) || ((\RowCount\() & 1) == 0)
sub rbp,4*4
jb .LProcessRemainingBlocks\@

View file

@ -13,6 +13,7 @@
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/mlas/inc/mlas.h"
#include "core/providers/cpu/rnn/rnn_activation_functors.h"
#include "core/util/math.h"
#include "core/util/math_cpuonly.h"
@ -292,7 +293,7 @@ void ComputeGemm(const int M,
gemm_shape.N = static_cast<size_t>(N);
gemm_shape.K = static_cast<size_t>(K);
gemm_shape.BIsSigned = b_is_signed;
MLAS_GEMM_U8X8_DATA_PARAMS gemm_params;
gemm_params.A = quantized_A_buffer;
gemm_params.lda = static_cast<size_t>(K);
@ -326,6 +327,14 @@ const float beta_6 = 1.19825839466702e-06f;
const float sigmoid_bound = 20.0f;
const float tanh_bound = 10.0f;
#if defined(__GNUC__) && !defined(__wasm__)
#define restrict __restrict__
#elif defined(_MSC_VER)
#define restrict __restrict
#else
#define restrict
#endif
inline void clip_for_sigmoid_in_place(float* ps, int c) {
for (int i = 0; i < c; i++) {
if (ps[i] < -sigmoid_bound)
@ -402,67 +411,41 @@ void clip_ignore_bias(const float b, const float* pb, float* pd, int c) {
}
}
void clip_add_bias(const float b, const float* pb, float* pd, int c) {
void clip_add_bias(const float b, const float* restrict pb, float* restrict pd, int c) {
for (int i = 0; i < c; i++) {
float x = pd[i] + pb[i];
if (x > b)
pd[i] = b;
else if (x < -b)
pd[i] = -b;
else
pd[i] = x;
x = std::min(b, x);
x = std::max(-b, x);
pd[i] = x;
}
}
void sigmoid_m(const float* ps1, float* ps1_c, const float* ps2, float* pd, int c,
void sigmoid_m(const float* restrict ps1, float* restrict ps1_c, const float* restrict ps2, float* restrict pd, int c,
const float alpha, const float beta) {
ORT_UNUSED_PARAMETER(alpha);
ORT_UNUSED_PARAMETER(beta);
ORT_UNUSED_PARAMETER(ps1_c);
clip_for_sigmoid(ps1, ps1_c, c);
MlasComputeLogistic(ps1, pd, c);
for (int i = 0; i < c; i++) {
float x = 0.5f * ps1_c[i];
float x2 = x * x;
float p = x2 * alpha_13 + alpha_11;
p = x2 * p + alpha_9;
p = x2 * p + alpha_7;
p = x2 * p + alpha_5;
p = x2 * p + alpha_3;
p = x2 * p + alpha_1;
p = x * p;
float q = x2 * beta_6 + beta_4;
q = x2 * q + beta_2;
q = x2 * q + beta_0;
pd[i] = ps2[i] * 0.5f * (1 + (p / q));
pd[i] *= ps2[i];
}
}
void tanh_m(const float* ps1, float* ps1_c, const float* ps2, float* pd, int c,
void tanh_m(const float* restrict ps1, float* restrict ps1_c, const float* restrict ps2, float* restrict pd, int c,
const float alpha, const float beta) {
ORT_UNUSED_PARAMETER(alpha);
ORT_UNUSED_PARAMETER(beta);
ORT_UNUSED_PARAMETER(ps1_c);
clip_for_tanh(ps1, ps1_c, c);
MlasComputeTanh(ps1, pd, c);
for (int i = 0; i < c; i++) {
float x = ps1_c[i];
float x2 = x * x;
float p = x2 * alpha_13 + alpha_11;
p = x2 * p + alpha_9;
p = x2 * p + alpha_7;
p = x2 * p + alpha_5;
p = x2 * p + alpha_3;
p = x2 * p + alpha_1;
p = x * p;
float q = x2 * beta_6 + beta_4;
q = x2 * q + beta_2;
q = x2 * q + beta_0;
pd[i] = ps2[i] * p / q;
pd[i] *= ps2[i];
}
}
void relu_m(const float* ps1, float* ps1_c, const float* ps2, float* pd, int c, float alpha, float beta) {
void relu_m(const float* restrict ps1, float* restrict ps1_c, const float* restrict ps2, float* restrict pd, int c,
const float alpha, const float beta) {
ORT_UNUSED_PARAMETER(ps1_c);
ORT_UNUSED_PARAMETER(alpha);
ORT_UNUSED_PARAMETER(beta);
@ -507,56 +490,23 @@ void sigmoid(float* pd, int c, float alpha, float beta) {
ORT_UNUSED_PARAMETER(alpha);
ORT_UNUSED_PARAMETER(beta);
clip_for_sigmoid_in_place(pd, c);
for (int i = 0; i < c; i++) {
float x = 0.5f * pd[i];
float x2 = x * x;
float p = x2 * alpha_13 + alpha_11;
p = x2 * p + alpha_9;
p = x2 * p + alpha_7;
p = x2 * p + alpha_5;
p = x2 * p + alpha_3;
p = x2 * p + alpha_1;
p = x * p;
float q = x2 * beta_6 + beta_4;
q = x2 * q + beta_2;
q = x2 * q + beta_0;
pd[i] = 0.5f * (1 + (p / q));
}
MlasComputeLogistic(pd, pd, c);
}
void tanh(float* pd, int c, float alpha, float beta) {
ORT_UNUSED_PARAMETER(alpha);
ORT_UNUSED_PARAMETER(beta);
clip_for_tanh_in_place(pd, c);
for (int i = 0; i < c; i++) {
float x = pd[i];
float x2 = x * x;
float p = x2 * alpha_13 + alpha_11;
p = x2 * p + alpha_9;
p = x2 * p + alpha_7;
p = x2 * p + alpha_5;
p = x2 * p + alpha_3;
p = x2 * p + alpha_1;
p = x * p;
float q = x2 * beta_6 + beta_4;
q = x2 * q + beta_2;
q = x2 * q + beta_0;
pd[i] = p / q;
}
MlasComputeTanh(pd, pd, c);
}
void relu(float* pd, int c, float alpha, float beta) {
ORT_UNUSED_PARAMETER(alpha);
ORT_UNUSED_PARAMETER(beta);
for (int i = 0; i < c; i++) {
if (pd[i] < 0)
pd[i] = 0.0f;
}
MLAS_ACTIVATION activation;
activation.ActivationKind = MlasReluActivation;
MlasActivation(&activation, pd, nullptr, 1, c, c);
}
void sigmoid_exact(float* pd, int c, float alpha, float beta) {
@ -579,8 +529,23 @@ void tanh_exact(float* pd, int c, float alpha, float beta) {
}
}
void merge_lstm_gates_to_memory(const float* pprev, const float* pi, const float* pf, const float* pg, float* pcurr,
int c) {
// Help compiler simply and correctly optimize for pcurr == pprev case.
// Although without this in_place(), if restrict pprev and pcur, compiler could also work.
// Yet this in_place() follow the restrict semantic better.
static void merge_lstm_gates_to_memory_in_place(const float* restrict pi, const float* restrict pf,
const float* restrict pg, float* restrict pcurr, int c) {
for (int i = 0; i < c; i++) {
pcurr[i] = pcurr[i] * pf[i] + pi[i] * pg[i];
}
}
void merge_lstm_gates_to_memory(const float* pprev, const float* restrict pi, const float* restrict pf,
const float* restrict pg, float* pcurr, int c) {
if (pprev == pcurr) {
merge_lstm_gates_to_memory_in_place(pi, pf, pg, pcurr, c);
return;
}
for (int i = 0; i < c; i++) {
pcurr[i] = pprev[i] * pf[i] + pi[i] * pg[i];
}
@ -953,3 +918,4 @@ GruOutputGateFuncPtr GruOutputGateFuncByName(const std::string& func) {
} // namespace detail
} // namespace rnn
} // namespace onnxruntime

View file

@ -536,8 +536,10 @@ void UniDirectionalLstm<T>::GateComputations(
// DumpMatrix("H" + row_str, pH, 1, hidden_size_);
}
#if defined(DUMP_MATRIXES)
auto num_rows = local_fused_hidden_rows - row;
std::string rows_str = " rows[" + std::to_string(row) + ".." + std::to_string(num_rows) + "]";
#endif
DumpMatrix("i" + rows_str, &*out, num_rows, hidden_size_, 0, hidden_size_x4);
DumpMatrix("o" + rows_str, &*out, num_rows, hidden_size_, 1 * hidden_size_, hidden_size_x4);