diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc index d9dd7a63c0..1d9fe64336 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc @@ -6,7 +6,6 @@ #include "core/framework/tensorprotoutils.h" #include "fast_gelu.h" #include "fast_gelu_impl.h" -#include "contrib_ops/cpu/bert/bias_gelu_helper.h" namespace onnxruntime { namespace contrib { @@ -33,15 +32,41 @@ FastGelu::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel } template -Status FastGelu::ComputeInternal(OpKernelContext* context) const { - ORT_RETURN_IF_ERROR(bias_gelu_helper::CheckInputs(context)); +Status FastGelu::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* input = ctx->Input(0); - const Tensor* input = context->Input(0); - const Tensor* bias = context->Input(1); - Tensor* output = context->Output(0, input->Shape()); + const auto input_dims = input->Shape().GetDims(); + if (input_dims.size() < 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 0 is expected to have 1 or more dimensions, got ", input_dims.size()); + } + + size_t num_inputs = OpKernel::Node().InputDefs().size(); + bool has_bias = (num_inputs == 2); + + int input_length = 1; + for (size_t i = 0; i < input_dims.size(); i++) { + input_length *= static_cast(input_dims[i]); + } + + int bias_length = 0; + const Tensor* bias = nullptr; + if (has_bias) { + bias = ctx->Input(1); + const auto bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 1 is expected to have 1 dimensions, got ", bias_dims.size()); + } + if (bias_dims[0] != input_dims[input_dims.size() - 1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 1 dimension 0 should have same length as the last dimension of input 0"); + } + bias_length = static_cast(bias_dims[0]); + } + + Tensor* output = ctx->Output(0, input->Shape()); - int64_t input_length = input->Shape().Size(); - int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size(); typedef typename ToCudaType::MappedType CudaT; if (!LaunchFastGeluKernel( GetDeviceProp(),