mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-07 00:13:17 +00:00
Add function body for Gelu and FastGelu (#7496)
* LayerNorm function body v1 * LayerNorm function body * layernorm function test * Minor fixes * Fix signed unsigned comparison * Move contrib ops test * Handle optional output parameters * Add test case for optional outputs * Handle float16 random generation * Add function body to Gelu and FastGelu * Add FastGelu test * Fix comments * Include cmath
This commit is contained in:
parent
7079dfb93d
commit
b0a3b501fe
2 changed files with 120 additions and 4 deletions
|
|
@ -1,6 +1,8 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/graph/constants.h"
|
||||
#include "core/graph/contrib_ops/attn_lstm_schema_defs.h"
|
||||
|
|
@ -548,7 +550,46 @@ GELU (Gaussian Error Linear Unit) approximation: Y=0.5*X*(1+tanh(0.797885*X+0.03
|
|||
.Input(1, "bias", "bias tensor", "T", OpSchema::Optional)
|
||||
.Output(0, "Y", "output tensor", "T")
|
||||
.TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float or half tensors.")
|
||||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
|
||||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)
|
||||
.SetContextDependentFunctionBodyBuilder([](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) {
|
||||
// fastgelu(x) =
|
||||
auto* tp = ctx.getInputType(0);
|
||||
if ((tp == nullptr) || (!tp->has_tensor_type()))
|
||||
return false;
|
||||
auto elem_type = (ONNX_NAMESPACE::TensorProto_DataType)tp->tensor_type().elem_type();
|
||||
|
||||
// Optional input 1 indicates a bias to be added to input 0.
|
||||
auto hasBias = ctx.hasInput(1);
|
||||
|
||||
std::string xb(hasBias ? "X_bias" : "X");
|
||||
|
||||
std::vector<FunctionBodyHelper::NodeDef> body{
|
||||
// Constants:
|
||||
ONNX_NAMESPACE::Const("a", 0.5f, elem_type),
|
||||
ONNX_NAMESPACE::Const("b", 0.797885f, elem_type),
|
||||
ONNX_NAMESPACE::Const("c", 0.035677f, elem_type),
|
||||
ONNX_NAMESPACE::Const("one", 1.0f, elem_type),
|
||||
// nodes: {outputs, op, inputs, attributes}
|
||||
// Following node to be added only if bias is specified.
|
||||
// {{xb}, "Add", {"X", "bias"}},
|
||||
{{"T1"}, "Mul", {xb, xb}},
|
||||
{{"T2"}, "Mul", {"c", "T1"}},
|
||||
{{"T3"}, "Add", {"b", "T2"}},
|
||||
{{"T4"}, "Mul", {xb, "T3"}},
|
||||
{{"T5"}, "Tanh", {"T4"}},
|
||||
{{"T6"}, "Add", {"one", "T5"}},
|
||||
{{"T7"}, "Mul", {xb, "T6"}},
|
||||
{{"Y"}, "Mul", {"a", "T7"}}};
|
||||
|
||||
if (hasBias)
|
||||
body.insert(body.begin(), {{xb}, "Add", {"X", "bias"}});
|
||||
|
||||
ONNX_NAMESPACE::OperatorSetIdProto onnx_opset_13;
|
||||
onnx_opset_13.set_domain("");
|
||||
onnx_opset_13.set_version(13);
|
||||
|
||||
return ONNX_NAMESPACE::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13});
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(SkipLayerNormalization)
|
||||
.SetDomain(kMSDomain)
|
||||
|
|
@ -2172,7 +2213,7 @@ Example 4:
|
|||
};
|
||||
|
||||
std::vector<FunctionBodyHelper::NodeDef> body{
|
||||
ONNX_NAMESPACE::Const("Epsilon", epsilon, (ONNX_NAMESPACE::TensorProto_DataType) U),
|
||||
ONNX_NAMESPACE::Const("Epsilon", epsilon, (ONNX_NAMESPACE::TensorProto_DataType)U),
|
||||
// The treatment of "axis" is different in "LayerNormalization" and in Reduction operations.
|
||||
// This complicates the function definition, requiring reshaping inputs/outputs.
|
||||
// Input X shape: [d[0], ..., d[axis-1], d[axis], ..., d[rank-1]]
|
||||
|
|
@ -2291,7 +2332,33 @@ inputs by their magnitude, rather than gates inputs by their sign as in ReLUs.)D
|
|||
"T",
|
||||
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
|
||||
"Constrain input and output types to float tensors.")
|
||||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
|
||||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)
|
||||
.SetContextDependentFunctionBodyBuilder([](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) {
|
||||
// gelu(x) = x * Phi(x) = x * 1/2(1+erf(x/sqrt(2)))
|
||||
auto* tp = ctx.getInputType(0);
|
||||
if ((tp == nullptr) || (!tp->has_tensor_type()))
|
||||
return false;
|
||||
auto elem_type = (ONNX_NAMESPACE::TensorProto_DataType)tp->tensor_type().elem_type();
|
||||
|
||||
std::vector<FunctionBodyHelper::NodeDef> body{
|
||||
// Constants:
|
||||
ONNX_NAMESPACE::Const("Half", 0.5, elem_type),
|
||||
ONNX_NAMESPACE::Const("One", 1.0, elem_type),
|
||||
ONNX_NAMESPACE::Const("C", std::sqrt(0.5), elem_type),
|
||||
// ONNX_NAMESPACE::Const("C", M_SQRT1_2, elem_type),
|
||||
// nodes: {outputs, op, inputs, attributes}
|
||||
{{"CX"}, "Mul", {"C", "X"}},
|
||||
{{"ERFCX"}, "Erf", {"CX"}},
|
||||
{{"ERFCXPlus1"}, "Add", {"ERFCX", "One"}},
|
||||
{{"PhiX"}, "Mul", {"ERFCXPlus1", "Half"}},
|
||||
{{"Y"}, "Mul", {"X", "PhiX"}}};
|
||||
|
||||
ONNX_NAMESPACE::OperatorSetIdProto onnx_opset_13;
|
||||
onnx_opset_13.set_domain("");
|
||||
onnx_opset_13.set_version(13);
|
||||
|
||||
return ONNX_NAMESPACE::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13});
|
||||
});
|
||||
|
||||
static const char* BiasGelu_ver1_doc =
|
||||
R"DOC(Bias Gelu.
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ void CheckLayerNorm(bool compute_mean = true, bool compute_isd = true) {
|
|||
testCase.RunTest();
|
||||
else
|
||||
testCase.CreateModel(true);
|
||||
} // namespace test
|
||||
}
|
||||
|
||||
TEST_F(ContribFunExpansionTest, LayerNorm) {
|
||||
// Test expand-and-run
|
||||
|
|
@ -59,5 +59,54 @@ TEST_F(ContribFunExpansionTest, LayerNorm_OptionalOutputs) {
|
|||
CheckLayerNorm<float, float, true>(true, false);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CheckGelu() {
|
||||
FunctionTestCase testCase("Gelu", kMSDomain);
|
||||
std::vector<int64_t> shape{8, 16};
|
||||
|
||||
testCase.AddInput<T>("x", shape);
|
||||
testCase.AddOutput("y");
|
||||
|
||||
// Only check expanded graph. Can't run it yet because no implementation of Erf is available yet.
|
||||
testCase.CreateModel(true);
|
||||
}
|
||||
|
||||
TEST_F(ContribFunExpansionTest, Gelu) {
|
||||
CheckGelu<float>();
|
||||
CheckGelu<double>();
|
||||
CheckGelu<BFloat16>();
|
||||
CheckGelu<MLFloat16>();
|
||||
}
|
||||
|
||||
template <typename T, bool RunTest = true>
|
||||
void CheckFastGelu(bool withBias = true) {
|
||||
FunctionTestCase testCase("FastGelu", kMSDomain);
|
||||
std::vector<int64_t> shape{8, 16};
|
||||
std::vector<int64_t> bias_shape{16};
|
||||
|
||||
testCase.AddInput<T, RunTest>("x", shape);
|
||||
if (withBias) {
|
||||
testCase.AddInput<T, RunTest>("bias", bias_shape);
|
||||
}
|
||||
testCase.AddOutput("y");
|
||||
|
||||
if (RunTest)
|
||||
testCase.RunTest();
|
||||
else
|
||||
testCase.CreateModel(true);
|
||||
}
|
||||
|
||||
TEST_F(ContribFunExpansionTest, FastGeluWithBias) {
|
||||
CheckFastGelu<float>(true);
|
||||
CheckFastGelu<BFloat16, false>(true);
|
||||
CheckFastGelu<MLFloat16, false>(true);
|
||||
}
|
||||
|
||||
TEST_F(ContribFunExpansionTest, FastGeluWithoutBias) {
|
||||
CheckFastGelu<float>(false);
|
||||
CheckFastGelu<BFloat16, false>(false);
|
||||
CheckFastGelu<MLFloat16, false>(false);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue