Add function body to GeluGrad schema (#7190)

* Add GeluGrad function definition

* complete gelugrad function definition

* add opset to function definition
This commit is contained in:
G. Ramalingam 2021-04-06 12:40:59 -07:00 committed by GitHub
parent dbcfc4bee6
commit a9ff4c29e5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 62 additions and 1 deletions

View file

@ -15,6 +15,12 @@ namespace ONNX_NAMESPACE {
// For floating-value constants of different precision:
TensorProto ToTensor(double value, TensorProto_DataType elem_type);
// Utility function to construct a constant of given type/precision.
inline static FunctionBodyHelper::NodeDef Const(const std::string& name, double value, TensorProto_DataType elem_type) {
return FunctionBodyHelper::NodeDef{
{name}, "Constant", {}, {{"value", ToTensor(value, elem_type)}}};
}
// Utility function to construct a FunctionProto from an opschema (for the signature information),
// a sequence of NodeDefs (for the function body), and the relied opsets.
bool BuildFunctionProto(FunctionProto& functionProto,

View file

@ -1596,7 +1596,42 @@ Example 4:
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)
.SetContextDependentFunctionBodyBuilder(
[](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) {
/* Default GeluGrad computation:
dX = dY * [0.5f * [erf(sqrt(1/2)*X) + 1.0] + alpha*X*exp(-0.5f * X * X)]
which expands to the following ONNX graph:
*/
auto* tp = ctx.getInputType(0);
if ((tp == nullptr) || (!tp->has_tensor_type()))
return false;
auto elem_type = (ONNX_NAMESPACE::TensorProto_DataType)tp->tensor_type().elem_type();
double kAlpha = M_2_SQRTPI * M_SQRT1_2 * 0.5;
std::vector<FunctionBodyHelper::NodeDef> body{
ONNX_NAMESPACE::Const("C_Half", 0.5f, elem_type),
ONNX_NAMESPACE::Const("C_One", 1.0f, elem_type),
ONNX_NAMESPACE::Const("C_SqrtHalf", float(M_SQRT1_2), elem_type),
ONNX_NAMESPACE::Const("C_MinusHalf", -0.5f, elem_type),
ONNX_NAMESPACE::Const("C_alpha", kAlpha, elem_type),
{{"ErfArg"}, "Mul", {"X", "C_SqrtHalf"}},
{{"ErfTerm"}, "Erf", {"ErfArg"}},
{{"PartialSum"}, "Add", {"ErfTerm", "C_One"}},
{{"HalfPartialSum"}, "Mul", {"C_Half", "PartialSum"}},
{{"AlphaX"}, "Mul", {"X", "C_alpha"}},
{{"MinusHalfX"}, "Mul", {"C_MinusHalf", "X"}},
{{"ExpArg"}, "Mul", {"MinusHalfX", "X"}},
{{"ExpTerm"}, "Exp", {"ExpArg"}},
{{"Term3"}, "Mul", {"AlphaX", "ExpTerm"}},
{{"FullSum"}, "Add", {"HalfPartialSum", "Term3"}},
{{"dX"}, "Mul", {"dY", "FullSum"}}};
OperatorSetIdProto onnx_opset_13;
onnx_opset_13.set_domain("");
onnx_opset_13.set_version(13);
return ONNX_NAMESPACE::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13});
});
ONNX_CONTRIB_OPERATOR_SCHEMA(LayerNormalizationGrad)
.SetDomain(kMSDomain)

View file

@ -207,6 +207,20 @@ static void InitSoftmaxGradTestCase(FunctionTestCase& testCase, std::vector<int6
testCase.AddOutput("dX");
}
static void InitGeluGradTestCase(FunctionTestCase& testCase, std::vector<int64_t> shape) {
int64_t size = 1;
for (auto dim : shape)
size *= dim;
std::vector<float> value(size);
for (int64_t i = 0; i < size; i++)
value[i] = float(i) / 100.0f;
testCase.AddInput("dY", shape, value);
testCase.AddInput("X", shape, value);
testCase.AddOutput("dX");
}
TEST(SoftmaxGradExpansionTest, DefaultAxis) {
FunctionTestCase testCase("SoftmaxGrad");
InitSoftmaxGradTestCase(testCase, {3, 2});
@ -272,5 +286,11 @@ TEST(SoftmaxGradExpansionTest, OpsetTest) {
AssertEqual(results1, results2);
}
TEST(GeluGradExpansionTest, 2D) {
FunctionTestCase testCase("GeluGrad");
InitGeluGradTestCase(testCase, {16, 4});
testCase.RunTest();
}
} // namespace test
} // namespace onnxruntime