Get onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc from ort_training.

This commit is contained in:
Sergii Dymchenko 2020-04-09 17:55:52 -07:00
parent 6bbc80951d
commit c5176087bf

View file

@ -6,7 +6,6 @@
#include "core/framework/tensorprotoutils.h"
#include "fast_gelu.h"
#include "fast_gelu_impl.h"
#include "contrib_ops/cpu/bert/bias_gelu_helper.h"
namespace onnxruntime {
namespace contrib {
@ -33,15 +32,41 @@ FastGelu<T>::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel
}
template <typename T>
Status FastGelu<T>::ComputeInternal(OpKernelContext* context) const {
ORT_RETURN_IF_ERROR(bias_gelu_helper::CheckInputs(context));
Status FastGelu<T>::ComputeInternal(OpKernelContext* ctx) const {
const Tensor* input = ctx->Input<Tensor>(0);
const Tensor* input = context->Input<Tensor>(0);
const Tensor* bias = context->Input<Tensor>(1);
Tensor* output = context->Output(0, input->Shape());
const auto input_dims = input->Shape().GetDims();
if (input_dims.size() < 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 0 is expected to have 1 or more dimensions, got ", input_dims.size());
}
size_t num_inputs = OpKernel::Node().InputDefs().size();
bool has_bias = (num_inputs == 2);
int input_length = 1;
for (size_t i = 0; i < input_dims.size(); i++) {
input_length *= static_cast<int>(input_dims[i]);
}
int bias_length = 0;
const Tensor* bias = nullptr;
if (has_bias) {
bias = ctx->Input<Tensor>(1);
const auto bias_dims = bias->Shape().GetDims();
if (bias_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 1 is expected to have 1 dimensions, got ", bias_dims.size());
}
if (bias_dims[0] != input_dims[input_dims.size() - 1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 1 dimension 0 should have same length as the last dimension of input 0");
}
bias_length = static_cast<int>(bias_dims[0]);
}
Tensor* output = ctx->Output(0, input->Shape());
int64_t input_length = input->Shape().Size();
int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size();
typedef typename ToCudaType<T>::MappedType CudaT;
if (!LaunchFastGeluKernel<CudaT>(
GetDeviceProp(),