diff --git a/onnxruntime/core/graph/contrib_ops/onnx_function_util.h b/onnxruntime/core/graph/contrib_ops/onnx_function_util.h index babe1f6537..b57441bc6a 100644 --- a/onnxruntime/core/graph/contrib_ops/onnx_function_util.h +++ b/onnxruntime/core/graph/contrib_ops/onnx_function_util.h @@ -15,6 +15,12 @@ namespace ONNX_NAMESPACE { // For floating-value constants of different precision: TensorProto ToTensor(double value, TensorProto_DataType elem_type); +// Utility function to construct a constant of given type/precision. +inline static FunctionBodyHelper::NodeDef Const(const std::string& name, double value, TensorProto_DataType elem_type) { + return FunctionBodyHelper::NodeDef{ + {name}, "Constant", {}, {{"value", ToTensor(value, elem_type)}}}; +} + // Utility function to construct a FunctionProto from an opschema (for the signature information), // a sequence of NodeDefs (for the function body), and the relied opsets. bool BuildFunctionProto(FunctionProto& functionProto, diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index f59a2ebf71..8c7794cba3 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -1596,7 +1596,42 @@ Example 4: "T", {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") - .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput); + .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput) + .SetContextDependentFunctionBodyBuilder( + [](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) { + /* Default GeluGrad computation: + dX = dY * [0.5f * [erf(sqrt(1/2)*X) + 1.0] + alpha*X*exp(-0.5f * X * X)] + which expands to the following ONNX graph: + */ + auto* tp = ctx.getInputType(0); + if ((tp == nullptr) || (!tp->has_tensor_type())) + return false; + auto elem_type = (ONNX_NAMESPACE::TensorProto_DataType)tp->tensor_type().elem_type(); + double kAlpha = M_2_SQRTPI * M_SQRT1_2 * 0.5; + std::vector body{ + ONNX_NAMESPACE::Const("C_Half", 0.5f, elem_type), + ONNX_NAMESPACE::Const("C_One", 1.0f, elem_type), + ONNX_NAMESPACE::Const("C_SqrtHalf", float(M_SQRT1_2), elem_type), + ONNX_NAMESPACE::Const("C_MinusHalf", -0.5f, elem_type), + ONNX_NAMESPACE::Const("C_alpha", kAlpha, elem_type), + {{"ErfArg"}, "Mul", {"X", "C_SqrtHalf"}}, + {{"ErfTerm"}, "Erf", {"ErfArg"}}, + {{"PartialSum"}, "Add", {"ErfTerm", "C_One"}}, + {{"HalfPartialSum"}, "Mul", {"C_Half", "PartialSum"}}, + {{"AlphaX"}, "Mul", {"X", "C_alpha"}}, + {{"MinusHalfX"}, "Mul", {"C_MinusHalf", "X"}}, + {{"ExpArg"}, "Mul", {"MinusHalfX", "X"}}, + {{"ExpTerm"}, "Exp", {"ExpArg"}}, + {{"Term3"}, "Mul", {"AlphaX", "ExpTerm"}}, + {{"FullSum"}, "Add", {"HalfPartialSum", "Term3"}}, + {{"dX"}, "Mul", {"dY", "FullSum"}}}; + + OperatorSetIdProto onnx_opset_13; + onnx_opset_13.set_domain(""); + onnx_opset_13.set_version(13); + + return ONNX_NAMESPACE::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13}); + }); ONNX_CONTRIB_OPERATOR_SCHEMA(LayerNormalizationGrad) .SetDomain(kMSDomain) diff --git a/orttraining/orttraining/test/gradient/function_ops_test.cc b/orttraining/orttraining/test/gradient/function_ops_test.cc index 6eb06b2820..e74fed7136 100644 --- a/orttraining/orttraining/test/gradient/function_ops_test.cc +++ b/orttraining/orttraining/test/gradient/function_ops_test.cc @@ -207,6 +207,20 @@ static void InitSoftmaxGradTestCase(FunctionTestCase& testCase, std::vector shape) { + int64_t size = 1; + for (auto dim : shape) + size *= dim; + + std::vector value(size); + for (int64_t i = 0; i < size; i++) + value[i] = float(i) / 100.0f; + + testCase.AddInput("dY", shape, value); + testCase.AddInput("X", shape, value); + testCase.AddOutput("dX"); +} + TEST(SoftmaxGradExpansionTest, DefaultAxis) { FunctionTestCase testCase("SoftmaxGrad"); InitSoftmaxGradTestCase(testCase, {3, 2}); @@ -272,5 +286,11 @@ TEST(SoftmaxGradExpansionTest, OpsetTest) { AssertEqual(results1, results2); } +TEST(GeluGradExpansionTest, 2D) { + FunctionTestCase testCase("GeluGrad"); + InitGeluGradTestCase(testCase, {16, 4}); + testCase.RunTest(); +} + } // namespace test } // namespace onnxruntime \ No newline at end of file