From dbc626dcbe3e2e17529f0696759b3b7bc95918e2 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Mon, 12 Oct 2020 13:56:08 -0700 Subject: [PATCH] Add ExpGrad registration and test. (#5438) **Description**: Add missing gradient registration for the `Exp` op. **Motivation and Context** * Adding support for training a model that uses the `Exp` op. Co-authored-by: Derek Murray --- .../core/graph/gradient_builder.cc | 7 +++++++ .../orttraining/core/graph/gradient_builder.h | 1 + .../core/graph/gradient_builder_registry.cc | 1 + .../test/gradient/gradient_ops_test.cc | 21 +++++++++++++++++++ 4 files changed, 30 insertions(+) diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 9fe1e898a4..0ed4590a19 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1456,5 +1456,12 @@ IMPLEMENT_GRADIENT_BUILDER(GetExpandGradient) { return output; } +IMPLEMENT_GRADIENT_BUILDER(GetExpGradient) { + return std::vector{ + NodeDef("Mul", + {GO(0), O(0)}, + {GI(0)})}; +} + } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index 92f679753b..6116544102 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -65,6 +65,7 @@ DECLARE_GRADIENT_BUILDER(GetWhereGradient) DECLARE_GRADIENT_BUILDER(GetSendGradient) DECLARE_GRADIENT_BUILDER(GetRecvGradient) DECLARE_GRADIENT_BUILDER(GetExpandGradient) +DECLARE_GRADIENT_BUILDER(GetExpGradient) } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index 9ac3f51419..c68926b73f 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -96,6 +96,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("Send", GetSendGradient); REGISTER_GRADIENT_BUILDER("Recv", GetRecvGradient); REGISTER_GRADIENT_BUILDER("Expand", GetExpandGradient); + REGISTER_GRADIENT_BUILDER("Exp", GetExpGradient); }; } // namespace training diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 7c59539c1b..a78e5d9fb5 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -416,6 +416,27 @@ TEST(GradientCheckerTest, LogGrad) { EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } +TEST(GradientCheckerTest, ExpGrad) { + // Define input data with a narrower distribution than the default GradientChecker, to avoid + // precision issues. + TensorShape shape({2, 3, 4}); + std::vector> x_datas(1); + const auto seed = GetTestRandomSeed(); + std::default_random_engine generator{gsl::narrow_cast(seed)}; + std::uniform_real_distribution distribution{-1.0, 1.0}; + x_datas[0].resize(shape.Size()); + std::generate(x_datas[0].begin(), x_datas[0].end(), [&] { return distribution(generator); }); + + float max_error; + float error_tolerance = 1e-3f; + GradientChecker gradient_checker; + OpDef op_def{"Exp"}; + + gradient_checker.ComputeGradientError(op_def, {shape}, {shape}, &max_error, x_datas); + + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); +} + TEST(GradientCheckerTest, TanhGrad) { UnaryOpGradientTest("Tanh"); }