mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
optimize CPU implementation of EmbedLayerNorm (#2491)
* optimize CPU implementation of EmbedLayerNorm * use atomic in parallelization
This commit is contained in:
parent
e57b735bb9
commit
ccbd778d0d
1 changed files with 55 additions and 29 deletions
|
|
@ -4,6 +4,9 @@
|
|||
#include "embed_layer_norm.h"
|
||||
#include "embed_layer_norm_helper.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
#include "core/platform/threadpool.h"
|
||||
|
||||
#include <atomic>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -60,40 +63,63 @@ Status EmbedLayerNorm<T>::Compute(OpKernelContext* context) const {
|
|||
int position_embedding_length = static_cast<int>(position_embedding->Shape()[0]);
|
||||
int segment_embedding_length = static_cast<int>(segment_embedding->Shape()[0]);
|
||||
|
||||
ConstEigenArrayMap<T> word_embedding_arr(word_embedding->template Data<T>(), hidden_size, word_embedding_length);
|
||||
ConstEigenArrayMap<T> position_embedding_arr(position_embedding->template Data<T>(), hidden_size, position_embedding_length);
|
||||
ConstEigenArrayMap<T> segment_embedding_arr(segment_embedding->template Data<T>(), hidden_size, segment_embedding_length);
|
||||
ConstEigenVectorMap<T> gamma_vector(gamma->template Data<T>(), hidden_size);
|
||||
ConstEigenVectorMap<T> beta_vector(beta->template Data<T>(), hidden_size);
|
||||
EigenArrayMap<T> output_arr(output->template MutableData<T>(), hidden_size, batch_size * sequence_length);
|
||||
auto input_ids_data = input_ids->template Data<int>();
|
||||
auto segment_ids_data = segment_ids->template Data<int>();
|
||||
auto word_embedding_data = word_embedding->template Data<T>();
|
||||
auto position_embedding_data = position_embedding->template Data<T>();
|
||||
auto segment_embedding_data = segment_embedding->template Data<T>();
|
||||
auto gamma_data = gamma->template Data<T>();
|
||||
auto beta_data = beta->template Data<T>();
|
||||
auto output_data = output->template MutableData<T>();
|
||||
|
||||
// Calculate output
|
||||
{
|
||||
size_t index = 0;
|
||||
for (int b = 0; b < batch_size; b++) {
|
||||
for (int s = 0; s < sequence_length; s++) {
|
||||
int word_col_index = input_ids->template Data<int>()[index];
|
||||
if (word_col_index < 0 || word_col_index >= word_embedding_length) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "word_col_index out of range");
|
||||
}
|
||||
int position_col_index = s;
|
||||
if (position_col_index >= position_embedding_length) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "position_col_index out of range");
|
||||
}
|
||||
int segment_col_index = segment_ids->template Data<int>()[index];
|
||||
if (segment_col_index < 0 || segment_col_index >= segment_embedding_length) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "segment_col_index out of range");
|
||||
}
|
||||
std::atomic_bool failed{false};
|
||||
|
||||
output_arr.col(index) = word_embedding_arr.col(word_col_index) +
|
||||
position_embedding_arr.col(position_col_index) +
|
||||
segment_embedding_arr.col(segment_col_index);
|
||||
output_arr.col(index) -= output_arr.col(index).mean();
|
||||
output_arr.col(index) /= static_cast<T>(sqrt(output_arr.col(index).pow(2).mean() + 1.0e-13));
|
||||
output_arr.col(index) *= gamma_vector.array();
|
||||
output_arr.col(index) += beta_vector.array();
|
||||
index++;
|
||||
int n = batch_size * sequence_length;
|
||||
concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), n, [=, &failed](int index) {
|
||||
int word_col_index = input_ids_data[index];
|
||||
if (word_col_index < 0 || word_col_index >= word_embedding_length) {
|
||||
failed.store(true, std::memory_order_release);
|
||||
return;
|
||||
}
|
||||
int position_col_index = index % sequence_length;
|
||||
if (position_col_index >= position_embedding_length) {
|
||||
failed.store(true, std::memory_order_release);
|
||||
return;
|
||||
}
|
||||
int segment_col_index = segment_ids_data[index];
|
||||
if (segment_col_index < 0 || segment_col_index >= segment_embedding_length) {
|
||||
failed.store(true, std::memory_order_release);
|
||||
return;
|
||||
}
|
||||
|
||||
T* y = output_data + index * hidden_size;
|
||||
const T* input_word_embedding = word_embedding_data + word_col_index * hidden_size;
|
||||
const T* input_position_embedding = position_embedding_data + position_col_index * hidden_size;
|
||||
const T* input_segment_embedding = segment_embedding_data + segment_col_index * hidden_size;
|
||||
|
||||
T sum = static_cast<T>(0);
|
||||
for (int i = 0; i < hidden_size; i++) {
|
||||
T subtotal = input_word_embedding[i] + input_position_embedding[i] + input_segment_embedding[i];
|
||||
y[i] = subtotal;
|
||||
sum += subtotal;
|
||||
}
|
||||
T mean = sum / hidden_size;
|
||||
sum = 0;
|
||||
for (int i = 0; i < hidden_size; i++) {
|
||||
T a = y[i] - mean;
|
||||
y[i] = a;
|
||||
sum += a * a;
|
||||
}
|
||||
T e = sqrt(sum / hidden_size + static_cast<T>(1.0e-13));
|
||||
for (int i = 0; i < hidden_size; i++) {
|
||||
y[i] = y[i] / e * gamma_data[i] + beta_data[i];
|
||||
}
|
||||
});
|
||||
|
||||
if (failed.load(std::memory_order_acquire)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input index out of range");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue