From bbedf2c4c5abb2d52859313301b38726598a137b Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Wed, 9 Nov 2022 09:59:38 -0800 Subject: [PATCH] Improve cache locality and perf of DeepGru on CPU (#13582) ### Description Introduce Gemm weights pre-pack. ### Motivation and Context A 1-P customer requested a performance improvement for DeepGru which consumes a bulk of CPU in their model. This provides measurable performance improvements. Customer model numbers. gru: mean = 356 us; 1ms = 99.8 prctile; 99th prctile = 665 ms (yuslepukhin/deep_gru_opt) main: mean = 375 us; 1ms = 99.8 prctile; 99th prctile = 695 ms (where yuslepukhin/deep_gru_opt branched off main) 1.13.1: mean = 391 us; 1ms = 99.6 prctile; 99th prctile = 744 ms --- .../core/providers/cpu/rnn/deep_cpu_gru.cc | 386 +++++++++++++++--- .../core/providers/cpu/rnn/deep_cpu_gru.h | 28 +- .../core/providers/cpu/rnn/rnn_helpers.h | 11 + 3 files changed, 356 insertions(+), 69 deletions(-) diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc index f9b598b5a1..688a734d58 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc @@ -10,11 +10,12 @@ #endif #include "core/providers/cpu/rnn/deep_cpu_gru.h" +#include "core/common/narrow.h" #ifdef _MSC_VER #pragma warning(pop) #endif -//TODO: fix the warnings +// TODO: fix the warnings #if defined(_MSC_VER) && !defined(__clang__) // Chance of arithmetic overflow could be reduced #pragma warning(disable : 26451) @@ -168,13 +169,15 @@ template class UniDirectionalGru { public: UniDirectionalGru(AllocatorPtr allocator, int seq_length, int batch_size, int input_size, int hidden_size, - bool linear_before_reset, Direction direction, const gsl::span& bias, - const gsl::span& initial_hidden_state, const ActivationFuncs::Entry& activation_func_f, + bool linear_before_reset, Direction direction, gsl::span bias, + gsl::span initial_hidden_state, const ActivationFuncs::Entry& activation_func_f, const ActivationFuncs::Entry& activation_func_g, float clip, onnxruntime::concurrency::ThreadPool* ttp); - void Compute(const gsl::span& inputs, const gsl::span& sequence_lengths, int num_directions, - const gsl::span& input_weights, const gsl::span& recurrent_weights, + void Compute(gsl::span inputs, gsl::span sequence_lengths, int num_directions, + const GemmWeights& input_weights, + const GemmWeights& recurrent_weights_ZR, + const GemmWeights& recurrent_weights_H, gsl::span& outputs, gsl::span& final_hidden_state); ~UniDirectionalGru() = default; @@ -249,6 +252,165 @@ class UniDirectionalGru { #define DumpMatrix(...) ((void)0) #endif +bool DeepCpuGruOp::TryPackInputWeights(const Tensor& weights, AllocatorPtr& alloc) { + const auto& shape = weights.Shape(); + if (shape.NumDimensions() != 3) { + return false; + } + + // weights: [num_directions, 3*hidden_size, input_size] + // recurrence weights: [num_directions, 3*hidden_size, hidden_size] + const auto num_directions = shape[0]; + if (num_directions != num_directions_) { + return false; + } + + const size_t N = static_cast(shape[1]); + const size_t K = static_cast(shape[2]); + + const size_t packed_weights_size = MlasGemmPackBSize(N, K); + if (packed_weights_size == 0) { + return false; + } + + const size_t buffer_size = SafeInt(packed_weights_size) * num_directions; + auto* packed_weights_data = alloc->Alloc(buffer_size); + // Initialize memory to 0 as there could be some padding associated with pre-packed + // buffer memory and we don not want it uninitialized and generate different hashes + // if and when we try to cache this pre-packed buffer for sharing between sessions. + memset(packed_weights_data, 0, buffer_size); + + pre_packed_input_weights_.buffer_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc)); + pre_packed_input_weights_.buffer_size_ = buffer_size; + pre_packed_input_weights_.shape_ = shape; + pre_packed_input_weights_.weights_size_ = packed_weights_size; + + const size_t N_x_K = N * K; + const auto* weights_data = weights.Data(); + for (int64_t dir = 0; dir < num_directions; ++dir) { + MlasGemmPackB(CblasTrans, N, K, weights_data, K, packed_weights_data); + weights_data += N_x_K; + packed_weights_data = static_cast(packed_weights_data) + packed_weights_size; + } + + return true; +} + +bool DeepCpuGruOp::TryPackRecurrentWeights(const Tensor& weights, AllocatorPtr& alloc) { + const auto& shape = weights.Shape(); + if (shape.NumDimensions() != 3) { + return false; + } + + // recurrence weights: [num_directions, 3*hidden_size, hidden_size] + const auto num_directions = shape[0]; + if (num_directions != num_directions_) { + return false; + } + + const auto N = shape[1]; + const auto K = shape[2]; + if (N != SafeInt(K) * 3 || K != hidden_size_) { + return false; + } + + const auto hidden_size_x_2 = N - hidden_size_; + + // We are making two packed buffers, one for ZR weights and another for H weights. + const size_t ZR_packed_size = MlasGemmPackBSize(narrow(hidden_size_x_2), narrow(K)); + if (ZR_packed_size == 0) { + return false; + } + + const size_t H_packed_size = MlasGemmPackBSize(narrow(hidden_size_), narrow(K)); + if (H_packed_size == 0) { + return false; + } + const size_t buffer_size_ZR = SafeInt(ZR_packed_size) * num_directions; + const size_t buffer_size_H = SafeInt(H_packed_size) * num_directions; + + auto* buffer_ZR = alloc->Alloc(buffer_size_ZR); + memset(buffer_ZR, 0, buffer_size_ZR); + + pre_packed_recurrent_ZR_.buffer_ = BufferUniquePtr(buffer_ZR, BufferDeleter(alloc)); + pre_packed_recurrent_ZR_.buffer_size_ = buffer_size_ZR; + pre_packed_recurrent_ZR_.shape_ = shape; // original shape, not used in prepacked calculations, but useful for validation + pre_packed_recurrent_ZR_.weights_size_ = ZR_packed_size; + + auto* buffer_H = alloc->Alloc(buffer_size_H); + memset(buffer_H, 0, buffer_size_H); + + pre_packed_recurrent_H_.buffer_ = BufferUniquePtr(buffer_H, BufferDeleter(alloc)); + pre_packed_recurrent_H_.buffer_size_ = buffer_size_H; + pre_packed_recurrent_H_.shape_ = shape; // original shape, not used in prepacked calculations, but useful for validation + pre_packed_recurrent_H_.weights_size_ = H_packed_size; + + const auto hidden_2_step = hidden_size_x_2 * K; + const auto hidden_1_step = hidden_size_ * K; // square + const auto* weights_data = weights.Data(); + MlasGemmPackB(CblasTrans, narrow(hidden_size_x_2), narrow(K), weights_data, narrow(K), buffer_ZR); + weights_data += hidden_2_step; + MlasGemmPackB(CblasTrans, narrow(hidden_size_), narrow(K), weights_data, narrow(K), buffer_H); + + if (num_directions == 2) { + weights_data += hidden_1_step; + buffer_ZR = static_cast(buffer_ZR) + ZR_packed_size; + MlasGemmPackB(CblasTrans, narrow(hidden_size_x_2), narrow(K), weights_data, narrow(K), buffer_ZR); + + weights_data += hidden_2_step; + buffer_H = static_cast(buffer_H) + H_packed_size; + MlasGemmPackB(CblasTrans, narrow(hidden_size_), narrow(K), weights_data, narrow(K), buffer_H); + } + + return true; +} + +Status DeepCpuGruOp::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, PrePackedWeights* prepacked_weights) { + is_packed = false; + + const bool share_prepacked_weights = (prepacked_weights != nullptr); + + // only pack float data type + if (tensor.IsDataType()) { + if (input_idx == 1) { + // input weights + is_packed = TryPackInputWeights(tensor, alloc); + if (is_packed && share_prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(pre_packed_input_weights_.buffer_)); + prepacked_weights->buffer_sizes_.push_back(pre_packed_input_weights_.buffer_size_); + } + } else if (input_idx == 2) { + // for two directions we need to split recurrent in two buffers + is_packed = TryPackRecurrentWeights(tensor, alloc); + if (is_packed && share_prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(pre_packed_recurrent_ZR_.buffer_)); + prepacked_weights->buffer_sizes_.push_back(pre_packed_recurrent_ZR_.buffer_size_); + prepacked_weights->buffers_.push_back(std::move(pre_packed_recurrent_H_.buffer_)); + prepacked_weights->buffer_sizes_.push_back(pre_packed_recurrent_H_.buffer_size_); + } + } + } + return Status::OK(); +} + +Status DeepCpuGruOp::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + int input_idx, + /*out*/ bool& used_shared_buffers) { + used_shared_buffers = false; + + if (input_idx == 1) { + pre_packed_input_weights_.buffer_ = std::move(prepacked_buffers[0]); + used_shared_buffers = true; + } else if (input_idx == 2) { + pre_packed_recurrent_ZR_.buffer_ = std::move(prepacked_buffers[0]); + pre_packed_recurrent_H_.buffer_ = std::move(prepacked_buffers[1]); + used_shared_buffers = true; + } + + return Status::OK(); +} + Status DeepCpuGruOp::Compute(OpKernelContext* context) const { const Tensor& X = *context->Input(0); // inputs. [seq_length, batch_size, input_size] @@ -270,9 +432,9 @@ template Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const { concurrency::ThreadPool* thread_pool = context.GetOperatorThreadPool(); - const Tensor& X = *context.Input(0); // inputs. [seq_length, batch_size, input_size] - const Tensor& W = *context.Input(1); // weights. [num_directions, 3*hidden_size, input_size] - const Tensor& R = *context.Input(2); // recurrence weights. [num_directions, 3*hidden_size, hidden_size] + const Tensor& X = *context.Input(0); // inputs. [seq_length, batch_size, input_size] + const Tensor* W = (pre_packed_input_weights_.buffer_) ? nullptr : context.Input(1); // weights. [num_directions, 3*hidden_size, input_size] + const Tensor* R = (pre_packed_recurrent_ZR_.buffer_) ? nullptr : context.Input(2); // recurrence weights. [num_directions, 3*hidden_size, hidden_size] // optional const auto* B = context.Input(3); // bias. [num_directions, 6*hidden_size] @@ -285,7 +447,13 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const { int batch_size = narrow(X_shape[1]); int input_size = narrow(X_shape[2]); - auto status = ValidateCommonRnnInputs(X, W.Shape(), R.Shape(), B, 3, sequence_lens, initial_h, num_directions_, hidden_size_); +//#ifdef _DEBUG +// std::cout << "GRU: seq_len: " << seq_length << " batch_size: " << batch_size << " input_size: " << input_size << std::endl; +//#endif + + const auto& W_shape = (W != nullptr) ? W->Shape() : pre_packed_input_weights_.shape_; + const auto& R_shape = (R != nullptr) ? R->Shape() : pre_packed_recurrent_ZR_.shape_; // original shape saved + auto status = ValidateCommonRnnInputs(X, W_shape, R_shape, B, 3, sequence_lens, initial_h, num_directions_, hidden_size_); ORT_RETURN_IF_ERROR(status); // GRU outputs are optional but must be in the same order @@ -308,17 +476,32 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const { AllocatorPtr alloc; status = context.GetTempSpaceAllocator(&alloc); ORT_RETURN_IF_ERROR(status); - gsl::span input_weights = W.DataAsSpan(); - gsl::span recurrent_weights = R.DataAsSpan(); + const auto* input_weights = (W != nullptr) ? W->Data() : nullptr; + const auto recurrent_weights = (R != nullptr) ? R->DataAsSpan() : gsl::span(); gsl::span bias = B != nullptr ? B->DataAsSpan() : gsl::span(); // spans for first direction const size_t input_weights_size_per_direction = 3 * hidden_size_ * input_size; - const size_t recurrent_weights_size_per_direction = 3 * hidden_size_ * hidden_size_; + const size_t recurrent_weights_size_per_direction_ZR = 2 * hidden_size_ * hidden_size_; + const size_t recurrent_weights_size_per_direction_H = hidden_size_ * hidden_size_; + const size_t recurrent_weights_size_per_direction = recurrent_weights_size_per_direction_ZR + recurrent_weights_size_per_direction_H; const size_t bias_size_per_direction = 6 * hidden_size_; - gsl::span input_weights_1 = input_weights.subspan(0, input_weights_size_per_direction); - gsl::span recurrent_weights_1 = recurrent_weights.subspan(0, recurrent_weights_size_per_direction); + GemmWeights input_weights_1(0, input_weights, input_weights_size_per_direction, pre_packed_input_weights_); + + GemmWeights recurrent_weights_ZR_1; + GemmWeights recurrent_weights_H_1; + if (R != nullptr) { + auto recurrent_ZR_span = recurrent_weights.subspan(0, recurrent_weights_size_per_direction_ZR); + auto recurrent_H_span = recurrent_weights.subspan(recurrent_weights_size_per_direction_ZR, recurrent_weights_size_per_direction_H); + recurrent_weights_ZR_1.Init(0, recurrent_ZR_span.data(), recurrent_ZR_span.size(), pre_packed_recurrent_ZR_, nullptr); + recurrent_weights_H_1.Init(0, recurrent_H_span.data(), recurrent_H_span.size(), pre_packed_recurrent_H_, nullptr); + } else { + // The data ptr and the size are taken from pre-packed buffer + recurrent_weights_ZR_1.Init(0, nullptr, 0, pre_packed_recurrent_ZR_, nullptr); + recurrent_weights_H_1.Init(0, nullptr, 0, pre_packed_recurrent_H_, nullptr); + } + gsl::span bias_1 = bias.empty() ? bias : bias.subspan(0, bias_size_per_direction); gsl::span input = X.DataAsSpan(); @@ -352,11 +535,23 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const { gsl::span hidden_output_1 = hidden_output.subspan(0, hidden_output_size_per_direction); if (direction_ == Direction::kBidirectional) { - // spans for second direction - gsl::span input_weights_2 = input_weights.subspan(input_weights_size_per_direction, - input_weights_size_per_direction); - gsl::span recurrent_weights_2 = recurrent_weights.subspan(recurrent_weights_size_per_direction, - recurrent_weights_size_per_direction); + GemmWeights input_weights_2(1, input_weights, input_weights_size_per_direction, pre_packed_input_weights_); + + GemmWeights recurrent_weights_ZR_2; + GemmWeights recurrent_weights_H_2; + if (R != nullptr) { + auto recurrent_ZR_span = recurrent_weights.subspan(recurrent_weights_size_per_direction, recurrent_weights_size_per_direction_ZR); + auto recurrent_H_span = recurrent_weights.subspan(recurrent_weights_size_per_direction + recurrent_weights_size_per_direction_ZR, + recurrent_weights_size_per_direction_H); + // Indices are zero since the span already provides the correct view even though we are taking the second direction weights + recurrent_weights_ZR_2.Init(0, recurrent_ZR_span.data(), recurrent_ZR_span.size(), pre_packed_recurrent_ZR_, nullptr); + recurrent_weights_H_2.Init(0, recurrent_H_span.data(), recurrent_H_span.size(), pre_packed_recurrent_H_, nullptr); + } else { + // The data ptr and the size are taken from pre-packed buffer + recurrent_weights_ZR_2.Init(1, nullptr, 0, pre_packed_recurrent_ZR_, nullptr); + recurrent_weights_H_2.Init(1, nullptr, 0, pre_packed_recurrent_H_, nullptr); + } + gsl::span bias_2 = bias.empty() ? bias : bias.subspan(bias_size_per_direction, bias_size_per_direction); gsl::span initial_hidden_2 = initial_hidden.empty() @@ -375,7 +570,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const { activation_funcs_.Entries()[0], activation_funcs_.Entries()[1], clip_, thread_pool); - fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, + fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_ZR_1, recurrent_weights_H_1, output_1, hidden_output_1); detail::UniDirectionalGru bw(alloc, seq_length, batch_size, input_size, hidden_size_, @@ -383,7 +578,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const { activation_funcs_.Entries()[2], activation_funcs_.Entries()[3], clip_, thread_pool); - bw.Compute(input, sequence_lens_span, num_directions_, input_weights_2, recurrent_weights_2, + bw.Compute(input, sequence_lens_span, num_directions_, input_weights_2, recurrent_weights_ZR_2, recurrent_weights_H_2, output_2, hidden_output_2); } else { detail::UniDirectionalGru gru_p(alloc, seq_length, batch_size, input_size, hidden_size_, @@ -391,7 +586,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const { activation_funcs_.Entries()[0], activation_funcs_.Entries()[1], clip_, thread_pool); - gru_p.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, + gru_p.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_ZR_1, recurrent_weights_H_1, output_1, hidden_output_1); } @@ -415,12 +610,12 @@ UniDirectionalGru::UniDirectionalGru(AllocatorPtr allocator, const int hidden_size, const bool linear_before_reset, Direction direction, - const gsl::span& bias, - const gsl::span& initial_hidden_state, + gsl::span bias, + gsl::span initial_hidden_state, const ActivationFuncs::Entry& activation_func_f, const ActivationFuncs::Entry& activation_func_g, const float clip, onnxruntime::concurrency::ThreadPool* ttp) - : allocator_(allocator), + : allocator_(std::move(allocator)), seq_length_(seq_length), batch_size_(batch_size), input_size_(input_size), @@ -488,11 +683,12 @@ UniDirectionalGru::UniDirectionalGru(AllocatorPtr allocator, } template -void UniDirectionalGru::Compute(const gsl::span& inputs_arg, - const gsl::span& sequence_lengths_arg, +void UniDirectionalGru::Compute(gsl::span inputs_arg, + gsl::span sequence_lengths_arg, const int num_directions, - const gsl::span& input_weights, - const gsl::span& recurrent_weights, + const GemmWeights& input_weights_s, + const GemmWeights& recurrent_weightsZR_s, + const GemmWeights& recurrent_weightsH_s, gsl::span& outputs, gsl::span& final_hidden_state) { using span_T_const_iter = typename gsl::span::iterator; @@ -508,12 +704,21 @@ void UniDirectionalGru::Compute(const gsl::span& inputs_arg, sequence_lengths = sequence_lengths_; } - DumpMatrix("Inputs", inputs.data(), seq_length_ * batch_size_, input_size_); - DumpMatrix("input_weights", input_weights.data(), 3 * hidden_size_, input_size_); - DumpMatrix("recurrent_weights", recurrent_weights.data(), 3 * hidden_size_, hidden_size_); + gsl::span input_weights; + if (!input_weights_s.is_prepacked_) { + input_weights = input_weights_s.GetUnpackedSpan(); + DumpMatrix("Inputs", inputs.data(), seq_length_ * batch_size_, input_size_); + DumpMatrix("input_weights", input_weights.data(), 3 * hidden_size_, input_size_); + DumpMatrix("recurrent_weights", recurrent_weights.data(), 3 * hidden_size_, hidden_size_); + } - gsl::span recurrent_weightsZR = recurrent_weights.subspan(0, 2 * hidden_size_ * hidden_size_); - gsl::span recurrent_weightsH = recurrent_weights.subspan(2 * hidden_size_ * hidden_size_, hidden_size_ * hidden_size_); + gsl::span recurrent_weightsZR; + if (!recurrent_weightsZR_s.is_prepacked_) + recurrent_weightsZR = recurrent_weightsZR_s.GetUnpackedSpan(); + + gsl::span recurrent_weightsH; + if (!recurrent_weightsH_s.is_prepacked_) + recurrent_weightsH = recurrent_weightsH_s.GetUnpackedSpan(); gsl::span original_outputs = outputs; const bool output_sequence = !outputs.empty(); @@ -541,13 +746,28 @@ void UniDirectionalGru::Compute(const gsl::span& inputs_arg, float alpha = 1.0f; // apply weights to all the inputs - ComputeGemm(total_rows, hidden_size_x3, input_size_, alpha, - inputs.begin(), inputs.end(), - input_size_, - input_weights.begin(), input_weights.end(), - input_size_, 0.f, - outputZRH_.begin(), outputZRH_.end(), - hidden_size_x3, ttp_); + if (!input_weights_s.is_prepacked_) { + ComputeGemm(total_rows, hidden_size_x3, input_size_, alpha, + inputs.begin(), inputs.end(), + input_size_, + input_weights.begin(), input_weights.end(), + input_size_, 0.f, + outputZRH_.begin(), outputZRH_.end(), + hidden_size_x3, ttp_); + } else { + MlasGemm( + CblasNoTrans, + static_cast(total_rows), + static_cast(hidden_size_x3), + static_cast(input_size_), + alpha, + inputs.data(), + static_cast(input_size_), + input_weights_s.buffer_, + 0.0f, + &*outputZRH_.begin(), + static_cast(hidden_size_x3), ttp_); + } DumpMatrix("inputs with weights applied", outputZRH_.data(), seq_length_ * batch_size_ * 3, hidden_size_); @@ -611,13 +831,25 @@ void UniDirectionalGru::Compute(const gsl::span& inputs_arg, // calculate Ht-1*R[zr], and add to the weighted inputs that are in outputZRH_ // Ht-1 * R[zr] + Xt*(W[zr]^T) - ComputeGemm(batch_size_, hidden_size_x2, hidden_size_, alpha, - prev_Ht, prev_Ht_end, - hidden_size_, - recurrent_weightsZR.begin(), recurrent_weightsZR.end(), - hidden_size_, 1.f, // beta == 1 so we add existing values in outputZRH_ - outputZRH_.begin() + out_added_offset, outputZRH_.end(), - hidden_size_x3, ttp_); + if (!recurrent_weightsZR_s.is_prepacked_) { + ComputeGemm(batch_size_, hidden_size_x2, hidden_size_, alpha, + prev_Ht, prev_Ht_end, + hidden_size_, + recurrent_weightsZR.begin(), recurrent_weightsZR.end(), + hidden_size_, 1.f, // beta == 1 so we add existing values in outputZRH_ + outputZRH_.begin() + out_added_offset, outputZRH_.end(), + hidden_size_x3, ttp_); + } else { + MlasGemm( + CblasNoTrans, + static_cast(batch_size_), static_cast(hidden_size_x2), static_cast(hidden_size_), alpha, + &*prev_Ht, + static_cast(hidden_size_), + recurrent_weightsZR_s.buffer_, + 1.f, + &*(outputZRH_.begin() + out_added_offset), + static_cast(hidden_size_x3), ttp_); + } DumpMatrix("Ht-1 * R[zr] + Xt*(W[zr]^T)" + seqno_str, outputZRH_.data() + out_added_offset, batch_size_, hidden_size_x2, 0, hidden_size_x3); @@ -631,15 +863,27 @@ void UniDirectionalGru::Compute(const gsl::span& inputs_arg, } // compute Ht-1 * (Rh^T) + Rbh - ComputeGemm(batch_size_, hidden_size_, hidden_size_, alpha, - prev_Ht, prev_Ht_end, // Ht-1 - hidden_size_, - recurrent_weightsH.begin(), recurrent_weightsH.end(), // Rh^T - hidden_size_, - use_bias_ ? 1.f : 0.f, // don't add values in linear_output_ if no bias input - linear_output_.begin(), - linear_output_.end(), // pre: Rbh if use_bias_, post:output - hidden_size_, ttp_); + if (!recurrent_weightsH_s.is_prepacked_) { + ComputeGemm(batch_size_, hidden_size_, hidden_size_, alpha, + prev_Ht, prev_Ht_end, // Ht-1 + hidden_size_, + recurrent_weightsH.begin(), recurrent_weightsH.end(), // Rh^T + hidden_size_, + use_bias_ ? 1.f : 0.f, // don't add values in linear_output_ if no bias input + linear_output_.begin(), + linear_output_.end(), // pre: Rbh if use_bias_, post:output + hidden_size_, ttp_); + } else { + MlasGemm( + CblasNoTrans, + static_cast(batch_size_), static_cast(hidden_size_), static_cast(hidden_size_), alpha, + &*prev_Ht, + static_cast(hidden_size_), + recurrent_weightsH_s.buffer_, + use_bias_ ? 1.f : 0.f, // don't add values in linear_output_ if no bias input + &*linear_output_.begin(), + static_cast(hidden_size_), ttp_); + } DumpMatrix("Ht-1 * (Rh^T) + Rbh " + seqno_str, linear_output_.data(), batch_size_, hidden_size_); } @@ -704,19 +948,31 @@ void UniDirectionalGru::Compute(const gsl::span& inputs_arg, auto out_H = outputZRH_.begin() + out_added_offset + hidden_size_x2; // Calculate Xt*(Wh^T) + rt (.) Ht-1 * Rh - ComputeGemm(batch_size_, hidden_size_, hidden_size_, alpha, - cur_h_local, cur_h_local_end, // rt (.) Ht-1 - hidden_size_, - recurrent_weightsH.begin(), recurrent_weightsH.end(), // Rh^T - hidden_size_, 1.f, // beta == 1 to add Xt*(Wh^T) from out_H - out_H, outputZRH_.end(), - hidden_size_x3, ttp_); + if (!recurrent_weightsH_s.is_prepacked_) { + ComputeGemm(batch_size_, hidden_size_, hidden_size_, alpha, + cur_h_local, cur_h_local_end, // rt (.) Ht-1 + hidden_size_, + recurrent_weightsH.begin(), recurrent_weightsH.end(), // Rh^T + hidden_size_, 1.f, // beta == 1 to add Xt*(Wh^T) from out_H + out_H, outputZRH_.end(), + hidden_size_x3, ttp_); + } else { + MlasGemm( + CblasNoTrans, + static_cast(batch_size_), static_cast(hidden_size_), static_cast(hidden_size_), alpha, + &*cur_h_local, + static_cast(hidden_size_), + recurrent_weightsH_s.buffer_, + 1.f, // beta == 1 to add Xt*(Wh^T) from out_H + &*out_H, + static_cast(hidden_size_x3), ttp_); + } } DumpMatrix("Xt*(Wh^T) + (" + label + ")" + seqno_str, outputZRH_.data() + out_added_offset, batch_size_, hidden_size_, hidden_size_x2, hidden_size_x3); - //2nd Set of Activations + // 2nd Set of Activations span_T_iter output; span_T_iter output_end; if (output_sequence) { diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h index 8d805f2a17..c893c40468 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h @@ -52,8 +52,8 @@ class DeepCpuGruOp final : public OpKernel { activation_func_betas); layout_ = info.GetAttrOrDefault("layout", static_cast(0)); - ORT_ENFORCE(layout_ == 0, - "Batchwise recurrent operations (layout == 1) are not supported. If you need support create a github issue with justification."); + ORT_ENFORCE(layout_ == 0, + "Batchwise recurrent operations (layout == 1) are not supported. If you need support create a github issue with justification."); } Status Compute(OpKernelContext* context) const override; @@ -61,16 +61,36 @@ class DeepCpuGruOp final : public OpKernel { ~DeepCpuGruOp() override = default; private: + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) override; + + Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + int input_idx, + /*out*/ bool& used_shared_buffers) override; + + bool TryPackInputWeights(const Tensor& weight, AllocatorPtr& alloc); + + bool TryPackRecurrentWeights(const Tensor& weights, AllocatorPtr& alloc); + rnn::detail::Direction direction_; int num_directions_; - int hidden_size_ {}; + int hidden_size_{}; float clip_; - int linear_before_reset_ {}; + int linear_before_reset_{}; int64_t layout_; rnn::detail::ActivationFuncs activation_funcs_; + // This kernel supports either forward or bidirectional + // This is split in half for bidirectional, but we prepack it in the same buffer + rnn::detail::PackedWeights pre_packed_input_weights_; + // recurrent_weights_ZR_ fwd, followed by bwd + rnn::detail::PackedWeights pre_packed_recurrent_ZR_; + // recurrent_weights_H_ fwd, followed by bwd + rnn::detail::PackedWeights pre_packed_recurrent_H_; + template Status ComputeImpl(OpKernelContext& context) const; }; diff --git a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h index 076e20430e..90d310ef34 100644 --- a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h +++ b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h @@ -209,11 +209,22 @@ struct GemmWeights { } else { is_prepacked_ = false; buffer_ = weights_data + weights_size * idx; + weights_size_ = weights_size; } } + /// + /// Get span + /// + /// + gsl::span GetUnpackedSpan() const { + ORT_ENFORCE(!is_prepacked_, "Can not get unpacked span from prepacked weights"); + return gsl::span(reinterpret_cast(buffer_), weights_size_); + } + bool is_prepacked_{false}; const void* buffer_{nullptr}; + size_t weights_size_{0}; QuantizationParameter* quant_para_{nullptr}; };