Tuning GRU performance for batch size >= 2 (#644)

GRU with batch size >1 is implemented on the assumption that Lotus use single-thread Eigen Gemm. The assumption doesn't hold anymore. MLAS and MKLML support multi-thread. We don't rely eigen gemm anymore.
This PR implements batch size > 1 as batch size ==1. With this change, we have about 2x performance gain for GRU.Please refer to the performance test below:
(ms)
Batch_size | Seq_length | input_size | hiddden_size | Old | New
8 | 30 | 512 | 512 | 19.16 | 10.47
16 | 30 | 512 | 512 | 28.13 | 15.15
32 | 30 | 512 | 512 | 36.97 | 26.89
8 | 30 | 1024 | 1024 | 142.853 | 55.67
16 | 30 | 1024 | 1024 | 184.397 | 72.32
32 | 30 | 1024 | 1024 236.364 | 112.78
This commit is contained in:
Yufeng Li 2019-04-23 14:50:11 -07:00 committed by GitHub
parent 80d69515ed
commit d0f846aad5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 186 additions and 448 deletions

View file

@ -161,8 +161,6 @@ using namespace rnn::detail;
// internal helper code
namespace detail {
/// The class represents DeepCPU implementation of a gated recurrent unit (GRU) operator.
/// For details, refer to http://aka.ms/dl-optimization/.
template <typename T>
class UniDirectionalGru {
public:
@ -178,8 +176,7 @@ class UniDirectionalGru {
const gsl::span<const T>& initial_hidden_state,
const ActivationFuncs::Entry& activation_func_f,
const ActivationFuncs::Entry& activation_func_g,
const float clip,
onnxruntime::concurrency::ThreadPool& ttp_);
const float clip);
void Compute(const gsl::span<const T>& inputs,
const gsl::span<const int>& sequence_lengths,
@ -195,8 +192,6 @@ class UniDirectionalGru {
AllocatorPtr allocator_;
const logging::Logger& logger_;
onnxruntime::concurrency::ThreadPool& ttp_;
int seq_length_;
int batch_size_;
int input_size_;
@ -207,9 +202,6 @@ class UniDirectionalGru {
Direction direction_;
bool use_bias_;
bool batch_parallel_;
int hidden_num_threads_ = -1;
IAllocatorUniquePtr<T> outputZRH_ptr_;
gsl::span<T> outputZRH_;
@ -243,17 +235,18 @@ class UniDirectionalGru {
gsl::span<T> inputs_reverse_;
gsl::span<T> outputs_reverse_;
deepcpu::ClipWithBiasFuncPtr clip_with_bias_ptr_ = nullptr;
deepcpu::ClipWithBiasFuncPtr clip_with_bias_ptr_{};
float zr_alpha_ = 0.f, zr_beta_ = 0.f;
float h_alpha_ = 0.f, h_beta_ = 0.f;
float zr_alpha_{};
float zr_beta_{};
float h_alpha_{};
float h_beta_{};
deepcpu::GruResetGateFuncPtr reset_gate_ = nullptr;
deepcpu::ActivationFuncPtr update_gate_ = nullptr;
deepcpu::GruOutputGateFuncPtr output_gate_ = nullptr;
deepcpu::GruResetGateFuncPtr reset_gate_{};
deepcpu::ActivationFuncPtr update_gate_{};
deepcpu::GruOutputGateFuncPtr output_gate_{};
void AllocateBuffers();
void SetNumThreads();
};
} // namespace detail
@ -268,7 +261,6 @@ Status DeepCpuGruOp::Compute(OpKernelContext* context) const {
const Tensor& X = *context->Input<Tensor>(0); // inputs. [seq_length, batch_size, input_size]
Status status;
// auto& logger = context->Logger();
auto data_type = X.DataType();
if (data_type == DataTypeImpl::GetType<float>())
@ -393,7 +385,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const {
bias_1, initial_hidden_1,
activation_funcs_.Entries()[0],
activation_funcs_.Entries()[1],
clip_, ttp_);
clip_);
fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1);
std::unique_ptr<detail::UniDirectionalGru<T>> bw = std::make_unique<detail::UniDirectionalGru<T>>(
@ -402,7 +394,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const {
bias_2, initial_hidden_2,
activation_funcs_.Entries()[2],
activation_funcs_.Entries()[3],
clip_, ttp_);
clip_);
bw->Compute(input, sequence_lens_span, num_directions_, input_weights_2, recurrent_weights_2, output_2, hidden_output_2);
} else {
std::unique_ptr<detail::UniDirectionalGru<T>> gru_p = std::make_unique<detail::UniDirectionalGru<T>>(
@ -411,7 +403,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const {
bias_1, initial_hidden_1,
activation_funcs_.Entries()[0],
activation_funcs_.Entries()[1],
clip_, ttp_);
clip_);
gru_p->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1);
}
@ -441,11 +433,9 @@ UniDirectionalGru<T>::UniDirectionalGru(AllocatorPtr allocator,
const gsl::span<const T>& initial_hidden_state,
const ActivationFuncs::Entry& activation_func_f,
const ActivationFuncs::Entry& activation_func_g,
const float clip,
onnxruntime::concurrency::ThreadPool& ttp)
const float clip)
: allocator_(allocator),
logger_(logger),
ttp_(ttp),
seq_length_(seq_length),
batch_size_(batch_size),
input_size_(input_size),
@ -454,7 +444,6 @@ UniDirectionalGru<T>::UniDirectionalGru(AllocatorPtr allocator,
clip_(clip),
direction_(direction),
use_bias_(!bias.empty()) {
//
clip_with_bias_ptr_ = use_bias_ ? deepcpu::clip_add_bias : deepcpu::clip_ignore_bias;
// setup activation function pointers and alpha/beta values to use with them
@ -467,7 +456,6 @@ UniDirectionalGru<T>::UniDirectionalGru(AllocatorPtr allocator,
h_alpha_ = activation_func_g.alpha;
h_beta_ = activation_func_g.beta;
SetNumThreads();
AllocateBuffers();
if (use_bias_) {
@ -598,429 +586,209 @@ void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
span_T_const_iter batched_bias_Rh_local_end = batched_bias_Rh_.cend();
span_T_const_iter batched_bias_WRh_local_end = batched_bias_WRh_.cend();
if (batch_parallel_) {
int fused_hidden_rows = batch_size_ / hidden_num_threads_;
if (batch_size_ % hidden_num_threads_ != 0)
fused_hidden_rows++;
size_t out_added_offset;
// lambda executed by ThreadPool
auto hidden_gemm_and_activations = [&](const int row) {
//handling boundaries
int local_fused_hidden_rows = fused_hidden_rows;
if ((row + fused_hidden_rows) > batch_size_)
local_fused_hidden_rows = batch_size_ - row;
span_T_const_iter prev_Ht = batched_hidden0_.cbegin(); // Ht-1
span_T_const_iter prev_Ht_end = batched_hidden0_.cend();
span_T_iter cur_h_local = cur_h_.begin();
span_T_iter cur_h_local_end = cur_h_.end();
size_t out_added_offset;
span_T_const_iter prev_Ht = batched_hidden0_.cbegin() + row * hidden_size_; // Ht-1
span_T_const_iter prev_Ht_end = batched_hidden0_.cend();
span_T_iter cur_h_local = cur_h_.begin() + row * hidden_size_;
span_T_iter cur_h_local_end = cur_h_.end();
span_T_iter linear_output_local;
span_T_iter linear_output_local_end;
span_T_const_iter batched_bias_WRz_local;
span_T_const_iter batched_bias_WRr_local;
span_T_const_iter batched_bias_WRh_local;
span_T_const_iter batched_bias_Wh_local;
span_T_const_iter batched_bias_Rh_local;
span_T_const_iter batched_bias_WRz_local;
span_T_const_iter batched_bias_WRr_local;
span_T_const_iter batched_bias_WRh_local;
span_T_const_iter batched_bias_Wh_local;
span_T_const_iter batched_bias_Rh_local;
if (use_bias_) {
batched_bias_WRz_local = batched_bias_WRz_.cbegin();
batched_bias_WRr_local = batched_bias_WRr_.cbegin();
if (use_bias_) {
batched_bias_WRz_local = batched_bias_WRz_.cbegin() + row * hidden_size_;
batched_bias_WRr_local = batched_bias_WRr_.cbegin() + row * hidden_size_;
if (linear_before_reset_) {
batched_bias_Wh_local = batched_bias_Wh_.cbegin() + row * hidden_size_;
batched_bias_Rh_local = batched_bias_Rh_.cbegin() + row * hidden_size_;
linear_output_local = linear_output_.begin() + row * hidden_size_;
linear_output_local_end = linear_output_.end();
} else {
batched_bias_WRh_local = batched_bias_WRh_.cbegin() + row * hidden_size_;
}
}
for (int step = 0; step < max_sequence_length; step++) {
const std::string row_str = " [row=" + std::to_string(row) + ",seqno=" + std::to_string(step) + "]";
DumpMatrix("Ht-1" + row_str, &*prev_Ht, local_fused_hidden_rows, hidden_size_);
out_added_offset = (step * batch_size_ + row) * hidden_size_x3;
// calculate Ht-1*R[zr], and add to the weighted inputs that are in outputZRH_
ComputeGemm(local_fused_hidden_rows, hidden_size_x2, hidden_size_, alpha,
prev_Ht, prev_Ht_end,
hidden_size_,
recurrent_weightsZR.cbegin(), recurrent_weightsZR.cend(),
hidden_size_, beta,
outputZRH_.begin() + out_added_offset, outputZRH_.end(),
hidden_size_x3);
DumpMatrix("Xt*(W[zr]^T) + Ht-1 * R[zr]" + row_str,
outputZRH_.data() + out_added_offset, local_fused_hidden_rows, hidden_size_x2, 0, hidden_size_x3);
if (linear_before_reset_) {
// copy Rbh to linear output
gsl::copy(batched_bias_Rh_.subspan(batched_bias_Rh_local - batched_bias_Rh_.begin(), local_fused_hidden_rows * hidden_size_),
linear_output_.subspan(linear_output_local - linear_output_.begin(), linear_output_local_end - linear_output_local));
// compute Ht-1 * (Rh^T) + Rbh
ComputeGemm(local_fused_hidden_rows, hidden_size_, hidden_size_, alpha,
prev_Ht, prev_Ht_end, // Ht-1
hidden_size_,
recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T
hidden_size_, beta,
linear_output_local, linear_output_.end(), // pre: Rbh, post:output
hidden_size_);
DumpMatrix("Ht-1 * (Rh^T) + Rbh " + row_str, &*linear_output_local, batch_size_, hidden_size_);
}
// 1st Set Of Activations
for (int r = 0; r < local_fused_hidden_rows; r++) {
const T* p_bias_r = use_bias_ ? SafeRawConstPointer<T>(batched_bias_WRr_local + r * hidden_size_,
batched_bias_WRr_local_end, hidden_size_)
: nullptr;
// initialize p_rt with input to calculate rt. outputZRH_ has Xt*(Wr^T) + Ht-1*(Rr^T).
T* p_rt = SafeRawPointer(outputZRH_, out_added_offset + r * hidden_size_x3 + hidden_size_, hidden_size_);
// add the bias and clip. post: p_rt == Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr
clip_with_bias_ptr_(clip_, p_bias_r, p_rt, hidden_size_);
if (linear_before_reset_) {
// p_linear_output = Ht-1 * (Rh^T) + Rbh
T* p_linear_output = SafeRawPointer<T>(linear_output_local + r * hidden_size_,
linear_output_local_end, hidden_size_);
T* p_cur_h = SafeRawPointer<T>(cur_h_local + r * hidden_size_, cur_h_local_end, hidden_size_);
// calculate rt in-place [p_rt = f(p_rt)]
// calculate rt (.) (Ht-1 * (Rh^T) + Rbh) using p_linear_output. write to p_cur_h
reset_gate_(p_linear_output, p_rt, p_cur_h, hidden_size_, zr_alpha_, zr_beta_);
} else {
const T* p_prev_Ht = SafeRawConstPointer<T>(prev_Ht + r * hidden_size_, prev_Ht_end, hidden_size_);
T* p_cur_h = SafeRawPointer<T>(cur_h_local + r * hidden_size_, cur_h_local_end, hidden_size_);
// calculate rt in-place [p_rt = f(p_rt)]
// calculate rt (.) Ht-1 using p_prev_Ht, and write to p_cur_h
reset_gate_(p_prev_Ht, p_rt, p_cur_h, hidden_size_, zr_alpha_, zr_beta_);
}
}
std::string label = linear_before_reset_ ? "rt (.) (Ht-1 * (Rh^T) + Rbh)" : "rt (.) Ht-1";
DumpMatrix(label + row_str, &*cur_h_local, local_fused_hidden_rows, hidden_size_);
if (linear_before_reset_) {
// input contains rt (.) (Ht-1*(Rh^T) + Rbh)
auto input = cur_h_local;
// out_H currently contains Xt*(W[zrh]^T).
auto out_H = outputZRH_.begin() + out_added_offset;
for (int r = 0; r < local_fused_hidden_rows; r++) {
// skip over the inputs with Z and R weights
out_H += hidden_size_x2;
for (int h = 0; h < hidden_size_; ++h) {
*out_H += *input;
++out_H;
++input;
}
}
} else {
label += " * Rh^T";
ComputeGemm(local_fused_hidden_rows, hidden_size_, hidden_size_, alpha,
cur_h_local, cur_h_local_end,
hidden_size_,
recurrent_weightsH.cbegin(), recurrent_weightsH.cend(),
hidden_size_, beta,
outputZRH_.begin() + out_added_offset + hidden_size_x2, outputZRH_.end(),
hidden_size_x3);
}
DumpMatrix("Xt*(Wh^T) + (" + label + ")" + row_str,
outputZRH_.data() + out_added_offset, local_fused_hidden_rows, hidden_size_,
hidden_size_x2, hidden_size_x3);
// 2nd Set of Activations
span_T_iter output;
span_T_iter output_end;
if (output_sequence) {
output = outputs.begin() + step * output_step_length + row * hidden_size_;
output_end = outputs.end();
} else {
output = final_hidden_state.begin() + row * hidden_size_;
output_end = final_hidden_state.end();
}
for (int r = 0; r < local_fused_hidden_rows; r++) {
if (step >= min_sequence_length && step >= sequence_lengths[row + r]) {
if (output_sequence) {
auto fill_output = output + r * hidden_size_;
std::fill_n(fill_output, hidden_size_, T{});
}
continue;
}
const T* p_bias_z = use_bias_ ? SafeRawConstPointer<T>(batched_bias_WRz_local, batched_bias_WRz_local_end,
hidden_size_)
: nullptr;
// initialize p_zt with Xt*(Wz^T) + Ht-1*(Rz^T), which is most of the input to calculate zt:
T* p_zt = SafeRawPointer<T>(outputZRH_, out_added_offset + r * hidden_size_x3, hidden_size_);
// using p_zt, add bias and clip in-place
clip_with_bias_ptr_(clip_, p_bias_z, p_zt, hidden_size_);
// calculate zt in-place. p_zt = f(p_zt)
update_gate_(p_zt, hidden_size_, zr_alpha_, zr_beta_);
DumpMatrix("zt[" + std::to_string(r) + "]" + row_str, p_zt, 1, hidden_size_);
const T* p_bias_h = nullptr;
if (use_bias_) {
if (linear_before_reset_) {
// Wbh
p_bias_h = SafeRawConstPointer<T>(batched_bias_Wh_local + r * hidden_size_,
batched_bias_Wh_local_end, hidden_size_);
} else {
// Wbh + Wrh
p_bias_h = SafeRawConstPointer<T>(batched_bias_WRh_local + r * hidden_size_,
batched_bias_WRh_local_end, hidden_size_);
}
}
// setup p_ht with input to calculate ht
// p_ht = Xt*(Wh^T) + (rt (.) Ht-1 * Rh^T) # linear_before_reset_ == false
// = Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) # linear_before_reset_ == true
T* p_ht = SafeRawPointer<T>(outputZRH_, out_added_offset + r * hidden_size_x3 + hidden_size_x2, hidden_size_);
// add Wbh [and Wrh] and clip
clip_with_bias_ptr_(clip_, p_bias_h, p_ht, hidden_size_); // post: p_ht = input to g() for calculating ht
DumpMatrix("ht input [" + std::to_string(r) + "]" + row_str, p_ht, 1, hidden_size_);
const T* p_prev_Ht = SafeRawConstPointer<T>(prev_Ht + r * hidden_size_, prev_Ht_end, hidden_size_);
T* p_Ht = SafeRawPointer<T>(output + r * hidden_size_, output_end, hidden_size_);
// calculate ht = g(p_ht) and write in-place to p_ht
// calculate Ht = (1 - zt) (.) ht + zt (.) Ht-1 and write to p_Ht
output_gate_(p_ht, p_zt, p_prev_Ht, p_Ht, hidden_size_, h_alpha_, h_beta_);
}
DumpMatrix("output" + row_str, &*output, 1, hidden_size_);
prev_Ht = output;
prev_Ht_end = output_end;
}
};
ExecuteLambdaInParallel("Processing batch", hidden_gemm_and_activations, batch_size_, fused_hidden_rows, ttp_, logger_);
} else {
size_t out_added_offset;
span_T_const_iter prev_Ht = batched_hidden0_.cbegin(); // Ht-1
span_T_const_iter prev_Ht_end = batched_hidden0_.cend();
span_T_iter cur_h_local = cur_h_.begin();
span_T_iter cur_h_local_end = cur_h_.end();
span_T_const_iter batched_bias_WRz_local;
span_T_const_iter batched_bias_WRr_local;
span_T_const_iter batched_bias_WRh_local;
span_T_const_iter batched_bias_Wh_local;
span_T_const_iter batched_bias_Rh_local;
if (use_bias_) {
batched_bias_WRz_local = batched_bias_WRz_.cbegin();
batched_bias_WRr_local = batched_bias_WRr_.cbegin();
if (linear_before_reset_) {
batched_bias_Wh_local = batched_bias_Wh_.cbegin();
batched_bias_Rh_local = batched_bias_Rh_.cbegin();
} else {
batched_bias_WRh_local = batched_bias_WRh_.cbegin();
}
if (linear_before_reset_) {
batched_bias_Wh_local = batched_bias_Wh_.cbegin();
batched_bias_Rh_local = batched_bias_Rh_.cbegin();
} else {
batched_bias_WRh_local = batched_bias_WRh_.cbegin();
}
}
// for each item in sequence run all calculations
for (int step = 0; step < max_sequence_length; step++) {
const std::string seqno_str = " [seqno=" + std::to_string(step) + "]";
// for each item in sequence run all calculations
for (int step = 0; step < max_sequence_length; step++) {
const std::string seqno_str = " [seqno=" + std::to_string(step) + "]";
DumpMatrix("Ht-1" + seqno_str, &*prev_Ht, batch_size_, hidden_size_);
DumpMatrix("Ht-1" + seqno_str, &*prev_Ht, batch_size_, hidden_size_);
out_added_offset = (step * batch_size_) * hidden_size_x3;
out_added_offset = (step * batch_size_) * hidden_size_x3;
// calculate Ht-1*R[zr], and add to the weighted inputs that are in outputZRH_
// Ht-1 * R[zr] + Xt*(W[zr]^T)
ComputeGemm(batch_size_, hidden_size_x2, hidden_size_, alpha,
prev_Ht, prev_Ht_end,
// calculate Ht-1*R[zr], and add to the weighted inputs that are in outputZRH_
// Ht-1 * R[zr] + Xt*(W[zr]^T)
ComputeGemm(batch_size_, hidden_size_x2, hidden_size_, alpha,
prev_Ht, prev_Ht_end,
hidden_size_,
recurrent_weightsZR.cbegin(), recurrent_weightsZR.cend(),
hidden_size_, beta,
outputZRH_.begin() + out_added_offset, outputZRH_.end(),
hidden_size_x3);
DumpMatrix("Ht-1 * R[zr] + Xt*(W[zr]^T)" + seqno_str,
outputZRH_.data() + out_added_offset, batch_size_, hidden_size_x2, 0, hidden_size_x3);
if (linear_before_reset_) {
// copy Rbh to linear output
gsl::copy(batched_bias_Rh_.subspan(batched_bias_Rh_local - batched_bias_Rh_.begin(), batched_bias_Rh_local_end - batched_bias_Rh_local), linear_output_);
// compute Ht-1 * (Rh^T) + Rbh
ComputeGemm(batch_size_, hidden_size_, hidden_size_, alpha,
prev_Ht, prev_Ht_end, // Ht-1
hidden_size_,
recurrent_weightsZR.cbegin(), recurrent_weightsZR.cend(),
recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T
hidden_size_, beta,
outputZRH_.begin() + out_added_offset, outputZRH_.end(),
hidden_size_x3);
linear_output_.begin(), linear_output_.end(), // pre: Rbh, post:output
hidden_size_);
DumpMatrix("Ht-1 * R[zr] + Xt*(W[zr]^T)" + seqno_str,
outputZRH_.data() + out_added_offset, batch_size_, hidden_size_x2, 0, hidden_size_x3);
DumpMatrix("Ht-1 * (Rh^T) + Rbh " + seqno_str, linear_output_.data(), batch_size_, hidden_size_);
}
// 1st Set Of Activations
for (int r = 0; r < batch_size_; r++) {
const T* p_bias_r = use_bias_ ? SafeRawConstPointer<T>(batched_bias_WRr_local + r * hidden_size_,
batched_bias_WRr_local_end, hidden_size_)
: nullptr;
// initialize p_rt with input to calculate rt. outputZRH_ has Xt*(Wr^T) + Ht-1*(Rr^T).
T* p_rt = SafeRawPointer(outputZRH_, out_added_offset + r * hidden_size_x3 + hidden_size_, hidden_size_);
// add the bias and clip. post: p_rt == Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr
clip_with_bias_ptr_(clip_, p_bias_r, p_rt, hidden_size_);
if (linear_before_reset_) {
// copy Rbh to linear output
gsl::copy(batched_bias_Rh_.subspan(batched_bias_Rh_local - batched_bias_Rh_.begin(), batched_bias_Rh_local_end - batched_bias_Rh_local), linear_output_);
// p_linear_output = Ht-1 * (Rh^T) + Rbh
T* p_linear_output = SafeRawPointer<T>(linear_output_, r * hidden_size_, hidden_size_);
T* p_cur_h = SafeRawPointer<T>(cur_h_local + r * hidden_size_, cur_h_local_end, hidden_size_);
// compute Ht-1 * (Rh^T) + Rbh
ComputeGemm(batch_size_, hidden_size_, hidden_size_, alpha,
prev_Ht, prev_Ht_end, // Ht-1
hidden_size_,
recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T
hidden_size_, beta,
linear_output_.begin(), linear_output_.end(), // pre: Rbh, post:output
hidden_size_);
// calculate rt in-place [p_rt = f(p_rt)]
// calculate rt (.) (Ht-1 * (Rh^T) + Rbh) using p_linear_output. write to p_cur_h
reset_gate_(p_linear_output, p_rt, p_cur_h, hidden_size_, zr_alpha_, zr_beta_);
DumpMatrix("Ht-1 * (Rh^T) + Rbh " + seqno_str, linear_output_.data(), batch_size_, hidden_size_);
} else {
const T* p_prev_Ht = SafeRawConstPointer<T>(prev_Ht + r * hidden_size_, prev_Ht_end, hidden_size_);
T* p_cur_h = SafeRawPointer<T>(cur_h_local + r * hidden_size_, cur_h_local_end, hidden_size_);
// calculate rt in-place [p_rt = f(p_rt)]
// calculate rt (.) Ht-1 using p_prev_Ht, and write to p_cur_h
reset_gate_(p_prev_Ht, p_rt, p_cur_h, hidden_size_, zr_alpha_, zr_beta_);
}
}
std::string label = linear_before_reset_ ? "rt (.) (Ht-1 * (Rh^T) + Rbh)" : "rt (.) Ht-1";
DumpMatrix(label + seqno_str, &*cur_h_local, batch_size_, hidden_size_);
if (linear_before_reset_) {
// input contains rt (.) (Ht-1*(Rh^T) + Rbh)
auto input = cur_h_local;
// out_H currently contains Xt*(W[zrh]^T).
auto out_H = outputZRH_.begin() + out_added_offset;
for (int r = 0; r < batch_size_; r++) {
// skip over the inputs with Z and R weights
out_H += hidden_size_x2;
for (int h = 0; h < hidden_size_; ++h) {
*out_H += *input;
++out_H;
++input;
}
}
} else {
label += " * Rh^T";
// out_H currently contains Xt*(Wh^T).
auto out_H = outputZRH_.begin() + out_added_offset + hidden_size_x2;
// Calculate Xt*(Wh^T) + rt (.) Ht-1 * Rh
ComputeGemm(batch_size_, hidden_size_, hidden_size_, alpha,
cur_h_local, cur_h_local_end, // rt (.) Ht-1
hidden_size_,
recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T
hidden_size_, beta,
out_H, outputZRH_.end(),
hidden_size_x3);
}
DumpMatrix("Xt*(Wh^T) + (" + label + ")" + seqno_str, outputZRH_.data() + out_added_offset,
batch_size_, hidden_size_, hidden_size_x2, hidden_size_x3);
//2nd Set of Activations
span_T_iter output;
span_T_iter output_end;
if (output_sequence) {
output = outputs.begin() + step * output_step_length;
output_end = outputs.end();
} else {
output = final_hidden_state.begin();
output_end = final_hidden_state.end();
}
for (int r = 0; r < batch_size_; r++) {
if (step >= min_sequence_length && step >= sequence_lengths[r]) {
if (output_sequence) {
auto fill_output = output + r * hidden_size_;
std::fill_n(fill_output, hidden_size_, T{});
}
continue;
}
// 1st Set Of Activations
for (int r = 0; r < batch_size_; r++) {
const T* p_bias_r = use_bias_ ? SafeRawConstPointer<T>(batched_bias_WRr_local + r * hidden_size_,
batched_bias_WRr_local_end, hidden_size_)
: nullptr;
const T* p_bias_z = use_bias_ ? SafeRawConstPointer<T>(batched_bias_WRz_local,
batched_bias_WRz_local_end, hidden_size_)
: nullptr;
// initialize p_rt with input to calculate rt. outputZRH_ has Xt*(Wr^T) + Ht-1*(Rr^T).
T* p_rt = SafeRawPointer(outputZRH_, out_added_offset + r * hidden_size_x3 + hidden_size_, hidden_size_);
// initialize p_zt with Xt*(Wz^T) + Ht-1*(Rz^T), which is most of the input to calculate zt:
T* p_zt = SafeRawPointer<T>(outputZRH_, out_added_offset + r * hidden_size_x3, hidden_size_);
// add the bias and clip. post: p_rt == Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr
clip_with_bias_ptr_(clip_, p_bias_r, p_rt, hidden_size_);
// using p_zt, add bias and clip in-place
clip_with_bias_ptr_(clip_, p_bias_z, p_zt, hidden_size_);
// calculate zt in-place. p_zt = f(p_zt)
update_gate_(p_zt, hidden_size_, zr_alpha_, zr_beta_);
DumpMatrix("zt[" + std::to_string(r) + "]" + seqno_str, p_zt, 1, hidden_size_);
const T* p_bias_h = nullptr;
if (use_bias_) {
if (linear_before_reset_) {
// p_linear_output = Ht-1 * (Rh^T) + Rbh
T* p_linear_output = SafeRawPointer<T>(linear_output_, r * hidden_size_, hidden_size_);
T* p_cur_h = SafeRawPointer<T>(cur_h_local + r * hidden_size_, cur_h_local_end, hidden_size_);
// calculate rt in-place [p_rt = f(p_rt)]
// calculate rt (.) (Ht-1 * (Rh^T) + Rbh) using p_linear_output. write to p_cur_h
reset_gate_(p_linear_output, p_rt, p_cur_h, hidden_size_, zr_alpha_, zr_beta_);
// Wbh
p_bias_h = SafeRawConstPointer<T>(batched_bias_Wh_local + r * hidden_size_,
batched_bias_Wh_local_end, hidden_size_);
} else {
const T* p_prev_Ht = SafeRawConstPointer<T>(prev_Ht + r * hidden_size_, prev_Ht_end, hidden_size_);
T* p_cur_h = SafeRawPointer<T>(cur_h_local + r * hidden_size_, cur_h_local_end, hidden_size_);
// calculate rt in-place [p_rt = f(p_rt)]
// calculate rt (.) Ht-1 using p_prev_Ht, and write to p_cur_h
reset_gate_(p_prev_Ht, p_rt, p_cur_h, hidden_size_, zr_alpha_, zr_beta_);
// Wbh + Wrh
p_bias_h = SafeRawConstPointer<T>(batched_bias_WRh_local + r * hidden_size_,
batched_bias_WRh_local_end, hidden_size_);
}
}
std::string label = linear_before_reset_ ? "rt (.) (Ht-1 * (Rh^T) + Rbh)" : "rt (.) Ht-1";
DumpMatrix(label + seqno_str, &*cur_h_local, batch_size_, hidden_size_);
// setup p_ht with input to calculate ht
// p_ht = Xt*(Wh^T) + (rt (.) Ht-1 * Rh^T) # linear_before_reset_ == false
// = Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) # linear_before_reset_ == true
T* p_ht = SafeRawPointer<T>(outputZRH_, out_added_offset + r * hidden_size_x3 + hidden_size_x2, hidden_size_);
if (linear_before_reset_) {
// input contains rt (.) (Ht-1*(Rh^T) + Rbh)
auto input = cur_h_local;
// out_H currently contains Xt*(W[zrh]^T).
auto out_H = outputZRH_.begin() + out_added_offset;
// add Wbh [and Wrh] and clip
clip_with_bias_ptr_(clip_, p_bias_h, p_ht, hidden_size_); // post: p_ht == input to g() for calculating ht
for (int r = 0; r < batch_size_; r++) {
// skip over the inputs with Z and R weights
out_H += hidden_size_x2;
for (int h = 0; h < hidden_size_; ++h) {
*out_H += *input;
++out_H;
++input;
}
}
} else {
label += " * Rh^T";
DumpMatrix("ht input [" + std::to_string(r) + "]" + seqno_str, p_ht, 1, hidden_size_);
// out_H currently contains Xt*(Wh^T).
auto out_H = outputZRH_.begin() + out_added_offset + hidden_size_x2;
const T* p_prev_Ht = SafeRawConstPointer<T>(prev_Ht + r * hidden_size_, prev_Ht_end, hidden_size_);
T* p_Ht = SafeRawPointer<T>(output + r * hidden_size_, output_end, hidden_size_);
// Calculate Xt*(Wh^T) + rt (.) Ht-1 * Rh
ComputeGemm(batch_size_, hidden_size_, hidden_size_, alpha,
cur_h_local, cur_h_local_end, // rt (.) Ht-1
hidden_size_,
recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T
hidden_size_, beta,
out_H, outputZRH_.end(),
hidden_size_x3);
}
DumpMatrix("Xt*(Wh^T) + (" + label + ")" + seqno_str, outputZRH_.data() + out_added_offset,
batch_size_, hidden_size_, hidden_size_x2, hidden_size_x3);
//2nd Set of Activations
span_T_iter output;
span_T_iter output_end;
if (output_sequence) {
output = outputs.begin() + step * output_step_length;
output_end = outputs.end();
} else {
output = final_hidden_state.begin();
output_end = final_hidden_state.end();
}
for (int r = 0; r < batch_size_; r++) {
if (step >= min_sequence_length && step >= sequence_lengths[r]) {
if (output_sequence) {
auto fill_output = output + r * hidden_size_;
std::fill_n(fill_output, hidden_size_, T{});
}
continue;
}
const T* p_bias_z = use_bias_ ? SafeRawConstPointer<T>(batched_bias_WRz_local,
batched_bias_WRz_local_end, hidden_size_)
: nullptr;
// initialize p_zt with Xt*(Wz^T) + Ht-1*(Rz^T), which is most of the input to calculate zt:
T* p_zt = SafeRawPointer<T>(outputZRH_, out_added_offset + r * hidden_size_x3, hidden_size_);
// using p_zt, add bias and clip in-place
clip_with_bias_ptr_(clip_, p_bias_z, p_zt, hidden_size_);
// calculate zt in-place. p_zt = f(p_zt)
update_gate_(p_zt, hidden_size_, zr_alpha_, zr_beta_);
DumpMatrix("zt[" + std::to_string(r) + "]" + seqno_str, p_zt, 1, hidden_size_);
const T* p_bias_h = nullptr;
if (use_bias_) {
if (linear_before_reset_) {
// Wbh
p_bias_h = SafeRawConstPointer<T>(batched_bias_Wh_local + r * hidden_size_,
batched_bias_Wh_local_end, hidden_size_);
} else {
// Wbh + Wrh
p_bias_h = SafeRawConstPointer<T>(batched_bias_WRh_local + r * hidden_size_,
batched_bias_WRh_local_end, hidden_size_);
}
}
// setup p_ht with input to calculate ht
// p_ht = Xt*(Wh^T) + (rt (.) Ht-1 * Rh^T) # linear_before_reset_ == false
// = Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) # linear_before_reset_ == true
T* p_ht = SafeRawPointer<T>(outputZRH_, out_added_offset + r * hidden_size_x3 + hidden_size_x2, hidden_size_);
// add Wbh [and Wrh] and clip
clip_with_bias_ptr_(clip_, p_bias_h, p_ht, hidden_size_); // post: p_ht == input to g() for calculating ht
DumpMatrix("ht input [" + std::to_string(r) + "]" + seqno_str, p_ht, 1, hidden_size_);
const T* p_prev_Ht = SafeRawConstPointer<T>(prev_Ht + r * hidden_size_, prev_Ht_end, hidden_size_);
T* p_Ht = SafeRawPointer<T>(output + r * hidden_size_, output_end, hidden_size_);
// calculate ht = g(p_ht) and write in-place to p_ht
// calculate Ht = (1 - zt) (.) ht + zt (.) Ht-1 and write to p_Ht
output_gate_(p_ht, p_zt, p_prev_Ht, p_Ht, hidden_size_, h_alpha_, h_beta_); // calculate ht and Ht
}
DumpMatrix("output" + seqno_str, &*output, batch_size_, hidden_size_);
prev_Ht = output;
prev_Ht_end = output_end;
// calculate ht = g(p_ht) and write in-place to p_ht
// calculate Ht = (1 - zt) (.) ht + zt (.) Ht-1 and write to p_Ht
output_gate_(p_ht, p_zt, p_prev_Ht, p_Ht, hidden_size_, h_alpha_, h_beta_); // calculate ht and Ht
}
DumpMatrix("output" + seqno_str, &*output, batch_size_, hidden_size_);
prev_Ht = output;
prev_Ht_end = output_end;
}
// copy last output to final_hidden_state
@ -1072,29 +840,5 @@ void UniDirectionalGru<T>::AllocateBuffers() {
}
}
template <typename T>
void UniDirectionalGru<T>::SetNumThreads() {
int threads = std::thread::hardware_concurrency() - 1;
if (threads < 1)
threads = 1;
hidden_num_threads_ = threads;
batch_parallel_ = false;
// for readability of the below logic
const auto num_rows = batch_size_;
const auto num_columns = hidden_size_;
// parallelize by partitioning the batch rows
if (num_rows > 4 ||
(num_rows >= 2 && num_columns <= 256) ||
(num_rows >= 3 && num_columns <= 512)) {
batch_parallel_ = true;
VLOGS(logger_, 1) << "Hidden Threads : " << hidden_num_threads_;
}
ORT_ENFORCE(hidden_num_threads_ >= 1);
}
} // namespace detail
} // namespace onnxruntime

View file

@ -60,18 +60,12 @@ class DeepCpuGruOp final : public OpKernel {
rnn::detail::Direction direction_;
int num_directions_;
int hidden_size_ = 0;
int hidden_size_ {};
float clip_;
int linear_before_reset_ = 0;
int linear_before_reset_ {};
rnn::detail::ActivationFuncs activation_funcs_;
// Threadpool for operator. If concurrent Compute calls are possible, it will be shared
// across them. mutable due to this.
// The alternative would be to create a threadpool in each call to Compute but that would incur thread creation
// cost on every call.
mutable onnxruntime::concurrency::ThreadPool ttp_{"DEEPCPU_GRU", (int)std::thread::hardware_concurrency()};
template <typename T>
Status ComputeImpl(OpKernelContext& context) const;
};