From d0f846aad5168ff3de50bb2e06c80b99c3da41ea Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Tue, 23 Apr 2019 14:50:11 -0700 Subject: [PATCH] Tuning GRU performance for batch size >= 2 (#644) GRU with batch size >1 is implemented on the assumption that Lotus use single-thread Eigen Gemm. The assumption doesn't hold anymore. MLAS and MKLML support multi-thread. We don't rely eigen gemm anymore. This PR implements batch size > 1 as batch size ==1. With this change, we have about 2x performance gain for GRU.Please refer to the performance test below: (ms) Batch_size | Seq_length | input_size | hiddden_size | Old | New 8 | 30 | 512 | 512 | 19.16 | 10.47 16 | 30 | 512 | 512 | 28.13 | 15.15 32 | 30 | 512 | 512 | 36.97 | 26.89 8 | 30 | 1024 | 1024 | 142.853 | 55.67 16 | 30 | 1024 | 1024 | 184.397 | 72.32 32 | 30 | 1024 | 1024 236.364 | 112.78 --- .../core/providers/cpu/rnn/deep_cpu_gru.cc | 624 ++++++------------ .../core/providers/cpu/rnn/deep_cpu_gru.h | 10 +- 2 files changed, 186 insertions(+), 448 deletions(-) diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc index ce54029307..4f6878d276 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc @@ -161,8 +161,6 @@ using namespace rnn::detail; // internal helper code namespace detail { -/// The class represents DeepCPU implementation of a gated recurrent unit (GRU) operator. -/// For details, refer to http://aka.ms/dl-optimization/. template class UniDirectionalGru { public: @@ -178,8 +176,7 @@ class UniDirectionalGru { const gsl::span& initial_hidden_state, const ActivationFuncs::Entry& activation_func_f, const ActivationFuncs::Entry& activation_func_g, - const float clip, - onnxruntime::concurrency::ThreadPool& ttp_); + const float clip); void Compute(const gsl::span& inputs, const gsl::span& sequence_lengths, @@ -195,8 +192,6 @@ class UniDirectionalGru { AllocatorPtr allocator_; const logging::Logger& logger_; - onnxruntime::concurrency::ThreadPool& ttp_; - int seq_length_; int batch_size_; int input_size_; @@ -207,9 +202,6 @@ class UniDirectionalGru { Direction direction_; bool use_bias_; - bool batch_parallel_; - - int hidden_num_threads_ = -1; IAllocatorUniquePtr outputZRH_ptr_; gsl::span outputZRH_; @@ -243,17 +235,18 @@ class UniDirectionalGru { gsl::span inputs_reverse_; gsl::span outputs_reverse_; - deepcpu::ClipWithBiasFuncPtr clip_with_bias_ptr_ = nullptr; + deepcpu::ClipWithBiasFuncPtr clip_with_bias_ptr_{}; - float zr_alpha_ = 0.f, zr_beta_ = 0.f; - float h_alpha_ = 0.f, h_beta_ = 0.f; + float zr_alpha_{}; + float zr_beta_{}; + float h_alpha_{}; + float h_beta_{}; - deepcpu::GruResetGateFuncPtr reset_gate_ = nullptr; - deepcpu::ActivationFuncPtr update_gate_ = nullptr; - deepcpu::GruOutputGateFuncPtr output_gate_ = nullptr; + deepcpu::GruResetGateFuncPtr reset_gate_{}; + deepcpu::ActivationFuncPtr update_gate_{}; + deepcpu::GruOutputGateFuncPtr output_gate_{}; void AllocateBuffers(); - void SetNumThreads(); }; } // namespace detail @@ -268,7 +261,6 @@ Status DeepCpuGruOp::Compute(OpKernelContext* context) const { const Tensor& X = *context->Input(0); // inputs. [seq_length, batch_size, input_size] Status status; - // auto& logger = context->Logger(); auto data_type = X.DataType(); if (data_type == DataTypeImpl::GetType()) @@ -393,7 +385,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const { bias_1, initial_hidden_1, activation_funcs_.Entries()[0], activation_funcs_.Entries()[1], - clip_, ttp_); + clip_); fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1); std::unique_ptr> bw = std::make_unique>( @@ -402,7 +394,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const { bias_2, initial_hidden_2, activation_funcs_.Entries()[2], activation_funcs_.Entries()[3], - clip_, ttp_); + clip_); bw->Compute(input, sequence_lens_span, num_directions_, input_weights_2, recurrent_weights_2, output_2, hidden_output_2); } else { std::unique_ptr> gru_p = std::make_unique>( @@ -411,7 +403,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const { bias_1, initial_hidden_1, activation_funcs_.Entries()[0], activation_funcs_.Entries()[1], - clip_, ttp_); + clip_); gru_p->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1); } @@ -441,11 +433,9 @@ UniDirectionalGru::UniDirectionalGru(AllocatorPtr allocator, const gsl::span& initial_hidden_state, const ActivationFuncs::Entry& activation_func_f, const ActivationFuncs::Entry& activation_func_g, - const float clip, - onnxruntime::concurrency::ThreadPool& ttp) + const float clip) : allocator_(allocator), logger_(logger), - ttp_(ttp), seq_length_(seq_length), batch_size_(batch_size), input_size_(input_size), @@ -454,7 +444,6 @@ UniDirectionalGru::UniDirectionalGru(AllocatorPtr allocator, clip_(clip), direction_(direction), use_bias_(!bias.empty()) { - // clip_with_bias_ptr_ = use_bias_ ? deepcpu::clip_add_bias : deepcpu::clip_ignore_bias; // setup activation function pointers and alpha/beta values to use with them @@ -467,7 +456,6 @@ UniDirectionalGru::UniDirectionalGru(AllocatorPtr allocator, h_alpha_ = activation_func_g.alpha; h_beta_ = activation_func_g.beta; - SetNumThreads(); AllocateBuffers(); if (use_bias_) { @@ -598,429 +586,209 @@ void UniDirectionalGru::Compute(const gsl::span& inputs_arg, span_T_const_iter batched_bias_Rh_local_end = batched_bias_Rh_.cend(); span_T_const_iter batched_bias_WRh_local_end = batched_bias_WRh_.cend(); - if (batch_parallel_) { - int fused_hidden_rows = batch_size_ / hidden_num_threads_; - if (batch_size_ % hidden_num_threads_ != 0) - fused_hidden_rows++; + size_t out_added_offset; - // lambda executed by ThreadPool - auto hidden_gemm_and_activations = [&](const int row) { - //handling boundaries - int local_fused_hidden_rows = fused_hidden_rows; - if ((row + fused_hidden_rows) > batch_size_) - local_fused_hidden_rows = batch_size_ - row; + span_T_const_iter prev_Ht = batched_hidden0_.cbegin(); // Ht-1 + span_T_const_iter prev_Ht_end = batched_hidden0_.cend(); + span_T_iter cur_h_local = cur_h_.begin(); + span_T_iter cur_h_local_end = cur_h_.end(); - size_t out_added_offset; - span_T_const_iter prev_Ht = batched_hidden0_.cbegin() + row * hidden_size_; // Ht-1 - span_T_const_iter prev_Ht_end = batched_hidden0_.cend(); - span_T_iter cur_h_local = cur_h_.begin() + row * hidden_size_; - span_T_iter cur_h_local_end = cur_h_.end(); - span_T_iter linear_output_local; - span_T_iter linear_output_local_end; + span_T_const_iter batched_bias_WRz_local; + span_T_const_iter batched_bias_WRr_local; + span_T_const_iter batched_bias_WRh_local; + span_T_const_iter batched_bias_Wh_local; + span_T_const_iter batched_bias_Rh_local; - span_T_const_iter batched_bias_WRz_local; - span_T_const_iter batched_bias_WRr_local; - span_T_const_iter batched_bias_WRh_local; - span_T_const_iter batched_bias_Wh_local; - span_T_const_iter batched_bias_Rh_local; + if (use_bias_) { + batched_bias_WRz_local = batched_bias_WRz_.cbegin(); + batched_bias_WRr_local = batched_bias_WRr_.cbegin(); - if (use_bias_) { - batched_bias_WRz_local = batched_bias_WRz_.cbegin() + row * hidden_size_; - batched_bias_WRr_local = batched_bias_WRr_.cbegin() + row * hidden_size_; - - if (linear_before_reset_) { - batched_bias_Wh_local = batched_bias_Wh_.cbegin() + row * hidden_size_; - batched_bias_Rh_local = batched_bias_Rh_.cbegin() + row * hidden_size_; - linear_output_local = linear_output_.begin() + row * hidden_size_; - linear_output_local_end = linear_output_.end(); - } else { - batched_bias_WRh_local = batched_bias_WRh_.cbegin() + row * hidden_size_; - } - } - - for (int step = 0; step < max_sequence_length; step++) { - const std::string row_str = " [row=" + std::to_string(row) + ",seqno=" + std::to_string(step) + "]"; - - DumpMatrix("Ht-1" + row_str, &*prev_Ht, local_fused_hidden_rows, hidden_size_); - - out_added_offset = (step * batch_size_ + row) * hidden_size_x3; - - // calculate Ht-1*R[zr], and add to the weighted inputs that are in outputZRH_ - ComputeGemm(local_fused_hidden_rows, hidden_size_x2, hidden_size_, alpha, - prev_Ht, prev_Ht_end, - hidden_size_, - recurrent_weightsZR.cbegin(), recurrent_weightsZR.cend(), - hidden_size_, beta, - outputZRH_.begin() + out_added_offset, outputZRH_.end(), - hidden_size_x3); - - DumpMatrix("Xt*(W[zr]^T) + Ht-1 * R[zr]" + row_str, - outputZRH_.data() + out_added_offset, local_fused_hidden_rows, hidden_size_x2, 0, hidden_size_x3); - - if (linear_before_reset_) { - // copy Rbh to linear output - gsl::copy(batched_bias_Rh_.subspan(batched_bias_Rh_local - batched_bias_Rh_.begin(), local_fused_hidden_rows * hidden_size_), - linear_output_.subspan(linear_output_local - linear_output_.begin(), linear_output_local_end - linear_output_local)); - - // compute Ht-1 * (Rh^T) + Rbh - ComputeGemm(local_fused_hidden_rows, hidden_size_, hidden_size_, alpha, - prev_Ht, prev_Ht_end, // Ht-1 - hidden_size_, - recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T - hidden_size_, beta, - linear_output_local, linear_output_.end(), // pre: Rbh, post:output - hidden_size_); - - DumpMatrix("Ht-1 * (Rh^T) + Rbh " + row_str, &*linear_output_local, batch_size_, hidden_size_); - } - - // 1st Set Of Activations - for (int r = 0; r < local_fused_hidden_rows; r++) { - const T* p_bias_r = use_bias_ ? SafeRawConstPointer(batched_bias_WRr_local + r * hidden_size_, - batched_bias_WRr_local_end, hidden_size_) - : nullptr; - - // initialize p_rt with input to calculate rt. outputZRH_ has Xt*(Wr^T) + Ht-1*(Rr^T). - T* p_rt = SafeRawPointer(outputZRH_, out_added_offset + r * hidden_size_x3 + hidden_size_, hidden_size_); - - // add the bias and clip. post: p_rt == Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr - clip_with_bias_ptr_(clip_, p_bias_r, p_rt, hidden_size_); - - if (linear_before_reset_) { - // p_linear_output = Ht-1 * (Rh^T) + Rbh - T* p_linear_output = SafeRawPointer(linear_output_local + r * hidden_size_, - linear_output_local_end, hidden_size_); - T* p_cur_h = SafeRawPointer(cur_h_local + r * hidden_size_, cur_h_local_end, hidden_size_); - - // calculate rt in-place [p_rt = f(p_rt)] - // calculate rt (.) (Ht-1 * (Rh^T) + Rbh) using p_linear_output. write to p_cur_h - reset_gate_(p_linear_output, p_rt, p_cur_h, hidden_size_, zr_alpha_, zr_beta_); - - } else { - const T* p_prev_Ht = SafeRawConstPointer(prev_Ht + r * hidden_size_, prev_Ht_end, hidden_size_); - T* p_cur_h = SafeRawPointer(cur_h_local + r * hidden_size_, cur_h_local_end, hidden_size_); - - // calculate rt in-place [p_rt = f(p_rt)] - // calculate rt (.) Ht-1 using p_prev_Ht, and write to p_cur_h - reset_gate_(p_prev_Ht, p_rt, p_cur_h, hidden_size_, zr_alpha_, zr_beta_); - } - } - - std::string label = linear_before_reset_ ? "rt (.) (Ht-1 * (Rh^T) + Rbh)" : "rt (.) Ht-1"; - DumpMatrix(label + row_str, &*cur_h_local, local_fused_hidden_rows, hidden_size_); - - if (linear_before_reset_) { - // input contains rt (.) (Ht-1*(Rh^T) + Rbh) - auto input = cur_h_local; - // out_H currently contains Xt*(W[zrh]^T). - auto out_H = outputZRH_.begin() + out_added_offset; - - for (int r = 0; r < local_fused_hidden_rows; r++) { - // skip over the inputs with Z and R weights - out_H += hidden_size_x2; - for (int h = 0; h < hidden_size_; ++h) { - *out_H += *input; - ++out_H; - ++input; - } - } - } else { - label += " * Rh^T"; - ComputeGemm(local_fused_hidden_rows, hidden_size_, hidden_size_, alpha, - cur_h_local, cur_h_local_end, - hidden_size_, - recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), - hidden_size_, beta, - outputZRH_.begin() + out_added_offset + hidden_size_x2, outputZRH_.end(), - hidden_size_x3); - } - - DumpMatrix("Xt*(Wh^T) + (" + label + ")" + row_str, - outputZRH_.data() + out_added_offset, local_fused_hidden_rows, hidden_size_, - hidden_size_x2, hidden_size_x3); - - // 2nd Set of Activations - span_T_iter output; - span_T_iter output_end; - if (output_sequence) { - output = outputs.begin() + step * output_step_length + row * hidden_size_; - output_end = outputs.end(); - - } else { - output = final_hidden_state.begin() + row * hidden_size_; - output_end = final_hidden_state.end(); - } - - for (int r = 0; r < local_fused_hidden_rows; r++) { - if (step >= min_sequence_length && step >= sequence_lengths[row + r]) { - if (output_sequence) { - auto fill_output = output + r * hidden_size_; - std::fill_n(fill_output, hidden_size_, T{}); - } - - continue; - } - - const T* p_bias_z = use_bias_ ? SafeRawConstPointer(batched_bias_WRz_local, batched_bias_WRz_local_end, - hidden_size_) - : nullptr; - - // initialize p_zt with Xt*(Wz^T) + Ht-1*(Rz^T), which is most of the input to calculate zt: - T* p_zt = SafeRawPointer(outputZRH_, out_added_offset + r * hidden_size_x3, hidden_size_); - - // using p_zt, add bias and clip in-place - clip_with_bias_ptr_(clip_, p_bias_z, p_zt, hidden_size_); - - // calculate zt in-place. p_zt = f(p_zt) - update_gate_(p_zt, hidden_size_, zr_alpha_, zr_beta_); - - DumpMatrix("zt[" + std::to_string(r) + "]" + row_str, p_zt, 1, hidden_size_); - - const T* p_bias_h = nullptr; - if (use_bias_) { - if (linear_before_reset_) { - // Wbh - p_bias_h = SafeRawConstPointer(batched_bias_Wh_local + r * hidden_size_, - batched_bias_Wh_local_end, hidden_size_); - - } else { - // Wbh + Wrh - p_bias_h = SafeRawConstPointer(batched_bias_WRh_local + r * hidden_size_, - batched_bias_WRh_local_end, hidden_size_); - } - } - - // setup p_ht with input to calculate ht - // p_ht = Xt*(Wh^T) + (rt (.) Ht-1 * Rh^T) # linear_before_reset_ == false - // = Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) # linear_before_reset_ == true - T* p_ht = SafeRawPointer(outputZRH_, out_added_offset + r * hidden_size_x3 + hidden_size_x2, hidden_size_); - - // add Wbh [and Wrh] and clip - clip_with_bias_ptr_(clip_, p_bias_h, p_ht, hidden_size_); // post: p_ht = input to g() for calculating ht - - DumpMatrix("ht input [" + std::to_string(r) + "]" + row_str, p_ht, 1, hidden_size_); - - const T* p_prev_Ht = SafeRawConstPointer(prev_Ht + r * hidden_size_, prev_Ht_end, hidden_size_); - T* p_Ht = SafeRawPointer(output + r * hidden_size_, output_end, hidden_size_); - - // calculate ht = g(p_ht) and write in-place to p_ht - // calculate Ht = (1 - zt) (.) ht + zt (.) Ht-1 and write to p_Ht - output_gate_(p_ht, p_zt, p_prev_Ht, p_Ht, hidden_size_, h_alpha_, h_beta_); - } - - DumpMatrix("output" + row_str, &*output, 1, hidden_size_); - - prev_Ht = output; - prev_Ht_end = output_end; - } - }; - - ExecuteLambdaInParallel("Processing batch", hidden_gemm_and_activations, batch_size_, fused_hidden_rows, ttp_, logger_); - } else { - size_t out_added_offset; - - span_T_const_iter prev_Ht = batched_hidden0_.cbegin(); // Ht-1 - span_T_const_iter prev_Ht_end = batched_hidden0_.cend(); - span_T_iter cur_h_local = cur_h_.begin(); - span_T_iter cur_h_local_end = cur_h_.end(); - - span_T_const_iter batched_bias_WRz_local; - span_T_const_iter batched_bias_WRr_local; - span_T_const_iter batched_bias_WRh_local; - span_T_const_iter batched_bias_Wh_local; - span_T_const_iter batched_bias_Rh_local; - - if (use_bias_) { - batched_bias_WRz_local = batched_bias_WRz_.cbegin(); - batched_bias_WRr_local = batched_bias_WRr_.cbegin(); - - if (linear_before_reset_) { - batched_bias_Wh_local = batched_bias_Wh_.cbegin(); - batched_bias_Rh_local = batched_bias_Rh_.cbegin(); - } else { - batched_bias_WRh_local = batched_bias_WRh_.cbegin(); - } + if (linear_before_reset_) { + batched_bias_Wh_local = batched_bias_Wh_.cbegin(); + batched_bias_Rh_local = batched_bias_Rh_.cbegin(); + } else { + batched_bias_WRh_local = batched_bias_WRh_.cbegin(); } + } - // for each item in sequence run all calculations - for (int step = 0; step < max_sequence_length; step++) { - const std::string seqno_str = " [seqno=" + std::to_string(step) + "]"; + // for each item in sequence run all calculations + for (int step = 0; step < max_sequence_length; step++) { + const std::string seqno_str = " [seqno=" + std::to_string(step) + "]"; - DumpMatrix("Ht-1" + seqno_str, &*prev_Ht, batch_size_, hidden_size_); + DumpMatrix("Ht-1" + seqno_str, &*prev_Ht, batch_size_, hidden_size_); - out_added_offset = (step * batch_size_) * hidden_size_x3; + out_added_offset = (step * batch_size_) * hidden_size_x3; - // calculate Ht-1*R[zr], and add to the weighted inputs that are in outputZRH_ - // Ht-1 * R[zr] + Xt*(W[zr]^T) - ComputeGemm(batch_size_, hidden_size_x2, hidden_size_, alpha, - prev_Ht, prev_Ht_end, + // calculate Ht-1*R[zr], and add to the weighted inputs that are in outputZRH_ + // Ht-1 * R[zr] + Xt*(W[zr]^T) + ComputeGemm(batch_size_, hidden_size_x2, hidden_size_, alpha, + prev_Ht, prev_Ht_end, + hidden_size_, + recurrent_weightsZR.cbegin(), recurrent_weightsZR.cend(), + hidden_size_, beta, + outputZRH_.begin() + out_added_offset, outputZRH_.end(), + hidden_size_x3); + + DumpMatrix("Ht-1 * R[zr] + Xt*(W[zr]^T)" + seqno_str, + outputZRH_.data() + out_added_offset, batch_size_, hidden_size_x2, 0, hidden_size_x3); + + if (linear_before_reset_) { + // copy Rbh to linear output + gsl::copy(batched_bias_Rh_.subspan(batched_bias_Rh_local - batched_bias_Rh_.begin(), batched_bias_Rh_local_end - batched_bias_Rh_local), linear_output_); + + // compute Ht-1 * (Rh^T) + Rbh + ComputeGemm(batch_size_, hidden_size_, hidden_size_, alpha, + prev_Ht, prev_Ht_end, // Ht-1 hidden_size_, - recurrent_weightsZR.cbegin(), recurrent_weightsZR.cend(), + recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T hidden_size_, beta, - outputZRH_.begin() + out_added_offset, outputZRH_.end(), - hidden_size_x3); + linear_output_.begin(), linear_output_.end(), // pre: Rbh, post:output + hidden_size_); - DumpMatrix("Ht-1 * R[zr] + Xt*(W[zr]^T)" + seqno_str, - outputZRH_.data() + out_added_offset, batch_size_, hidden_size_x2, 0, hidden_size_x3); + DumpMatrix("Ht-1 * (Rh^T) + Rbh " + seqno_str, linear_output_.data(), batch_size_, hidden_size_); + } + + // 1st Set Of Activations + for (int r = 0; r < batch_size_; r++) { + const T* p_bias_r = use_bias_ ? SafeRawConstPointer(batched_bias_WRr_local + r * hidden_size_, + batched_bias_WRr_local_end, hidden_size_) + : nullptr; + + // initialize p_rt with input to calculate rt. outputZRH_ has Xt*(Wr^T) + Ht-1*(Rr^T). + T* p_rt = SafeRawPointer(outputZRH_, out_added_offset + r * hidden_size_x3 + hidden_size_, hidden_size_); + + // add the bias and clip. post: p_rt == Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr + clip_with_bias_ptr_(clip_, p_bias_r, p_rt, hidden_size_); if (linear_before_reset_) { - // copy Rbh to linear output - gsl::copy(batched_bias_Rh_.subspan(batched_bias_Rh_local - batched_bias_Rh_.begin(), batched_bias_Rh_local_end - batched_bias_Rh_local), linear_output_); + // p_linear_output = Ht-1 * (Rh^T) + Rbh + T* p_linear_output = SafeRawPointer(linear_output_, r * hidden_size_, hidden_size_); + T* p_cur_h = SafeRawPointer(cur_h_local + r * hidden_size_, cur_h_local_end, hidden_size_); - // compute Ht-1 * (Rh^T) + Rbh - ComputeGemm(batch_size_, hidden_size_, hidden_size_, alpha, - prev_Ht, prev_Ht_end, // Ht-1 - hidden_size_, - recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T - hidden_size_, beta, - linear_output_.begin(), linear_output_.end(), // pre: Rbh, post:output - hidden_size_); + // calculate rt in-place [p_rt = f(p_rt)] + // calculate rt (.) (Ht-1 * (Rh^T) + Rbh) using p_linear_output. write to p_cur_h + reset_gate_(p_linear_output, p_rt, p_cur_h, hidden_size_, zr_alpha_, zr_beta_); - DumpMatrix("Ht-1 * (Rh^T) + Rbh " + seqno_str, linear_output_.data(), batch_size_, hidden_size_); + } else { + const T* p_prev_Ht = SafeRawConstPointer(prev_Ht + r * hidden_size_, prev_Ht_end, hidden_size_); + T* p_cur_h = SafeRawPointer(cur_h_local + r * hidden_size_, cur_h_local_end, hidden_size_); + + // calculate rt in-place [p_rt = f(p_rt)] + // calculate rt (.) Ht-1 using p_prev_Ht, and write to p_cur_h + reset_gate_(p_prev_Ht, p_rt, p_cur_h, hidden_size_, zr_alpha_, zr_beta_); + } + } + + std::string label = linear_before_reset_ ? "rt (.) (Ht-1 * (Rh^T) + Rbh)" : "rt (.) Ht-1"; + DumpMatrix(label + seqno_str, &*cur_h_local, batch_size_, hidden_size_); + + if (linear_before_reset_) { + // input contains rt (.) (Ht-1*(Rh^T) + Rbh) + auto input = cur_h_local; + // out_H currently contains Xt*(W[zrh]^T). + auto out_H = outputZRH_.begin() + out_added_offset; + + for (int r = 0; r < batch_size_; r++) { + // skip over the inputs with Z and R weights + out_H += hidden_size_x2; + for (int h = 0; h < hidden_size_; ++h) { + *out_H += *input; + ++out_H; + ++input; + } + } + } else { + label += " * Rh^T"; + + // out_H currently contains Xt*(Wh^T). + auto out_H = outputZRH_.begin() + out_added_offset + hidden_size_x2; + + // Calculate Xt*(Wh^T) + rt (.) Ht-1 * Rh + ComputeGemm(batch_size_, hidden_size_, hidden_size_, alpha, + cur_h_local, cur_h_local_end, // rt (.) Ht-1 + hidden_size_, + recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T + hidden_size_, beta, + out_H, outputZRH_.end(), + hidden_size_x3); + } + + DumpMatrix("Xt*(Wh^T) + (" + label + ")" + seqno_str, outputZRH_.data() + out_added_offset, + batch_size_, hidden_size_, hidden_size_x2, hidden_size_x3); + + //2nd Set of Activations + span_T_iter output; + span_T_iter output_end; + if (output_sequence) { + output = outputs.begin() + step * output_step_length; + output_end = outputs.end(); + + } else { + output = final_hidden_state.begin(); + output_end = final_hidden_state.end(); + } + + for (int r = 0; r < batch_size_; r++) { + if (step >= min_sequence_length && step >= sequence_lengths[r]) { + if (output_sequence) { + auto fill_output = output + r * hidden_size_; + std::fill_n(fill_output, hidden_size_, T{}); + } + + continue; } - // 1st Set Of Activations - for (int r = 0; r < batch_size_; r++) { - const T* p_bias_r = use_bias_ ? SafeRawConstPointer(batched_bias_WRr_local + r * hidden_size_, - batched_bias_WRr_local_end, hidden_size_) - : nullptr; + const T* p_bias_z = use_bias_ ? SafeRawConstPointer(batched_bias_WRz_local, + batched_bias_WRz_local_end, hidden_size_) + : nullptr; - // initialize p_rt with input to calculate rt. outputZRH_ has Xt*(Wr^T) + Ht-1*(Rr^T). - T* p_rt = SafeRawPointer(outputZRH_, out_added_offset + r * hidden_size_x3 + hidden_size_, hidden_size_); + // initialize p_zt with Xt*(Wz^T) + Ht-1*(Rz^T), which is most of the input to calculate zt: + T* p_zt = SafeRawPointer(outputZRH_, out_added_offset + r * hidden_size_x3, hidden_size_); - // add the bias and clip. post: p_rt == Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr - clip_with_bias_ptr_(clip_, p_bias_r, p_rt, hidden_size_); + // using p_zt, add bias and clip in-place + clip_with_bias_ptr_(clip_, p_bias_z, p_zt, hidden_size_); + // calculate zt in-place. p_zt = f(p_zt) + update_gate_(p_zt, hidden_size_, zr_alpha_, zr_beta_); + + DumpMatrix("zt[" + std::to_string(r) + "]" + seqno_str, p_zt, 1, hidden_size_); + + const T* p_bias_h = nullptr; + if (use_bias_) { if (linear_before_reset_) { - // p_linear_output = Ht-1 * (Rh^T) + Rbh - T* p_linear_output = SafeRawPointer(linear_output_, r * hidden_size_, hidden_size_); - T* p_cur_h = SafeRawPointer(cur_h_local + r * hidden_size_, cur_h_local_end, hidden_size_); - - // calculate rt in-place [p_rt = f(p_rt)] - // calculate rt (.) (Ht-1 * (Rh^T) + Rbh) using p_linear_output. write to p_cur_h - reset_gate_(p_linear_output, p_rt, p_cur_h, hidden_size_, zr_alpha_, zr_beta_); + // Wbh + p_bias_h = SafeRawConstPointer(batched_bias_Wh_local + r * hidden_size_, + batched_bias_Wh_local_end, hidden_size_); } else { - const T* p_prev_Ht = SafeRawConstPointer(prev_Ht + r * hidden_size_, prev_Ht_end, hidden_size_); - T* p_cur_h = SafeRawPointer(cur_h_local + r * hidden_size_, cur_h_local_end, hidden_size_); - - // calculate rt in-place [p_rt = f(p_rt)] - // calculate rt (.) Ht-1 using p_prev_Ht, and write to p_cur_h - reset_gate_(p_prev_Ht, p_rt, p_cur_h, hidden_size_, zr_alpha_, zr_beta_); + // Wbh + Wrh + p_bias_h = SafeRawConstPointer(batched_bias_WRh_local + r * hidden_size_, + batched_bias_WRh_local_end, hidden_size_); } } - std::string label = linear_before_reset_ ? "rt (.) (Ht-1 * (Rh^T) + Rbh)" : "rt (.) Ht-1"; - DumpMatrix(label + seqno_str, &*cur_h_local, batch_size_, hidden_size_); + // setup p_ht with input to calculate ht + // p_ht = Xt*(Wh^T) + (rt (.) Ht-1 * Rh^T) # linear_before_reset_ == false + // = Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) # linear_before_reset_ == true + T* p_ht = SafeRawPointer(outputZRH_, out_added_offset + r * hidden_size_x3 + hidden_size_x2, hidden_size_); - if (linear_before_reset_) { - // input contains rt (.) (Ht-1*(Rh^T) + Rbh) - auto input = cur_h_local; - // out_H currently contains Xt*(W[zrh]^T). - auto out_H = outputZRH_.begin() + out_added_offset; + // add Wbh [and Wrh] and clip + clip_with_bias_ptr_(clip_, p_bias_h, p_ht, hidden_size_); // post: p_ht == input to g() for calculating ht - for (int r = 0; r < batch_size_; r++) { - // skip over the inputs with Z and R weights - out_H += hidden_size_x2; - for (int h = 0; h < hidden_size_; ++h) { - *out_H += *input; - ++out_H; - ++input; - } - } - } else { - label += " * Rh^T"; + DumpMatrix("ht input [" + std::to_string(r) + "]" + seqno_str, p_ht, 1, hidden_size_); - // out_H currently contains Xt*(Wh^T). - auto out_H = outputZRH_.begin() + out_added_offset + hidden_size_x2; + const T* p_prev_Ht = SafeRawConstPointer(prev_Ht + r * hidden_size_, prev_Ht_end, hidden_size_); + T* p_Ht = SafeRawPointer(output + r * hidden_size_, output_end, hidden_size_); - // Calculate Xt*(Wh^T) + rt (.) Ht-1 * Rh - ComputeGemm(batch_size_, hidden_size_, hidden_size_, alpha, - cur_h_local, cur_h_local_end, // rt (.) Ht-1 - hidden_size_, - recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T - hidden_size_, beta, - out_H, outputZRH_.end(), - hidden_size_x3); - } - - DumpMatrix("Xt*(Wh^T) + (" + label + ")" + seqno_str, outputZRH_.data() + out_added_offset, - batch_size_, hidden_size_, hidden_size_x2, hidden_size_x3); - - //2nd Set of Activations - span_T_iter output; - span_T_iter output_end; - if (output_sequence) { - output = outputs.begin() + step * output_step_length; - output_end = outputs.end(); - - } else { - output = final_hidden_state.begin(); - output_end = final_hidden_state.end(); - } - - for (int r = 0; r < batch_size_; r++) { - if (step >= min_sequence_length && step >= sequence_lengths[r]) { - if (output_sequence) { - auto fill_output = output + r * hidden_size_; - std::fill_n(fill_output, hidden_size_, T{}); - } - - continue; - } - - const T* p_bias_z = use_bias_ ? SafeRawConstPointer(batched_bias_WRz_local, - batched_bias_WRz_local_end, hidden_size_) - : nullptr; - - // initialize p_zt with Xt*(Wz^T) + Ht-1*(Rz^T), which is most of the input to calculate zt: - T* p_zt = SafeRawPointer(outputZRH_, out_added_offset + r * hidden_size_x3, hidden_size_); - - // using p_zt, add bias and clip in-place - clip_with_bias_ptr_(clip_, p_bias_z, p_zt, hidden_size_); - - // calculate zt in-place. p_zt = f(p_zt) - update_gate_(p_zt, hidden_size_, zr_alpha_, zr_beta_); - - DumpMatrix("zt[" + std::to_string(r) + "]" + seqno_str, p_zt, 1, hidden_size_); - - const T* p_bias_h = nullptr; - if (use_bias_) { - if (linear_before_reset_) { - // Wbh - p_bias_h = SafeRawConstPointer(batched_bias_Wh_local + r * hidden_size_, - batched_bias_Wh_local_end, hidden_size_); - - } else { - // Wbh + Wrh - p_bias_h = SafeRawConstPointer(batched_bias_WRh_local + r * hidden_size_, - batched_bias_WRh_local_end, hidden_size_); - } - } - - // setup p_ht with input to calculate ht - // p_ht = Xt*(Wh^T) + (rt (.) Ht-1 * Rh^T) # linear_before_reset_ == false - // = Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) # linear_before_reset_ == true - T* p_ht = SafeRawPointer(outputZRH_, out_added_offset + r * hidden_size_x3 + hidden_size_x2, hidden_size_); - - // add Wbh [and Wrh] and clip - clip_with_bias_ptr_(clip_, p_bias_h, p_ht, hidden_size_); // post: p_ht == input to g() for calculating ht - - DumpMatrix("ht input [" + std::to_string(r) + "]" + seqno_str, p_ht, 1, hidden_size_); - - const T* p_prev_Ht = SafeRawConstPointer(prev_Ht + r * hidden_size_, prev_Ht_end, hidden_size_); - T* p_Ht = SafeRawPointer(output + r * hidden_size_, output_end, hidden_size_); - - // calculate ht = g(p_ht) and write in-place to p_ht - // calculate Ht = (1 - zt) (.) ht + zt (.) Ht-1 and write to p_Ht - output_gate_(p_ht, p_zt, p_prev_Ht, p_Ht, hidden_size_, h_alpha_, h_beta_); // calculate ht and Ht - } - - DumpMatrix("output" + seqno_str, &*output, batch_size_, hidden_size_); - - prev_Ht = output; - prev_Ht_end = output_end; + // calculate ht = g(p_ht) and write in-place to p_ht + // calculate Ht = (1 - zt) (.) ht + zt (.) Ht-1 and write to p_Ht + output_gate_(p_ht, p_zt, p_prev_Ht, p_Ht, hidden_size_, h_alpha_, h_beta_); // calculate ht and Ht } + + DumpMatrix("output" + seqno_str, &*output, batch_size_, hidden_size_); + + prev_Ht = output; + prev_Ht_end = output_end; } // copy last output to final_hidden_state @@ -1072,29 +840,5 @@ void UniDirectionalGru::AllocateBuffers() { } } -template -void UniDirectionalGru::SetNumThreads() { - int threads = std::thread::hardware_concurrency() - 1; - - if (threads < 1) - threads = 1; - - hidden_num_threads_ = threads; - batch_parallel_ = false; - - // for readability of the below logic - const auto num_rows = batch_size_; - const auto num_columns = hidden_size_; - - // parallelize by partitioning the batch rows - if (num_rows > 4 || - (num_rows >= 2 && num_columns <= 256) || - (num_rows >= 3 && num_columns <= 512)) { - batch_parallel_ = true; - VLOGS(logger_, 1) << "Hidden Threads : " << hidden_num_threads_; - } - - ORT_ENFORCE(hidden_num_threads_ >= 1); -} } // namespace detail } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h index 421bf106cf..f4967c9d85 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h @@ -60,18 +60,12 @@ class DeepCpuGruOp final : public OpKernel { rnn::detail::Direction direction_; int num_directions_; - int hidden_size_ = 0; + int hidden_size_ {}; float clip_; - int linear_before_reset_ = 0; + int linear_before_reset_ {}; rnn::detail::ActivationFuncs activation_funcs_; - // Threadpool for operator. If concurrent Compute calls are possible, it will be shared - // across them. mutable due to this. - // The alternative would be to create a threadpool in each call to Compute but that would incur thread creation - // cost on every call. - mutable onnxruntime::concurrency::ThreadPool ttp_{"DEEPCPU_GRU", (int)std::thread::hardware_concurrency()}; - template Status ComputeImpl(OpKernelContext& context) const; };