diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 1f00e08432..9efe5cbb61 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -2613,7 +2613,46 @@ Return true if all elements are true and false otherwise. "T", {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") - .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput); + .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput) + .SetContextDependentFunctionBodyBuilder( + [](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) { + auto* tp = ctx.getInputType(0); + if ((tp == nullptr) || (!tp->has_tensor_type())) + return false; + auto elem_type = (ONNX_NAMESPACE::TensorProto_DataType)tp->tensor_type().elem_type(); + static constexpr double kAlpha = M_2_SQRTPI * M_SQRT1_2; + static constexpr double kGamma = 0.044715f; + static constexpr double kBeta = kGamma * kAlpha * 3.0f; + std::vector body{ + ONNX_NAMESPACE::Const("half", 0.5f, elem_type), + ONNX_NAMESPACE::Const("one", 1.0f, elem_type), + ONNX_NAMESPACE::Const("alpha", kAlpha, elem_type), + ONNX_NAMESPACE::Const("gamma", kGamma, elem_type), + ONNX_NAMESPACE::Const("beta", kBeta, elem_type), + {{"x_square"}, "Mul", {"X", "X"}}, + {{"x_cube"}, "Mul", {"X", "x_square"}}, + {{"gamma_x_cube"}, "Mul", {"gamma", "x_cube"}}, + {{"sum1"}, "Add", {"X", "gamma_x_cube"}}, + {{"tanh_arg"}, "Mul", {"alpha", "sum1"}}, + {{"tanh_val"}, "Tanh", {"tanh_arg"}}, + {{"tanh_square"}, "Mul", {"tanh_val", "tanh_val"}}, + {{"sech_square"}, "Sub", {"one", "tanh_square"}}, + {{"alpha_x"}, "Mul", {"alpha", "X"}}, + {{"beta_x_cube"}, "Mul", {"beta", "x_cube"}}, + {{"sum"}, "Add", {"alpha_x", "beta_x_cube"}}, + {{"term2"}, "Mul", {"sech_square", "sum"}}, + {{"sum2"}, "Add", {"tanh_val", "term2"}}, + {{"sum3"}, "Add", {"sum2", "one"}}, + {{"prod"}, "Mul", {"half", "sum3"}}, + {{"dX"}, "Mul", {"dY", "prod"}}, + }; + + OperatorSetIdProto onnx_opset_13; + onnx_opset_13.set_domain(""); + onnx_opset_13.set_version(13); + + return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13}); + }); ONNX_CONTRIB_OPERATOR_SCHEMA(BiasGeluGrad_dX) .SetDomain(kMSDomain) diff --git a/orttraining/orttraining/test/gradient/function_ops_test.cc b/orttraining/orttraining/test/gradient/function_ops_test.cc index 665114df71..5df08a59b5 100644 --- a/orttraining/orttraining/test/gradient/function_ops_test.cc +++ b/orttraining/orttraining/test/gradient/function_ops_test.cc @@ -169,29 +169,34 @@ TEST_F(FunExpansionTest, DropoutGrad_WithRatio2) { CheckDropoutGradWithRatio(true); } -TEST_F(FunExpansionTest, GeluGrad_2D) { - FunctionTestCase testCase("GeluGrad"); +template +void TestUnaryOpGrad(const char* opname) { + + FunctionTestCase testCase(opname); std::vector shape{16, 4}; - testCase.AddInput("dY", shape); - testCase.AddInput("X", shape); + testCase.AddInput("dY", shape); + testCase.AddInput("X", shape); testCase.AddOutput("dX"); - testCase.RunTest(); + if (RunTest) + testCase.RunTest(); + else + // Test only expanded model creation and model checking. + testCase.CreateModel(true); } -template -void CheckGeluGrad() { - // Tests only expanded model creation and checking. - FunctionTestCase testCase("GeluGrad"); - std::vector shape{16, 4}; - testCase.AddInput("dY", shape); - testCase.AddInput("X", shape); - testCase.AddOutput("dX"); - testCase.CreateModel(true); +TEST_F(FunExpansionTest, GeluGrad_float) { + TestUnaryOpGrad("GeluGrad"); } TEST_F(FunExpansionTest, GeluGrad_HalfPrecision) { - CheckGeluGrad(); - CheckGeluGrad(); + TestUnaryOpGrad("GeluGrad"); + TestUnaryOpGrad("GeluGrad"); +} + +TEST_F(FunExpansionTest, FastGeluGrad) { + TestUnaryOpGrad("FastGeluGrad"); + TestUnaryOpGrad("FastGeluGrad"); + TestUnaryOpGrad("FastGeluGrad"); } } // namespace test