Add ExpGrad registration and test. (#5438)

**Description**: Add missing gradient registration for the `Exp` op.

**Motivation and Context**
* Adding support for training a model that uses the `Exp` op.

Co-authored-by: Derek Murray <demurra@microsoft.com>
This commit is contained in:
Derek Murray 2020-10-12 13:56:08 -07:00 committed by GitHub
parent 2a018cc235
commit dbc626dcbe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 30 additions and 0 deletions

View file

@ -1456,5 +1456,12 @@ IMPLEMENT_GRADIENT_BUILDER(GetExpandGradient) {
return output;
}
IMPLEMENT_GRADIENT_BUILDER(GetExpGradient) {
return std::vector<NodeDef>{
NodeDef("Mul",
{GO(0), O(0)},
{GI(0)})};
}
} // namespace training
} // namespace onnxruntime

View file

@ -65,6 +65,7 @@ DECLARE_GRADIENT_BUILDER(GetWhereGradient)
DECLARE_GRADIENT_BUILDER(GetSendGradient)
DECLARE_GRADIENT_BUILDER(GetRecvGradient)
DECLARE_GRADIENT_BUILDER(GetExpandGradient)
DECLARE_GRADIENT_BUILDER(GetExpGradient)
} // namespace training
} // namespace onnxruntime

View file

@ -96,6 +96,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("Send", GetSendGradient);
REGISTER_GRADIENT_BUILDER("Recv", GetRecvGradient);
REGISTER_GRADIENT_BUILDER("Expand", GetExpandGradient);
REGISTER_GRADIENT_BUILDER("Exp", GetExpGradient);
};
} // namespace training

View file

@ -416,6 +416,27 @@ TEST(GradientCheckerTest, LogGrad) {
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
}
TEST(GradientCheckerTest, ExpGrad) {
// Define input data with a narrower distribution than the default GradientChecker, to avoid
// precision issues.
TensorShape shape({2, 3, 4});
std::vector<std::vector<float>> x_datas(1);
const auto seed = GetTestRandomSeed();
std::default_random_engine generator{gsl::narrow_cast<decltype(generator)::result_type>(seed)};
std::uniform_real_distribution<float> distribution{-1.0, 1.0};
x_datas[0].resize(shape.Size());
std::generate(x_datas[0].begin(), x_datas[0].end(), [&] { return distribution(generator); });
float max_error;
float error_tolerance = 1e-3f;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"Exp"};
gradient_checker.ComputeGradientError(op_def, {shape}, {shape}, &max_error, x_datas);
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
}
TEST(GradientCheckerTest, TanhGrad) {
UnaryOpGradientTest("Tanh");
}