From 0b77c9ca7ce34e160222c2c165ea4e7daf8d56a0 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Tue, 5 Oct 2021 11:38:42 -0700 Subject: [PATCH] Cleanup function definitions of contrib ops (#9265) * Simplify function definitions * Simplify fast-gelu function definition * Simplify training function op body definitions Signed-off-by: Ganesan Ramalingam * Eliminate redundant function Signed-off-by: Ganesan Ramalingam * Formatting changes Signed-off-by: Ganesan Ramalingam * Minor formatting changes Signed-off-by: Ganesan Ramalingam * Add comment Signed-off-by: Ganesan Ramalingam * Specify int64 type for constant 1 Signed-off-by: Ganesan Ramalingam --- .../core/graph/contrib_ops/contrib_defs.cc | 173 ++++++++---------- .../graph/contrib_ops/onnx_function_util.h | 25 ++- .../core/graph/training_op_defs.cc | 149 ++++++++------- 3 files changed, 179 insertions(+), 168 deletions(-) diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 429c9b18ad..dfaf089979 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -758,39 +758,32 @@ GELU (Gaussian Error Linear Unit) approximation: Y=0.5*X*(1+tanh(0.797885*X+0.03 auto* tp = ctx.getInputType(0); if ((tp == nullptr) || (!tp->has_tensor_type())) return false; - auto elem_type = (ONNX_NAMESPACE::TensorProto_DataType)tp->tensor_type().elem_type(); + auto elem_type = tp->tensor_type().elem_type(); // Optional input 1 indicates a bias to be added to input 0. auto hasBias = ctx.hasInput(1); - std::string xb(hasBias ? "X_bias" : "X"); + FunctionBuilder builder(functionProto); + builder + .AddOpset("", 13) + .Const("a", 0.5, elem_type) + .Const("b", 0.797885, elem_type) + .Const("c", 0.035677, elem_type) + .Const("one", 1.0, elem_type) + .Add(hasBias ? "X_bias = Add (X, bias)" : "X_bias = Identity (X)") + .Add(R"( + T1 = Mul (X_bias, X_bias) + T2 = Mul (c, T1) + T3 = Add (b, T2) + T4 = Mul (X_bias, T3) + T5 = Tanh (T4) + T6 = Add (one, T5) + T7 = Mul (X_bias, T6) + Y = Mul (a, T7) + )"); - std::vector body{ - // Constants: - ONNX_NAMESPACE::Const("a", 0.5f, elem_type), - ONNX_NAMESPACE::Const("b", 0.797885f, elem_type), - ONNX_NAMESPACE::Const("c", 0.035677f, elem_type), - ONNX_NAMESPACE::Const("one", 1.0f, elem_type), - // nodes: {outputs, op, inputs, attributes} - // Following node to be added only if bias is specified. - // {{xb}, "Add", {"X", "bias"}}, - {{"T1"}, "Mul", {xb, xb}}, - {{"T2"}, "Mul", {"c", "T1"}}, - {{"T3"}, "Add", {"b", "T2"}}, - {{"T4"}, "Mul", {xb, "T3"}}, - {{"T5"}, "Tanh", {"T4"}}, - {{"T6"}, "Add", {"one", "T5"}}, - {{"T7"}, "Mul", {xb, "T6"}}, - {{"Y"}, "Mul", {"a", "T7"}}}; - - if (hasBias) - body.insert(body.begin(), {{xb}, "Add", {"X", "bias"}}); - - ONNX_NAMESPACE::OperatorSetIdProto onnx_opset_13; - onnx_opset_13.set_domain(""); - onnx_opset_13.set_version(13); - - return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13}); + schema.BuildFunction(functionProto); + return true; }); ONNX_CONTRIB_OPERATOR_SCHEMA(SkipLayerNormalization) @@ -2477,58 +2470,56 @@ Example 4: return tp; }; - std::vector body{ - ONNX_NAMESPACE::Const("Epsilon", epsilon, (ONNX_NAMESPACE::TensorProto_DataType)U), - // The treatment of "axis" is different in "LayerNormalization" and in Reduction operations. - // This complicates the function definition, requiring reshaping inputs/outputs. - // Input X shape: [d[0], ..., d[axis-1], d[axis], ..., d[rank-1]] - // This is treated as a 2D shape [d[0] * ... * d[axis-1], d[axis] * ... * d[rank-1]] - // Normalization is applied to the second dimension. - // Output Y has same shape as X - // Outputs Mean and InvStdDev have shape: [d[0], ..., d[axis-1], 1, ..., 1] - {{"XShape"}, "Shape", {"X"}}, // shape of input tensor: 1D tensor - {{"Rank"}, "Size", {"XShape"}}, // rank of input tensor: scalar - {{"Zero1D"}, "Constant", {}, {{"value", mktensor(0)}}}, // [0] : 1D tensor - {{"Axis1D"}, "Constant", {}, {{"value", mktensor(axis)}}}, // [axis] : 1D tensor - {{"PrefixShape"}, "Slice", {"XShape", "Zero1D", "Axis1D"}}, // [d[0], ..., d[axis-1]] - (axis > 0) ? // number of axes that are reduced = - FunctionBodyHelper::NodeDef({"NumReducedAxes"}, "Sub", {"Rank", "Axis1D"}) // [rank - axis]: 1D tensor - : FunctionBodyHelper::NodeDef({"NumReducedAxes"}, "Neg", {"Axis1D"}), // [-axis] : 1D tensor - {{"SuffixShape"}, "ConstantOfShape", {"NumReducedAxes"}, // - {{"value", mktensor(1)}}}, // [1, ..., 1] for reduced axes - {{"ReducedShape"}, "Concat", {"PrefixShape", "SuffixShape"}, {{"axis", int64_t(0)}}}, // [d[0], ..., d[axis-1], 1, ..., 1] - {{"X2D"}, "Flatten", {"X"}, {{"axis", axis}}}, - {{"XU"}, "Cast", {"X2D"}, {{"to", U}}}, - {{"Mean2D"}, "ReduceMean", {"XU"}, {{"axes", std::vector{1}}}}, - {{"Square"}, "Mul", {"XU", "XU"}}, - {{"MeanOfSquare"}, "ReduceMean", {"Square"}, {{"axes", std::vector{1}}}}, - {{"SquareOfMean"}, "Mul", {"Mean2D", "Mean2D"}}, - {{"Var"}, "Sub", {"MeanOfSquare", "SquareOfMean"}}, - {{"VarPlusEpsilon"}, "Add", {"Var", "Epsilon"}}, - {{"StdDev"}, "Sqrt", {"VarPlusEpsilon"}}, - {{"Deviation"}, "Sub", {"XU", "Mean2D"}}, - {{"Normalized"}, "Div", {"Deviation", "StdDev"}}, - {{"NormalizedT"}, "Cast", {"Normalized"}, {{"to", T}}}, - {{"Scale2D"}, "Flatten", {"Scale"}, {{"axis", int64_t(0)}}}, - {{"Scaled"}, "Mul", {"NormalizedT", "Scale2D"}}}; + // The treatment of "axis" is different in "LayerNormalization" and in Reduction operations. + // This complicates the function definition, requiring reshaping inputs/outputs. + // Input X shape: [d[0], ..., d[axis-1], d[axis], ..., d[rank-1]] + // This is treated as a 2D shape [d[0] * ... * d[axis-1], d[axis] * ... * d[rank-1]] + // Normalization is applied to the second dimension. + // Output Y has same shape as X + // Outputs Mean and InvStdDev have shape: [d[0], ..., d[axis-1], 1, ..., 1] + FunctionBuilder builder(functionProto); + builder + .AddOpset("", 13) + .Const("Epsilon", epsilon, U) + .Add("XShape = Shape (X)") // shape of input tensor: 1D tensor + .Add("Rank = Size (XShape)") // rank of input tensor: scalar + .Add("Zero1D = Constant()", "value", mktensor(0)) // [0] : 1D tensor + .Add("Axis1D = Constant()", "value", mktensor(axis)) // [axis] : 1D tensor + .Add("PrefixShape = Slice (XShape, Zero1D, Axis1D)") // [d[0], ..., d[axis-1]] + .Add(axis > 0 // number of axes that are reduced = + ? "NumReducedAxes = Sub (Rank, Axis1D)" // [rank - axis]: 1D tensor + : "NumReducedAxes = Neg (Axis1D)") // [-axis] : 1D tensor + .Add("SuffixShape = ConstantOfShape (NumReducedAxes)", "value", mktensor(1)) // [1, ..., 1] for reduced axes + .Add("ReducedShape = Concat (PrefixShape, SuffixShape)") // [d[0], ..., d[axis-1], 1, ..., 1] + .Add("X2D = Flatten (X)", "axis", axis) + .Add("XU = Cast (X2D)", "to", U) + .Add("Mean2D = ReduceMean (XU)") + .Add("Square = Mul (XU, XU)") + .Add("MeanOfSquare = ReduceMean (Square)") + .Add("SquareOfMean = Mul (Mean2D, Mean2D)") + .Add("Var = Sub (MeanOfSquare, SquareOfMean)") + .Add("VarPlusEpsilon = Add (Var, Epsilon)") + .Add("StdDev = Sqrt (VarPlusEpsilon)") + .Add("Deviation = Sub (XU, Mean2D)") + .Add("Normalized = Div (Deviation, StdDev)") + .Add("NormalizedT = Cast (Normalized)", "to", T) + .Add("Scale2D = Flatten (Scale)") + .Add("Scaled = Mul (NormalizedT, Scale2D)"); if (ctx.hasInput(2)) { - body.push_back({{"B2D"}, "Flatten", {"B"}, {{"axis", int64_t(0)}}}); - body.push_back({{"Biased"}, "Add", {"Scaled", "B2D"}}); + builder.Add("B2D = Flatten (B)"); + builder.Add("Biased = Add (Scaled, B2D)"); } else { - body.push_back({{"Biased"}, "Identity", {"Scaled"}}); + builder.Add("Biased = Identity (Scaled)"); } - body.push_back({{"Y"}, "Reshape", {"Biased", "XShape"}}); - body.push_back({{"InvStdDev2D"}, "Reciprocal", {"StdDev"}}); + builder.Add("Y = Reshape (Biased, XShape)"); + builder.Add("InvStdDev2D = Reciprocal (StdDev)"); if (ctx.hasOutput(1)) - body.push_back({{"Mean"}, "Reshape", {"Mean2D", "ReducedShape"}}); + builder.Add("Mean = Reshape (Mean2D, ReducedShape)"); if (ctx.hasOutput(2)) - body.push_back({{"InvStdDev"}, "Reshape", {"InvStdDev2D", "ReducedShape"}}); + builder.Add("InvStdDev = Reshape (InvStdDev2D, ReducedShape)"); - OperatorSetIdProto onnx_opset_13; - onnx_opset_13.set_domain(""); - onnx_opset_13.set_version(13); - - return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13}); + schema.BuildFunction(functionProto); + return true; }); ONNX_CONTRIB_OPERATOR_SCHEMA(SimplifiedLayerNormalization) @@ -2602,26 +2593,24 @@ inputs by their magnitude, rather than gates inputs by their sign as in ReLUs.)D auto* tp = ctx.getInputType(0); if ((tp == nullptr) || (!tp->has_tensor_type())) return false; - auto elem_type = (ONNX_NAMESPACE::TensorProto_DataType)tp->tensor_type().elem_type(); + auto elem_type = tp->tensor_type().elem_type(); - std::vector body{ - // Constants: - ONNX_NAMESPACE::Const("Half", 0.5, elem_type), - ONNX_NAMESPACE::Const("One", 1.0, elem_type), - ONNX_NAMESPACE::Const("C", std::sqrt(0.5), elem_type), - // ONNX_NAMESPACE::Const("C", M_SQRT1_2, elem_type), - // nodes: {outputs, op, inputs, attributes} - {{"CX"}, "Mul", {"C", "X"}}, - {{"ERFCX"}, "Erf", {"CX"}}, - {{"ERFCXPlus1"}, "Add", {"ERFCX", "One"}}, - {{"PhiX"}, "Mul", {"ERFCXPlus1", "Half"}}, - {{"Y"}, "Mul", {"X", "PhiX"}}}; + FunctionBuilder builder(functionProto); + builder + .AddOpset("", 13) + .Const("Half", 0.5, elem_type) + .Const("One", 1.0, elem_type) + .Const("C", std::sqrt(0.5), elem_type) + .Add(R"( + CX = Mul (C, X) + ERFCX = Erf (CX) + ERFCXPlus1 = Add (ERFCX, One) + PhiX = Mul (ERFCXPlus1, Half) + Y = Mul (X, PhiX) + )"); - ONNX_NAMESPACE::OperatorSetIdProto onnx_opset_13; - onnx_opset_13.set_domain(""); - onnx_opset_13.set_version(13); - - return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13}); + schema.BuildFunction(functionProto); + return true; }); static const char* BiasGelu_ver1_doc = diff --git a/onnxruntime/core/graph/contrib_ops/onnx_function_util.h b/onnxruntime/core/graph/contrib_ops/onnx_function_util.h index 61694eafe5..3ecd8a6a6e 100644 --- a/onnxruntime/core/graph/contrib_ops/onnx_function_util.h +++ b/onnxruntime/core/graph/contrib_ops/onnx_function_util.h @@ -58,7 +58,30 @@ class FunctionBuilder { template FunctionBuilder& Add(const char* node_txt, const std::string& attr_name, T attr_value) { - return Add (node_txt, MakeAttribute(attr_name, attr_value)); + return Add(node_txt, MakeAttribute(attr_name, attr_value)); + } + + FunctionBuilder& Const(const std::string& name, double value, int64_t elem_type) { + std::string constant_op(name); + constant_op += " = Constant()"; + return Add(constant_op.c_str(), MakeAttribute("value", ToTensor(value, (TensorProto_DataType) elem_type))); + } + + template + FunctionBuilder& Const(const std::string& name, T const_value) { + std::string constant_op(name); + constant_op += " = Constant()"; + return Add (constant_op.c_str(), MakeAttribute("value", ToTensor(const_value))); + } + + template + FunctionBuilder& Const(const std::string& name, const std::vector& values) { + std::string constant_op(name); + constant_op += " = Constant()"; + auto tensor = ToTensor(values); + tensor.add_dims(values.size()); // Treat as 1D tensor. + + return Add (constant_op.c_str(), MakeAttribute("value", tensor)); } FunctionBuilder& AddOpset(const char* domain, int version) { diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 6f2c818f4d..5770b3b269 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -616,40 +616,38 @@ void RegisterTrainingOpSchemas() { auto* axis_attr = ctx.getAttribute("axis"); int64_t axis = (axis_attr != nullptr) ? axis_attr->i() : 1; - auto zero1d = ToTensor(std::vector({0})); - zero1d.add_dims(1); - - // nodes: {outputs, op, inputs, attributes} // First, convert axis specification k to reduction axes [k, k+1, ..., n-1] - std::vector body{ - FunctionBodyHelper::Const("one", 1), - FunctionBodyHelper::Const("k", axis), - {{"axis_zero"}, "Constant", {}, {{"value", zero1d}}}, - {{"shape"}, "Shape", {"dY"}}, - {{"n_as_vector"}, "Shape", {"shape"}}, - {{"n"}, "Squeeze", {"n_as_vector", "axis_zero"}}, - }; + FunctionBuilder builder(functionProto); + builder + .AddOpset("", 13) + .Const("one", int64_t(1)) + .Const("k", axis) + .Const("axis_zero", std::vector({0})) // a 1D tensor constant + .Add(R"( + shape = Shape (dY) + n_as_vector = Shape (shape) + n = Squeeze (n_as_vector, axis_zero) + )"); // For negative axis, add n to axis-value k; then use Range(...). if (axis >= 0) { - body.push_back({{"reduction_axes"}, "Range", {"k", "n", "one"}}); + builder.Add("reduction_axes = Range (k, n, one)"); } else { - body.push_back({{"n_plus_k"}, "Add", {"n", "k"}}); - body.push_back({{"reduction_axes"}, "Range", {"n_plus_k", "n", "one"}}); + builder.Add("n_plus_k = Add (n, k)"); + builder.Add("reduction_axes = Range (n_plus_k, n, one)"); } // compute dX = Y * ( dY - dot(Y, dY)) = Y * ( dY - ReduceSum(Y * dY)) - body.push_back({{"a"}, "Mul", {"Y", "dY"}}); - body.push_back({{"b"}, "ReduceSum", {"a", "reduction_axes"}}); - body.push_back({{"c"}, "Sub", {"dY", "b"}}); - body.push_back({{"dX"}, "Mul", {"Y", "c"}}); + builder.Add(R"( + a = Mul (Y ,dY) + b = ReduceSum (a ,reduction_axes) + c = Sub (dY ,b) + dX = Mul (Y ,c) + )"); - OperatorSetIdProto onnx_opset_13; - onnx_opset_13.set_domain(""); - onnx_opset_13.set_version(13); - - return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13}); + schema.BuildFunction(functionProto); + return true; }); ONNX_CONTRIB_OPERATOR_SCHEMA(LogSoftmaxGrad) @@ -2055,31 +2053,32 @@ Example 4: auto* tp = ctx.getInputType(0); if ((tp == nullptr) || (!tp->has_tensor_type())) return false; - auto elem_type = (ONNX_NAMESPACE::TensorProto_DataType)tp->tensor_type().elem_type(); + auto elem_type = tp->tensor_type().elem_type(); double kAlpha = M_2_SQRTPI * M_SQRT1_2 * 0.5; - std::vector body{ - ONNX_NAMESPACE::Const("C_Half", 0.5f, elem_type), - ONNX_NAMESPACE::Const("C_One", 1.0f, elem_type), - ONNX_NAMESPACE::Const("C_SqrtHalf", float(M_SQRT1_2), elem_type), - ONNX_NAMESPACE::Const("C_MinusHalf", -0.5f, elem_type), - ONNX_NAMESPACE::Const("C_alpha", kAlpha, elem_type), - {{"ErfArg"}, "Mul", {"X", "C_SqrtHalf"}}, - {{"ErfTerm"}, "Erf", {"ErfArg"}}, - {{"PartialSum"}, "Add", {"ErfTerm", "C_One"}}, - {{"HalfPartialSum"}, "Mul", {"C_Half", "PartialSum"}}, - {{"AlphaX"}, "Mul", {"X", "C_alpha"}}, - {{"MinusHalfX"}, "Mul", {"C_MinusHalf", "X"}}, - {{"ExpArg"}, "Mul", {"MinusHalfX", "X"}}, - {{"ExpTerm"}, "Exp", {"ExpArg"}}, - {{"Term3"}, "Mul", {"AlphaX", "ExpTerm"}}, - {{"FullSum"}, "Add", {"HalfPartialSum", "Term3"}}, - {{"dX"}, "Mul", {"dY", "FullSum"}}}; + FunctionBuilder builder(functionProto); + builder + .AddOpset("", 13) + .Const("C_Half", 0.5f, elem_type) + .Const("C_One", 1.0f, elem_type) + .Const("C_SqrtHalf", float(M_SQRT1_2), elem_type) + .Const("C_MinusHalf", -0.5f, elem_type) + .Const("C_alpha", kAlpha, elem_type) + .Add(R"( + ErfArg = Mul (X, C_SqrtHalf) + ErfTerm = Erf (ErfArg) + PartialSum = Add (ErfTerm, C_One) + HalfPartialSum = Mul (C_Half, PartialSum) + AlphaX = Mul (X, C_alpha) + MinusHalfX = Mul (C_MinusHalf, X) + ExpArg = Mul (MinusHalfX, X) + ExpTerm = Exp (ExpArg) + Term3 = Mul (AlphaX, ExpTerm) + FullSum = Add (HalfPartialSum, Term3) + dX = Mul (dY, FullSum) + )"); - OperatorSetIdProto onnx_opset_13; - onnx_opset_13.set_domain(""); - onnx_opset_13.set_version(13); - - return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13}); + schema.BuildFunction(functionProto); + return true; }); ONNX_CONTRIB_OPERATOR_SCHEMA(LayerNormalizationGrad) @@ -2675,35 +2674,35 @@ Return true if all elements are true and false otherwise. static constexpr double kAlpha = M_2_SQRTPI * M_SQRT1_2; static constexpr double kGamma = 0.044715f; static constexpr double kBeta = kGamma * kAlpha * 3.0f; - std::vector body{ - ONNX_NAMESPACE::Const("half", 0.5f, elem_type), - ONNX_NAMESPACE::Const("one", 1.0f, elem_type), - ONNX_NAMESPACE::Const("alpha", kAlpha, elem_type), - ONNX_NAMESPACE::Const("gamma", kGamma, elem_type), - ONNX_NAMESPACE::Const("beta", kBeta, elem_type), - {{"x_square"}, "Mul", {"X", "X"}}, - {{"x_cube"}, "Mul", {"X", "x_square"}}, - {{"gamma_x_cube"}, "Mul", {"gamma", "x_cube"}}, - {{"sum1"}, "Add", {"X", "gamma_x_cube"}}, - {{"tanh_arg"}, "Mul", {"alpha", "sum1"}}, - {{"tanh_val"}, "Tanh", {"tanh_arg"}}, - {{"tanh_square"}, "Mul", {"tanh_val", "tanh_val"}}, - {{"sech_square"}, "Sub", {"one", "tanh_square"}}, - {{"alpha_x"}, "Mul", {"alpha", "X"}}, - {{"beta_x_cube"}, "Mul", {"beta", "x_cube"}}, - {{"sum"}, "Add", {"alpha_x", "beta_x_cube"}}, - {{"term2"}, "Mul", {"sech_square", "sum"}}, - {{"sum2"}, "Add", {"tanh_val", "term2"}}, - {{"sum3"}, "Add", {"sum2", "one"}}, - {{"prod"}, "Mul", {"half", "sum3"}}, - {{"dX"}, "Mul", {"dY", "prod"}}, - }; + FunctionBuilder builder(functionProto); + builder + .AddOpset("", 13) + .Const("half", 0.5f, elem_type) + .Const("one", 1.0f, elem_type) + .Const("alpha", kAlpha, elem_type) + .Const("gamma", kGamma, elem_type) + .Const("beta", kBeta, elem_type) + .Add(R"ONNX( + x_square = Mul (X, X) + x_cube = Mul (X, x_square) + gamma_x_cube = Mul (gamma, x_cube) + sum1 = Add (X, gamma_x_cube) + tanh_arg = Mul (alpha, sum1) + tanh_val = Tanh (tanh_arg) + tanh_square = Mul (tanh_val, tanh_val) + sech_square = Sub (one, tanh_square) + alpha_x = Mul (alpha, X) + beta_x_cube = Mul (beta, x_cube) + sum = Add (alpha_x, beta_x_cube) + term2 = Mul (sech_square, sum) + sum2 = Add (tanh_val, term2) + sum3 = Add (sum2, one) + prod = Mul (half, sum3) + dX = Mul (dY, prod) + )ONNX"); - OperatorSetIdProto onnx_opset_13; - onnx_opset_13.set_domain(""); - onnx_opset_13.set_version(13); - - return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13}); + schema.BuildFunction(functionProto); + return true; }); ONNX_CONTRIB_OPERATOR_SCHEMA(BiasGeluGrad_dX)