mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
Add function-body to opschema of FastGeluGrad (#9028)
* Add function body to FastGeluGrad * Add test case
This commit is contained in:
parent
4322f7e647
commit
7d28b596f4
2 changed files with 61 additions and 17 deletions
|
|
@ -2613,7 +2613,46 @@ Return true if all elements are true and false otherwise.
|
|||
"T",
|
||||
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
|
||||
"Constrain input and output types to float tensors.")
|
||||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
|
||||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)
|
||||
.SetContextDependentFunctionBodyBuilder(
|
||||
[](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) {
|
||||
auto* tp = ctx.getInputType(0);
|
||||
if ((tp == nullptr) || (!tp->has_tensor_type()))
|
||||
return false;
|
||||
auto elem_type = (ONNX_NAMESPACE::TensorProto_DataType)tp->tensor_type().elem_type();
|
||||
static constexpr double kAlpha = M_2_SQRTPI * M_SQRT1_2;
|
||||
static constexpr double kGamma = 0.044715f;
|
||||
static constexpr double kBeta = kGamma * kAlpha * 3.0f;
|
||||
std::vector<FunctionBodyHelper::NodeDef> body{
|
||||
ONNX_NAMESPACE::Const("half", 0.5f, elem_type),
|
||||
ONNX_NAMESPACE::Const("one", 1.0f, elem_type),
|
||||
ONNX_NAMESPACE::Const("alpha", kAlpha, elem_type),
|
||||
ONNX_NAMESPACE::Const("gamma", kGamma, elem_type),
|
||||
ONNX_NAMESPACE::Const("beta", kBeta, elem_type),
|
||||
{{"x_square"}, "Mul", {"X", "X"}},
|
||||
{{"x_cube"}, "Mul", {"X", "x_square"}},
|
||||
{{"gamma_x_cube"}, "Mul", {"gamma", "x_cube"}},
|
||||
{{"sum1"}, "Add", {"X", "gamma_x_cube"}},
|
||||
{{"tanh_arg"}, "Mul", {"alpha", "sum1"}},
|
||||
{{"tanh_val"}, "Tanh", {"tanh_arg"}},
|
||||
{{"tanh_square"}, "Mul", {"tanh_val", "tanh_val"}},
|
||||
{{"sech_square"}, "Sub", {"one", "tanh_square"}},
|
||||
{{"alpha_x"}, "Mul", {"alpha", "X"}},
|
||||
{{"beta_x_cube"}, "Mul", {"beta", "x_cube"}},
|
||||
{{"sum"}, "Add", {"alpha_x", "beta_x_cube"}},
|
||||
{{"term2"}, "Mul", {"sech_square", "sum"}},
|
||||
{{"sum2"}, "Add", {"tanh_val", "term2"}},
|
||||
{{"sum3"}, "Add", {"sum2", "one"}},
|
||||
{{"prod"}, "Mul", {"half", "sum3"}},
|
||||
{{"dX"}, "Mul", {"dY", "prod"}},
|
||||
};
|
||||
|
||||
OperatorSetIdProto onnx_opset_13;
|
||||
onnx_opset_13.set_domain("");
|
||||
onnx_opset_13.set_version(13);
|
||||
|
||||
return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13});
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(BiasGeluGrad_dX)
|
||||
.SetDomain(kMSDomain)
|
||||
|
|
|
|||
|
|
@ -169,29 +169,34 @@ TEST_F(FunExpansionTest, DropoutGrad_WithRatio2) {
|
|||
CheckDropoutGradWithRatio<MLFloat16>(true);
|
||||
}
|
||||
|
||||
TEST_F(FunExpansionTest, GeluGrad_2D) {
|
||||
FunctionTestCase testCase("GeluGrad");
|
||||
template <typename T, bool RunTest = true>
|
||||
void TestUnaryOpGrad(const char* opname) {
|
||||
|
||||
FunctionTestCase testCase(opname);
|
||||
std::vector<int64_t> shape{16, 4};
|
||||
testCase.AddInput<float>("dY", shape);
|
||||
testCase.AddInput<float>("X", shape);
|
||||
testCase.AddInput<T, RunTest>("dY", shape);
|
||||
testCase.AddInput<T, RunTest>("X", shape);
|
||||
testCase.AddOutput("dX");
|
||||
testCase.RunTest();
|
||||
if (RunTest)
|
||||
testCase.RunTest();
|
||||
else
|
||||
// Test only expanded model creation and model checking.
|
||||
testCase.CreateModel(true);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CheckGeluGrad() {
|
||||
// Tests only expanded model creation and checking.
|
||||
FunctionTestCase testCase("GeluGrad");
|
||||
std::vector<int64_t> shape{16, 4};
|
||||
testCase.AddInput<T, false>("dY", shape);
|
||||
testCase.AddInput<T, false>("X", shape);
|
||||
testCase.AddOutput("dX");
|
||||
testCase.CreateModel(true);
|
||||
TEST_F(FunExpansionTest, GeluGrad_float) {
|
||||
TestUnaryOpGrad<float, true>("GeluGrad");
|
||||
}
|
||||
|
||||
TEST_F(FunExpansionTest, GeluGrad_HalfPrecision) {
|
||||
CheckGeluGrad<BFloat16>();
|
||||
CheckGeluGrad<MLFloat16>();
|
||||
TestUnaryOpGrad<BFloat16, false>("GeluGrad");
|
||||
TestUnaryOpGrad<MLFloat16, false>("GeluGrad");
|
||||
}
|
||||
|
||||
TEST_F(FunExpansionTest, FastGeluGrad) {
|
||||
TestUnaryOpGrad<float, true>("FastGeluGrad");
|
||||
TestUnaryOpGrad<BFloat16, false>("FastGeluGrad");
|
||||
TestUnaryOpGrad<MLFloat16, false>("FastGeluGrad");
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
|
|
|
|||
Loading…
Reference in a new issue