Add FlattenGrad and test. (#5461)

Co-authored-by: Derek Murray <demurra@microsoft.com>
This commit is contained in:
Derek Murray 2020-10-15 16:11:57 -07:00 committed by GitHub
parent 88f6523baf
commit 64f6d856e4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 22 additions and 0 deletions

View file

@ -1463,5 +1463,12 @@ IMPLEMENT_GRADIENT_BUILDER(GetExpGradient) {
{GI(0)})};
}
IMPLEMENT_GRADIENT_BUILDER(GetFlattenGradient) {
return std::vector<NodeDef>{
NodeDef("Shape", {I(0)}, {IA("input_shape")}),
NodeDef("Reshape", {GO(0), IA("input_shape")}, {GI(0)})
};
}
} // namespace training
} // namespace onnxruntime

View file

@ -66,6 +66,7 @@ DECLARE_GRADIENT_BUILDER(GetSendGradient)
DECLARE_GRADIENT_BUILDER(GetRecvGradient)
DECLARE_GRADIENT_BUILDER(GetExpandGradient)
DECLARE_GRADIENT_BUILDER(GetExpGradient)
DECLARE_GRADIENT_BUILDER(GetFlattenGradient)
} // namespace training
} // namespace onnxruntime

View file

@ -97,6 +97,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("Recv", GetRecvGradient);
REGISTER_GRADIENT_BUILDER("Expand", GetExpandGradient);
REGISTER_GRADIENT_BUILDER("Exp", GetExpGradient);
REGISTER_GRADIENT_BUILDER("Flatten", GetFlattenGradient);
};
} // namespace training

View file

@ -437,6 +437,19 @@ TEST(GradientCheckerTest, ExpGrad) {
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
}
TEST(GradientCheckerTest, FlattenGrad) {
TensorShape shape({2, 3, 4});
float max_error;
float error_tolerance = 1e-3f;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"Flatten", kOnnxDomain, 11};
for (int axis = -3; axis < 3; ++axis) {
gradient_checker.ComputeGradientError(op_def, {shape}, {shape}, &max_error, {MakeAttribute("axis", int64_t(axis))});
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
}
}
TEST(GradientCheckerTest, TanhGrad) {
UnaryOpGradientTest("Tanh");
}