diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 9fe1e898a4..0ed4590a19 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1456,5 +1456,12 @@ IMPLEMENT_GRADIENT_BUILDER(GetExpandGradient) { return output; } +IMPLEMENT_GRADIENT_BUILDER(GetExpGradient) { + return std::vector{ + NodeDef("Mul", + {GO(0), O(0)}, + {GI(0)})}; +} + } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index 92f679753b..6116544102 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -65,6 +65,7 @@ DECLARE_GRADIENT_BUILDER(GetWhereGradient) DECLARE_GRADIENT_BUILDER(GetSendGradient) DECLARE_GRADIENT_BUILDER(GetRecvGradient) DECLARE_GRADIENT_BUILDER(GetExpandGradient) +DECLARE_GRADIENT_BUILDER(GetExpGradient) } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index 9ac3f51419..c68926b73f 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -96,6 +96,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("Send", GetSendGradient); REGISTER_GRADIENT_BUILDER("Recv", GetRecvGradient); REGISTER_GRADIENT_BUILDER("Expand", GetExpandGradient); + REGISTER_GRADIENT_BUILDER("Exp", GetExpGradient); }; } // namespace training diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 7c59539c1b..a78e5d9fb5 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -416,6 +416,27 @@ TEST(GradientCheckerTest, LogGrad) { EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } +TEST(GradientCheckerTest, ExpGrad) { + // Define input data with a narrower distribution than the default GradientChecker, to avoid + // precision issues. + TensorShape shape({2, 3, 4}); + std::vector> x_datas(1); + const auto seed = GetTestRandomSeed(); + std::default_random_engine generator{gsl::narrow_cast(seed)}; + std::uniform_real_distribution distribution{-1.0, 1.0}; + x_datas[0].resize(shape.Size()); + std::generate(x_datas[0].begin(), x_datas[0].end(), [&] { return distribution(generator); }); + + float max_error; + float error_tolerance = 1e-3f; + GradientChecker gradient_checker; + OpDef op_def{"Exp"}; + + gradient_checker.ComputeGradientError(op_def, {shape}, {shape}, &max_error, x_datas); + + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); +} + TEST(GradientCheckerTest, TanhGrad) { UnaryOpGradientTest("Tanh"); }