optimize CPU implementation of EmbedLayerNorm (#2491)

* optimize CPU implementation of EmbedLayerNorm
* use atomic in parallelization
This commit is contained in:
Yulong Wang 2019-11-27 12:34:57 -08:00 committed by Tianlei Wu
parent e57b735bb9
commit ccbd778d0d

View file

@ -4,6 +4,9 @@
#include "embed_layer_norm.h"
#include "embed_layer_norm_helper.h"
#include "core/util/math_cpuonly.h"
#include "core/platform/threadpool.h"
#include <atomic>
namespace onnxruntime {
namespace contrib {
@ -60,40 +63,63 @@ Status EmbedLayerNorm<T>::Compute(OpKernelContext* context) const {
int position_embedding_length = static_cast<int>(position_embedding->Shape()[0]);
int segment_embedding_length = static_cast<int>(segment_embedding->Shape()[0]);
ConstEigenArrayMap<T> word_embedding_arr(word_embedding->template Data<T>(), hidden_size, word_embedding_length);
ConstEigenArrayMap<T> position_embedding_arr(position_embedding->template Data<T>(), hidden_size, position_embedding_length);
ConstEigenArrayMap<T> segment_embedding_arr(segment_embedding->template Data<T>(), hidden_size, segment_embedding_length);
ConstEigenVectorMap<T> gamma_vector(gamma->template Data<T>(), hidden_size);
ConstEigenVectorMap<T> beta_vector(beta->template Data<T>(), hidden_size);
EigenArrayMap<T> output_arr(output->template MutableData<T>(), hidden_size, batch_size * sequence_length);
auto input_ids_data = input_ids->template Data<int>();
auto segment_ids_data = segment_ids->template Data<int>();
auto word_embedding_data = word_embedding->template Data<T>();
auto position_embedding_data = position_embedding->template Data<T>();
auto segment_embedding_data = segment_embedding->template Data<T>();
auto gamma_data = gamma->template Data<T>();
auto beta_data = beta->template Data<T>();
auto output_data = output->template MutableData<T>();
// Calculate output
{
size_t index = 0;
for (int b = 0; b < batch_size; b++) {
for (int s = 0; s < sequence_length; s++) {
int word_col_index = input_ids->template Data<int>()[index];
if (word_col_index < 0 || word_col_index >= word_embedding_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "word_col_index out of range");
}
int position_col_index = s;
if (position_col_index >= position_embedding_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "position_col_index out of range");
}
int segment_col_index = segment_ids->template Data<int>()[index];
if (segment_col_index < 0 || segment_col_index >= segment_embedding_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "segment_col_index out of range");
}
std::atomic_bool failed{false};
output_arr.col(index) = word_embedding_arr.col(word_col_index) +
position_embedding_arr.col(position_col_index) +
segment_embedding_arr.col(segment_col_index);
output_arr.col(index) -= output_arr.col(index).mean();
output_arr.col(index) /= static_cast<T>(sqrt(output_arr.col(index).pow(2).mean() + 1.0e-13));
output_arr.col(index) *= gamma_vector.array();
output_arr.col(index) += beta_vector.array();
index++;
int n = batch_size * sequence_length;
concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), n, [=, &failed](int index) {
int word_col_index = input_ids_data[index];
if (word_col_index < 0 || word_col_index >= word_embedding_length) {
failed.store(true, std::memory_order_release);
return;
}
int position_col_index = index % sequence_length;
if (position_col_index >= position_embedding_length) {
failed.store(true, std::memory_order_release);
return;
}
int segment_col_index = segment_ids_data[index];
if (segment_col_index < 0 || segment_col_index >= segment_embedding_length) {
failed.store(true, std::memory_order_release);
return;
}
T* y = output_data + index * hidden_size;
const T* input_word_embedding = word_embedding_data + word_col_index * hidden_size;
const T* input_position_embedding = position_embedding_data + position_col_index * hidden_size;
const T* input_segment_embedding = segment_embedding_data + segment_col_index * hidden_size;
T sum = static_cast<T>(0);
for (int i = 0; i < hidden_size; i++) {
T subtotal = input_word_embedding[i] + input_position_embedding[i] + input_segment_embedding[i];
y[i] = subtotal;
sum += subtotal;
}
T mean = sum / hidden_size;
sum = 0;
for (int i = 0; i < hidden_size; i++) {
T a = y[i] - mean;
y[i] = a;
sum += a * a;
}
T e = sqrt(sum / hidden_size + static_cast<T>(1.0e-13));
for (int i = 0; i < hidden_size; i++) {
y[i] = y[i] / e * gamma_data[i] + beta_data[i];
}
});
if (failed.load(std::memory_order_acquire)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input index out of range");
}
}