mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
Cleanup function definitions of contrib ops (#9265)
* Simplify function definitions * Simplify fast-gelu function definition * Simplify training function op body definitions Signed-off-by: Ganesan Ramalingam <grama@microsoft.com> * Eliminate redundant function Signed-off-by: Ganesan Ramalingam <grama@microsoft.com> * Formatting changes Signed-off-by: Ganesan Ramalingam <grama@microsoft.com> * Minor formatting changes Signed-off-by: Ganesan Ramalingam <grama@microsoft.com> * Add comment Signed-off-by: Ganesan Ramalingam <grama@microsoft.com> * Specify int64 type for constant 1 Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
This commit is contained in:
parent
6e2f66ee9c
commit
0b77c9ca7c
3 changed files with 179 additions and 168 deletions
|
|
@ -758,39 +758,32 @@ GELU (Gaussian Error Linear Unit) approximation: Y=0.5*X*(1+tanh(0.797885*X+0.03
|
|||
auto* tp = ctx.getInputType(0);
|
||||
if ((tp == nullptr) || (!tp->has_tensor_type()))
|
||||
return false;
|
||||
auto elem_type = (ONNX_NAMESPACE::TensorProto_DataType)tp->tensor_type().elem_type();
|
||||
auto elem_type = tp->tensor_type().elem_type();
|
||||
|
||||
// Optional input 1 indicates a bias to be added to input 0.
|
||||
auto hasBias = ctx.hasInput(1);
|
||||
|
||||
std::string xb(hasBias ? "X_bias" : "X");
|
||||
FunctionBuilder builder(functionProto);
|
||||
builder
|
||||
.AddOpset("", 13)
|
||||
.Const("a", 0.5, elem_type)
|
||||
.Const("b", 0.797885, elem_type)
|
||||
.Const("c", 0.035677, elem_type)
|
||||
.Const("one", 1.0, elem_type)
|
||||
.Add(hasBias ? "X_bias = Add (X, bias)" : "X_bias = Identity (X)")
|
||||
.Add(R"(
|
||||
T1 = Mul (X_bias, X_bias)
|
||||
T2 = Mul (c, T1)
|
||||
T3 = Add (b, T2)
|
||||
T4 = Mul (X_bias, T3)
|
||||
T5 = Tanh (T4)
|
||||
T6 = Add (one, T5)
|
||||
T7 = Mul (X_bias, T6)
|
||||
Y = Mul (a, T7)
|
||||
)");
|
||||
|
||||
std::vector<FunctionBodyHelper::NodeDef> body{
|
||||
// Constants:
|
||||
ONNX_NAMESPACE::Const("a", 0.5f, elem_type),
|
||||
ONNX_NAMESPACE::Const("b", 0.797885f, elem_type),
|
||||
ONNX_NAMESPACE::Const("c", 0.035677f, elem_type),
|
||||
ONNX_NAMESPACE::Const("one", 1.0f, elem_type),
|
||||
// nodes: {outputs, op, inputs, attributes}
|
||||
// Following node to be added only if bias is specified.
|
||||
// {{xb}, "Add", {"X", "bias"}},
|
||||
{{"T1"}, "Mul", {xb, xb}},
|
||||
{{"T2"}, "Mul", {"c", "T1"}},
|
||||
{{"T3"}, "Add", {"b", "T2"}},
|
||||
{{"T4"}, "Mul", {xb, "T3"}},
|
||||
{{"T5"}, "Tanh", {"T4"}},
|
||||
{{"T6"}, "Add", {"one", "T5"}},
|
||||
{{"T7"}, "Mul", {xb, "T6"}},
|
||||
{{"Y"}, "Mul", {"a", "T7"}}};
|
||||
|
||||
if (hasBias)
|
||||
body.insert(body.begin(), {{xb}, "Add", {"X", "bias"}});
|
||||
|
||||
ONNX_NAMESPACE::OperatorSetIdProto onnx_opset_13;
|
||||
onnx_opset_13.set_domain("");
|
||||
onnx_opset_13.set_version(13);
|
||||
|
||||
return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13});
|
||||
schema.BuildFunction(functionProto);
|
||||
return true;
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(SkipLayerNormalization)
|
||||
|
|
@ -2477,58 +2470,56 @@ Example 4:
|
|||
return tp;
|
||||
};
|
||||
|
||||
std::vector<FunctionBodyHelper::NodeDef> body{
|
||||
ONNX_NAMESPACE::Const("Epsilon", epsilon, (ONNX_NAMESPACE::TensorProto_DataType)U),
|
||||
// The treatment of "axis" is different in "LayerNormalization" and in Reduction operations.
|
||||
// This complicates the function definition, requiring reshaping inputs/outputs.
|
||||
// Input X shape: [d[0], ..., d[axis-1], d[axis], ..., d[rank-1]]
|
||||
// This is treated as a 2D shape [d[0] * ... * d[axis-1], d[axis] * ... * d[rank-1]]
|
||||
// Normalization is applied to the second dimension.
|
||||
// Output Y has same shape as X
|
||||
// Outputs Mean and InvStdDev have shape: [d[0], ..., d[axis-1], 1, ..., 1]
|
||||
{{"XShape"}, "Shape", {"X"}}, // shape of input tensor: 1D tensor
|
||||
{{"Rank"}, "Size", {"XShape"}}, // rank of input tensor: scalar
|
||||
{{"Zero1D"}, "Constant", {}, {{"value", mktensor(0)}}}, // [0] : 1D tensor
|
||||
{{"Axis1D"}, "Constant", {}, {{"value", mktensor(axis)}}}, // [axis] : 1D tensor
|
||||
{{"PrefixShape"}, "Slice", {"XShape", "Zero1D", "Axis1D"}}, // [d[0], ..., d[axis-1]]
|
||||
(axis > 0) ? // number of axes that are reduced =
|
||||
FunctionBodyHelper::NodeDef({"NumReducedAxes"}, "Sub", {"Rank", "Axis1D"}) // [rank - axis]: 1D tensor
|
||||
: FunctionBodyHelper::NodeDef({"NumReducedAxes"}, "Neg", {"Axis1D"}), // [-axis] : 1D tensor
|
||||
{{"SuffixShape"}, "ConstantOfShape", {"NumReducedAxes"}, //
|
||||
{{"value", mktensor(1)}}}, // [1, ..., 1] for reduced axes
|
||||
{{"ReducedShape"}, "Concat", {"PrefixShape", "SuffixShape"}, {{"axis", int64_t(0)}}}, // [d[0], ..., d[axis-1], 1, ..., 1]
|
||||
{{"X2D"}, "Flatten", {"X"}, {{"axis", axis}}},
|
||||
{{"XU"}, "Cast", {"X2D"}, {{"to", U}}},
|
||||
{{"Mean2D"}, "ReduceMean", {"XU"}, {{"axes", std::vector<int64_t>{1}}}},
|
||||
{{"Square"}, "Mul", {"XU", "XU"}},
|
||||
{{"MeanOfSquare"}, "ReduceMean", {"Square"}, {{"axes", std::vector<int64_t>{1}}}},
|
||||
{{"SquareOfMean"}, "Mul", {"Mean2D", "Mean2D"}},
|
||||
{{"Var"}, "Sub", {"MeanOfSquare", "SquareOfMean"}},
|
||||
{{"VarPlusEpsilon"}, "Add", {"Var", "Epsilon"}},
|
||||
{{"StdDev"}, "Sqrt", {"VarPlusEpsilon"}},
|
||||
{{"Deviation"}, "Sub", {"XU", "Mean2D"}},
|
||||
{{"Normalized"}, "Div", {"Deviation", "StdDev"}},
|
||||
{{"NormalizedT"}, "Cast", {"Normalized"}, {{"to", T}}},
|
||||
{{"Scale2D"}, "Flatten", {"Scale"}, {{"axis", int64_t(0)}}},
|
||||
{{"Scaled"}, "Mul", {"NormalizedT", "Scale2D"}}};
|
||||
// The treatment of "axis" is different in "LayerNormalization" and in Reduction operations.
|
||||
// This complicates the function definition, requiring reshaping inputs/outputs.
|
||||
// Input X shape: [d[0], ..., d[axis-1], d[axis], ..., d[rank-1]]
|
||||
// This is treated as a 2D shape [d[0] * ... * d[axis-1], d[axis] * ... * d[rank-1]]
|
||||
// Normalization is applied to the second dimension.
|
||||
// Output Y has same shape as X
|
||||
// Outputs Mean and InvStdDev have shape: [d[0], ..., d[axis-1], 1, ..., 1]
|
||||
FunctionBuilder builder(functionProto);
|
||||
builder
|
||||
.AddOpset("", 13)
|
||||
.Const("Epsilon", epsilon, U)
|
||||
.Add("XShape = Shape (X)") // shape of input tensor: 1D tensor
|
||||
.Add("Rank = Size (XShape)") // rank of input tensor: scalar
|
||||
.Add("Zero1D = Constant()", "value", mktensor(0)) // [0] : 1D tensor
|
||||
.Add("Axis1D = Constant()", "value", mktensor(axis)) // [axis] : 1D tensor
|
||||
.Add("PrefixShape = Slice (XShape, Zero1D, Axis1D)") // [d[0], ..., d[axis-1]]
|
||||
.Add(axis > 0 // number of axes that are reduced =
|
||||
? "NumReducedAxes = Sub (Rank, Axis1D)" // [rank - axis]: 1D tensor
|
||||
: "NumReducedAxes = Neg (Axis1D)") // [-axis] : 1D tensor
|
||||
.Add("SuffixShape = ConstantOfShape (NumReducedAxes)", "value", mktensor(1)) // [1, ..., 1] for reduced axes
|
||||
.Add("ReducedShape = Concat <axis = 0> (PrefixShape, SuffixShape)") // [d[0], ..., d[axis-1], 1, ..., 1]
|
||||
.Add("X2D = Flatten (X)", "axis", axis)
|
||||
.Add("XU = Cast (X2D)", "to", U)
|
||||
.Add("Mean2D = ReduceMean <axes = [1]> (XU)")
|
||||
.Add("Square = Mul (XU, XU)")
|
||||
.Add("MeanOfSquare = ReduceMean <axes = [1]> (Square)")
|
||||
.Add("SquareOfMean = Mul (Mean2D, Mean2D)")
|
||||
.Add("Var = Sub (MeanOfSquare, SquareOfMean)")
|
||||
.Add("VarPlusEpsilon = Add (Var, Epsilon)")
|
||||
.Add("StdDev = Sqrt (VarPlusEpsilon)")
|
||||
.Add("Deviation = Sub (XU, Mean2D)")
|
||||
.Add("Normalized = Div (Deviation, StdDev)")
|
||||
.Add("NormalizedT = Cast (Normalized)", "to", T)
|
||||
.Add("Scale2D = Flatten <axis = 0> (Scale)")
|
||||
.Add("Scaled = Mul (NormalizedT, Scale2D)");
|
||||
if (ctx.hasInput(2)) {
|
||||
body.push_back({{"B2D"}, "Flatten", {"B"}, {{"axis", int64_t(0)}}});
|
||||
body.push_back({{"Biased"}, "Add", {"Scaled", "B2D"}});
|
||||
builder.Add("B2D = Flatten <axis=0> (B)");
|
||||
builder.Add("Biased = Add (Scaled, B2D)");
|
||||
} else {
|
||||
body.push_back({{"Biased"}, "Identity", {"Scaled"}});
|
||||
builder.Add("Biased = Identity (Scaled)");
|
||||
}
|
||||
body.push_back({{"Y"}, "Reshape", {"Biased", "XShape"}});
|
||||
body.push_back({{"InvStdDev2D"}, "Reciprocal", {"StdDev"}});
|
||||
builder.Add("Y = Reshape (Biased, XShape)");
|
||||
builder.Add("InvStdDev2D = Reciprocal (StdDev)");
|
||||
if (ctx.hasOutput(1))
|
||||
body.push_back({{"Mean"}, "Reshape", {"Mean2D", "ReducedShape"}});
|
||||
builder.Add("Mean = Reshape (Mean2D, ReducedShape)");
|
||||
if (ctx.hasOutput(2))
|
||||
body.push_back({{"InvStdDev"}, "Reshape", {"InvStdDev2D", "ReducedShape"}});
|
||||
builder.Add("InvStdDev = Reshape (InvStdDev2D, ReducedShape)");
|
||||
|
||||
OperatorSetIdProto onnx_opset_13;
|
||||
onnx_opset_13.set_domain("");
|
||||
onnx_opset_13.set_version(13);
|
||||
|
||||
return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13});
|
||||
schema.BuildFunction(functionProto);
|
||||
return true;
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(SimplifiedLayerNormalization)
|
||||
|
|
@ -2602,26 +2593,24 @@ inputs by their magnitude, rather than gates inputs by their sign as in ReLUs.)D
|
|||
auto* tp = ctx.getInputType(0);
|
||||
if ((tp == nullptr) || (!tp->has_tensor_type()))
|
||||
return false;
|
||||
auto elem_type = (ONNX_NAMESPACE::TensorProto_DataType)tp->tensor_type().elem_type();
|
||||
auto elem_type = tp->tensor_type().elem_type();
|
||||
|
||||
std::vector<FunctionBodyHelper::NodeDef> body{
|
||||
// Constants:
|
||||
ONNX_NAMESPACE::Const("Half", 0.5, elem_type),
|
||||
ONNX_NAMESPACE::Const("One", 1.0, elem_type),
|
||||
ONNX_NAMESPACE::Const("C", std::sqrt(0.5), elem_type),
|
||||
// ONNX_NAMESPACE::Const("C", M_SQRT1_2, elem_type),
|
||||
// nodes: {outputs, op, inputs, attributes}
|
||||
{{"CX"}, "Mul", {"C", "X"}},
|
||||
{{"ERFCX"}, "Erf", {"CX"}},
|
||||
{{"ERFCXPlus1"}, "Add", {"ERFCX", "One"}},
|
||||
{{"PhiX"}, "Mul", {"ERFCXPlus1", "Half"}},
|
||||
{{"Y"}, "Mul", {"X", "PhiX"}}};
|
||||
FunctionBuilder builder(functionProto);
|
||||
builder
|
||||
.AddOpset("", 13)
|
||||
.Const("Half", 0.5, elem_type)
|
||||
.Const("One", 1.0, elem_type)
|
||||
.Const("C", std::sqrt(0.5), elem_type)
|
||||
.Add(R"(
|
||||
CX = Mul (C, X)
|
||||
ERFCX = Erf (CX)
|
||||
ERFCXPlus1 = Add (ERFCX, One)
|
||||
PhiX = Mul (ERFCXPlus1, Half)
|
||||
Y = Mul (X, PhiX)
|
||||
)");
|
||||
|
||||
ONNX_NAMESPACE::OperatorSetIdProto onnx_opset_13;
|
||||
onnx_opset_13.set_domain("");
|
||||
onnx_opset_13.set_version(13);
|
||||
|
||||
return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13});
|
||||
schema.BuildFunction(functionProto);
|
||||
return true;
|
||||
});
|
||||
|
||||
static const char* BiasGelu_ver1_doc =
|
||||
|
|
|
|||
|
|
@ -58,7 +58,30 @@ class FunctionBuilder {
|
|||
|
||||
template <typename T>
|
||||
FunctionBuilder& Add(const char* node_txt, const std::string& attr_name, T attr_value) {
|
||||
return Add (node_txt, MakeAttribute(attr_name, attr_value));
|
||||
return Add(node_txt, MakeAttribute(attr_name, attr_value));
|
||||
}
|
||||
|
||||
FunctionBuilder& Const(const std::string& name, double value, int64_t elem_type) {
|
||||
std::string constant_op(name);
|
||||
constant_op += " = Constant()";
|
||||
return Add(constant_op.c_str(), MakeAttribute("value", ToTensor(value, (TensorProto_DataType) elem_type)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FunctionBuilder& Const(const std::string& name, T const_value) {
|
||||
std::string constant_op(name);
|
||||
constant_op += " = Constant()";
|
||||
return Add (constant_op.c_str(), MakeAttribute("value", ToTensor(const_value)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FunctionBuilder& Const(const std::string& name, const std::vector<T>& values) {
|
||||
std::string constant_op(name);
|
||||
constant_op += " = Constant()";
|
||||
auto tensor = ToTensor(values);
|
||||
tensor.add_dims(values.size()); // Treat as 1D tensor.
|
||||
|
||||
return Add (constant_op.c_str(), MakeAttribute("value", tensor));
|
||||
}
|
||||
|
||||
FunctionBuilder& AddOpset(const char* domain, int version) {
|
||||
|
|
|
|||
|
|
@ -616,40 +616,38 @@ void RegisterTrainingOpSchemas() {
|
|||
|
||||
auto* axis_attr = ctx.getAttribute("axis");
|
||||
int64_t axis = (axis_attr != nullptr) ? axis_attr->i() : 1;
|
||||
auto zero1d = ToTensor(std::vector<int64_t>({0}));
|
||||
zero1d.add_dims(1);
|
||||
|
||||
// nodes: {outputs, op, inputs, attributes}
|
||||
|
||||
// First, convert axis specification k to reduction axes [k, k+1, ..., n-1]
|
||||
std::vector<FunctionBodyHelper::NodeDef> body{
|
||||
FunctionBodyHelper::Const<int64_t>("one", 1),
|
||||
FunctionBodyHelper::Const<int64_t>("k", axis),
|
||||
{{"axis_zero"}, "Constant", {}, {{"value", zero1d}}},
|
||||
{{"shape"}, "Shape", {"dY"}},
|
||||
{{"n_as_vector"}, "Shape", {"shape"}},
|
||||
{{"n"}, "Squeeze", {"n_as_vector", "axis_zero"}},
|
||||
};
|
||||
FunctionBuilder builder(functionProto);
|
||||
builder
|
||||
.AddOpset("", 13)
|
||||
.Const("one", int64_t(1))
|
||||
.Const("k", axis)
|
||||
.Const("axis_zero", std::vector<int64_t>({0})) // a 1D tensor constant
|
||||
.Add(R"(
|
||||
shape = Shape (dY)
|
||||
n_as_vector = Shape (shape)
|
||||
n = Squeeze (n_as_vector, axis_zero)
|
||||
)");
|
||||
|
||||
// For negative axis, add n to axis-value k; then use Range(...).
|
||||
if (axis >= 0) {
|
||||
body.push_back({{"reduction_axes"}, "Range", {"k", "n", "one"}});
|
||||
builder.Add("reduction_axes = Range (k, n, one)");
|
||||
} else {
|
||||
body.push_back({{"n_plus_k"}, "Add", {"n", "k"}});
|
||||
body.push_back({{"reduction_axes"}, "Range", {"n_plus_k", "n", "one"}});
|
||||
builder.Add("n_plus_k = Add (n, k)");
|
||||
builder.Add("reduction_axes = Range (n_plus_k, n, one)");
|
||||
}
|
||||
|
||||
// compute dX = Y * ( dY - dot(Y, dY)) = Y * ( dY - ReduceSum(Y * dY))
|
||||
body.push_back({{"a"}, "Mul", {"Y", "dY"}});
|
||||
body.push_back({{"b"}, "ReduceSum", {"a", "reduction_axes"}});
|
||||
body.push_back({{"c"}, "Sub", {"dY", "b"}});
|
||||
body.push_back({{"dX"}, "Mul", {"Y", "c"}});
|
||||
builder.Add(R"(
|
||||
a = Mul (Y ,dY)
|
||||
b = ReduceSum (a ,reduction_axes)
|
||||
c = Sub (dY ,b)
|
||||
dX = Mul (Y ,c)
|
||||
)");
|
||||
|
||||
OperatorSetIdProto onnx_opset_13;
|
||||
onnx_opset_13.set_domain("");
|
||||
onnx_opset_13.set_version(13);
|
||||
|
||||
return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13});
|
||||
schema.BuildFunction(functionProto);
|
||||
return true;
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(LogSoftmaxGrad)
|
||||
|
|
@ -2055,31 +2053,32 @@ Example 4:
|
|||
auto* tp = ctx.getInputType(0);
|
||||
if ((tp == nullptr) || (!tp->has_tensor_type()))
|
||||
return false;
|
||||
auto elem_type = (ONNX_NAMESPACE::TensorProto_DataType)tp->tensor_type().elem_type();
|
||||
auto elem_type = tp->tensor_type().elem_type();
|
||||
double kAlpha = M_2_SQRTPI * M_SQRT1_2 * 0.5;
|
||||
std::vector<FunctionBodyHelper::NodeDef> body{
|
||||
ONNX_NAMESPACE::Const("C_Half", 0.5f, elem_type),
|
||||
ONNX_NAMESPACE::Const("C_One", 1.0f, elem_type),
|
||||
ONNX_NAMESPACE::Const("C_SqrtHalf", float(M_SQRT1_2), elem_type),
|
||||
ONNX_NAMESPACE::Const("C_MinusHalf", -0.5f, elem_type),
|
||||
ONNX_NAMESPACE::Const("C_alpha", kAlpha, elem_type),
|
||||
{{"ErfArg"}, "Mul", {"X", "C_SqrtHalf"}},
|
||||
{{"ErfTerm"}, "Erf", {"ErfArg"}},
|
||||
{{"PartialSum"}, "Add", {"ErfTerm", "C_One"}},
|
||||
{{"HalfPartialSum"}, "Mul", {"C_Half", "PartialSum"}},
|
||||
{{"AlphaX"}, "Mul", {"X", "C_alpha"}},
|
||||
{{"MinusHalfX"}, "Mul", {"C_MinusHalf", "X"}},
|
||||
{{"ExpArg"}, "Mul", {"MinusHalfX", "X"}},
|
||||
{{"ExpTerm"}, "Exp", {"ExpArg"}},
|
||||
{{"Term3"}, "Mul", {"AlphaX", "ExpTerm"}},
|
||||
{{"FullSum"}, "Add", {"HalfPartialSum", "Term3"}},
|
||||
{{"dX"}, "Mul", {"dY", "FullSum"}}};
|
||||
FunctionBuilder builder(functionProto);
|
||||
builder
|
||||
.AddOpset("", 13)
|
||||
.Const("C_Half", 0.5f, elem_type)
|
||||
.Const("C_One", 1.0f, elem_type)
|
||||
.Const("C_SqrtHalf", float(M_SQRT1_2), elem_type)
|
||||
.Const("C_MinusHalf", -0.5f, elem_type)
|
||||
.Const("C_alpha", kAlpha, elem_type)
|
||||
.Add(R"(
|
||||
ErfArg = Mul (X, C_SqrtHalf)
|
||||
ErfTerm = Erf (ErfArg)
|
||||
PartialSum = Add (ErfTerm, C_One)
|
||||
HalfPartialSum = Mul (C_Half, PartialSum)
|
||||
AlphaX = Mul (X, C_alpha)
|
||||
MinusHalfX = Mul (C_MinusHalf, X)
|
||||
ExpArg = Mul (MinusHalfX, X)
|
||||
ExpTerm = Exp (ExpArg)
|
||||
Term3 = Mul (AlphaX, ExpTerm)
|
||||
FullSum = Add (HalfPartialSum, Term3)
|
||||
dX = Mul (dY, FullSum)
|
||||
)");
|
||||
|
||||
OperatorSetIdProto onnx_opset_13;
|
||||
onnx_opset_13.set_domain("");
|
||||
onnx_opset_13.set_version(13);
|
||||
|
||||
return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13});
|
||||
schema.BuildFunction(functionProto);
|
||||
return true;
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(LayerNormalizationGrad)
|
||||
|
|
@ -2675,35 +2674,35 @@ Return true if all elements are true and false otherwise.
|
|||
static constexpr double kAlpha = M_2_SQRTPI * M_SQRT1_2;
|
||||
static constexpr double kGamma = 0.044715f;
|
||||
static constexpr double kBeta = kGamma * kAlpha * 3.0f;
|
||||
std::vector<FunctionBodyHelper::NodeDef> body{
|
||||
ONNX_NAMESPACE::Const("half", 0.5f, elem_type),
|
||||
ONNX_NAMESPACE::Const("one", 1.0f, elem_type),
|
||||
ONNX_NAMESPACE::Const("alpha", kAlpha, elem_type),
|
||||
ONNX_NAMESPACE::Const("gamma", kGamma, elem_type),
|
||||
ONNX_NAMESPACE::Const("beta", kBeta, elem_type),
|
||||
{{"x_square"}, "Mul", {"X", "X"}},
|
||||
{{"x_cube"}, "Mul", {"X", "x_square"}},
|
||||
{{"gamma_x_cube"}, "Mul", {"gamma", "x_cube"}},
|
||||
{{"sum1"}, "Add", {"X", "gamma_x_cube"}},
|
||||
{{"tanh_arg"}, "Mul", {"alpha", "sum1"}},
|
||||
{{"tanh_val"}, "Tanh", {"tanh_arg"}},
|
||||
{{"tanh_square"}, "Mul", {"tanh_val", "tanh_val"}},
|
||||
{{"sech_square"}, "Sub", {"one", "tanh_square"}},
|
||||
{{"alpha_x"}, "Mul", {"alpha", "X"}},
|
||||
{{"beta_x_cube"}, "Mul", {"beta", "x_cube"}},
|
||||
{{"sum"}, "Add", {"alpha_x", "beta_x_cube"}},
|
||||
{{"term2"}, "Mul", {"sech_square", "sum"}},
|
||||
{{"sum2"}, "Add", {"tanh_val", "term2"}},
|
||||
{{"sum3"}, "Add", {"sum2", "one"}},
|
||||
{{"prod"}, "Mul", {"half", "sum3"}},
|
||||
{{"dX"}, "Mul", {"dY", "prod"}},
|
||||
};
|
||||
FunctionBuilder builder(functionProto);
|
||||
builder
|
||||
.AddOpset("", 13)
|
||||
.Const("half", 0.5f, elem_type)
|
||||
.Const("one", 1.0f, elem_type)
|
||||
.Const("alpha", kAlpha, elem_type)
|
||||
.Const("gamma", kGamma, elem_type)
|
||||
.Const("beta", kBeta, elem_type)
|
||||
.Add(R"ONNX(
|
||||
x_square = Mul (X, X)
|
||||
x_cube = Mul (X, x_square)
|
||||
gamma_x_cube = Mul (gamma, x_cube)
|
||||
sum1 = Add (X, gamma_x_cube)
|
||||
tanh_arg = Mul (alpha, sum1)
|
||||
tanh_val = Tanh (tanh_arg)
|
||||
tanh_square = Mul (tanh_val, tanh_val)
|
||||
sech_square = Sub (one, tanh_square)
|
||||
alpha_x = Mul (alpha, X)
|
||||
beta_x_cube = Mul (beta, x_cube)
|
||||
sum = Add (alpha_x, beta_x_cube)
|
||||
term2 = Mul (sech_square, sum)
|
||||
sum2 = Add (tanh_val, term2)
|
||||
sum3 = Add (sum2, one)
|
||||
prod = Mul (half, sum3)
|
||||
dX = Mul (dY, prod)
|
||||
)ONNX");
|
||||
|
||||
OperatorSetIdProto onnx_opset_13;
|
||||
onnx_opset_13.set_domain("");
|
||||
onnx_opset_13.set_version(13);
|
||||
|
||||
return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13});
|
||||
schema.BuildFunction(functionProto);
|
||||
return true;
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(BiasGeluGrad_dX)
|
||||
|
|
|
|||
Loading…
Reference in a new issue