mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-02 23:39:58 +00:00
Add ExpGrad registration and test. (#5438)
**Description**: Add missing gradient registration for the `Exp` op. **Motivation and Context** * Adding support for training a model that uses the `Exp` op. Co-authored-by: Derek Murray <demurra@microsoft.com>
This commit is contained in:
parent
2a018cc235
commit
dbc626dcbe
4 changed files with 30 additions and 0 deletions
|
|
@ -1456,5 +1456,12 @@ IMPLEMENT_GRADIENT_BUILDER(GetExpandGradient) {
|
|||
return output;
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetExpGradient) {
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef("Mul",
|
||||
{GO(0), O(0)},
|
||||
{GI(0)})};
|
||||
}
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -65,6 +65,7 @@ DECLARE_GRADIENT_BUILDER(GetWhereGradient)
|
|||
DECLARE_GRADIENT_BUILDER(GetSendGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetRecvGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetExpandGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetExpGradient)
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -96,6 +96,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
|
|||
REGISTER_GRADIENT_BUILDER("Send", GetSendGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Recv", GetRecvGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Expand", GetExpandGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Exp", GetExpGradient);
|
||||
};
|
||||
|
||||
} // namespace training
|
||||
|
|
|
|||
|
|
@ -416,6 +416,27 @@ TEST(GradientCheckerTest, LogGrad) {
|
|||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, ExpGrad) {
|
||||
// Define input data with a narrower distribution than the default GradientChecker, to avoid
|
||||
// precision issues.
|
||||
TensorShape shape({2, 3, 4});
|
||||
std::vector<std::vector<float>> x_datas(1);
|
||||
const auto seed = GetTestRandomSeed();
|
||||
std::default_random_engine generator{gsl::narrow_cast<decltype(generator)::result_type>(seed)};
|
||||
std::uniform_real_distribution<float> distribution{-1.0, 1.0};
|
||||
x_datas[0].resize(shape.Size());
|
||||
std::generate(x_datas[0].begin(), x_datas[0].end(), [&] { return distribution(generator); });
|
||||
|
||||
float max_error;
|
||||
float error_tolerance = 1e-3f;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
OpDef op_def{"Exp"};
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {shape}, {shape}, &max_error, x_datas);
|
||||
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, TanhGrad) {
|
||||
UnaryOpGradientTest("Tanh");
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue