From e2856cfa21a640c0d9f8a04d2ec4d4fc2b2f2293 Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Wed, 5 May 2021 18:53:46 -0700 Subject: [PATCH] Undo edit to file --- .../contrib_ops/cpu/bert/embed_layer_norm.cc | 89 +++++++++---------- 1 file changed, 43 insertions(+), 46 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc b/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc index 5bc974711d..952ad156c1 100644 --- a/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc @@ -10,7 +10,6 @@ namespace onnxruntime { namespace contrib { - // These ops are internal-only, so register outside of onnx #define REGISTER_KERNEL_TYPED(T) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ @@ -73,53 +72,51 @@ Status EmbedLayerNorm::Compute(OpKernelContext* context) const { std::atomic_bool failed{false}; int n = batch_size * sequence_length; - concurrency::ThreadPool::TryBatchParallelFor( - context->GetOperatorThreadPool(), n, [=, &failed](ptrdiff_t index) { - int word_col_index = input_ids_data[index]; - if (word_col_index < 0 || word_col_index >= word_embedding_length) { - failed.store(true, std::memory_order_release); - return; - } - int position_col_index = index % sequence_length; - if (position_col_index >= position_embedding_length) { - failed.store(true, std::memory_order_release); - return; - } - int segment_col_index = 0; - if (nullptr != segment_ids_data) { - segment_col_index = segment_ids_data[index]; - if (segment_col_index < 0 || segment_col_index >= segment_embedding_length) { - failed.store(true, std::memory_order_release); - return; - } - } + concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), n, [=, &failed](ptrdiff_t index) { + int word_col_index = input_ids_data[index]; + if (word_col_index < 0 || word_col_index >= word_embedding_length) { + failed.store(true, std::memory_order_release); + return; + } + int position_col_index = index % sequence_length; + if (position_col_index >= position_embedding_length) { + failed.store(true, std::memory_order_release); + return; + } + int segment_col_index = 0; + if (nullptr != segment_ids_data) { + segment_col_index = segment_ids_data[index]; + if (segment_col_index < 0 || segment_col_index >= segment_embedding_length) { + failed.store(true, std::memory_order_release); + return; + } + } - T* y = output_data + index * hidden_size; - const T* input_word_embedding = word_embedding_data + word_col_index * hidden_size; - const T* input_position_embedding = position_embedding_data + position_col_index * hidden_size; - const T* input_segment_embedding = (nullptr == segment_embedding_data) ? nullptr : segment_embedding_data + segment_col_index * hidden_size; + T* y = output_data + index * hidden_size; + const T* input_word_embedding = word_embedding_data + word_col_index * hidden_size; + const T* input_position_embedding = position_embedding_data + position_col_index * hidden_size; + const T* input_segment_embedding = (nullptr == segment_embedding_data) ? nullptr : segment_embedding_data + segment_col_index * hidden_size; - T sum = static_cast(0); - for (int i = 0; i < hidden_size; i++) { - T subtotal = input_word_embedding[i] + input_position_embedding[i]; - if (nullptr != segment_embedding_data) - subtotal += input_segment_embedding[i]; - y[i] = subtotal; - sum += subtotal; - } - T mean = sum / hidden_size; - sum = 0; - for (int i = 0; i < hidden_size; i++) { - T a = y[i] - mean; - y[i] = a; - sum += a * a; - } - T e = sqrt(sum / hidden_size + static_cast(epsilon_)); - for (int i = 0; i < hidden_size; i++) { - y[i] = y[i] / e * gamma_data[i] + beta_data[i]; - } - }, - 0); + T sum = static_cast(0); + for (int i = 0; i < hidden_size; i++) { + T subtotal = input_word_embedding[i] + input_position_embedding[i]; + if (nullptr != segment_embedding_data) + subtotal += input_segment_embedding[i]; + y[i] = subtotal; + sum += subtotal; + } + T mean = sum / hidden_size; + sum = 0; + for (int i = 0; i < hidden_size; i++) { + T a = y[i] - mean; + y[i] = a; + sum += a * a; + } + T e = sqrt(sum / hidden_size + static_cast(epsilon_)); + for (int i = 0; i < hidden_size; i++) { + y[i] = y[i] / e * gamma_data[i] + beta_data[i]; + } + }, 0); if (failed.load(std::memory_order_acquire)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input index out of range");