onnxruntime/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
Yufeng Li c08d6b42e8
Add tool to support packing mode for BERT model (#15283)
### Description
<!-- Describe your changes. -->
Add a tool to convert fused BERT like model to packing mode


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2023-03-31 08:46:47 -07:00

155 lines
5.7 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/tensor.h"
#include "core/util/math_cpuonly.h"
#include "core/providers/common.h"
#include "core/platform/threadpool.h"
#include "skip_layer_norm.h"
namespace onnxruntime {
namespace contrib {
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
SkipLayerNormalization, \
kMSDomain, \
1, \
T, \
kCpuExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
SkipLayerNorm<T>);
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(double)
template <typename T>
SkipLayerNorm<T>::SkipLayerNorm(const OpKernelInfo& op_kernel_info)
: OpKernel(op_kernel_info) {
ORT_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &epsilon_).IsOK());
ORT_ENFORCE(epsilon_ >= 0);
}
template <typename T>
Status SkipLayerNorm<T>::Compute(OpKernelContext* p_ctx) const {
const Tensor* input = p_ctx->Input<Tensor>(0);
const Tensor* skip = p_ctx->Input<Tensor>(1);
const Tensor* gamma = p_ctx->Input<Tensor>(2);
const Tensor* beta = p_ctx->Input<Tensor>(3);
const Tensor* bias = p_ctx->Input<Tensor>(4);
Tensor* output = p_ctx->Output(0, input->Shape());
// For inferencing, we support one more optional output which is the sum
// of the input and skip tensors
Tensor* skip_input_bias_add_output = p_ctx->Output(3, input->Shape());
const auto& input_dims = input->Shape().GetDims();
size_t input_dims_size = input_dims.size();
if (input_dims_size != 3 && input_dims_size != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"input is expected to have 3 or 2 dimensions, got ", input_dims_size);
}
int hidden_size = static_cast<int>(input_dims[input_dims_size - 1]);
if (input->Shape() != skip->Shape()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"skip is expected to have same shape as input");
}
const auto& gamma_dims = gamma->Shape().GetDims();
if (gamma_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"gamma is expected to have 1 dimension, got ", gamma_dims.size());
}
if (gamma_dims[0] != hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Last dimension of gamma and input does not match");
}
if (nullptr != beta) {
const auto& beta_dims = beta->Shape().GetDims();
if (beta_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"beta is expected to have 1 dimension, got ", beta_dims.size());
}
if (beta_dims[0] != hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Last dimension of beta and input does not match");
}
}
if (nullptr != bias) {
const auto& bias_dims = bias->Shape().GetDims();
if (bias_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"bias is expected to have 1 dimension, got ", bias_dims.size());
}
if (bias_dims[0] != hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Last dimension of bias and input does not match");
}
}
int64_t task_count = input->Shape().SizeToDimension(input_dims_size - 1);
const T* input_data = input->Data<T>();
const T* skip_data = skip->Data<T>();
const T* gamma_data = gamma->Data<T>();
const T* beta_data = beta == nullptr ? nullptr : beta->Data<T>();
const T* bias_data = bias == nullptr ? nullptr : bias->Data<T>();
T* output_data = output->MutableData<T>();
// For inferencing, we support one more optional output which is the sum
// of the input and skip tensors
T* skip_input_bias_add_output_data = skip_input_bias_add_output != nullptr ? skip_input_bias_add_output->MutableData<T>() : nullptr;
concurrency::ThreadPool::TryBatchParallelFor(
p_ctx->GetOperatorThreadPool(), static_cast<int32_t>(task_count),
[&](ptrdiff_t task_idx) {
auto offset = task_idx * hidden_size;
const T* p_input = input_data + offset;
const T* p_skip = skip_data + offset;
T* p_output = output_data + offset;
T* p_skip_input_bias_add_output_data = skip_input_bias_add_output_data != nullptr ? skip_input_bias_add_output_data + offset : nullptr;
T mean = 0;
T mean_square = 0;
for (int64_t h = 0; h < hidden_size; h++) {
T value = p_input[h] + p_skip[h];
if (nullptr != bias_data) {
value += bias_data[h];
}
if (nullptr != p_skip_input_bias_add_output_data) {
p_skip_input_bias_add_output_data[h] = value;
}
p_output[h] = value;
mean += value;
mean_square += value * value;
}
mean = mean / hidden_size;
mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon_);
for (int64_t h = 0; h < hidden_size; h++) {
if (nullptr == beta_data) {
p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h];
} else {
p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h] + beta_data[h];
}
}
},
0);
return Status::OK();
}
} // namespace contrib
} // namespace onnxruntime