Improve cache locality and perf of DeepGru on CPU (#13582)

### Description
<!-- Describe your changes. -->
Introduce Gemm weights pre-pack.

### Motivation and Context
A 1-P customer requested a performance improvement for DeepGru which
consumes a bulk of CPU in their model. This provides measurable
performance improvements.

Customer model numbers.

gru: mean = 356 us; 1ms = 99.8 prctile; 99th prctile = 665 ms
(yuslepukhin/deep_gru_opt)
main: mean = 375 us; 1ms = 99.8 prctile; 99th prctile = 695 ms (where
yuslepukhin/deep_gru_opt branched off main)
1.13.1: mean = 391 us; 1ms = 99.6 prctile; 99th prctile = 744 ms
This commit is contained in:
Dmitri Smirnov 2022-11-09 09:59:38 -08:00 committed by GitHub
parent e0361e6256
commit bbedf2c4c5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 356 additions and 69 deletions

View file

@ -10,11 +10,12 @@
#endif
#include "core/providers/cpu/rnn/deep_cpu_gru.h"
#include "core/common/narrow.h"
#ifdef _MSC_VER
#pragma warning(pop)
#endif
//TODO: fix the warnings
// TODO: fix the warnings
#if defined(_MSC_VER) && !defined(__clang__)
// Chance of arithmetic overflow could be reduced
#pragma warning(disable : 26451)
@ -168,13 +169,15 @@ template <typename T>
class UniDirectionalGru {
public:
UniDirectionalGru(AllocatorPtr allocator, int seq_length, int batch_size, int input_size, int hidden_size,
bool linear_before_reset, Direction direction, const gsl::span<const T>& bias,
const gsl::span<const T>& initial_hidden_state, const ActivationFuncs::Entry& activation_func_f,
bool linear_before_reset, Direction direction, gsl::span<const T> bias,
gsl::span<const T> initial_hidden_state, const ActivationFuncs::Entry& activation_func_f,
const ActivationFuncs::Entry& activation_func_g, float clip,
onnxruntime::concurrency::ThreadPool* ttp);
void Compute(const gsl::span<const T>& inputs, const gsl::span<const int>& sequence_lengths, int num_directions,
const gsl::span<const T>& input_weights, const gsl::span<const T>& recurrent_weights,
void Compute(gsl::span<const T> inputs, gsl::span<const int> sequence_lengths, int num_directions,
const GemmWeights<T>& input_weights,
const GemmWeights<T>& recurrent_weights_ZR,
const GemmWeights<T>& recurrent_weights_H,
gsl::span<T>& outputs, gsl::span<T>& final_hidden_state);
~UniDirectionalGru() = default;
@ -249,6 +252,165 @@ class UniDirectionalGru {
#define DumpMatrix(...) ((void)0)
#endif
bool DeepCpuGruOp::TryPackInputWeights(const Tensor& weights, AllocatorPtr& alloc) {
const auto& shape = weights.Shape();
if (shape.NumDimensions() != 3) {
return false;
}
// weights: [num_directions, 3*hidden_size, input_size]
// recurrence weights: [num_directions, 3*hidden_size, hidden_size]
const auto num_directions = shape[0];
if (num_directions != num_directions_) {
return false;
}
const size_t N = static_cast<size_t>(shape[1]);
const size_t K = static_cast<size_t>(shape[2]);
const size_t packed_weights_size = MlasGemmPackBSize(N, K);
if (packed_weights_size == 0) {
return false;
}
const size_t buffer_size = SafeInt<size_t>(packed_weights_size) * num_directions;
auto* packed_weights_data = alloc->Alloc(buffer_size);
// Initialize memory to 0 as there could be some padding associated with pre-packed
// buffer memory and we don not want it uninitialized and generate different hashes
// if and when we try to cache this pre-packed buffer for sharing between sessions.
memset(packed_weights_data, 0, buffer_size);
pre_packed_input_weights_.buffer_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
pre_packed_input_weights_.buffer_size_ = buffer_size;
pre_packed_input_weights_.shape_ = shape;
pre_packed_input_weights_.weights_size_ = packed_weights_size;
const size_t N_x_K = N * K;
const auto* weights_data = weights.Data<float>();
for (int64_t dir = 0; dir < num_directions; ++dir) {
MlasGemmPackB(CblasTrans, N, K, weights_data, K, packed_weights_data);
weights_data += N_x_K;
packed_weights_data = static_cast<uint8_t*>(packed_weights_data) + packed_weights_size;
}
return true;
}
bool DeepCpuGruOp::TryPackRecurrentWeights(const Tensor& weights, AllocatorPtr& alloc) {
const auto& shape = weights.Shape();
if (shape.NumDimensions() != 3) {
return false;
}
// recurrence weights: [num_directions, 3*hidden_size, hidden_size]
const auto num_directions = shape[0];
if (num_directions != num_directions_) {
return false;
}
const auto N = shape[1];
const auto K = shape[2];
if (N != SafeInt<int64_t>(K) * 3 || K != hidden_size_) {
return false;
}
const auto hidden_size_x_2 = N - hidden_size_;
// We are making two packed buffers, one for ZR weights and another for H weights.
const size_t ZR_packed_size = MlasGemmPackBSize(narrow<size_t>(hidden_size_x_2), narrow<size_t>(K));
if (ZR_packed_size == 0) {
return false;
}
const size_t H_packed_size = MlasGemmPackBSize(narrow<size_t>(hidden_size_), narrow<size_t>(K));
if (H_packed_size == 0) {
return false;
}
const size_t buffer_size_ZR = SafeInt<size_t>(ZR_packed_size) * num_directions;
const size_t buffer_size_H = SafeInt<size_t>(H_packed_size) * num_directions;
auto* buffer_ZR = alloc->Alloc(buffer_size_ZR);
memset(buffer_ZR, 0, buffer_size_ZR);
pre_packed_recurrent_ZR_.buffer_ = BufferUniquePtr(buffer_ZR, BufferDeleter(alloc));
pre_packed_recurrent_ZR_.buffer_size_ = buffer_size_ZR;
pre_packed_recurrent_ZR_.shape_ = shape; // original shape, not used in prepacked calculations, but useful for validation
pre_packed_recurrent_ZR_.weights_size_ = ZR_packed_size;
auto* buffer_H = alloc->Alloc(buffer_size_H);
memset(buffer_H, 0, buffer_size_H);
pre_packed_recurrent_H_.buffer_ = BufferUniquePtr(buffer_H, BufferDeleter(alloc));
pre_packed_recurrent_H_.buffer_size_ = buffer_size_H;
pre_packed_recurrent_H_.shape_ = shape; // original shape, not used in prepacked calculations, but useful for validation
pre_packed_recurrent_H_.weights_size_ = H_packed_size;
const auto hidden_2_step = hidden_size_x_2 * K;
const auto hidden_1_step = hidden_size_ * K; // square
const auto* weights_data = weights.Data<float>();
MlasGemmPackB(CblasTrans, narrow<size_t>(hidden_size_x_2), narrow<size_t>(K), weights_data, narrow<size_t>(K), buffer_ZR);
weights_data += hidden_2_step;
MlasGemmPackB(CblasTrans, narrow<size_t>(hidden_size_), narrow<size_t>(K), weights_data, narrow<size_t>(K), buffer_H);
if (num_directions == 2) {
weights_data += hidden_1_step;
buffer_ZR = static_cast<uint8_t*>(buffer_ZR) + ZR_packed_size;
MlasGemmPackB(CblasTrans, narrow<size_t>(hidden_size_x_2), narrow<size_t>(K), weights_data, narrow<size_t>(K), buffer_ZR);
weights_data += hidden_2_step;
buffer_H = static_cast<uint8_t*>(buffer_H) + H_packed_size;
MlasGemmPackB(CblasTrans, narrow<size_t>(hidden_size_), narrow<size_t>(K), weights_data, narrow<size_t>(K), buffer_H);
}
return true;
}
Status DeepCpuGruOp::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
bool& is_packed, PrePackedWeights* prepacked_weights) {
is_packed = false;
const bool share_prepacked_weights = (prepacked_weights != nullptr);
// only pack float data type
if (tensor.IsDataType<float>()) {
if (input_idx == 1) {
// input weights
is_packed = TryPackInputWeights(tensor, alloc);
if (is_packed && share_prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(pre_packed_input_weights_.buffer_));
prepacked_weights->buffer_sizes_.push_back(pre_packed_input_weights_.buffer_size_);
}
} else if (input_idx == 2) {
// for two directions we need to split recurrent in two buffers
is_packed = TryPackRecurrentWeights(tensor, alloc);
if (is_packed && share_prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(pre_packed_recurrent_ZR_.buffer_));
prepacked_weights->buffer_sizes_.push_back(pre_packed_recurrent_ZR_.buffer_size_);
prepacked_weights->buffers_.push_back(std::move(pre_packed_recurrent_H_.buffer_));
prepacked_weights->buffer_sizes_.push_back(pre_packed_recurrent_H_.buffer_size_);
}
}
}
return Status::OK();
}
Status DeepCpuGruOp::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
int input_idx,
/*out*/ bool& used_shared_buffers) {
used_shared_buffers = false;
if (input_idx == 1) {
pre_packed_input_weights_.buffer_ = std::move(prepacked_buffers[0]);
used_shared_buffers = true;
} else if (input_idx == 2) {
pre_packed_recurrent_ZR_.buffer_ = std::move(prepacked_buffers[0]);
pre_packed_recurrent_H_.buffer_ = std::move(prepacked_buffers[1]);
used_shared_buffers = true;
}
return Status::OK();
}
Status DeepCpuGruOp::Compute(OpKernelContext* context) const {
const Tensor& X = *context->Input<Tensor>(0); // inputs. [seq_length, batch_size, input_size]
@ -270,9 +432,9 @@ template <typename T>
Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const {
concurrency::ThreadPool* thread_pool = context.GetOperatorThreadPool();
const Tensor& X = *context.Input<Tensor>(0); // inputs. [seq_length, batch_size, input_size]
const Tensor& W = *context.Input<Tensor>(1); // weights. [num_directions, 3*hidden_size, input_size]
const Tensor& R = *context.Input<Tensor>(2); // recurrence weights. [num_directions, 3*hidden_size, hidden_size]
const Tensor& X = *context.Input<Tensor>(0); // inputs. [seq_length, batch_size, input_size]
const Tensor* W = (pre_packed_input_weights_.buffer_) ? nullptr : context.Input<Tensor>(1); // weights. [num_directions, 3*hidden_size, input_size]
const Tensor* R = (pre_packed_recurrent_ZR_.buffer_) ? nullptr : context.Input<Tensor>(2); // recurrence weights. [num_directions, 3*hidden_size, hidden_size]
// optional
const auto* B = context.Input<Tensor>(3); // bias. [num_directions, 6*hidden_size]
@ -285,7 +447,13 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const {
int batch_size = narrow<int>(X_shape[1]);
int input_size = narrow<int>(X_shape[2]);
auto status = ValidateCommonRnnInputs(X, W.Shape(), R.Shape(), B, 3, sequence_lens, initial_h, num_directions_, hidden_size_);
//#ifdef _DEBUG
// std::cout << "GRU: seq_len: " << seq_length << " batch_size: " << batch_size << " input_size: " << input_size << std::endl;
//#endif
const auto& W_shape = (W != nullptr) ? W->Shape() : pre_packed_input_weights_.shape_;
const auto& R_shape = (R != nullptr) ? R->Shape() : pre_packed_recurrent_ZR_.shape_; // original shape saved
auto status = ValidateCommonRnnInputs(X, W_shape, R_shape, B, 3, sequence_lens, initial_h, num_directions_, hidden_size_);
ORT_RETURN_IF_ERROR(status);
// GRU outputs are optional but must be in the same order
@ -308,17 +476,32 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const {
AllocatorPtr alloc;
status = context.GetTempSpaceAllocator(&alloc);
ORT_RETURN_IF_ERROR(status);
gsl::span<const T> input_weights = W.DataAsSpan<T>();
gsl::span<const T> recurrent_weights = R.DataAsSpan<T>();
const auto* input_weights = (W != nullptr) ? W->Data<T>() : nullptr;
const auto recurrent_weights = (R != nullptr) ? R->DataAsSpan<T>() : gsl::span<const T>();
gsl::span<const T> bias = B != nullptr ? B->DataAsSpan<T>() : gsl::span<const T>();
// spans for first direction
const size_t input_weights_size_per_direction = 3 * hidden_size_ * input_size;
const size_t recurrent_weights_size_per_direction = 3 * hidden_size_ * hidden_size_;
const size_t recurrent_weights_size_per_direction_ZR = 2 * hidden_size_ * hidden_size_;
const size_t recurrent_weights_size_per_direction_H = hidden_size_ * hidden_size_;
const size_t recurrent_weights_size_per_direction = recurrent_weights_size_per_direction_ZR + recurrent_weights_size_per_direction_H;
const size_t bias_size_per_direction = 6 * hidden_size_;
gsl::span<const T> input_weights_1 = input_weights.subspan(0, input_weights_size_per_direction);
gsl::span<const T> recurrent_weights_1 = recurrent_weights.subspan(0, recurrent_weights_size_per_direction);
GemmWeights<T> input_weights_1(0, input_weights, input_weights_size_per_direction, pre_packed_input_weights_);
GemmWeights<T> recurrent_weights_ZR_1;
GemmWeights<T> recurrent_weights_H_1;
if (R != nullptr) {
auto recurrent_ZR_span = recurrent_weights.subspan(0, recurrent_weights_size_per_direction_ZR);
auto recurrent_H_span = recurrent_weights.subspan(recurrent_weights_size_per_direction_ZR, recurrent_weights_size_per_direction_H);
recurrent_weights_ZR_1.Init(0, recurrent_ZR_span.data(), recurrent_ZR_span.size(), pre_packed_recurrent_ZR_, nullptr);
recurrent_weights_H_1.Init(0, recurrent_H_span.data(), recurrent_H_span.size(), pre_packed_recurrent_H_, nullptr);
} else {
// The data ptr and the size are taken from pre-packed buffer
recurrent_weights_ZR_1.Init(0, nullptr, 0, pre_packed_recurrent_ZR_, nullptr);
recurrent_weights_H_1.Init(0, nullptr, 0, pre_packed_recurrent_H_, nullptr);
}
gsl::span<const T> bias_1 = bias.empty() ? bias : bias.subspan(0, bias_size_per_direction);
gsl::span<const T> input = X.DataAsSpan<T>();
@ -352,11 +535,23 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const {
gsl::span<T> hidden_output_1 = hidden_output.subspan(0, hidden_output_size_per_direction);
if (direction_ == Direction::kBidirectional) {
// spans for second direction
gsl::span<const T> input_weights_2 = input_weights.subspan(input_weights_size_per_direction,
input_weights_size_per_direction);
gsl::span<const T> recurrent_weights_2 = recurrent_weights.subspan(recurrent_weights_size_per_direction,
recurrent_weights_size_per_direction);
GemmWeights<T> input_weights_2(1, input_weights, input_weights_size_per_direction, pre_packed_input_weights_);
GemmWeights<T> recurrent_weights_ZR_2;
GemmWeights<T> recurrent_weights_H_2;
if (R != nullptr) {
auto recurrent_ZR_span = recurrent_weights.subspan(recurrent_weights_size_per_direction, recurrent_weights_size_per_direction_ZR);
auto recurrent_H_span = recurrent_weights.subspan(recurrent_weights_size_per_direction + recurrent_weights_size_per_direction_ZR,
recurrent_weights_size_per_direction_H);
// Indices are zero since the span already provides the correct view even though we are taking the second direction weights
recurrent_weights_ZR_2.Init(0, recurrent_ZR_span.data(), recurrent_ZR_span.size(), pre_packed_recurrent_ZR_, nullptr);
recurrent_weights_H_2.Init(0, recurrent_H_span.data(), recurrent_H_span.size(), pre_packed_recurrent_H_, nullptr);
} else {
// The data ptr and the size are taken from pre-packed buffer
recurrent_weights_ZR_2.Init(1, nullptr, 0, pre_packed_recurrent_ZR_, nullptr);
recurrent_weights_H_2.Init(1, nullptr, 0, pre_packed_recurrent_H_, nullptr);
}
gsl::span<const T> bias_2 = bias.empty() ? bias : bias.subspan(bias_size_per_direction, bias_size_per_direction);
gsl::span<const T> initial_hidden_2 = initial_hidden.empty()
@ -375,7 +570,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const {
activation_funcs_.Entries()[0],
activation_funcs_.Entries()[1],
clip_, thread_pool);
fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1,
fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_ZR_1, recurrent_weights_H_1,
output_1, hidden_output_1);
detail::UniDirectionalGru<T> bw(alloc, seq_length, batch_size, input_size, hidden_size_,
@ -383,7 +578,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const {
activation_funcs_.Entries()[2],
activation_funcs_.Entries()[3],
clip_, thread_pool);
bw.Compute(input, sequence_lens_span, num_directions_, input_weights_2, recurrent_weights_2,
bw.Compute(input, sequence_lens_span, num_directions_, input_weights_2, recurrent_weights_ZR_2, recurrent_weights_H_2,
output_2, hidden_output_2);
} else {
detail::UniDirectionalGru<T> gru_p(alloc, seq_length, batch_size, input_size, hidden_size_,
@ -391,7 +586,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const {
activation_funcs_.Entries()[0],
activation_funcs_.Entries()[1],
clip_, thread_pool);
gru_p.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1,
gru_p.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_ZR_1, recurrent_weights_H_1,
output_1, hidden_output_1);
}
@ -415,12 +610,12 @@ UniDirectionalGru<T>::UniDirectionalGru(AllocatorPtr allocator,
const int hidden_size,
const bool linear_before_reset,
Direction direction,
const gsl::span<const T>& bias,
const gsl::span<const T>& initial_hidden_state,
gsl::span<const T> bias,
gsl::span<const T> initial_hidden_state,
const ActivationFuncs::Entry& activation_func_f,
const ActivationFuncs::Entry& activation_func_g,
const float clip, onnxruntime::concurrency::ThreadPool* ttp)
: allocator_(allocator),
: allocator_(std::move(allocator)),
seq_length_(seq_length),
batch_size_(batch_size),
input_size_(input_size),
@ -488,11 +683,12 @@ UniDirectionalGru<T>::UniDirectionalGru(AllocatorPtr allocator,
}
template <typename T>
void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
const gsl::span<const int>& sequence_lengths_arg,
void UniDirectionalGru<T>::Compute(gsl::span<const T> inputs_arg,
gsl::span<const int> sequence_lengths_arg,
const int num_directions,
const gsl::span<const T>& input_weights,
const gsl::span<const T>& recurrent_weights,
const GemmWeights<T>& input_weights_s,
const GemmWeights<T>& recurrent_weightsZR_s,
const GemmWeights<T>& recurrent_weightsH_s,
gsl::span<T>& outputs,
gsl::span<T>& final_hidden_state) {
using span_T_const_iter = typename gsl::span<const T>::iterator;
@ -508,12 +704,21 @@ void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
sequence_lengths = sequence_lengths_;
}
DumpMatrix("Inputs", inputs.data(), seq_length_ * batch_size_, input_size_);
DumpMatrix("input_weights", input_weights.data(), 3 * hidden_size_, input_size_);
DumpMatrix("recurrent_weights", recurrent_weights.data(), 3 * hidden_size_, hidden_size_);
gsl::span<const T> input_weights;
if (!input_weights_s.is_prepacked_) {
input_weights = input_weights_s.GetUnpackedSpan();
DumpMatrix("Inputs", inputs.data(), seq_length_ * batch_size_, input_size_);
DumpMatrix("input_weights", input_weights.data(), 3 * hidden_size_, input_size_);
DumpMatrix("recurrent_weights", recurrent_weights.data(), 3 * hidden_size_, hidden_size_);
}
gsl::span<const T> recurrent_weightsZR = recurrent_weights.subspan(0, 2 * hidden_size_ * hidden_size_);
gsl::span<const T> recurrent_weightsH = recurrent_weights.subspan(2 * hidden_size_ * hidden_size_, hidden_size_ * hidden_size_);
gsl::span<const T> recurrent_weightsZR;
if (!recurrent_weightsZR_s.is_prepacked_)
recurrent_weightsZR = recurrent_weightsZR_s.GetUnpackedSpan();
gsl::span<const T> recurrent_weightsH;
if (!recurrent_weightsH_s.is_prepacked_)
recurrent_weightsH = recurrent_weightsH_s.GetUnpackedSpan();
gsl::span<T> original_outputs = outputs;
const bool output_sequence = !outputs.empty();
@ -541,13 +746,28 @@ void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
float alpha = 1.0f;
// apply weights to all the inputs
ComputeGemm(total_rows, hidden_size_x3, input_size_, alpha,
inputs.begin(), inputs.end(),
input_size_,
input_weights.begin(), input_weights.end(),
input_size_, 0.f,
outputZRH_.begin(), outputZRH_.end(),
hidden_size_x3, ttp_);
if (!input_weights_s.is_prepacked_) {
ComputeGemm(total_rows, hidden_size_x3, input_size_, alpha,
inputs.begin(), inputs.end(),
input_size_,
input_weights.begin(), input_weights.end(),
input_size_, 0.f,
outputZRH_.begin(), outputZRH_.end(),
hidden_size_x3, ttp_);
} else {
MlasGemm(
CblasNoTrans,
static_cast<size_t>(total_rows),
static_cast<size_t>(hidden_size_x3),
static_cast<size_t>(input_size_),
alpha,
inputs.data(),
static_cast<size_t>(input_size_),
input_weights_s.buffer_,
0.0f,
&*outputZRH_.begin(),
static_cast<size_t>(hidden_size_x3), ttp_);
}
DumpMatrix("inputs with weights applied", outputZRH_.data(), seq_length_ * batch_size_ * 3, hidden_size_);
@ -611,13 +831,25 @@ void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
// calculate Ht-1*R[zr], and add to the weighted inputs that are in outputZRH_
// Ht-1 * R[zr] + Xt*(W[zr]^T)
ComputeGemm(batch_size_, hidden_size_x2, hidden_size_, alpha,
prev_Ht, prev_Ht_end,
hidden_size_,
recurrent_weightsZR.begin(), recurrent_weightsZR.end(),
hidden_size_, 1.f, // beta == 1 so we add existing values in outputZRH_
outputZRH_.begin() + out_added_offset, outputZRH_.end(),
hidden_size_x3, ttp_);
if (!recurrent_weightsZR_s.is_prepacked_) {
ComputeGemm(batch_size_, hidden_size_x2, hidden_size_, alpha,
prev_Ht, prev_Ht_end,
hidden_size_,
recurrent_weightsZR.begin(), recurrent_weightsZR.end(),
hidden_size_, 1.f, // beta == 1 so we add existing values in outputZRH_
outputZRH_.begin() + out_added_offset, outputZRH_.end(),
hidden_size_x3, ttp_);
} else {
MlasGemm(
CblasNoTrans,
static_cast<size_t>(batch_size_), static_cast<size_t>(hidden_size_x2), static_cast<size_t>(hidden_size_), alpha,
&*prev_Ht,
static_cast<size_t>(hidden_size_),
recurrent_weightsZR_s.buffer_,
1.f,
&*(outputZRH_.begin() + out_added_offset),
static_cast<size_t>(hidden_size_x3), ttp_);
}
DumpMatrix("Ht-1 * R[zr] + Xt*(W[zr]^T)" + seqno_str,
outputZRH_.data() + out_added_offset, batch_size_, hidden_size_x2, 0, hidden_size_x3);
@ -631,15 +863,27 @@ void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
}
// compute Ht-1 * (Rh^T) + Rbh
ComputeGemm(batch_size_, hidden_size_, hidden_size_, alpha,
prev_Ht, prev_Ht_end, // Ht-1
hidden_size_,
recurrent_weightsH.begin(), recurrent_weightsH.end(), // Rh^T
hidden_size_,
use_bias_ ? 1.f : 0.f, // don't add values in linear_output_ if no bias input
linear_output_.begin(),
linear_output_.end(), // pre: Rbh if use_bias_, post:output
hidden_size_, ttp_);
if (!recurrent_weightsH_s.is_prepacked_) {
ComputeGemm(batch_size_, hidden_size_, hidden_size_, alpha,
prev_Ht, prev_Ht_end, // Ht-1
hidden_size_,
recurrent_weightsH.begin(), recurrent_weightsH.end(), // Rh^T
hidden_size_,
use_bias_ ? 1.f : 0.f, // don't add values in linear_output_ if no bias input
linear_output_.begin(),
linear_output_.end(), // pre: Rbh if use_bias_, post:output
hidden_size_, ttp_);
} else {
MlasGemm(
CblasNoTrans,
static_cast<size_t>(batch_size_), static_cast<size_t>(hidden_size_), static_cast<size_t>(hidden_size_), alpha,
&*prev_Ht,
static_cast<size_t>(hidden_size_),
recurrent_weightsH_s.buffer_,
use_bias_ ? 1.f : 0.f, // don't add values in linear_output_ if no bias input
&*linear_output_.begin(),
static_cast<size_t>(hidden_size_), ttp_);
}
DumpMatrix("Ht-1 * (Rh^T) + Rbh " + seqno_str, linear_output_.data(), batch_size_, hidden_size_);
}
@ -704,19 +948,31 @@ void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
auto out_H = outputZRH_.begin() + out_added_offset + hidden_size_x2;
// Calculate Xt*(Wh^T) + rt (.) Ht-1 * Rh
ComputeGemm(batch_size_, hidden_size_, hidden_size_, alpha,
cur_h_local, cur_h_local_end, // rt (.) Ht-1
hidden_size_,
recurrent_weightsH.begin(), recurrent_weightsH.end(), // Rh^T
hidden_size_, 1.f, // beta == 1 to add Xt*(Wh^T) from out_H
out_H, outputZRH_.end(),
hidden_size_x3, ttp_);
if (!recurrent_weightsH_s.is_prepacked_) {
ComputeGemm(batch_size_, hidden_size_, hidden_size_, alpha,
cur_h_local, cur_h_local_end, // rt (.) Ht-1
hidden_size_,
recurrent_weightsH.begin(), recurrent_weightsH.end(), // Rh^T
hidden_size_, 1.f, // beta == 1 to add Xt*(Wh^T) from out_H
out_H, outputZRH_.end(),
hidden_size_x3, ttp_);
} else {
MlasGemm(
CblasNoTrans,
static_cast<size_t>(batch_size_), static_cast<size_t>(hidden_size_), static_cast<size_t>(hidden_size_), alpha,
&*cur_h_local,
static_cast<size_t>(hidden_size_),
recurrent_weightsH_s.buffer_,
1.f, // beta == 1 to add Xt*(Wh^T) from out_H
&*out_H,
static_cast<size_t>(hidden_size_x3), ttp_);
}
}
DumpMatrix("Xt*(Wh^T) + (" + label + ")" + seqno_str, outputZRH_.data() + out_added_offset,
batch_size_, hidden_size_, hidden_size_x2, hidden_size_x3);
//2nd Set of Activations
// 2nd Set of Activations
span_T_iter output;
span_T_iter output_end;
if (output_sequence) {

View file

@ -52,8 +52,8 @@ class DeepCpuGruOp final : public OpKernel {
activation_func_betas);
layout_ = info.GetAttrOrDefault("layout", static_cast<int64_t>(0));
ORT_ENFORCE(layout_ == 0,
"Batchwise recurrent operations (layout == 1) are not supported. If you need support create a github issue with justification.");
ORT_ENFORCE(layout_ == 0,
"Batchwise recurrent operations (layout == 1) are not supported. If you need support create a github issue with justification.");
}
Status Compute(OpKernelContext* context) const override;
@ -61,16 +61,36 @@ class DeepCpuGruOp final : public OpKernel {
~DeepCpuGruOp() override = default;
private:
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override;
Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
int input_idx,
/*out*/ bool& used_shared_buffers) override;
bool TryPackInputWeights(const Tensor& weight, AllocatorPtr& alloc);
bool TryPackRecurrentWeights(const Tensor& weights, AllocatorPtr& alloc);
rnn::detail::Direction direction_;
int num_directions_;
int hidden_size_ {};
int hidden_size_{};
float clip_;
int linear_before_reset_ {};
int linear_before_reset_{};
int64_t layout_;
rnn::detail::ActivationFuncs activation_funcs_;
// This kernel supports either forward or bidirectional
// This is split in half for bidirectional, but we prepack it in the same buffer
rnn::detail::PackedWeights pre_packed_input_weights_;
// recurrent_weights_ZR_ fwd, followed by bwd
rnn::detail::PackedWeights pre_packed_recurrent_ZR_;
// recurrent_weights_H_ fwd, followed by bwd
rnn::detail::PackedWeights pre_packed_recurrent_H_;
template <typename T>
Status ComputeImpl(OpKernelContext& context) const;
};

View file

@ -209,11 +209,22 @@ struct GemmWeights {
} else {
is_prepacked_ = false;
buffer_ = weights_data + weights_size * idx;
weights_size_ = weights_size;
}
}
/// <summary>
/// Get span
/// </summary>
/// <returns></returns>
gsl::span<const T> GetUnpackedSpan() const {
ORT_ENFORCE(!is_prepacked_, "Can not get unpacked span from prepacked weights");
return gsl::span<const T>(reinterpret_cast<const T*>(buffer_), weights_size_);
}
bool is_prepacked_{false};
const void* buffer_{nullptr};
size_t weights_size_{0};
QuantizationParameter* quant_para_{nullptr};
};