From 64f6d856e4f06e4f4e54c257695a761e6e5c7f2a Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Thu, 15 Oct 2020 16:11:57 -0700 Subject: [PATCH] Add FlattenGrad and test. (#5461) Co-authored-by: Derek Murray --- .../orttraining/core/graph/gradient_builder.cc | 7 +++++++ .../orttraining/core/graph/gradient_builder.h | 1 + .../core/graph/gradient_builder_registry.cc | 1 + .../orttraining/test/gradient/gradient_ops_test.cc | 13 +++++++++++++ 4 files changed, 22 insertions(+) diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 0ed4590a19..a11de95f3b 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1463,5 +1463,12 @@ IMPLEMENT_GRADIENT_BUILDER(GetExpGradient) { {GI(0)})}; } +IMPLEMENT_GRADIENT_BUILDER(GetFlattenGradient) { + return std::vector{ + NodeDef("Shape", {I(0)}, {IA("input_shape")}), + NodeDef("Reshape", {GO(0), IA("input_shape")}, {GI(0)}) + }; +} + } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index 6116544102..c6643335c3 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -66,6 +66,7 @@ DECLARE_GRADIENT_BUILDER(GetSendGradient) DECLARE_GRADIENT_BUILDER(GetRecvGradient) DECLARE_GRADIENT_BUILDER(GetExpandGradient) DECLARE_GRADIENT_BUILDER(GetExpGradient) +DECLARE_GRADIENT_BUILDER(GetFlattenGradient) } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index c68926b73f..3261d3cec5 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -97,6 +97,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("Recv", GetRecvGradient); REGISTER_GRADIENT_BUILDER("Expand", GetExpandGradient); REGISTER_GRADIENT_BUILDER("Exp", GetExpGradient); + REGISTER_GRADIENT_BUILDER("Flatten", GetFlattenGradient); }; } // namespace training diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index a78e5d9fb5..291cbe8920 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -437,6 +437,19 @@ TEST(GradientCheckerTest, ExpGrad) { EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } +TEST(GradientCheckerTest, FlattenGrad) { + TensorShape shape({2, 3, 4}); + float max_error; + float error_tolerance = 1e-3f; + GradientChecker gradient_checker; + OpDef op_def{"Flatten", kOnnxDomain, 11}; + + for (int axis = -3; axis < 3; ++axis) { + gradient_checker.ComputeGradientError(op_def, {shape}, {shape}, &max_error, {MakeAttribute("axis", int64_t(axis))}); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } +} + TEST(GradientCheckerTest, TanhGrad) { UnaryOpGradientTest("Tanh"); }