mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Get onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc from ort_training.
This commit is contained in:
parent
6bbc80951d
commit
c5176087bf
1 changed files with 33 additions and 8 deletions
|
|
@ -6,7 +6,6 @@
|
|||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "fast_gelu.h"
|
||||
#include "fast_gelu_impl.h"
|
||||
#include "contrib_ops/cpu/bert/bias_gelu_helper.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -33,15 +32,41 @@ FastGelu<T>::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
Status FastGelu<T>::ComputeInternal(OpKernelContext* context) const {
|
||||
ORT_RETURN_IF_ERROR(bias_gelu_helper::CheckInputs(context));
|
||||
Status FastGelu<T>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
const Tensor* input = ctx->Input<Tensor>(0);
|
||||
|
||||
const Tensor* input = context->Input<Tensor>(0);
|
||||
const Tensor* bias = context->Input<Tensor>(1);
|
||||
Tensor* output = context->Output(0, input->Shape());
|
||||
const auto input_dims = input->Shape().GetDims();
|
||||
if (input_dims.size() < 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 0 is expected to have 1 or more dimensions, got ", input_dims.size());
|
||||
}
|
||||
|
||||
size_t num_inputs = OpKernel::Node().InputDefs().size();
|
||||
bool has_bias = (num_inputs == 2);
|
||||
|
||||
int input_length = 1;
|
||||
for (size_t i = 0; i < input_dims.size(); i++) {
|
||||
input_length *= static_cast<int>(input_dims[i]);
|
||||
}
|
||||
|
||||
int bias_length = 0;
|
||||
const Tensor* bias = nullptr;
|
||||
if (has_bias) {
|
||||
bias = ctx->Input<Tensor>(1);
|
||||
const auto bias_dims = bias->Shape().GetDims();
|
||||
if (bias_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 1 is expected to have 1 dimensions, got ", bias_dims.size());
|
||||
}
|
||||
if (bias_dims[0] != input_dims[input_dims.size() - 1]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 1 dimension 0 should have same length as the last dimension of input 0");
|
||||
}
|
||||
bias_length = static_cast<int>(bias_dims[0]);
|
||||
}
|
||||
|
||||
Tensor* output = ctx->Output(0, input->Shape());
|
||||
|
||||
int64_t input_length = input->Shape().Size();
|
||||
int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size();
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
if (!LaunchFastGeluKernel<CudaT>(
|
||||
GetDeviceProp(),
|
||||
|
|
|
|||
Loading…
Reference in a new issue