Cleanup function definitions of contrib ops (#9265)

* Simplify function definitions

* Simplify fast-gelu function definition

* Simplify training function op body definitions

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>

* Eliminate redundant function

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>

* Formatting changes

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>

* Minor formatting changes

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>

* Add comment

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>

* Specify int64 type for constant 1

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
This commit is contained in:
G. Ramalingam 2021-10-05 11:38:42 -07:00 committed by GitHub
parent 6e2f66ee9c
commit 0b77c9ca7c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 179 additions and 168 deletions

View file

@ -758,39 +758,32 @@ GELU (Gaussian Error Linear Unit) approximation: Y=0.5*X*(1+tanh(0.797885*X+0.03
auto* tp = ctx.getInputType(0);
if ((tp == nullptr) || (!tp->has_tensor_type()))
return false;
auto elem_type = (ONNX_NAMESPACE::TensorProto_DataType)tp->tensor_type().elem_type();
auto elem_type = tp->tensor_type().elem_type();
// Optional input 1 indicates a bias to be added to input 0.
auto hasBias = ctx.hasInput(1);
std::string xb(hasBias ? "X_bias" : "X");
FunctionBuilder builder(functionProto);
builder
.AddOpset("", 13)
.Const("a", 0.5, elem_type)
.Const("b", 0.797885, elem_type)
.Const("c", 0.035677, elem_type)
.Const("one", 1.0, elem_type)
.Add(hasBias ? "X_bias = Add (X, bias)" : "X_bias = Identity (X)")
.Add(R"(
T1 = Mul (X_bias, X_bias)
T2 = Mul (c, T1)
T3 = Add (b, T2)
T4 = Mul (X_bias, T3)
T5 = Tanh (T4)
T6 = Add (one, T5)
T7 = Mul (X_bias, T6)
Y = Mul (a, T7)
)");
std::vector<FunctionBodyHelper::NodeDef> body{
// Constants:
ONNX_NAMESPACE::Const("a", 0.5f, elem_type),
ONNX_NAMESPACE::Const("b", 0.797885f, elem_type),
ONNX_NAMESPACE::Const("c", 0.035677f, elem_type),
ONNX_NAMESPACE::Const("one", 1.0f, elem_type),
// nodes: {outputs, op, inputs, attributes}
// Following node to be added only if bias is specified.
// {{xb}, "Add", {"X", "bias"}},
{{"T1"}, "Mul", {xb, xb}},
{{"T2"}, "Mul", {"c", "T1"}},
{{"T3"}, "Add", {"b", "T2"}},
{{"T4"}, "Mul", {xb, "T3"}},
{{"T5"}, "Tanh", {"T4"}},
{{"T6"}, "Add", {"one", "T5"}},
{{"T7"}, "Mul", {xb, "T6"}},
{{"Y"}, "Mul", {"a", "T7"}}};
if (hasBias)
body.insert(body.begin(), {{xb}, "Add", {"X", "bias"}});
ONNX_NAMESPACE::OperatorSetIdProto onnx_opset_13;
onnx_opset_13.set_domain("");
onnx_opset_13.set_version(13);
return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13});
schema.BuildFunction(functionProto);
return true;
});
ONNX_CONTRIB_OPERATOR_SCHEMA(SkipLayerNormalization)
@ -2477,58 +2470,56 @@ Example 4:
return tp;
};
std::vector<FunctionBodyHelper::NodeDef> body{
ONNX_NAMESPACE::Const("Epsilon", epsilon, (ONNX_NAMESPACE::TensorProto_DataType)U),
// The treatment of "axis" is different in "LayerNormalization" and in Reduction operations.
// This complicates the function definition, requiring reshaping inputs/outputs.
// Input X shape: [d[0], ..., d[axis-1], d[axis], ..., d[rank-1]]
// This is treated as a 2D shape [d[0] * ... * d[axis-1], d[axis] * ... * d[rank-1]]
// Normalization is applied to the second dimension.
// Output Y has same shape as X
// Outputs Mean and InvStdDev have shape: [d[0], ..., d[axis-1], 1, ..., 1]
{{"XShape"}, "Shape", {"X"}}, // shape of input tensor: 1D tensor
{{"Rank"}, "Size", {"XShape"}}, // rank of input tensor: scalar
{{"Zero1D"}, "Constant", {}, {{"value", mktensor(0)}}}, // [0] : 1D tensor
{{"Axis1D"}, "Constant", {}, {{"value", mktensor(axis)}}}, // [axis] : 1D tensor
{{"PrefixShape"}, "Slice", {"XShape", "Zero1D", "Axis1D"}}, // [d[0], ..., d[axis-1]]
(axis > 0) ? // number of axes that are reduced =
FunctionBodyHelper::NodeDef({"NumReducedAxes"}, "Sub", {"Rank", "Axis1D"}) // [rank - axis]: 1D tensor
: FunctionBodyHelper::NodeDef({"NumReducedAxes"}, "Neg", {"Axis1D"}), // [-axis] : 1D tensor
{{"SuffixShape"}, "ConstantOfShape", {"NumReducedAxes"}, //
{{"value", mktensor(1)}}}, // [1, ..., 1] for reduced axes
{{"ReducedShape"}, "Concat", {"PrefixShape", "SuffixShape"}, {{"axis", int64_t(0)}}}, // [d[0], ..., d[axis-1], 1, ..., 1]
{{"X2D"}, "Flatten", {"X"}, {{"axis", axis}}},
{{"XU"}, "Cast", {"X2D"}, {{"to", U}}},
{{"Mean2D"}, "ReduceMean", {"XU"}, {{"axes", std::vector<int64_t>{1}}}},
{{"Square"}, "Mul", {"XU", "XU"}},
{{"MeanOfSquare"}, "ReduceMean", {"Square"}, {{"axes", std::vector<int64_t>{1}}}},
{{"SquareOfMean"}, "Mul", {"Mean2D", "Mean2D"}},
{{"Var"}, "Sub", {"MeanOfSquare", "SquareOfMean"}},
{{"VarPlusEpsilon"}, "Add", {"Var", "Epsilon"}},
{{"StdDev"}, "Sqrt", {"VarPlusEpsilon"}},
{{"Deviation"}, "Sub", {"XU", "Mean2D"}},
{{"Normalized"}, "Div", {"Deviation", "StdDev"}},
{{"NormalizedT"}, "Cast", {"Normalized"}, {{"to", T}}},
{{"Scale2D"}, "Flatten", {"Scale"}, {{"axis", int64_t(0)}}},
{{"Scaled"}, "Mul", {"NormalizedT", "Scale2D"}}};
// The treatment of "axis" is different in "LayerNormalization" and in Reduction operations.
// This complicates the function definition, requiring reshaping inputs/outputs.
// Input X shape: [d[0], ..., d[axis-1], d[axis], ..., d[rank-1]]
// This is treated as a 2D shape [d[0] * ... * d[axis-1], d[axis] * ... * d[rank-1]]
// Normalization is applied to the second dimension.
// Output Y has same shape as X
// Outputs Mean and InvStdDev have shape: [d[0], ..., d[axis-1], 1, ..., 1]
FunctionBuilder builder(functionProto);
builder
.AddOpset("", 13)
.Const("Epsilon", epsilon, U)
.Add("XShape = Shape (X)") // shape of input tensor: 1D tensor
.Add("Rank = Size (XShape)") // rank of input tensor: scalar
.Add("Zero1D = Constant()", "value", mktensor(0)) // [0] : 1D tensor
.Add("Axis1D = Constant()", "value", mktensor(axis)) // [axis] : 1D tensor
.Add("PrefixShape = Slice (XShape, Zero1D, Axis1D)") // [d[0], ..., d[axis-1]]
.Add(axis > 0 // number of axes that are reduced =
? "NumReducedAxes = Sub (Rank, Axis1D)" // [rank - axis]: 1D tensor
: "NumReducedAxes = Neg (Axis1D)") // [-axis] : 1D tensor
.Add("SuffixShape = ConstantOfShape (NumReducedAxes)", "value", mktensor(1)) // [1, ..., 1] for reduced axes
.Add("ReducedShape = Concat <axis = 0> (PrefixShape, SuffixShape)") // [d[0], ..., d[axis-1], 1, ..., 1]
.Add("X2D = Flatten (X)", "axis", axis)
.Add("XU = Cast (X2D)", "to", U)
.Add("Mean2D = ReduceMean <axes = [1]> (XU)")
.Add("Square = Mul (XU, XU)")
.Add("MeanOfSquare = ReduceMean <axes = [1]> (Square)")
.Add("SquareOfMean = Mul (Mean2D, Mean2D)")
.Add("Var = Sub (MeanOfSquare, SquareOfMean)")
.Add("VarPlusEpsilon = Add (Var, Epsilon)")
.Add("StdDev = Sqrt (VarPlusEpsilon)")
.Add("Deviation = Sub (XU, Mean2D)")
.Add("Normalized = Div (Deviation, StdDev)")
.Add("NormalizedT = Cast (Normalized)", "to", T)
.Add("Scale2D = Flatten <axis = 0> (Scale)")
.Add("Scaled = Mul (NormalizedT, Scale2D)");
if (ctx.hasInput(2)) {
body.push_back({{"B2D"}, "Flatten", {"B"}, {{"axis", int64_t(0)}}});
body.push_back({{"Biased"}, "Add", {"Scaled", "B2D"}});
builder.Add("B2D = Flatten <axis=0> (B)");
builder.Add("Biased = Add (Scaled, B2D)");
} else {
body.push_back({{"Biased"}, "Identity", {"Scaled"}});
builder.Add("Biased = Identity (Scaled)");
}
body.push_back({{"Y"}, "Reshape", {"Biased", "XShape"}});
body.push_back({{"InvStdDev2D"}, "Reciprocal", {"StdDev"}});
builder.Add("Y = Reshape (Biased, XShape)");
builder.Add("InvStdDev2D = Reciprocal (StdDev)");
if (ctx.hasOutput(1))
body.push_back({{"Mean"}, "Reshape", {"Mean2D", "ReducedShape"}});
builder.Add("Mean = Reshape (Mean2D, ReducedShape)");
if (ctx.hasOutput(2))
body.push_back({{"InvStdDev"}, "Reshape", {"InvStdDev2D", "ReducedShape"}});
builder.Add("InvStdDev = Reshape (InvStdDev2D, ReducedShape)");
OperatorSetIdProto onnx_opset_13;
onnx_opset_13.set_domain("");
onnx_opset_13.set_version(13);
return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13});
schema.BuildFunction(functionProto);
return true;
});
ONNX_CONTRIB_OPERATOR_SCHEMA(SimplifiedLayerNormalization)
@ -2602,26 +2593,24 @@ inputs by their magnitude, rather than gates inputs by their sign as in ReLUs.)D
auto* tp = ctx.getInputType(0);
if ((tp == nullptr) || (!tp->has_tensor_type()))
return false;
auto elem_type = (ONNX_NAMESPACE::TensorProto_DataType)tp->tensor_type().elem_type();
auto elem_type = tp->tensor_type().elem_type();
std::vector<FunctionBodyHelper::NodeDef> body{
// Constants:
ONNX_NAMESPACE::Const("Half", 0.5, elem_type),
ONNX_NAMESPACE::Const("One", 1.0, elem_type),
ONNX_NAMESPACE::Const("C", std::sqrt(0.5), elem_type),
// ONNX_NAMESPACE::Const("C", M_SQRT1_2, elem_type),
// nodes: {outputs, op, inputs, attributes}
{{"CX"}, "Mul", {"C", "X"}},
{{"ERFCX"}, "Erf", {"CX"}},
{{"ERFCXPlus1"}, "Add", {"ERFCX", "One"}},
{{"PhiX"}, "Mul", {"ERFCXPlus1", "Half"}},
{{"Y"}, "Mul", {"X", "PhiX"}}};
FunctionBuilder builder(functionProto);
builder
.AddOpset("", 13)
.Const("Half", 0.5, elem_type)
.Const("One", 1.0, elem_type)
.Const("C", std::sqrt(0.5), elem_type)
.Add(R"(
CX = Mul (C, X)
ERFCX = Erf (CX)
ERFCXPlus1 = Add (ERFCX, One)
PhiX = Mul (ERFCXPlus1, Half)
Y = Mul (X, PhiX)
)");
ONNX_NAMESPACE::OperatorSetIdProto onnx_opset_13;
onnx_opset_13.set_domain("");
onnx_opset_13.set_version(13);
return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13});
schema.BuildFunction(functionProto);
return true;
});
static const char* BiasGelu_ver1_doc =

