diff --git a/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc b/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc index b3b77aa0e1..21a24e7e4c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc @@ -4,6 +4,9 @@ #include "embed_layer_norm.h" #include "embed_layer_norm_helper.h" #include "core/util/math_cpuonly.h" +#include "core/platform/threadpool.h" + +#include namespace onnxruntime { namespace contrib { @@ -60,40 +63,63 @@ Status EmbedLayerNorm::Compute(OpKernelContext* context) const { int position_embedding_length = static_cast(position_embedding->Shape()[0]); int segment_embedding_length = static_cast(segment_embedding->Shape()[0]); - ConstEigenArrayMap word_embedding_arr(word_embedding->template Data(), hidden_size, word_embedding_length); - ConstEigenArrayMap position_embedding_arr(position_embedding->template Data(), hidden_size, position_embedding_length); - ConstEigenArrayMap segment_embedding_arr(segment_embedding->template Data(), hidden_size, segment_embedding_length); - ConstEigenVectorMap gamma_vector(gamma->template Data(), hidden_size); - ConstEigenVectorMap beta_vector(beta->template Data(), hidden_size); - EigenArrayMap output_arr(output->template MutableData(), hidden_size, batch_size * sequence_length); + auto input_ids_data = input_ids->template Data(); + auto segment_ids_data = segment_ids->template Data(); + auto word_embedding_data = word_embedding->template Data(); + auto position_embedding_data = position_embedding->template Data(); + auto segment_embedding_data = segment_embedding->template Data(); + auto gamma_data = gamma->template Data(); + auto beta_data = beta->template Data(); + auto output_data = output->template MutableData(); // Calculate output { - size_t index = 0; - for (int b = 0; b < batch_size; b++) { - for (int s = 0; s < sequence_length; s++) { - int word_col_index = input_ids->template Data()[index]; - if (word_col_index < 0 || word_col_index >= word_embedding_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "word_col_index out of range"); - } - int position_col_index = s; - if (position_col_index >= position_embedding_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "position_col_index out of range"); - } - int segment_col_index = segment_ids->template Data()[index]; - if (segment_col_index < 0 || segment_col_index >= segment_embedding_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "segment_col_index out of range"); - } + std::atomic_bool failed{false}; - output_arr.col(index) = word_embedding_arr.col(word_col_index) + - position_embedding_arr.col(position_col_index) + - segment_embedding_arr.col(segment_col_index); - output_arr.col(index) -= output_arr.col(index).mean(); - output_arr.col(index) /= static_cast(sqrt(output_arr.col(index).pow(2).mean() + 1.0e-13)); - output_arr.col(index) *= gamma_vector.array(); - output_arr.col(index) += beta_vector.array(); - index++; + int n = batch_size * sequence_length; + concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), n, [=, &failed](int index) { + int word_col_index = input_ids_data[index]; + if (word_col_index < 0 || word_col_index >= word_embedding_length) { + failed.store(true, std::memory_order_release); + return; } + int position_col_index = index % sequence_length; + if (position_col_index >= position_embedding_length) { + failed.store(true, std::memory_order_release); + return; + } + int segment_col_index = segment_ids_data[index]; + if (segment_col_index < 0 || segment_col_index >= segment_embedding_length) { + failed.store(true, std::memory_order_release); + return; + } + + T* y = output_data + index * hidden_size; + const T* input_word_embedding = word_embedding_data + word_col_index * hidden_size; + const T* input_position_embedding = position_embedding_data + position_col_index * hidden_size; + const T* input_segment_embedding = segment_embedding_data + segment_col_index * hidden_size; + + T sum = static_cast(0); + for (int i = 0; i < hidden_size; i++) { + T subtotal = input_word_embedding[i] + input_position_embedding[i] + input_segment_embedding[i]; + y[i] = subtotal; + sum += subtotal; + } + T mean = sum / hidden_size; + sum = 0; + for (int i = 0; i < hidden_size; i++) { + T a = y[i] - mean; + y[i] = a; + sum += a * a; + } + T e = sqrt(sum / hidden_size + static_cast(1.0e-13)); + for (int i = 0; i < hidden_size; i++) { + y[i] = y[i] / e * gamma_data[i] + beta_data[i]; + } + }); + + if (failed.load(std::memory_order_acquire)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input index out of range"); } }