diff --git a/onnxruntime/contrib_ops/cpu/activations.cc b/onnxruntime/contrib_ops/cpu/activations.cc index e337b9bb0d..28f96dbe61 100644 --- a/onnxruntime/contrib_ops/cpu/activations.cc +++ b/onnxruntime/contrib_ops/cpu/activations.cc @@ -25,5 +25,13 @@ ONNX_CPU_OPERATOR_KERNEL( KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()), ThresholdedRelu); -} // namespace contrib +ONNX_OPERATOR_KERNEL_EX( + Gelu, + kMSDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()), + Gelu); + +} // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/activations.h b/onnxruntime/contrib_ops/cpu/activations.h index 1e617f5f8b..19f730cbd2 100644 --- a/onnxruntime/contrib_ops/cpu/activations.h +++ b/onnxruntime/contrib_ops/cpu/activations.h @@ -6,6 +6,7 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" #include "core/util/math_cpuonly.h" +#include namespace onnxruntime { namespace contrib { @@ -28,5 +29,19 @@ class ScaledTanh final : public OpKernel { const float beta_; }; +template +class Gelu : public OpKernel { + public: + Gelu(const OpKernelInfo& info) : OpKernel(info) {} + + Status Compute(OpKernelContext* context) const override { + const auto* X = context->Input(0); + Tensor* Y = context->Output(0, X->Shape()); + EIGEN_X_VAR(xm); + EIGEN_Y = xm * 0.5f * ((xm * static_cast(M_SQRT1_2)).erf() + 1.0f); + return Status::OK(); + } +}; + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu_contrib_kernels.cc index ec454f04bf..e4d573b604 100644 --- a/onnxruntime/contrib_ops/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu_contrib_kernels.cc @@ -29,8 +29,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, CDist); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, CDist); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gelu); -// This section includes all opkernel declarations for former experimental ops which have now been removed from onnx. +// This section includes all op kernel declarations for former experimental ops which have now been removed from onnx. // To maintain backward compatibility these are added as contrib ops. // Note: the domain for all contrib ops should be MSDomain. However since these ops started out as onnx domain ops // we cannot change the domain now as this will break backward compatibility. @@ -55,7 +56,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Par class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ScaledTanh); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ThresholdedRelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale); - class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, ReorderInput); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, ReorderOutput); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, Conv); @@ -104,6 +104,7 @@ void RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // These ops were experimental ops in onnx domain which have been removed now. We add them here as // contrib ops to main backward compatibility @@ -127,7 +128,8 @@ void RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo}; + BuildKernelCreateInfo, + }; for (auto& function_table_entry : function_table) { kernel_registry.Register(function_table_entry()); diff --git a/onnxruntime/contrib_ops/cuda/activation/activations.cc b/onnxruntime/contrib_ops/cuda/activation/activations.cc index c7c8ea663c..a9d399d431 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations.cc +++ b/onnxruntime/contrib_ops/cuda/activation/activations.cc @@ -10,10 +10,10 @@ namespace onnxruntime { namespace contrib { namespace cuda { -#define REGISTER_ACTIVATION_KERNEL(x, ver, T) \ +#define REGISTER_ACTIVATION_KERNEL(x, ver, domain, T) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ x, \ - kOnnxDomain, \ + domain, \ ver, \ T, \ kCudaExecutionProvider, \ @@ -37,23 +37,23 @@ namespace cuda { return Status::OK(); \ } -#define UNARY_ACTIVATION_OP_TYPED(name, ver, T) \ - REGISTER_ACTIVATION_KERNEL(name, ver, T) \ +#define UNARY_ACTIVATION_OP_TYPED(name, ver, domain, T) \ + REGISTER_ACTIVATION_KERNEL(name, ver, domain, T) \ UNARY_ACTIVATION_COMPUTE(name, T) -#define UNARY_ACTIVATION_OP_HFD(name, ver) \ - UNARY_ACTIVATION_OP_TYPED(name, ver, MLFloat16) \ - UNARY_ACTIVATION_OP_TYPED(name, ver, float) \ - UNARY_ACTIVATION_OP_TYPED(name, ver, double) +#define UNARY_ACTIVATION_OP_HFD(name, ver, domain) \ + UNARY_ACTIVATION_OP_TYPED(name, ver, domain, MLFloat16) \ + UNARY_ACTIVATION_OP_TYPED(name, ver, domain, float) \ + UNARY_ACTIVATION_OP_TYPED(name, ver, domain, double) -UNARY_ACTIVATION_OP_HFD(Affine, 1); -UNARY_ACTIVATION_OP_HFD(ParametricSoftplus, 1); -UNARY_ACTIVATION_OP_HFD(ScaledTanh, 1); +UNARY_ACTIVATION_OP_HFD(Affine, 1, kOnnxDomain); +UNARY_ACTIVATION_OP_HFD(ParametricSoftplus, 1, kOnnxDomain); +UNARY_ACTIVATION_OP_HFD(ScaledTanh, 1, kOnnxDomain); +UNARY_ACTIVATION_OP_HFD(Gelu, 1, kMSDomain); - -REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, MLFloat16) -REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, float) -REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, double) +REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, MLFloat16) +REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, float) +REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, double) } //namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/activation/activations.h b/onnxruntime/contrib_ops/cuda/activation/activations.h index e27c46cc22..763e6cb922 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations.h +++ b/onnxruntime/contrib_ops/cuda/activation/activations.h @@ -66,6 +66,17 @@ class ScaledTanh final : public UnaryElementwise { float beta_; }; +template +class Gelu final : public UnaryElementwise { + public: + Gelu(const OpKernelInfo& info) : UnaryElementwise(info) {} + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + MAKE_FUNC_CTX_NULL() +}; + } // namespace cuda -} //namespace contrib +} // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu index 833c9c0da4..59fce9e081 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu +++ b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu @@ -36,6 +36,13 @@ struct OP_ScaledTanh : public CtxScaledTanh { } }; +template +struct OP_Gelu : public CtxGelu { + __device__ __inline__ T operator()(const T& a) const { + return a * _Normcdf(a); + } +}; + #define UNARY_ACTIVATION_IMPL(name) \ UNARY_ACTIVATION_IMPL_DECLARATION(name) { \ UnaryElementWiseImpl(input_data, \ diff --git a/onnxruntime/contrib_ops/cuda/activation/activations_impl.h b/onnxruntime/contrib_ops/cuda/activation/activations_impl.h index 066f29f81a..95ea6d5af6 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations_impl.h +++ b/onnxruntime/contrib_ops/cuda/activation/activations_impl.h @@ -11,11 +11,13 @@ namespace cuda { typedef onnxruntime::cuda::CtxAlphaBeta CtxAffine; typedef onnxruntime::cuda::CtxAlphaBeta CtxParametricSoftplus; typedef onnxruntime::cuda::CtxAlphaBeta CtxScaledTanh; +typedef onnxruntime::cuda::CtxNull CtxGelu; -#define UNARY_CONTRIB_ACTIVATION_OPS() \ - UNARY_ACTIVATION_OP_NAME(ScaledTanh) \ - UNARY_ACTIVATION_OP_NAME(Affine) \ - UNARY_ACTIVATION_OP_NAME(ParametricSoftplus) +#define UNARY_CONTRIB_ACTIVATION_OPS() \ + UNARY_ACTIVATION_OP_NAME(ScaledTanh) \ + UNARY_ACTIVATION_OP_NAME(Affine) \ + UNARY_ACTIVATION_OP_NAME(ParametricSoftplus) \ + UNARY_ACTIVATION_OP_NAME(Gelu) #define UNARY_ACTIVATION_IMPL_DECLARATION(name) \ template \ diff --git a/onnxruntime/contrib_ops/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda_contrib_kernels.cc index e6f318df6d..fa318c8f35 100644 --- a/onnxruntime/contrib_ops/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda_contrib_kernels.cc @@ -9,6 +9,9 @@ using namespace onnxruntime::common; namespace onnxruntime { namespace contrib { namespace cuda { +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Gelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Gelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Gelu); // These ops were experimental ops in onnx domain which have been removed now. We add them here as // contrib ops to maintain backward compatibility @@ -34,10 +37,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu); - void RegisterCudaContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { - + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // These ops were experimental ops in onnx domain which have been removed now. We add them here as // contrib ops to maintain backward compatibility BuildKernelCreateInfo, @@ -60,7 +65,7 @@ void RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index aa609de571..4fa8c76c65 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -9,6 +9,7 @@ #include "core/graph/op.h" #include "onnx/defs/schema.h" #include "onnx/defs/shape_inference.h" +#include "onnx/defs/function.h" #include "core/mlas/inc/mlas.h" #ifdef MICROSOFT_INTERNAL @@ -1723,10 +1724,23 @@ Example 4: RegisterNchwcSchemas(); } + ONNX_CONTRIB_OPERATOR_SCHEMA(Gelu) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL) + .SetDoc("Gelu") + .Input(0, "X", "The input data as Tensor.", "T") + .Output(0, "Y", "The output.", "T") + .TypeConstraint( + "T", + {"tensor(float16)", "tensor(float)", "tensor(double)"}, + "Constrain input and output types to float tensors.") + .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput); + #ifdef MICROSOFT_INTERNAL // register internal ops RegisterInternalSchemas(); #endif } } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/gelu_fusion.cc b/onnxruntime/core/optimizer/gelu_fusion.cc new file mode 100644 index 0000000000..6d09891f98 --- /dev/null +++ b/onnxruntime/core/optimizer/gelu_fusion.cc @@ -0,0 +1,179 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/initializer.h" +#include "core/optimizer/gelu_fusion.h" +#include "core/graph/graph_utils.h" +#include "float.h" +#include + +using namespace ONNX_NAMESPACE; +using namespace ::onnxruntime::common; +namespace onnxruntime { + +static bool CheckConstantInput(const Graph& graph, const NodeArg& input_arg, float expected_value) { + auto shape = input_arg.Shape(); + auto dim_size = shape->dim_size(); + if (dim_size != 0) { + // only check scalar. + return false; + } + + const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); + if (tensor_proto == nullptr) { + return false; + } + + auto init_const = std::make_unique(tensor_proto); + const auto data_type = tensor_proto->data_type(); + if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + float* val = init_const->data(); + float diff = std::abs(val[0] - static_cast(expected_value)); + if (diff > FLT_EPSILON) { + return false; + } + } else if (data_type == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE) { + double* val = init_const->data(); + double diff = std::abs(val[0] - static_cast(expected_value)); + if (diff > DBL_EPSILON) { + return false; + } + } else if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + MLFloat16* val = init_const->data(); + float diff = std::abs(math::halfToFloat(val[0].val) - static_cast(expected_value)); + if (diff > FLT_EPSILON) { + return false; + } + } + + return true; +} + +// Gelu supports limited data types. +static std::vector supported_data_types{"tensor(float16)", "tensor(float)", "tensor(double)"}; + +static bool IsSupportedDataType(const Node& node) { + for (const auto& input_arg : node.InputDefs()) { + if (std::find(supported_data_types.begin(), supported_data_types.end(), + *(input_arg->Type())) == supported_data_types.end()) { + return false; + } + } + return true; +} + +Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + std::deque removed_nodes; + + for (auto node_index : node_topology_list) { + auto& div = *graph.GetNode(node_index); + ORT_RETURN_IF_ERROR(Recurse(div, modified, graph_level)); + + if (!graph_utils::IsSupportedOptypeVersionAndDomain(div, "Div", {7}) || + !graph_utils::IsSupportedProvider(div, GetCompatibleExecutionProviders()) || + div.GetOutputEdgesCount() != 1 || + !IsSupportedDataType(div)) { + continue; + } + + // Check second input is sqrt(2) + if (!CheckConstantInput(graph, *(div.MutableInputDefs()[1]), static_cast(M_SQRT2))) { + continue; + } + + const Node& erf_node = *(div.OutputNodesBegin()); + if (!graph_utils::IsSupportedOptypeVersionAndDomain(erf_node, "Erf", {9}) || + erf_node.GetExecutionProviderType() != div.GetExecutionProviderType() || + erf_node.GetOutputEdgesCount() != 1 || + !IsSupportedDataType(erf_node)) { + continue; + } + + const Node& add_node = *(erf_node.OutputNodesBegin()); + if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7}) || + add_node.GetExecutionProviderType() != div.GetExecutionProviderType() || + add_node.GetOutputEdgesCount() != 1 || + !IsSupportedDataType(add_node)) { + continue; + } + + // Check the other input node(e.g. not of type Erf) is 1.0f. + const Node& add_first_input_node = *(add_node.InputNodesBegin()); + int add_const_input_index = 0; + if (add_first_input_node.OpType().compare("Erf") == 0) { + add_const_input_index = 1; + } + const auto& add_const_input_arg = add_node.InputDefs()[add_const_input_index]; + if (!CheckConstantInput(graph, *add_const_input_arg, 1.0f)) { + continue; + } + + const Node& mul_node = *(add_node.OutputNodesBegin()); + if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7}) || + mul_node.GetExecutionProviderType() != div.GetExecutionProviderType() || + !IsSupportedDataType(mul_node)) { + continue; + } + + const Node* mul2_node = nullptr; + for (auto iter = mul_node.InputNodesBegin(); iter != mul_node.InputNodesEnd(); ++iter) { + if ((*iter).OpType().compare("Mul") == 0) { + // find the other input node of Mul + mul2_node = &(*iter); + break; + } + } + if (mul2_node == nullptr) { + continue; + } + + if (!graph_utils::IsSupportedOptypeVersionAndDomain(*mul2_node, "Mul", {7}) || + mul2_node->GetExecutionProviderType() != div.GetExecutionProviderType() || + mul2_node->GetOutputEdgesCount() != 1 || + !IsSupportedDataType(*mul2_node)) { + continue; + } + + // Check the other input node(e.g. not of type Add) is 0.5f. + int mul_const_input_index = 0; + if (mul2_node->InputDefs()[0]->Name() == div.MutableInputDefs()[0]->Name()) { + mul_const_input_index = 1; + } + + const auto& mul_const_input_arg = mul2_node->InputDefs()[mul_const_input_index]; + if (!CheckConstantInput(graph, *mul_const_input_arg, 0.5f)) { + continue; + } + + const std::vector gelu_input_defs{div.MutableInputDefs()[0]}; + const std::vector gelu_output_defs{const_cast(mul_node.OutputDefs()[0])}; + Node& gelu_node = graph.AddNode(graph.GenerateNodeName("Gelu"), + "Gelu", + "fused Gelu subgraphs ", + gelu_input_defs, + gelu_output_defs, {}, kMSDomain); + + // Assign provider to this new node. Provider should be same as the provider for old node. + gelu_node.SetExecutionProviderType(div.GetExecutionProviderType()); + + removed_nodes.push_front(div.Index()); + removed_nodes.push_front(erf_node.Index()); + removed_nodes.push_front(add_node.Index()); + removed_nodes.push_front(mul2_node->Index()); + removed_nodes.push_front(mul_node.Index()); + } + + // Have to remove node in reversed order for now to walk around the issue in RemoveNode + for (onnxruntime::NodeIndex removed_node : removed_nodes) { + graph.RemoveNode(removed_node); + } + + if (!removed_nodes.empty()) { + modified = true; + } + + return Status::OK(); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/gelu_fusion.h b/onnxruntime/core/optimizer/gelu_fusion.h new file mode 100644 index 0000000000..216c7ac41a --- /dev/null +++ b/onnxruntime/core/optimizer/gelu_fusion.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** +@Class GeluFusion + +Rewrite graph fusing Gelu activation subgraph to a single Gelu node. + +The formula corresponding to Gelu activation subgraph: +x * 0.5 * (1.0 + erf(x / sqrt(2.0))), where x is the input. + +*/ +class GeluFusion : public GraphTransformer { + public: + GeluFusion(const std::unordered_set& compatible_execution_providers = {}) noexcept + : GraphTransformer("GeluFusion", compatible_execution_providers) {} + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 5ad2e5f3b3..b454b71310 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -18,6 +18,7 @@ #include "core/optimizer/shape_to_initializer.h" #include "core/optimizer/nchwc_transformer.h" #include "core/optimizer/free_dim_override_transformer.h" +#include "core/optimizer/gelu_fusion.h" #include "core/mlas/inc/mlas.h" #include "core/session/inference_session.h" @@ -114,6 +115,7 @@ std::vector> GenerateTransformers(TransformerL transformers.emplace_back(onnxruntime::make_unique(l2_execution_providers)); transformers.emplace_back(onnxruntime::make_unique(l2_execution_providers)); transformers.emplace_back(onnxruntime::make_unique(l2_execution_providers)); + transformers.emplace_back(onnxruntime::make_unique(l2_execution_providers)); #endif } break; diff --git a/onnxruntime/core/providers/cpu/activation/activations.cc b/onnxruntime/core/providers/cpu/activation/activations.cc index bc24f1310b..7c248967d2 100644 --- a/onnxruntime/core/providers/cpu/activation/activations.cc +++ b/onnxruntime/core/providers/cpu/activation/activations.cc @@ -45,5 +45,4 @@ Status Tanh::Compute(OpKernelContext* context) const { MlasComputeTanh(X->template Data(), Y->template MutableData(), x_shape.Size()); return Status::OK(); } - } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/activation/activations.h b/onnxruntime/core/providers/cpu/activation/activations.h index 0b634c48df..229a1c7cc7 100644 --- a/onnxruntime/core/providers/cpu/activation/activations.h +++ b/onnxruntime/core/providers/cpu/activation/activations.h @@ -188,5 +188,4 @@ class ThresholdedRelu final : public OpKernel { private: const float alpha_; }; - } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/activation/activations_impl.h b/onnxruntime/core/providers/cuda/activation/activations_impl.h index 0e3cdf561d..a3a39df63b 100644 --- a/onnxruntime/core/providers/cuda/activation/activations_impl.h +++ b/onnxruntime/core/providers/cuda/activation/activations_impl.h @@ -1,6 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - #pragma once namespace onnxruntime { @@ -34,16 +33,16 @@ typedef CtxNull CtxSoftsign; typedef CtxNull CtxTanh; typedef CtxAlpha CtxThresholdedRelu; -#define UNARY_ACTIVATION_OPS() \ - UNARY_ACTIVATION_OP_NAME(Elu) \ - UNARY_ACTIVATION_OP_NAME(HardSigmoid) \ - UNARY_ACTIVATION_OP_NAME(LeakyRelu) \ - UNARY_ACTIVATION_OP_NAME(Relu) \ - UNARY_ACTIVATION_OP_NAME(Selu) \ - UNARY_ACTIVATION_OP_NAME(Sigmoid) \ - UNARY_ACTIVATION_OP_NAME(Softplus) \ - UNARY_ACTIVATION_OP_NAME(Softsign) \ - UNARY_ACTIVATION_OP_NAME(Tanh) \ +#define UNARY_ACTIVATION_OPS() \ + UNARY_ACTIVATION_OP_NAME(Elu) \ + UNARY_ACTIVATION_OP_NAME(HardSigmoid) \ + UNARY_ACTIVATION_OP_NAME(LeakyRelu) \ + UNARY_ACTIVATION_OP_NAME(Relu) \ + UNARY_ACTIVATION_OP_NAME(Selu) \ + UNARY_ACTIVATION_OP_NAME(Sigmoid) \ + UNARY_ACTIVATION_OP_NAME(Softplus) \ + UNARY_ACTIVATION_OP_NAME(Softsign) \ + UNARY_ACTIVATION_OP_NAME(Tanh) \ UNARY_ACTIVATION_OP_NAME(ThresholdedRelu) #define UNARY_ACTIVATION_IMPL_DECLARATION(name) \ diff --git a/onnxruntime/core/providers/cuda/cu_inc/binary_elementwise_impl.cuh b/onnxruntime/core/providers/cuda/cu_inc/binary_elementwise_impl.cuh index e40eb8e967..9217ea7a2f 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/binary_elementwise_impl.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/binary_elementwise_impl.cuh @@ -51,7 +51,7 @@ __global__ void _BinaryElementWiseSimple( const T* lhs_data, const T* rhs_data, T* output_data, - FuncT func, + const FuncT& func, CUDA_LONG N) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); output_data[id] = func(lhs_data[IncL ? id : 0], rhs_data[IncR ? id : 0]); diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index 26354287f0..c58113c67d 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -174,6 +174,18 @@ __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; } template __device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; } +template +__device__ __inline__ T _Normcdf(T a); + +template <> +__device__ __inline__ float _Normcdf(float a) { return normcdff(a); } + +template <> +__device__ __inline__ double _Normcdf(double a) { return normcdf(a); } + +template <> +__device__ __inline__ half _Normcdf(half a) { return half(normcdff((float)a)); } + // We would like to use 64-bit integer to support large matrices. However, CUDA seems to support only 32-bit integer // For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type. diff --git a/onnxruntime/test/contrib_ops/activation_op_test.cc b/onnxruntime/test/contrib_ops/activation_op_test.cc index b848dcffa5..ad27aa5fe2 100644 --- a/onnxruntime/test/contrib_ops/activation_op_test.cc +++ b/onnxruntime/test/contrib_ops/activation_op_test.cc @@ -14,8 +14,9 @@ void TestActivationContribOp(const char* szOp, std::vector& input_vals, std::function expected_func, const std::unordered_map attribs = {}, bool is_tensorrt_supported = true, - int opset_version = 7) { - OpTester test(szOp, opset_version); + int opset_version = 7, + const char* domain = kOnnxDomain) { + OpTester test(szOp, opset_version, domain); for (auto attr : attribs) test.AddAttribute(attr.first, attr.second); @@ -45,10 +46,11 @@ std::vector input_values = { TEST(ActivationContribOpTest, ThresholdedRelu_version_1_to_9) { float alpha = 0.1f; - TestActivationContribOp("ThresholdedRelu", - input_values, - [alpha](float x) { return (x >= alpha) ? x : 0; }, - {{"alpha", alpha}}, true, 1); + TestActivationContribOp( + "ThresholdedRelu", + input_values, + [alpha](float x) { return (x >= alpha) ? x : 0; }, + {{"alpha", alpha}}, true, 1); } TEST(ActivationContribOpTest, ScaledTanh) { @@ -77,5 +79,13 @@ TEST(ActivationContribOpTest, ParametricSoftplus) { {{"alpha", alpha}, {"beta", beta}}); } +TEST(ActivationContribOpTest, Gelu) { + TestActivationContribOp( + "Gelu", + input_values, + [](float x) { return x * 0.5f * (1.0f + std::erf(x * static_cast(M_SQRT1_2))); }, + {}, false, 1, kMSDomain); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 27be9ad800..afe7b2dabb 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -29,6 +29,7 @@ #include "core/optimizer/rule_based_graph_transformer.h" #include "core/optimizer/constant_folding.h" #include "core/optimizer/shape_to_initializer.h" +#include "core/optimizer/gelu_fusion.h" using namespace std; using namespace ONNX_NAMESPACE; @@ -537,7 +538,6 @@ TEST(GraphTransformationTests, FuseConvBnAddMulFloat16) { } TEST(GraphTransformationTests, ReluClipFusion) { - // Clip op schema changed for opset version 11. Until Clip op is updated in ORT hard coding this model to use // older opset. Model model("ReluClipFusion", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), {{"", 10}}, {}); @@ -606,5 +606,26 @@ TEST(GraphTransformationTests, ReluClipFusion) { } } +#ifndef DISABLE_CONTRIB_OPS +TEST(GraphTransformationTests, GeluFusionTest) { + string model_uri = MODEL_FOLDER + "fusion/gelu.onnx"; + std::shared_ptr p_model; + ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2); + ASSERT_TRUE(ret.IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Div"] == 0); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["Erf"] == 0); + ASSERT_TRUE(op_to_count["Mul"] == 0); + ASSERT_TRUE(op_to_count["Gelu"] == 1); +} +#endif + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc index eac1880f7a..2e113527d0 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc @@ -12,15 +12,14 @@ void TestUnaryElementwiseOp(const char* szOp, std::vector& input_vals, std::function expected_func, const std::unordered_map attribs = {}, bool is_tensorrt_supported = true, - int opset_version = 7) { - OpTester test(szOp, opset_version); + int opset_version = 7, const char* domain = kOnnxDomain) { + OpTester test(szOp, opset_version, domain); for (auto attr : attribs) test.AddAttribute(attr.first, attr.second); std::vector dims{(int64_t)input_vals.size()}; - std::vector expected_vals; for (const auto& iv : input_vals) expected_vals.push_back(expected_func(iv)); @@ -36,11 +35,11 @@ void TestUnaryElementwiseOp(const char* szOp, std::vector& input_vals, //Disabled because of accuracy issues for MYRIAD FP16 and VAD_M #if defined(OPENVINO_CONFIG_MYRIAD) || defined(OPENVINO_CONFIG_VAD_M) - int relu = strcmp(szOp, "Relu"); - int leaky = strcmp(szOp, "LeakyRelu"); - if(relu == 0 || leaky == 0){ - excluded_providers.insert(kOpenVINOExecutionProvider); - } + int relu = strcmp(szOp, "Relu"); + int leaky = strcmp(szOp, "LeakyRelu"); + if (relu == 0 || leaky == 0) { + excluded_providers.insert(kOpenVINOExecutionProvider); + } #endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded_providers); @@ -108,10 +107,11 @@ TEST(ActivationOpTest, LeakyRelu) { TEST(ActivationOpTest, ThresholdedRelu) { float alpha = 0.1f; - TestUnaryElementwiseOp("ThresholdedRelu", - input_vals, - [alpha](float x) { return (x >= alpha) ? x : 0; }, - {{"alpha", alpha}}, true, 10); + TestUnaryElementwiseOp( + "ThresholdedRelu", + input_vals, + [alpha](float x) { return (x >= alpha) ? x : 0; }, + {{"alpha", alpha}}, true, 10); } TEST(ActivationOpTest, Selu) { @@ -204,9 +204,10 @@ TEST(ActivationOpTest, Softplus) { } TEST(ActivationOpTest, Softsign) { - TestUnaryElementwiseOp("Softsign", - no_inf_input_vals, - [](float x) { return x / (1 + std::abs(x)); }, {}, false); // Disable TensorRT because result mismatches + TestUnaryElementwiseOp( + "Softsign", + no_inf_input_vals, + [](float x) { return x / (1 + std::abs(x)); }, {}, false); // Disable TensorRT because result mismatches } } // namespace test diff --git a/onnxruntime/test/testdata/transform/fusion/gelu.onnx b/onnxruntime/test/testdata/transform/fusion/gelu.onnx new file mode 100644 index 0000000000..90248632cd Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/gelu.onnx differ