mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Add function body to GeluGrad schema (#7190)
* Add GeluGrad function definition * complete gelugrad function definition * add opset to function definition
This commit is contained in:
parent
dbcfc4bee6
commit
a9ff4c29e5
3 changed files with 62 additions and 1 deletions
|
|
@ -15,6 +15,12 @@ namespace ONNX_NAMESPACE {
|
|||
// For floating-value constants of different precision:
|
||||
TensorProto ToTensor(double value, TensorProto_DataType elem_type);
|
||||
|
||||
// Utility function to construct a constant of given type/precision.
|
||||
inline static FunctionBodyHelper::NodeDef Const(const std::string& name, double value, TensorProto_DataType elem_type) {
|
||||
return FunctionBodyHelper::NodeDef{
|
||||
{name}, "Constant", {}, {{"value", ToTensor(value, elem_type)}}};
|
||||
}
|
||||
|
||||
// Utility function to construct a FunctionProto from an opschema (for the signature information),
|
||||
// a sequence of NodeDefs (for the function body), and the relied opsets.
|
||||
bool BuildFunctionProto(FunctionProto& functionProto,
|
||||
|
|
|
|||
|
|
@ -1596,7 +1596,42 @@ Example 4:
|
|||
"T",
|
||||
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
|
||||
"Constrain input and output types to float tensors.")
|
||||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
|
||||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)
|
||||
.SetContextDependentFunctionBodyBuilder(
|
||||
[](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) {
|
||||
/* Default GeluGrad computation:
|
||||
dX = dY * [0.5f * [erf(sqrt(1/2)*X) + 1.0] + alpha*X*exp(-0.5f * X * X)]
|
||||
which expands to the following ONNX graph:
|
||||
*/
|
||||
auto* tp = ctx.getInputType(0);
|
||||
if ((tp == nullptr) || (!tp->has_tensor_type()))
|
||||
return false;
|
||||
auto elem_type = (ONNX_NAMESPACE::TensorProto_DataType)tp->tensor_type().elem_type();
|
||||
double kAlpha = M_2_SQRTPI * M_SQRT1_2 * 0.5;
|
||||
std::vector<FunctionBodyHelper::NodeDef> body{
|
||||
ONNX_NAMESPACE::Const("C_Half", 0.5f, elem_type),
|
||||
ONNX_NAMESPACE::Const("C_One", 1.0f, elem_type),
|
||||
ONNX_NAMESPACE::Const("C_SqrtHalf", float(M_SQRT1_2), elem_type),
|
||||
ONNX_NAMESPACE::Const("C_MinusHalf", -0.5f, elem_type),
|
||||
ONNX_NAMESPACE::Const("C_alpha", kAlpha, elem_type),
|
||||
{{"ErfArg"}, "Mul", {"X", "C_SqrtHalf"}},
|
||||
{{"ErfTerm"}, "Erf", {"ErfArg"}},
|
||||
{{"PartialSum"}, "Add", {"ErfTerm", "C_One"}},
|
||||
{{"HalfPartialSum"}, "Mul", {"C_Half", "PartialSum"}},
|
||||
{{"AlphaX"}, "Mul", {"X", "C_alpha"}},
|
||||
{{"MinusHalfX"}, "Mul", {"C_MinusHalf", "X"}},
|
||||
{{"ExpArg"}, "Mul", {"MinusHalfX", "X"}},
|
||||
{{"ExpTerm"}, "Exp", {"ExpArg"}},
|
||||
{{"Term3"}, "Mul", {"AlphaX", "ExpTerm"}},
|
||||
{{"FullSum"}, "Add", {"HalfPartialSum", "Term3"}},
|
||||
{{"dX"}, "Mul", {"dY", "FullSum"}}};
|
||||
|
||||
OperatorSetIdProto onnx_opset_13;
|
||||
onnx_opset_13.set_domain("");
|
||||
onnx_opset_13.set_version(13);
|
||||
|
||||
return ONNX_NAMESPACE::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13});
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(LayerNormalizationGrad)
|
||||
.SetDomain(kMSDomain)
|
||||
|
|
|
|||
|
|
@ -207,6 +207,20 @@ static void InitSoftmaxGradTestCase(FunctionTestCase& testCase, std::vector<int6
|
|||
testCase.AddOutput("dX");
|
||||
}
|
||||
|
||||
static void InitGeluGradTestCase(FunctionTestCase& testCase, std::vector<int64_t> shape) {
|
||||
int64_t size = 1;
|
||||
for (auto dim : shape)
|
||||
size *= dim;
|
||||
|
||||
std::vector<float> value(size);
|
||||
for (int64_t i = 0; i < size; i++)
|
||||
value[i] = float(i) / 100.0f;
|
||||
|
||||
testCase.AddInput("dY", shape, value);
|
||||
testCase.AddInput("X", shape, value);
|
||||
testCase.AddOutput("dX");
|
||||
}
|
||||
|
||||
TEST(SoftmaxGradExpansionTest, DefaultAxis) {
|
||||
FunctionTestCase testCase("SoftmaxGrad");
|
||||
InitSoftmaxGradTestCase(testCase, {3, 2});
|
||||
|
|
@ -272,5 +286,11 @@ TEST(SoftmaxGradExpansionTest, OpsetTest) {
|
|||
AssertEqual(results1, results2);
|
||||
}
|
||||
|
||||
TEST(GeluGradExpansionTest, 2D) {
|
||||
FunctionTestCase testCase("GeluGrad");
|
||||
InitGeluGradTestCase(testCase, {16, 4});
|
||||
testCase.RunTest();
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue