mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
Add FlattenGrad and test. (#5461)
Co-authored-by: Derek Murray <demurra@microsoft.com>
This commit is contained in:
parent
88f6523baf
commit
64f6d856e4
4 changed files with 22 additions and 0 deletions
|
|
@ -1463,5 +1463,12 @@ IMPLEMENT_GRADIENT_BUILDER(GetExpGradient) {
|
|||
{GI(0)})};
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetFlattenGradient) {
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef("Shape", {I(0)}, {IA("input_shape")}),
|
||||
NodeDef("Reshape", {GO(0), IA("input_shape")}, {GI(0)})
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -66,6 +66,7 @@ DECLARE_GRADIENT_BUILDER(GetSendGradient)
|
|||
DECLARE_GRADIENT_BUILDER(GetRecvGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetExpandGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetExpGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetFlattenGradient)
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -97,6 +97,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
|
|||
REGISTER_GRADIENT_BUILDER("Recv", GetRecvGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Expand", GetExpandGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Exp", GetExpGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Flatten", GetFlattenGradient);
|
||||
};
|
||||
|
||||
} // namespace training
|
||||
|
|
|
|||
|
|
@ -437,6 +437,19 @@ TEST(GradientCheckerTest, ExpGrad) {
|
|||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, FlattenGrad) {
|
||||
TensorShape shape({2, 3, 4});
|
||||
float max_error;
|
||||
float error_tolerance = 1e-3f;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
OpDef op_def{"Flatten", kOnnxDomain, 11};
|
||||
|
||||
for (int axis = -3; axis < 3; ++axis) {
|
||||
gradient_checker.ComputeGradientError(op_def, {shape}, {shape}, &max_error, {MakeAttribute("axis", int64_t(axis))});
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, TanhGrad) {
|
||||
UnaryOpGradientTest("Tanh");
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue