Undo edit to file

This commit is contained in:
Ryan Hill 2021-05-05 18:53:46 -07:00
parent 605508a071
commit e2856cfa21

View file

@ -10,7 +10,6 @@
namespace onnxruntime {
namespace contrib {
// These ops are internal-only, so register outside of onnx
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
@ -73,53 +72,51 @@ Status EmbedLayerNorm<T>::Compute(OpKernelContext* context) const {
std::atomic_bool failed{false};
int n = batch_size * sequence_length;
concurrency::ThreadPool::TryBatchParallelFor(
context->GetOperatorThreadPool(), n, [=, &failed](ptrdiff_t index) {
int word_col_index = input_ids_data[index];
if (word_col_index < 0 || word_col_index >= word_embedding_length) {
failed.store(true, std::memory_order_release);
return;
}
int position_col_index = index % sequence_length;
if (position_col_index >= position_embedding_length) {
failed.store(true, std::memory_order_release);
return;
}
int segment_col_index = 0;
if (nullptr != segment_ids_data) {
segment_col_index = segment_ids_data[index];
if (segment_col_index < 0 || segment_col_index >= segment_embedding_length) {
failed.store(true, std::memory_order_release);
return;
}
}
concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), n, [=, &failed](ptrdiff_t index) {
int word_col_index = input_ids_data[index];
if (word_col_index < 0 || word_col_index >= word_embedding_length) {
failed.store(true, std::memory_order_release);
return;
}
int position_col_index = index % sequence_length;
if (position_col_index >= position_embedding_length) {
failed.store(true, std::memory_order_release);
return;
}
int segment_col_index = 0;
if (nullptr != segment_ids_data) {
segment_col_index = segment_ids_data[index];
if (segment_col_index < 0 || segment_col_index >= segment_embedding_length) {
failed.store(true, std::memory_order_release);
return;
}
}
T* y = output_data + index * hidden_size;
const T* input_word_embedding = word_embedding_data + word_col_index * hidden_size;
const T* input_position_embedding = position_embedding_data + position_col_index * hidden_size;
const T* input_segment_embedding = (nullptr == segment_embedding_data) ? nullptr : segment_embedding_data + segment_col_index * hidden_size;
T* y = output_data + index * hidden_size;
const T* input_word_embedding = word_embedding_data + word_col_index * hidden_size;
const T* input_position_embedding = position_embedding_data + position_col_index * hidden_size;
const T* input_segment_embedding = (nullptr == segment_embedding_data) ? nullptr : segment_embedding_data + segment_col_index * hidden_size;
T sum = static_cast<T>(0);
for (int i = 0; i < hidden_size; i++) {
T subtotal = input_word_embedding[i] + input_position_embedding[i];
if (nullptr != segment_embedding_data)
subtotal += input_segment_embedding[i];
y[i] = subtotal;
sum += subtotal;
}
T mean = sum / hidden_size;
sum = 0;
for (int i = 0; i < hidden_size; i++) {
T a = y[i] - mean;
y[i] = a;
sum += a * a;
}
T e = sqrt(sum / hidden_size + static_cast<T>(epsilon_));
for (int i = 0; i < hidden_size; i++) {
y[i] = y[i] / e * gamma_data[i] + beta_data[i];
}
},
0);
T sum = static_cast<T>(0);
for (int i = 0; i < hidden_size; i++) {
T subtotal = input_word_embedding[i] + input_position_embedding[i];
if (nullptr != segment_embedding_data)
subtotal += input_segment_embedding[i];
y[i] = subtotal;
sum += subtotal;
}
T mean = sum / hidden_size;
sum = 0;
for (int i = 0; i < hidden_size; i++) {
T a = y[i] - mean;
y[i] = a;
sum += a * a;
}
T e = sqrt(sum / hidden_size + static_cast<T>(epsilon_));
for (int i = 0; i < hidden_size; i++) {
y[i] = y[i] / e * gamma_data[i] + beta_data[i];
}
}, 0);
if (failed.load(std::memory_order_acquire)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input index out of range");