diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index b1cdd50bf3..d95b95055c 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/framework/tensorprotoutils.h" #include "core/graph/constants.h" #include "core/graph/contrib_ops/attn_lstm_schema_defs.h" @@ -548,7 +550,46 @@ GELU (Gaussian Error Linear Unit) approximation: Y=0.5*X*(1+tanh(0.797885*X+0.03 .Input(1, "bias", "bias tensor", "T", OpSchema::Optional) .Output(0, "Y", "output tensor", "T") .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float or half tensors.") - .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput); + .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput) + .SetContextDependentFunctionBodyBuilder([](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) { + // fastgelu(x) = + auto* tp = ctx.getInputType(0); + if ((tp == nullptr) || (!tp->has_tensor_type())) + return false; + auto elem_type = (ONNX_NAMESPACE::TensorProto_DataType)tp->tensor_type().elem_type(); + + // Optional input 1 indicates a bias to be added to input 0. + auto hasBias = ctx.hasInput(1); + + std::string xb(hasBias ? "X_bias" : "X"); + + std::vector body{ + // Constants: + ONNX_NAMESPACE::Const("a", 0.5f, elem_type), + ONNX_NAMESPACE::Const("b", 0.797885f, elem_type), + ONNX_NAMESPACE::Const("c", 0.035677f, elem_type), + ONNX_NAMESPACE::Const("one", 1.0f, elem_type), + // nodes: {outputs, op, inputs, attributes} + // Following node to be added only if bias is specified. + // {{xb}, "Add", {"X", "bias"}}, + {{"T1"}, "Mul", {xb, xb}}, + {{"T2"}, "Mul", {"c", "T1"}}, + {{"T3"}, "Add", {"b", "T2"}}, + {{"T4"}, "Mul", {xb, "T3"}}, + {{"T5"}, "Tanh", {"T4"}}, + {{"T6"}, "Add", {"one", "T5"}}, + {{"T7"}, "Mul", {xb, "T6"}}, + {{"Y"}, "Mul", {"a", "T7"}}}; + + if (hasBias) + body.insert(body.begin(), {{xb}, "Add", {"X", "bias"}}); + + ONNX_NAMESPACE::OperatorSetIdProto onnx_opset_13; + onnx_opset_13.set_domain(""); + onnx_opset_13.set_version(13); + + return ONNX_NAMESPACE::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13}); + }); ONNX_CONTRIB_OPERATOR_SCHEMA(SkipLayerNormalization) .SetDomain(kMSDomain) @@ -2172,7 +2213,7 @@ Example 4: }; std::vector body{ - ONNX_NAMESPACE::Const("Epsilon", epsilon, (ONNX_NAMESPACE::TensorProto_DataType) U), + ONNX_NAMESPACE::Const("Epsilon", epsilon, (ONNX_NAMESPACE::TensorProto_DataType)U), // The treatment of "axis" is different in "LayerNormalization" and in Reduction operations. // This complicates the function definition, requiring reshaping inputs/outputs. // Input X shape: [d[0], ..., d[axis-1], d[axis], ..., d[rank-1]] @@ -2291,7 +2332,33 @@ inputs by their magnitude, rather than gates inputs by their sign as in ReLUs.)D "T", {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") - .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput); + .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput) + .SetContextDependentFunctionBodyBuilder([](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) { + // gelu(x) = x * Phi(x) = x * 1/2(1+erf(x/sqrt(2))) + auto* tp = ctx.getInputType(0); + if ((tp == nullptr) || (!tp->has_tensor_type())) + return false; + auto elem_type = (ONNX_NAMESPACE::TensorProto_DataType)tp->tensor_type().elem_type(); + + std::vector body{ + // Constants: + ONNX_NAMESPACE::Const("Half", 0.5, elem_type), + ONNX_NAMESPACE::Const("One", 1.0, elem_type), + ONNX_NAMESPACE::Const("C", std::sqrt(0.5), elem_type), + // ONNX_NAMESPACE::Const("C", M_SQRT1_2, elem_type), + // nodes: {outputs, op, inputs, attributes} + {{"CX"}, "Mul", {"C", "X"}}, + {{"ERFCX"}, "Erf", {"CX"}}, + {{"ERFCXPlus1"}, "Add", {"ERFCX", "One"}}, + {{"PhiX"}, "Mul", {"ERFCXPlus1", "Half"}}, + {{"Y"}, "Mul", {"X", "PhiX"}}}; + + ONNX_NAMESPACE::OperatorSetIdProto onnx_opset_13; + onnx_opset_13.set_domain(""); + onnx_opset_13.set_version(13); + + return ONNX_NAMESPACE::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13}); + }); static const char* BiasGelu_ver1_doc = R"DOC(Bias Gelu. diff --git a/onnxruntime/test/contrib_ops/function_ops_test.cc b/onnxruntime/test/contrib_ops/function_ops_test.cc index 3f202b0fee..bae5a9c97b 100644 --- a/onnxruntime/test/contrib_ops/function_ops_test.cc +++ b/onnxruntime/test/contrib_ops/function_ops_test.cc @@ -43,7 +43,7 @@ void CheckLayerNorm(bool compute_mean = true, bool compute_isd = true) { testCase.RunTest(); else testCase.CreateModel(true); -} // namespace test +} TEST_F(ContribFunExpansionTest, LayerNorm) { // Test expand-and-run @@ -59,5 +59,54 @@ TEST_F(ContribFunExpansionTest, LayerNorm_OptionalOutputs) { CheckLayerNorm(true, false); } +template +void CheckGelu() { + FunctionTestCase testCase("Gelu", kMSDomain); + std::vector shape{8, 16}; + + testCase.AddInput("x", shape); + testCase.AddOutput("y"); + + // Only check expanded graph. Can't run it yet because no implementation of Erf is available yet. + testCase.CreateModel(true); +} + +TEST_F(ContribFunExpansionTest, Gelu) { + CheckGelu(); + CheckGelu(); + CheckGelu(); + CheckGelu(); +} + +template +void CheckFastGelu(bool withBias = true) { + FunctionTestCase testCase("FastGelu", kMSDomain); + std::vector shape{8, 16}; + std::vector bias_shape{16}; + + testCase.AddInput("x", shape); + if (withBias) { + testCase.AddInput("bias", bias_shape); + } + testCase.AddOutput("y"); + + if (RunTest) + testCase.RunTest(); + else + testCase.CreateModel(true); +} + +TEST_F(ContribFunExpansionTest, FastGeluWithBias) { + CheckFastGelu(true); + CheckFastGelu(true); + CheckFastGelu(true); +} + +TEST_F(ContribFunExpansionTest, FastGeluWithoutBias) { + CheckFastGelu(false); + CheckFastGelu(false); + CheckFastGelu(false); +} + } // namespace test } // namespace onnxruntime \ No newline at end of file