Add function-body to opschema of FastGeluGrad (#9028)

* Add function body to FastGeluGrad

* Add test case
This commit is contained in:
G. Ramalingam 2021-09-14 12:27:55 -07:00 committed by GitHub
parent 4322f7e647
commit 7d28b596f4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 17 deletions

View file

@ -2613,7 +2613,46 @@ Return true if all elements are true and false otherwise.
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)
.SetContextDependentFunctionBodyBuilder(
[](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) {
auto* tp = ctx.getInputType(0);
if ((tp == nullptr) || (!tp->has_tensor_type()))
return false;
auto elem_type = (ONNX_NAMESPACE::TensorProto_DataType)tp->tensor_type().elem_type();
static constexpr double kAlpha = M_2_SQRTPI * M_SQRT1_2;
static constexpr double kGamma = 0.044715f;
static constexpr double kBeta = kGamma * kAlpha * 3.0f;
std::vector<FunctionBodyHelper::NodeDef> body{
ONNX_NAMESPACE::Const("half", 0.5f, elem_type),
ONNX_NAMESPACE::Const("one", 1.0f, elem_type),
ONNX_NAMESPACE::Const("alpha", kAlpha, elem_type),
ONNX_NAMESPACE::Const("gamma", kGamma, elem_type),
ONNX_NAMESPACE::Const("beta", kBeta, elem_type),
{{"x_square"}, "Mul", {"X", "X"}},
{{"x_cube"}, "Mul", {"X", "x_square"}},
{{"gamma_x_cube"}, "Mul", {"gamma", "x_cube"}},
{{"sum1"}, "Add", {"X", "gamma_x_cube"}},
{{"tanh_arg"}, "Mul", {"alpha", "sum1"}},
{{"tanh_val"}, "Tanh", {"tanh_arg"}},
{{"tanh_square"}, "Mul", {"tanh_val", "tanh_val"}},
{{"sech_square"}, "Sub", {"one", "tanh_square"}},
{{"alpha_x"}, "Mul", {"alpha", "X"}},
{{"beta_x_cube"}, "Mul", {"beta", "x_cube"}},
{{"sum"}, "Add", {"alpha_x", "beta_x_cube"}},
{{"term2"}, "Mul", {"sech_square", "sum"}},
{{"sum2"}, "Add", {"tanh_val", "term2"}},
{{"sum3"}, "Add", {"sum2", "one"}},
{{"prod"}, "Mul", {"half", "sum3"}},
{{"dX"}, "Mul", {"dY", "prod"}},
};
OperatorSetIdProto onnx_opset_13;
onnx_opset_13.set_domain("");
onnx_opset_13.set_version(13);
return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13});
});
ONNX_CONTRIB_OPERATOR_SCHEMA(BiasGeluGrad_dX)
.SetDomain(kMSDomain)

View file

@ -169,29 +169,34 @@ TEST_F(FunExpansionTest, DropoutGrad_WithRatio2) {
CheckDropoutGradWithRatio<MLFloat16>(true);
}
TEST_F(FunExpansionTest, GeluGrad_2D) {
FunctionTestCase testCase("GeluGrad");
template <typename T, bool RunTest = true>
void TestUnaryOpGrad(const char* opname) {
FunctionTestCase testCase(opname);
std::vector<int64_t> shape{16, 4};
testCase.AddInput<float>("dY", shape);
testCase.AddInput<float>("X", shape);
testCase.AddInput<T, RunTest>("dY", shape);
testCase.AddInput<T, RunTest>("X", shape);
testCase.AddOutput("dX");
testCase.RunTest();
if (RunTest)
testCase.RunTest();
else
// Test only expanded model creation and model checking.
testCase.CreateModel(true);
}
template <typename T>
void CheckGeluGrad() {
// Tests only expanded model creation and checking.
FunctionTestCase testCase("GeluGrad");
std::vector<int64_t> shape{16, 4};
testCase.AddInput<T, false>("dY", shape);
testCase.AddInput<T, false>("X", shape);
testCase.AddOutput("dX");
testCase.CreateModel(true);
TEST_F(FunExpansionTest, GeluGrad_float) {
TestUnaryOpGrad<float, true>("GeluGrad");
}
TEST_F(FunExpansionTest, GeluGrad_HalfPrecision) {
CheckGeluGrad<BFloat16>();
CheckGeluGrad<MLFloat16>();
TestUnaryOpGrad<BFloat16, false>("GeluGrad");
TestUnaryOpGrad<MLFloat16, false>("GeluGrad");
}
TEST_F(FunExpansionTest, FastGeluGrad) {
TestUnaryOpGrad<float, true>("FastGeluGrad");
TestUnaryOpGrad<BFloat16, false>("FastGeluGrad");
TestUnaryOpGrad<MLFloat16, false>("FastGeluGrad");
}
} // namespace test