mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Tuning GRU performance for batch size >= 2 (#644)
GRU with batch size >1 is implemented on the assumption that Lotus use single-thread Eigen Gemm. The assumption doesn't hold anymore. MLAS and MKLML support multi-thread. We don't rely eigen gemm anymore. This PR implements batch size > 1 as batch size ==1. With this change, we have about 2x performance gain for GRU.Please refer to the performance test below: (ms) Batch_size | Seq_length | input_size | hiddden_size | Old | New 8 | 30 | 512 | 512 | 19.16 | 10.47 16 | 30 | 512 | 512 | 28.13 | 15.15 32 | 30 | 512 | 512 | 36.97 | 26.89 8 | 30 | 1024 | 1024 | 142.853 | 55.67 16 | 30 | 1024 | 1024 | 184.397 | 72.32 32 | 30 | 1024 | 1024 236.364 | 112.78
This commit is contained in:
parent
80d69515ed
commit
d0f846aad5
2 changed files with 186 additions and 448 deletions
|
|
@ -161,8 +161,6 @@ using namespace rnn::detail;
|
|||
// internal helper code
|
||||
namespace detail {
|
||||
|
||||
/// The class represents DeepCPU implementation of a gated recurrent unit (GRU) operator.
|
||||
/// For details, refer to http://aka.ms/dl-optimization/.
|
||||
template <typename T>
|
||||
class UniDirectionalGru {
|
||||
public:
|
||||
|
|
@ -178,8 +176,7 @@ class UniDirectionalGru {
|
|||
const gsl::span<const T>& initial_hidden_state,
|
||||
const ActivationFuncs::Entry& activation_func_f,
|
||||
const ActivationFuncs::Entry& activation_func_g,
|
||||
const float clip,
|
||||
onnxruntime::concurrency::ThreadPool& ttp_);
|
||||
const float clip);
|
||||
|
||||
void Compute(const gsl::span<const T>& inputs,
|
||||
const gsl::span<const int>& sequence_lengths,
|
||||
|
|
@ -195,8 +192,6 @@ class UniDirectionalGru {
|
|||
AllocatorPtr allocator_;
|
||||
const logging::Logger& logger_;
|
||||
|
||||
onnxruntime::concurrency::ThreadPool& ttp_;
|
||||
|
||||
int seq_length_;
|
||||
int batch_size_;
|
||||
int input_size_;
|
||||
|
|
@ -207,9 +202,6 @@ class UniDirectionalGru {
|
|||
|
||||
Direction direction_;
|
||||
bool use_bias_;
|
||||
bool batch_parallel_;
|
||||
|
||||
int hidden_num_threads_ = -1;
|
||||
|
||||
IAllocatorUniquePtr<T> outputZRH_ptr_;
|
||||
gsl::span<T> outputZRH_;
|
||||
|
|
@ -243,17 +235,18 @@ class UniDirectionalGru {
|
|||
gsl::span<T> inputs_reverse_;
|
||||
gsl::span<T> outputs_reverse_;
|
||||
|
||||
deepcpu::ClipWithBiasFuncPtr clip_with_bias_ptr_ = nullptr;
|
||||
deepcpu::ClipWithBiasFuncPtr clip_with_bias_ptr_{};
|
||||
|
||||
float zr_alpha_ = 0.f, zr_beta_ = 0.f;
|
||||
float h_alpha_ = 0.f, h_beta_ = 0.f;
|
||||
float zr_alpha_{};
|
||||
float zr_beta_{};
|
||||
float h_alpha_{};
|
||||
float h_beta_{};
|
||||
|
||||
deepcpu::GruResetGateFuncPtr reset_gate_ = nullptr;
|
||||
deepcpu::ActivationFuncPtr update_gate_ = nullptr;
|
||||
deepcpu::GruOutputGateFuncPtr output_gate_ = nullptr;
|
||||
deepcpu::GruResetGateFuncPtr reset_gate_{};
|
||||
deepcpu::ActivationFuncPtr update_gate_{};
|
||||
deepcpu::GruOutputGateFuncPtr output_gate_{};
|
||||
|
||||
void AllocateBuffers();
|
||||
void SetNumThreads();
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
|
|
@ -268,7 +261,6 @@ Status DeepCpuGruOp::Compute(OpKernelContext* context) const {
|
|||
const Tensor& X = *context->Input<Tensor>(0); // inputs. [seq_length, batch_size, input_size]
|
||||
|
||||
Status status;
|
||||
// auto& logger = context->Logger();
|
||||
|
||||
auto data_type = X.DataType();
|
||||
if (data_type == DataTypeImpl::GetType<float>())
|
||||
|
|
@ -393,7 +385,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const {
|
|||
bias_1, initial_hidden_1,
|
||||
activation_funcs_.Entries()[0],
|
||||
activation_funcs_.Entries()[1],
|
||||
clip_, ttp_);
|
||||
clip_);
|
||||
fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1);
|
||||
|
||||
std::unique_ptr<detail::UniDirectionalGru<T>> bw = std::make_unique<detail::UniDirectionalGru<T>>(
|
||||
|
|
@ -402,7 +394,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const {
|
|||
bias_2, initial_hidden_2,
|
||||
activation_funcs_.Entries()[2],
|
||||
activation_funcs_.Entries()[3],
|
||||
clip_, ttp_);
|
||||
clip_);
|
||||
bw->Compute(input, sequence_lens_span, num_directions_, input_weights_2, recurrent_weights_2, output_2, hidden_output_2);
|
||||
} else {
|
||||
std::unique_ptr<detail::UniDirectionalGru<T>> gru_p = std::make_unique<detail::UniDirectionalGru<T>>(
|
||||
|
|
@ -411,7 +403,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const {
|
|||
bias_1, initial_hidden_1,
|
||||
activation_funcs_.Entries()[0],
|
||||
activation_funcs_.Entries()[1],
|
||||
clip_, ttp_);
|
||||
clip_);
|
||||
|
||||
gru_p->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1);
|
||||
}
|
||||
|
|
@ -441,11 +433,9 @@ UniDirectionalGru<T>::UniDirectionalGru(AllocatorPtr allocator,
|
|||
const gsl::span<const T>& initial_hidden_state,
|
||||
const ActivationFuncs::Entry& activation_func_f,
|
||||
const ActivationFuncs::Entry& activation_func_g,
|
||||
const float clip,
|
||||
onnxruntime::concurrency::ThreadPool& ttp)
|
||||
const float clip)
|
||||
: allocator_(allocator),
|
||||
logger_(logger),
|
||||
ttp_(ttp),
|
||||
seq_length_(seq_length),
|
||||
batch_size_(batch_size),
|
||||
input_size_(input_size),
|
||||
|
|
@ -454,7 +444,6 @@ UniDirectionalGru<T>::UniDirectionalGru(AllocatorPtr allocator,
|
|||
clip_(clip),
|
||||
direction_(direction),
|
||||
use_bias_(!bias.empty()) {
|
||||
//
|
||||
clip_with_bias_ptr_ = use_bias_ ? deepcpu::clip_add_bias : deepcpu::clip_ignore_bias;
|
||||
|
||||
// setup activation function pointers and alpha/beta values to use with them
|
||||
|
|
@ -467,7 +456,6 @@ UniDirectionalGru<T>::UniDirectionalGru(AllocatorPtr allocator,
|
|||
h_alpha_ = activation_func_g.alpha;
|
||||
h_beta_ = activation_func_g.beta;
|
||||
|
||||
SetNumThreads();
|
||||
AllocateBuffers();
|
||||
|
||||
if (use_bias_) {
|
||||
|
|
@ -598,429 +586,209 @@ void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
|
|||
span_T_const_iter batched_bias_Rh_local_end = batched_bias_Rh_.cend();
|
||||
span_T_const_iter batched_bias_WRh_local_end = batched_bias_WRh_.cend();
|
||||
|
||||
if (batch_parallel_) {
|
||||
int fused_hidden_rows = batch_size_ / hidden_num_threads_;
|
||||
if (batch_size_ % hidden_num_threads_ != 0)
|
||||
fused_hidden_rows++;
|
||||
size_t out_added_offset;
|
||||
|
||||
// lambda executed by ThreadPool
|
||||
auto hidden_gemm_and_activations = [&](const int row) {
|
||||
//handling boundaries
|
||||
int local_fused_hidden_rows = fused_hidden_rows;
|
||||
if ((row + fused_hidden_rows) > batch_size_)
|
||||
local_fused_hidden_rows = batch_size_ - row;
|
||||
span_T_const_iter prev_Ht = batched_hidden0_.cbegin(); // Ht-1
|
||||
span_T_const_iter prev_Ht_end = batched_hidden0_.cend();
|
||||
span_T_iter cur_h_local = cur_h_.begin();
|
||||
span_T_iter cur_h_local_end = cur_h_.end();
|
||||
|
||||
size_t out_added_offset;
|
||||
span_T_const_iter prev_Ht = batched_hidden0_.cbegin() + row * hidden_size_; // Ht-1
|
||||
span_T_const_iter prev_Ht_end = batched_hidden0_.cend();
|
||||
span_T_iter cur_h_local = cur_h_.begin() + row * hidden_size_;
|
||||
span_T_iter cur_h_local_end = cur_h_.end();
|
||||
span_T_iter linear_output_local;
|
||||
span_T_iter linear_output_local_end;
|
||||
span_T_const_iter batched_bias_WRz_local;
|
||||
span_T_const_iter batched_bias_WRr_local;
|
||||
span_T_const_iter batched_bias_WRh_local;
|
||||
span_T_const_iter batched_bias_Wh_local;
|
||||
span_T_const_iter batched_bias_Rh_local;
|
||||
|
||||
span_T_const_iter batched_bias_WRz_local;
|
||||
span_T_const_iter batched_bias_WRr_local;
|
||||
span_T_const_iter batched_bias_WRh_local;
|
||||
span_T_const_iter batched_bias_Wh_local;
|
||||
span_T_const_iter batched_bias_Rh_local;
|
||||
if (use_bias_) {
|
||||
batched_bias_WRz_local = batched_bias_WRz_.cbegin();
|
||||
batched_bias_WRr_local = batched_bias_WRr_.cbegin();
|
||||
|
||||
if (use_bias_) {
|
||||
batched_bias_WRz_local = batched_bias_WRz_.cbegin() + row * hidden_size_;
|
||||
batched_bias_WRr_local = batched_bias_WRr_.cbegin() + row * hidden_size_;
|
||||
|
||||
if (linear_before_reset_) {
|
||||
batched_bias_Wh_local = batched_bias_Wh_.cbegin() + row * hidden_size_;
|
||||
batched_bias_Rh_local = batched_bias_Rh_.cbegin() + row * hidden_size_;
|
||||
linear_output_local = linear_output_.begin() + row * hidden_size_;
|
||||
linear_output_local_end = linear_output_.end();
|
||||
} else {
|
||||
batched_bias_WRh_local = batched_bias_WRh_.cbegin() + row * hidden_size_;
|
||||
}
|
||||
}
|
||||
|
||||
for (int step = 0; step < max_sequence_length; step++) {
|
||||
const std::string row_str = " [row=" + std::to_string(row) + ",seqno=" + std::to_string(step) + "]";
|
||||
|
||||
DumpMatrix("Ht-1" + row_str, &*prev_Ht, local_fused_hidden_rows, hidden_size_);
|
||||
|
||||
out_added_offset = (step * batch_size_ + row) * hidden_size_x3;
|
||||
|
||||
// calculate Ht-1*R[zr], and add to the weighted inputs that are in outputZRH_
|
||||
ComputeGemm(local_fused_hidden_rows, hidden_size_x2, hidden_size_, alpha,
|
||||
prev_Ht, prev_Ht_end,
|
||||
hidden_size_,
|
||||
recurrent_weightsZR.cbegin(), recurrent_weightsZR.cend(),
|
||||
hidden_size_, beta,
|
||||
outputZRH_.begin() + out_added_offset, outputZRH_.end(),
|
||||
hidden_size_x3);
|
||||
|
||||
DumpMatrix("Xt*(W[zr]^T) + Ht-1 * R[zr]" + row_str,
|
||||
outputZRH_.data() + out_added_offset, local_fused_hidden_rows, hidden_size_x2, 0, hidden_size_x3);
|
||||
|
||||
if (linear_before_reset_) {
|
||||
// copy Rbh to linear output
|
||||
gsl::copy(batched_bias_Rh_.subspan(batched_bias_Rh_local - batched_bias_Rh_.begin(), local_fused_hidden_rows * hidden_size_),
|
||||
linear_output_.subspan(linear_output_local - linear_output_.begin(), linear_output_local_end - linear_output_local));
|
||||
|
||||
// compute Ht-1 * (Rh^T) + Rbh
|
||||
ComputeGemm(local_fused_hidden_rows, hidden_size_, hidden_size_, alpha,
|
||||
prev_Ht, prev_Ht_end, // Ht-1
|
||||
hidden_size_,
|
||||
recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T
|
||||
hidden_size_, beta,
|
||||
linear_output_local, linear_output_.end(), // pre: Rbh, post:output
|
||||
hidden_size_);
|
||||
|
||||
DumpMatrix("Ht-1 * (Rh^T) + Rbh " + row_str, &*linear_output_local, batch_size_, hidden_size_);
|
||||
}
|
||||
|
||||
// 1st Set Of Activations
|
||||
for (int r = 0; r < local_fused_hidden_rows; r++) {
|
||||
const T* p_bias_r = use_bias_ ? SafeRawConstPointer<T>(batched_bias_WRr_local + r * hidden_size_,
|
||||
batched_bias_WRr_local_end, hidden_size_)
|
||||
: nullptr;
|
||||
|
||||
// initialize p_rt with input to calculate rt. outputZRH_ has Xt*(Wr^T) + Ht-1*(Rr^T).
|
||||
T* p_rt = SafeRawPointer(outputZRH_, out_added_offset + r * hidden_size_x3 + hidden_size_, hidden_size_);
|
||||
|
||||
// add the bias and clip. post: p_rt == Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr
|
||||
clip_with_bias_ptr_(clip_, p_bias_r, p_rt, hidden_size_);
|
||||
|
||||
if (linear_before_reset_) {
|
||||
// p_linear_output = Ht-1 * (Rh^T) + Rbh
|
||||
T* p_linear_output = SafeRawPointer<T>(linear_output_local + r * hidden_size_,
|
||||
linear_output_local_end, hidden_size_);
|
||||
T* p_cur_h = SafeRawPointer<T>(cur_h_local + r * hidden_size_, cur_h_local_end, hidden_size_);
|
||||
|
||||
// calculate rt in-place [p_rt = f(p_rt)]
|
||||
// calculate rt (.) (Ht-1 * (Rh^T) + Rbh) using p_linear_output. write to p_cur_h
|
||||
reset_gate_(p_linear_output, p_rt, p_cur_h, hidden_size_, zr_alpha_, zr_beta_);
|
||||
|
||||
} else {
|
||||
const T* p_prev_Ht = SafeRawConstPointer<T>(prev_Ht + r * hidden_size_, prev_Ht_end, hidden_size_);
|
||||
T* p_cur_h = SafeRawPointer<T>(cur_h_local + r * hidden_size_, cur_h_local_end, hidden_size_);
|
||||
|
||||
// calculate rt in-place [p_rt = f(p_rt)]
|
||||
// calculate rt (.) Ht-1 using p_prev_Ht, and write to p_cur_h
|
||||
reset_gate_(p_prev_Ht, p_rt, p_cur_h, hidden_size_, zr_alpha_, zr_beta_);
|
||||
}
|
||||
}
|
||||
|
||||
std::string label = linear_before_reset_ ? "rt (.) (Ht-1 * (Rh^T) + Rbh)" : "rt (.) Ht-1";
|
||||
DumpMatrix(label + row_str, &*cur_h_local, local_fused_hidden_rows, hidden_size_);
|
||||
|
||||
if (linear_before_reset_) {
|
||||
// input contains rt (.) (Ht-1*(Rh^T) + Rbh)
|
||||
auto input = cur_h_local;
|
||||
// out_H currently contains Xt*(W[zrh]^T).
|
||||
auto out_H = outputZRH_.begin() + out_added_offset;
|
||||
|
||||
for (int r = 0; r < local_fused_hidden_rows; r++) {
|
||||
// skip over the inputs with Z and R weights
|
||||
out_H += hidden_size_x2;
|
||||
for (int h = 0; h < hidden_size_; ++h) {
|
||||
*out_H += *input;
|
||||
++out_H;
|
||||
++input;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
label += " * Rh^T";
|
||||
ComputeGemm(local_fused_hidden_rows, hidden_size_, hidden_size_, alpha,
|
||||
cur_h_local, cur_h_local_end,
|
||||
hidden_size_,
|
||||
recurrent_weightsH.cbegin(), recurrent_weightsH.cend(),
|
||||
hidden_size_, beta,
|
||||
outputZRH_.begin() + out_added_offset + hidden_size_x2, outputZRH_.end(),
|
||||
hidden_size_x3);
|
||||
}
|
||||
|
||||
DumpMatrix("Xt*(Wh^T) + (" + label + ")" + row_str,
|
||||
outputZRH_.data() + out_added_offset, local_fused_hidden_rows, hidden_size_,
|
||||
hidden_size_x2, hidden_size_x3);
|
||||
|
||||
// 2nd Set of Activations
|
||||
span_T_iter output;
|
||||
span_T_iter output_end;
|
||||
if (output_sequence) {
|
||||
output = outputs.begin() + step * output_step_length + row * hidden_size_;
|
||||
output_end = outputs.end();
|
||||
|
||||
} else {
|
||||
output = final_hidden_state.begin() + row * hidden_size_;
|
||||
output_end = final_hidden_state.end();
|
||||
}
|
||||
|
||||
for (int r = 0; r < local_fused_hidden_rows; r++) {
|
||||
if (step >= min_sequence_length && step >= sequence_lengths[row + r]) {
|
||||
if (output_sequence) {
|
||||
auto fill_output = output + r * hidden_size_;
|
||||
std::fill_n(fill_output, hidden_size_, T{});
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
const T* p_bias_z = use_bias_ ? SafeRawConstPointer<T>(batched_bias_WRz_local, batched_bias_WRz_local_end,
|
||||
hidden_size_)
|
||||
: nullptr;
|
||||
|
||||
// initialize p_zt with Xt*(Wz^T) + Ht-1*(Rz^T), which is most of the input to calculate zt:
|
||||
T* p_zt = SafeRawPointer<T>(outputZRH_, out_added_offset + r * hidden_size_x3, hidden_size_);
|
||||
|
||||
// using p_zt, add bias and clip in-place
|
||||
clip_with_bias_ptr_(clip_, p_bias_z, p_zt, hidden_size_);
|
||||
|
||||
// calculate zt in-place. p_zt = f(p_zt)
|
||||
update_gate_(p_zt, hidden_size_, zr_alpha_, zr_beta_);
|
||||
|
||||
DumpMatrix("zt[" + std::to_string(r) + "]" + row_str, p_zt, 1, hidden_size_);
|
||||
|
||||
const T* p_bias_h = nullptr;
|
||||
if (use_bias_) {
|
||||
if (linear_before_reset_) {
|
||||
// Wbh
|
||||
p_bias_h = SafeRawConstPointer<T>(batched_bias_Wh_local + r * hidden_size_,
|
||||
batched_bias_Wh_local_end, hidden_size_);
|
||||
|
||||
} else {
|
||||
// Wbh + Wrh
|
||||
p_bias_h = SafeRawConstPointer<T>(batched_bias_WRh_local + r * hidden_size_,
|
||||
batched_bias_WRh_local_end, hidden_size_);
|
||||
}
|
||||
}
|
||||
|
||||
// setup p_ht with input to calculate ht
|
||||
// p_ht = Xt*(Wh^T) + (rt (.) Ht-1 * Rh^T) # linear_before_reset_ == false
|
||||
// = Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) # linear_before_reset_ == true
|
||||
T* p_ht = SafeRawPointer<T>(outputZRH_, out_added_offset + r * hidden_size_x3 + hidden_size_x2, hidden_size_);
|
||||
|
||||
// add Wbh [and Wrh] and clip
|
||||
clip_with_bias_ptr_(clip_, p_bias_h, p_ht, hidden_size_); // post: p_ht = input to g() for calculating ht
|
||||
|
||||
DumpMatrix("ht input [" + std::to_string(r) + "]" + row_str, p_ht, 1, hidden_size_);
|
||||
|
||||
const T* p_prev_Ht = SafeRawConstPointer<T>(prev_Ht + r * hidden_size_, prev_Ht_end, hidden_size_);
|
||||
T* p_Ht = SafeRawPointer<T>(output + r * hidden_size_, output_end, hidden_size_);
|
||||
|
||||
// calculate ht = g(p_ht) and write in-place to p_ht
|
||||
// calculate Ht = (1 - zt) (.) ht + zt (.) Ht-1 and write to p_Ht
|
||||
output_gate_(p_ht, p_zt, p_prev_Ht, p_Ht, hidden_size_, h_alpha_, h_beta_);
|
||||
}
|
||||
|
||||
DumpMatrix("output" + row_str, &*output, 1, hidden_size_);
|
||||
|
||||
prev_Ht = output;
|
||||
prev_Ht_end = output_end;
|
||||
}
|
||||
};
|
||||
|
||||
ExecuteLambdaInParallel("Processing batch", hidden_gemm_and_activations, batch_size_, fused_hidden_rows, ttp_, logger_);
|
||||
} else {
|
||||
size_t out_added_offset;
|
||||
|
||||
span_T_const_iter prev_Ht = batched_hidden0_.cbegin(); // Ht-1
|
||||
span_T_const_iter prev_Ht_end = batched_hidden0_.cend();
|
||||
span_T_iter cur_h_local = cur_h_.begin();
|
||||
span_T_iter cur_h_local_end = cur_h_.end();
|
||||
|
||||
span_T_const_iter batched_bias_WRz_local;
|
||||
span_T_const_iter batched_bias_WRr_local;
|
||||
span_T_const_iter batched_bias_WRh_local;
|
||||
span_T_const_iter batched_bias_Wh_local;
|
||||
span_T_const_iter batched_bias_Rh_local;
|
||||
|
||||
if (use_bias_) {
|
||||
batched_bias_WRz_local = batched_bias_WRz_.cbegin();
|
||||
batched_bias_WRr_local = batched_bias_WRr_.cbegin();
|
||||
|
||||
if (linear_before_reset_) {
|
||||
batched_bias_Wh_local = batched_bias_Wh_.cbegin();
|
||||
batched_bias_Rh_local = batched_bias_Rh_.cbegin();
|
||||
} else {
|
||||
batched_bias_WRh_local = batched_bias_WRh_.cbegin();
|
||||
}
|
||||
if (linear_before_reset_) {
|
||||
batched_bias_Wh_local = batched_bias_Wh_.cbegin();
|
||||
batched_bias_Rh_local = batched_bias_Rh_.cbegin();
|
||||
} else {
|
||||
batched_bias_WRh_local = batched_bias_WRh_.cbegin();
|
||||
}
|
||||
}
|
||||
|
||||
// for each item in sequence run all calculations
|
||||
for (int step = 0; step < max_sequence_length; step++) {
|
||||
const std::string seqno_str = " [seqno=" + std::to_string(step) + "]";
|
||||
// for each item in sequence run all calculations
|
||||
for (int step = 0; step < max_sequence_length; step++) {
|
||||
const std::string seqno_str = " [seqno=" + std::to_string(step) + "]";
|
||||
|
||||
DumpMatrix("Ht-1" + seqno_str, &*prev_Ht, batch_size_, hidden_size_);
|
||||
DumpMatrix("Ht-1" + seqno_str, &*prev_Ht, batch_size_, hidden_size_);
|
||||
|
||||
out_added_offset = (step * batch_size_) * hidden_size_x3;
|
||||
out_added_offset = (step * batch_size_) * hidden_size_x3;
|
||||
|
||||
// calculate Ht-1*R[zr], and add to the weighted inputs that are in outputZRH_
|
||||
// Ht-1 * R[zr] + Xt*(W[zr]^T)
|
||||
ComputeGemm(batch_size_, hidden_size_x2, hidden_size_, alpha,
|
||||
prev_Ht, prev_Ht_end,
|
||||
// calculate Ht-1*R[zr], and add to the weighted inputs that are in outputZRH_
|
||||
// Ht-1 * R[zr] + Xt*(W[zr]^T)
|
||||
ComputeGemm(batch_size_, hidden_size_x2, hidden_size_, alpha,
|
||||
prev_Ht, prev_Ht_end,
|
||||
hidden_size_,
|
||||
recurrent_weightsZR.cbegin(), recurrent_weightsZR.cend(),
|
||||
hidden_size_, beta,
|
||||
outputZRH_.begin() + out_added_offset, outputZRH_.end(),
|
||||
hidden_size_x3);
|
||||
|
||||
DumpMatrix("Ht-1 * R[zr] + Xt*(W[zr]^T)" + seqno_str,
|
||||
outputZRH_.data() + out_added_offset, batch_size_, hidden_size_x2, 0, hidden_size_x3);
|
||||
|
||||
if (linear_before_reset_) {
|
||||
// copy Rbh to linear output
|
||||
gsl::copy(batched_bias_Rh_.subspan(batched_bias_Rh_local - batched_bias_Rh_.begin(), batched_bias_Rh_local_end - batched_bias_Rh_local), linear_output_);
|
||||
|
||||
// compute Ht-1 * (Rh^T) + Rbh
|
||||
ComputeGemm(batch_size_, hidden_size_, hidden_size_, alpha,
|
||||
prev_Ht, prev_Ht_end, // Ht-1
|
||||
hidden_size_,
|
||||
recurrent_weightsZR.cbegin(), recurrent_weightsZR.cend(),
|
||||
recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T
|
||||
hidden_size_, beta,
|
||||
outputZRH_.begin() + out_added_offset, outputZRH_.end(),
|
||||
hidden_size_x3);
|
||||
linear_output_.begin(), linear_output_.end(), // pre: Rbh, post:output
|
||||
hidden_size_);
|
||||
|
||||
DumpMatrix("Ht-1 * R[zr] + Xt*(W[zr]^T)" + seqno_str,
|
||||
outputZRH_.data() + out_added_offset, batch_size_, hidden_size_x2, 0, hidden_size_x3);
|
||||
DumpMatrix("Ht-1 * (Rh^T) + Rbh " + seqno_str, linear_output_.data(), batch_size_, hidden_size_);
|
||||
}
|
||||
|
||||
// 1st Set Of Activations
|
||||
for (int r = 0; r < batch_size_; r++) {
|
||||
const T* p_bias_r = use_bias_ ? SafeRawConstPointer<T>(batched_bias_WRr_local + r * hidden_size_,
|
||||
batched_bias_WRr_local_end, hidden_size_)
|
||||
: nullptr;
|
||||
|
||||
// initialize p_rt with input to calculate rt. outputZRH_ has Xt*(Wr^T) + Ht-1*(Rr^T).
|
||||
T* p_rt = SafeRawPointer(outputZRH_, out_added_offset + r * hidden_size_x3 + hidden_size_, hidden_size_);
|
||||
|
||||
// add the bias and clip. post: p_rt == Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr
|
||||
clip_with_bias_ptr_(clip_, p_bias_r, p_rt, hidden_size_);
|
||||
|
||||
if (linear_before_reset_) {
|
||||
// copy Rbh to linear output
|
||||
gsl::copy(batched_bias_Rh_.subspan(batched_bias_Rh_local - batched_bias_Rh_.begin(), batched_bias_Rh_local_end - batched_bias_Rh_local), linear_output_);
|
||||
// p_linear_output = Ht-1 * (Rh^T) + Rbh
|
||||
T* p_linear_output = SafeRawPointer<T>(linear_output_, r * hidden_size_, hidden_size_);
|
||||
T* p_cur_h = SafeRawPointer<T>(cur_h_local + r * hidden_size_, cur_h_local_end, hidden_size_);
|
||||
|
||||
// compute Ht-1 * (Rh^T) + Rbh
|
||||
ComputeGemm(batch_size_, hidden_size_, hidden_size_, alpha,
|
||||
prev_Ht, prev_Ht_end, // Ht-1
|
||||
hidden_size_,
|
||||
recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T
|
||||
hidden_size_, beta,
|
||||
linear_output_.begin(), linear_output_.end(), // pre: Rbh, post:output
|
||||
hidden_size_);
|
||||
// calculate rt in-place [p_rt = f(p_rt)]
|
||||
// calculate rt (.) (Ht-1 * (Rh^T) + Rbh) using p_linear_output. write to p_cur_h
|
||||
reset_gate_(p_linear_output, p_rt, p_cur_h, hidden_size_, zr_alpha_, zr_beta_);
|
||||
|
||||
DumpMatrix("Ht-1 * (Rh^T) + Rbh " + seqno_str, linear_output_.data(), batch_size_, hidden_size_);
|
||||
} else {
|
||||
const T* p_prev_Ht = SafeRawConstPointer<T>(prev_Ht + r * hidden_size_, prev_Ht_end, hidden_size_);
|
||||
T* p_cur_h = SafeRawPointer<T>(cur_h_local + r * hidden_size_, cur_h_local_end, hidden_size_);
|
||||
|
||||
// calculate rt in-place [p_rt = f(p_rt)]
|
||||
// calculate rt (.) Ht-1 using p_prev_Ht, and write to p_cur_h
|
||||
reset_gate_(p_prev_Ht, p_rt, p_cur_h, hidden_size_, zr_alpha_, zr_beta_);
|
||||
}
|
||||
}
|
||||
|
||||
std::string label = linear_before_reset_ ? "rt (.) (Ht-1 * (Rh^T) + Rbh)" : "rt (.) Ht-1";
|
||||
DumpMatrix(label + seqno_str, &*cur_h_local, batch_size_, hidden_size_);
|
||||
|
||||
if (linear_before_reset_) {
|
||||
// input contains rt (.) (Ht-1*(Rh^T) + Rbh)
|
||||
auto input = cur_h_local;
|
||||
// out_H currently contains Xt*(W[zrh]^T).
|
||||
auto out_H = outputZRH_.begin() + out_added_offset;
|
||||
|
||||
for (int r = 0; r < batch_size_; r++) {
|
||||
// skip over the inputs with Z and R weights
|
||||
out_H += hidden_size_x2;
|
||||
for (int h = 0; h < hidden_size_; ++h) {
|
||||
*out_H += *input;
|
||||
++out_H;
|
||||
++input;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
label += " * Rh^T";
|
||||
|
||||
// out_H currently contains Xt*(Wh^T).
|
||||
auto out_H = outputZRH_.begin() + out_added_offset + hidden_size_x2;
|
||||
|
||||
// Calculate Xt*(Wh^T) + rt (.) Ht-1 * Rh
|
||||
ComputeGemm(batch_size_, hidden_size_, hidden_size_, alpha,
|
||||
cur_h_local, cur_h_local_end, // rt (.) Ht-1
|
||||
hidden_size_,
|
||||
recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T
|
||||
hidden_size_, beta,
|
||||
out_H, outputZRH_.end(),
|
||||
hidden_size_x3);
|
||||
}
|
||||
|
||||
DumpMatrix("Xt*(Wh^T) + (" + label + ")" + seqno_str, outputZRH_.data() + out_added_offset,
|
||||
batch_size_, hidden_size_, hidden_size_x2, hidden_size_x3);
|
||||
|
||||
//2nd Set of Activations
|
||||
span_T_iter output;
|
||||
span_T_iter output_end;
|
||||
if (output_sequence) {
|
||||
output = outputs.begin() + step * output_step_length;
|
||||
output_end = outputs.end();
|
||||
|
||||
} else {
|
||||
output = final_hidden_state.begin();
|
||||
output_end = final_hidden_state.end();
|
||||
}
|
||||
|
||||
for (int r = 0; r < batch_size_; r++) {
|
||||
if (step >= min_sequence_length && step >= sequence_lengths[r]) {
|
||||
if (output_sequence) {
|
||||
auto fill_output = output + r * hidden_size_;
|
||||
std::fill_n(fill_output, hidden_size_, T{});
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
// 1st Set Of Activations
|
||||
for (int r = 0; r < batch_size_; r++) {
|
||||
const T* p_bias_r = use_bias_ ? SafeRawConstPointer<T>(batched_bias_WRr_local + r * hidden_size_,
|
||||
batched_bias_WRr_local_end, hidden_size_)
|
||||
: nullptr;
|
||||
const T* p_bias_z = use_bias_ ? SafeRawConstPointer<T>(batched_bias_WRz_local,
|
||||
batched_bias_WRz_local_end, hidden_size_)
|
||||
: nullptr;
|
||||
|
||||
// initialize p_rt with input to calculate rt. outputZRH_ has Xt*(Wr^T) + Ht-1*(Rr^T).
|
||||
T* p_rt = SafeRawPointer(outputZRH_, out_added_offset + r * hidden_size_x3 + hidden_size_, hidden_size_);
|
||||
// initialize p_zt with Xt*(Wz^T) + Ht-1*(Rz^T), which is most of the input to calculate zt:
|
||||
T* p_zt = SafeRawPointer<T>(outputZRH_, out_added_offset + r * hidden_size_x3, hidden_size_);
|
||||
|
||||
// add the bias and clip. post: p_rt == Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr
|
||||
clip_with_bias_ptr_(clip_, p_bias_r, p_rt, hidden_size_);
|
||||
// using p_zt, add bias and clip in-place
|
||||
clip_with_bias_ptr_(clip_, p_bias_z, p_zt, hidden_size_);
|
||||
|
||||
// calculate zt in-place. p_zt = f(p_zt)
|
||||
update_gate_(p_zt, hidden_size_, zr_alpha_, zr_beta_);
|
||||
|
||||
DumpMatrix("zt[" + std::to_string(r) + "]" + seqno_str, p_zt, 1, hidden_size_);
|
||||
|
||||
const T* p_bias_h = nullptr;
|
||||
if (use_bias_) {
|
||||
if (linear_before_reset_) {
|
||||
// p_linear_output = Ht-1 * (Rh^T) + Rbh
|
||||
T* p_linear_output = SafeRawPointer<T>(linear_output_, r * hidden_size_, hidden_size_);
|
||||
T* p_cur_h = SafeRawPointer<T>(cur_h_local + r * hidden_size_, cur_h_local_end, hidden_size_);
|
||||
|
||||
// calculate rt in-place [p_rt = f(p_rt)]
|
||||
// calculate rt (.) (Ht-1 * (Rh^T) + Rbh) using p_linear_output. write to p_cur_h
|
||||
reset_gate_(p_linear_output, p_rt, p_cur_h, hidden_size_, zr_alpha_, zr_beta_);
|
||||
// Wbh
|
||||
p_bias_h = SafeRawConstPointer<T>(batched_bias_Wh_local + r * hidden_size_,
|
||||
batched_bias_Wh_local_end, hidden_size_);
|
||||
|
||||
} else {
|
||||
const T* p_prev_Ht = SafeRawConstPointer<T>(prev_Ht + r * hidden_size_, prev_Ht_end, hidden_size_);
|
||||
T* p_cur_h = SafeRawPointer<T>(cur_h_local + r * hidden_size_, cur_h_local_end, hidden_size_);
|
||||
|
||||
// calculate rt in-place [p_rt = f(p_rt)]
|
||||
// calculate rt (.) Ht-1 using p_prev_Ht, and write to p_cur_h
|
||||
reset_gate_(p_prev_Ht, p_rt, p_cur_h, hidden_size_, zr_alpha_, zr_beta_);
|
||||
// Wbh + Wrh
|
||||
p_bias_h = SafeRawConstPointer<T>(batched_bias_WRh_local + r * hidden_size_,
|
||||
batched_bias_WRh_local_end, hidden_size_);
|
||||
}
|
||||
}
|
||||
|
||||
std::string label = linear_before_reset_ ? "rt (.) (Ht-1 * (Rh^T) + Rbh)" : "rt (.) Ht-1";
|
||||
DumpMatrix(label + seqno_str, &*cur_h_local, batch_size_, hidden_size_);
|
||||
// setup p_ht with input to calculate ht
|
||||
// p_ht = Xt*(Wh^T) + (rt (.) Ht-1 * Rh^T) # linear_before_reset_ == false
|
||||
// = Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) # linear_before_reset_ == true
|
||||
T* p_ht = SafeRawPointer<T>(outputZRH_, out_added_offset + r * hidden_size_x3 + hidden_size_x2, hidden_size_);
|
||||
|
||||
if (linear_before_reset_) {
|
||||
// input contains rt (.) (Ht-1*(Rh^T) + Rbh)
|
||||
auto input = cur_h_local;
|
||||
// out_H currently contains Xt*(W[zrh]^T).
|
||||
auto out_H = outputZRH_.begin() + out_added_offset;
|
||||
// add Wbh [and Wrh] and clip
|
||||
clip_with_bias_ptr_(clip_, p_bias_h, p_ht, hidden_size_); // post: p_ht == input to g() for calculating ht
|
||||
|
||||
for (int r = 0; r < batch_size_; r++) {
|
||||
// skip over the inputs with Z and R weights
|
||||
out_H += hidden_size_x2;
|
||||
for (int h = 0; h < hidden_size_; ++h) {
|
||||
*out_H += *input;
|
||||
++out_H;
|
||||
++input;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
label += " * Rh^T";
|
||||
DumpMatrix("ht input [" + std::to_string(r) + "]" + seqno_str, p_ht, 1, hidden_size_);
|
||||
|
||||
// out_H currently contains Xt*(Wh^T).
|
||||
auto out_H = outputZRH_.begin() + out_added_offset + hidden_size_x2;
|
||||
const T* p_prev_Ht = SafeRawConstPointer<T>(prev_Ht + r * hidden_size_, prev_Ht_end, hidden_size_);
|
||||
T* p_Ht = SafeRawPointer<T>(output + r * hidden_size_, output_end, hidden_size_);
|
||||
|
||||
// Calculate Xt*(Wh^T) + rt (.) Ht-1 * Rh
|
||||
ComputeGemm(batch_size_, hidden_size_, hidden_size_, alpha,
|
||||
cur_h_local, cur_h_local_end, // rt (.) Ht-1
|
||||
hidden_size_,
|
||||
recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T
|
||||
hidden_size_, beta,
|
||||
out_H, outputZRH_.end(),
|
||||
hidden_size_x3);
|
||||
}
|
||||
|
||||
DumpMatrix("Xt*(Wh^T) + (" + label + ")" + seqno_str, outputZRH_.data() + out_added_offset,
|
||||
batch_size_, hidden_size_, hidden_size_x2, hidden_size_x3);
|
||||
|
||||
//2nd Set of Activations
|
||||
span_T_iter output;
|
||||
span_T_iter output_end;
|
||||
if (output_sequence) {
|
||||
output = outputs.begin() + step * output_step_length;
|
||||
output_end = outputs.end();
|
||||
|
||||
} else {
|
||||
output = final_hidden_state.begin();
|
||||
output_end = final_hidden_state.end();
|
||||
}
|
||||
|
||||
for (int r = 0; r < batch_size_; r++) {
|
||||
if (step >= min_sequence_length && step >= sequence_lengths[r]) {
|
||||
if (output_sequence) {
|
||||
auto fill_output = output + r * hidden_size_;
|
||||
std::fill_n(fill_output, hidden_size_, T{});
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
const T* p_bias_z = use_bias_ ? SafeRawConstPointer<T>(batched_bias_WRz_local,
|
||||
batched_bias_WRz_local_end, hidden_size_)
|
||||
: nullptr;
|
||||
|
||||
// initialize p_zt with Xt*(Wz^T) + Ht-1*(Rz^T), which is most of the input to calculate zt:
|
||||
T* p_zt = SafeRawPointer<T>(outputZRH_, out_added_offset + r * hidden_size_x3, hidden_size_);
|
||||
|
||||
// using p_zt, add bias and clip in-place
|
||||
clip_with_bias_ptr_(clip_, p_bias_z, p_zt, hidden_size_);
|
||||
|
||||
// calculate zt in-place. p_zt = f(p_zt)
|
||||
update_gate_(p_zt, hidden_size_, zr_alpha_, zr_beta_);
|
||||
|
||||
DumpMatrix("zt[" + std::to_string(r) + "]" + seqno_str, p_zt, 1, hidden_size_);
|
||||
|
||||
const T* p_bias_h = nullptr;
|
||||
if (use_bias_) {
|
||||
if (linear_before_reset_) {
|
||||
// Wbh
|
||||
p_bias_h = SafeRawConstPointer<T>(batched_bias_Wh_local + r * hidden_size_,
|
||||
batched_bias_Wh_local_end, hidden_size_);
|
||||
|
||||
} else {
|
||||
// Wbh + Wrh
|
||||
p_bias_h = SafeRawConstPointer<T>(batched_bias_WRh_local + r * hidden_size_,
|
||||
batched_bias_WRh_local_end, hidden_size_);
|
||||
}
|
||||
}
|
||||
|
||||
// setup p_ht with input to calculate ht
|
||||
// p_ht = Xt*(Wh^T) + (rt (.) Ht-1 * Rh^T) # linear_before_reset_ == false
|
||||
// = Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) # linear_before_reset_ == true
|
||||
T* p_ht = SafeRawPointer<T>(outputZRH_, out_added_offset + r * hidden_size_x3 + hidden_size_x2, hidden_size_);
|
||||
|
||||
// add Wbh [and Wrh] and clip
|
||||
clip_with_bias_ptr_(clip_, p_bias_h, p_ht, hidden_size_); // post: p_ht == input to g() for calculating ht
|
||||
|
||||
DumpMatrix("ht input [" + std::to_string(r) + "]" + seqno_str, p_ht, 1, hidden_size_);
|
||||
|
||||
const T* p_prev_Ht = SafeRawConstPointer<T>(prev_Ht + r * hidden_size_, prev_Ht_end, hidden_size_);
|
||||
T* p_Ht = SafeRawPointer<T>(output + r * hidden_size_, output_end, hidden_size_);
|
||||
|
||||
// calculate ht = g(p_ht) and write in-place to p_ht
|
||||
// calculate Ht = (1 - zt) (.) ht + zt (.) Ht-1 and write to p_Ht
|
||||
output_gate_(p_ht, p_zt, p_prev_Ht, p_Ht, hidden_size_, h_alpha_, h_beta_); // calculate ht and Ht
|
||||
}
|
||||
|
||||
DumpMatrix("output" + seqno_str, &*output, batch_size_, hidden_size_);
|
||||
|
||||
prev_Ht = output;
|
||||
prev_Ht_end = output_end;
|
||||
// calculate ht = g(p_ht) and write in-place to p_ht
|
||||
// calculate Ht = (1 - zt) (.) ht + zt (.) Ht-1 and write to p_Ht
|
||||
output_gate_(p_ht, p_zt, p_prev_Ht, p_Ht, hidden_size_, h_alpha_, h_beta_); // calculate ht and Ht
|
||||
}
|
||||
|
||||
DumpMatrix("output" + seqno_str, &*output, batch_size_, hidden_size_);
|
||||
|
||||
prev_Ht = output;
|
||||
prev_Ht_end = output_end;
|
||||
}
|
||||
|
||||
// copy last output to final_hidden_state
|
||||
|
|
@ -1072,29 +840,5 @@ void UniDirectionalGru<T>::AllocateBuffers() {
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void UniDirectionalGru<T>::SetNumThreads() {
|
||||
int threads = std::thread::hardware_concurrency() - 1;
|
||||
|
||||
if (threads < 1)
|
||||
threads = 1;
|
||||
|
||||
hidden_num_threads_ = threads;
|
||||
batch_parallel_ = false;
|
||||
|
||||
// for readability of the below logic
|
||||
const auto num_rows = batch_size_;
|
||||
const auto num_columns = hidden_size_;
|
||||
|
||||
// parallelize by partitioning the batch rows
|
||||
if (num_rows > 4 ||
|
||||
(num_rows >= 2 && num_columns <= 256) ||
|
||||
(num_rows >= 3 && num_columns <= 512)) {
|
||||
batch_parallel_ = true;
|
||||
VLOGS(logger_, 1) << "Hidden Threads : " << hidden_num_threads_;
|
||||
}
|
||||
|
||||
ORT_ENFORCE(hidden_num_threads_ >= 1);
|
||||
}
|
||||
} // namespace detail
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -60,18 +60,12 @@ class DeepCpuGruOp final : public OpKernel {
|
|||
rnn::detail::Direction direction_;
|
||||
int num_directions_;
|
||||
|
||||
int hidden_size_ = 0;
|
||||
int hidden_size_ {};
|
||||
float clip_;
|
||||
int linear_before_reset_ = 0;
|
||||
int linear_before_reset_ {};
|
||||
|
||||
rnn::detail::ActivationFuncs activation_funcs_;
|
||||
|
||||
// Threadpool for operator. If concurrent Compute calls are possible, it will be shared
|
||||
// across them. mutable due to this.
|
||||
// The alternative would be to create a threadpool in each call to Compute but that would incur thread creation
|
||||
// cost on every call.
|
||||
mutable onnxruntime::concurrency::ThreadPool ttp_{"DEEPCPU_GRU", (int)std::thread::hardware_concurrency()};
|
||||
|
||||
template <typename T>
|
||||
Status ComputeImpl(OpKernelContext& context) const;
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in a new issue