View file

@ -58,7 +58,30 @@ class FunctionBuilder {
template <typename T>
FunctionBuilder& Add(const char* node_txt, const std::string& attr_name, T attr_value) {
return Add (node_txt, MakeAttribute(attr_name, attr_value));
return Add(node_txt, MakeAttribute(attr_name, attr_value));
}
FunctionBuilder& Const(const std::string& name, double value, int64_t elem_type) {
std::string constant_op(name);
constant_op += " = Constant()";
return Add(constant_op.c_str(), MakeAttribute("value", ToTensor(value, (TensorProto_DataType) elem_type)));
}
template <typename T>
FunctionBuilder& Const(const std::string& name, T const_value) {
std::string constant_op(name);
constant_op += " = Constant()";
return Add (constant_op.c_str(), MakeAttribute("value", ToTensor(const_value)));
}
template <typename T>
FunctionBuilder& Const(const std::string& name, const std::vector<T>& values) {
std::string constant_op(name);
constant_op += " = Constant()";
auto tensor = ToTensor(values);
tensor.add_dims(values.size()); // Treat as 1D tensor.
return Add (constant_op.c_str(), MakeAttribute("value", tensor));
}
FunctionBuilder& AddOpset(const char* domain, int version) {

View file

@ -616,40 +616,38 @@ void RegisterTrainingOpSchemas() {
auto* axis_attr = ctx.getAttribute("axis");
int64_t axis = (axis_attr != nullptr) ? axis_attr->i() : 1;
auto zero1d = ToTensor(std::vector<int64_t>({0}));
zero1d.add_dims(1);
// nodes: {outputs, op, inputs, attributes}
// First, convert axis specification k to reduction axes [k, k+1, ..., n-1]
std::vector<FunctionBodyHelper::NodeDef> body{
FunctionBodyHelper::Const<int64_t>("one", 1),
FunctionBodyHelper::Const<int64_t>("k", axis),
{{"axis_zero"}, "Constant", {}, {{"value", zero1d}}},
{{"shape"}, "Shape", {"dY"}},
{{"n_as_vector"}, "Shape", {"shape"}},
{{"n"}, "Squeeze", {"n_as_vector", "axis_zero"}},
};
FunctionBuilder builder(functionProto);
builder
.AddOpset("", 13)
.Const("one", int64_t(1))
.Const("k", axis)
.Const("axis_zero", std::vector<int64_t>({0})) // a 1D tensor constant
.Add(R"(
shape = Shape (dY)
n_as_vector = Shape (shape)
n = Squeeze (n_as_vector, axis_zero)
)");
// For negative axis, add n to axis-value k; then use Range(...).
if (axis >= 0) {
body.push_back({{"reduction_axes"}, "Range", {"k", "n", "one"}});
builder.Add("reduction_axes = Range (k, n, one)");
} else {
body.push_back({{"n_plus_k"}, "Add", {"n", "k"}});
body.push_back({{"reduction_axes"}, "Range", {"n_plus_k", "n", "one"}});
builder.Add("n_plus_k = Add (n, k)");
builder.Add("reduction_axes = Range (n_plus_k, n, one)");
}
// compute dX = Y * ( dY - dot(Y, dY)) = Y * ( dY - ReduceSum(Y * dY))
body.push_back({{"a"}, "Mul", {"Y", "dY"}});
body.push_back({{"b"}, "ReduceSum", {"a", "reduction_axes"}});
body.push_back({{"c"}, "Sub", {"dY", "b"}});
body.push_back({{"dX"}, "Mul", {"Y", "c"}});
builder.Add(R"(
a = Mul (Y ,dY)
b = ReduceSum (a ,reduction_axes)
c = Sub (dY ,b)
dX = Mul (Y ,c)
)");
OperatorSetIdProto onnx_opset_13;
onnx_opset_13.set_domain("");
onnx_opset_13.set_version(13);
return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13});
schema.BuildFunction(functionProto);
return true;
});
ONNX_CONTRIB_OPERATOR_SCHEMA(LogSoftmaxGrad)
@ -2055,31 +2053,32 @@ Example 4:
auto* tp = ctx.getInputType(0);
if ((tp == nullptr) || (!tp->has_tensor_type()))
return false;
auto elem_type = (ONNX_NAMESPACE::TensorProto_DataType)tp->tensor_type().elem_type();
auto elem_type = tp->tensor_type().elem_type();
double kAlpha = M_2_SQRTPI * M_SQRT1_2 * 0.5;
std::vector<FunctionBodyHelper::NodeDef> body{
ONNX_NAMESPACE::Const("C_Half", 0.5f, elem_type),
ONNX_NAMESPACE::Const("C_One", 1.0f, elem_type),
ONNX_NAMESPACE::Const("C_SqrtHalf", float(M_SQRT1_2), elem_type),
ONNX_NAMESPACE::Const("C_MinusHalf", -0.5f, elem_type),
ONNX_NAMESPACE::Const("C_alpha", kAlpha, elem_type),
{{"ErfArg"}, "Mul", {"X", "C_SqrtHalf"}},
{{"ErfTerm"}, "Erf", {"ErfArg"}},
{{"PartialSum"}, "Add", {"ErfTerm", "C_One"}},
{{"HalfPartialSum"}, "Mul", {"C_Half", "PartialSum"}},
{{"AlphaX"}, "Mul", {"X", "C_alpha"}},
{{"MinusHalfX"}, "Mul", {"C_MinusHalf", "X"}},
{{"ExpArg"}, "Mul", {"MinusHalfX", "X"}},
{{"ExpTerm"}, "Exp", {"ExpArg"}},
{{"Term3"}, "Mul", {"AlphaX", "ExpTerm"}},
{{"FullSum"}, "Add", {"HalfPartialSum", "Term3"}},
{{"dX"}, "Mul", {"dY", "FullSum"}}};
FunctionBuilder builder(functionProto);
builder
.AddOpset("", 13)
.Const("C_Half", 0.5f, elem_type)
.Const("C_One", 1.0f, elem_type)
.Const("C_SqrtHalf", float(M_SQRT1_2), elem_type)
.Const("C_MinusHalf", -0.5f, elem_type)
.Const("C_alpha", kAlpha, elem_type)
.Add(R"(
ErfArg = Mul (X, C_SqrtHalf)
ErfTerm = Erf (ErfArg)
PartialSum = Add (ErfTerm, C_One)
HalfPartialSum = Mul (C_Half, PartialSum)
AlphaX = Mul (X, C_alpha)
MinusHalfX = Mul (C_MinusHalf, X)
ExpArg = Mul (MinusHalfX, X)
ExpTerm = Exp (ExpArg)
Term3 = Mul (AlphaX, ExpTerm)
FullSum = Add (HalfPartialSum, Term3)
dX = Mul (dY, FullSum)
)");
OperatorSetIdProto onnx_opset_13;
onnx_opset_13.set_domain("");
onnx_opset_13.set_version(13);
return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13});
schema.BuildFunction(functionProto);
return true;
});
ONNX_CONTRIB_OPERATOR_SCHEMA(LayerNormalizationGrad)
@ -2675,35 +2674,35 @@ Return true if all elements are true and false otherwise.
static constexpr double kAlpha = M_2_SQRTPI * M_SQRT1_2;
static constexpr double kGamma = 0.044715f;
static constexpr double kBeta = kGamma * kAlpha * 3.0f;
std::vector<FunctionBodyHelper::NodeDef> body{
ONNX_NAMESPACE::Const("half", 0.5f, elem_type),
ONNX_NAMESPACE::Const("one", 1.0f, elem_type),
ONNX_NAMESPACE::Const("alpha", kAlpha, elem_type),
ONNX_NAMESPACE::Const("gamma", kGamma, elem_type),
ONNX_NAMESPACE::Const("beta", kBeta, elem_type),
{{"x_square"}, "Mul", {"X", "X"}},
{{"x_cube"}, "Mul", {"X", "x_square"}},
{{"gamma_x_cube"}, "Mul", {"gamma", "x_cube"}},
{{"sum1"}, "Add", {"X", "gamma_x_cube"}},
{{"tanh_arg"}, "Mul", {"alpha", "sum1"}},
{{"tanh_val"}, "Tanh", {"tanh_arg"}},
{{"tanh_square"}, "Mul", {"tanh_val", "tanh_val"}},
{{"sech_square"}, "Sub", {"one", "tanh_square"}},
{{"alpha_x"}, "Mul", {"alpha", "X"}},
{{"beta_x_cube"}, "Mul", {"beta", "x_cube"}},
{{"sum"}, "Add", {"alpha_x", "beta_x_cube"}},
{{"term2"}, "Mul", {"sech_square", "sum"}},
{{"sum2"}, "Add", {"tanh_val", "term2"}},
{{"sum3"}, "Add", {"sum2", "one"}},
{{"prod"}, "Mul", {"half", "sum3"}},
{{"dX"}, "Mul", {"dY", "prod"}},
};
FunctionBuilder builder(functionProto);
builder
.AddOpset("", 13)
.Const("half", 0.5f, elem_type)
.Const("one", 1.0f, elem_type)
.Const("alpha", kAlpha, elem_type)
.Const("gamma", kGamma, elem_type)
.Const("beta", kBeta, elem_type)
.Add(R"ONNX(
x_square = Mul (X, X)
x_cube = Mul (X, x_square)
gamma_x_cube = Mul (gamma, x_cube)
sum1 = Add (X, gamma_x_cube)
tanh_arg = Mul (alpha, sum1)
tanh_val = Tanh (tanh_arg)
tanh_square = Mul (tanh_val, tanh_val)
sech_square = Sub (one, tanh_square)
alpha_x = Mul (alpha, X)
beta_x_cube = Mul (beta, x_cube)
sum = Add (alpha_x, beta_x_cube)
term2 = Mul (sech_square, sum)
sum2 = Add (tanh_val, term2)
sum3 = Add (sum2, one)
prod = Mul (half, sum3)
dX = Mul (dY, prod)
)ONNX");
OperatorSetIdProto onnx_opset_13;
onnx_opset_13.set_domain("");
onnx_opset_13.set_version(13);
return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13});
schema.BuildFunction(functionProto);
return true;
});
ONNX_CONTRIB_OPERATOR_SCHEMA(BiasGeluGrad_dX)