2018-05-01 01:36:35 +00:00
|
|
|
#include <catch.hpp>
|
|
|
|
|
|
2018-05-04 15:04:57 +00:00
|
|
|
#include <torch/torch.h>
|
2018-05-01 01:36:35 +00:00
|
|
|
|
2018-05-04 15:04:57 +00:00
|
|
|
using namespace torch;
|
2018-05-07 21:45:00 +00:00
|
|
|
using namespace torch::nn;
|
2018-05-01 01:36:35 +00:00
|
|
|
|
|
|
|
|
bool test_optimizer_xor(Optimizer optim, std::shared_ptr<ContainerList> model) {
|
|
|
|
|
float running_loss = 1;
|
|
|
|
|
int epoch = 0;
|
|
|
|
|
while (running_loss > 0.1) {
|
2018-05-15 01:24:58 +00:00
|
|
|
int64_t bs = 4;
|
2018-05-01 01:36:35 +00:00
|
|
|
auto inp = at::CPU(at::kFloat).tensor({bs, 2});
|
|
|
|
|
auto lab = at::CPU(at::kFloat).tensor({bs});
|
2018-05-15 01:24:58 +00:00
|
|
|
for (size_t i = 0; i < bs; i++) {
|
|
|
|
|
const int64_t a = std::rand() % 2;
|
|
|
|
|
const int64_t b = std::rand() % 2;
|
|
|
|
|
const int64_t c = static_cast<uint64_t>(a) ^ static_cast<uint64_t>(b);
|
2018-05-01 01:36:35 +00:00
|
|
|
inp[i][0] = a;
|
|
|
|
|
inp[i][1] = b;
|
|
|
|
|
lab[i] = c;
|
|
|
|
|
}
|
|
|
|
|
// forward
|
2018-05-17 18:03:08 +00:00
|
|
|
auto input = Var(inp);
|
|
|
|
|
auto target = Var(lab, false);
|
2018-05-01 01:36:35 +00:00
|
|
|
|
2018-05-17 18:03:08 +00:00
|
|
|
std::function<at::Scalar()> closure = [&]() -> at::Scalar {
|
|
|
|
|
optim->zero_grad();
|
|
|
|
|
auto x = input;
|
|
|
|
|
for (auto& layer : *model)
|
|
|
|
|
x = layer->forward({x})[0].sigmoid_();
|
|
|
|
|
Variable loss = at::binary_cross_entropy(x, target);
|
2018-05-22 22:42:52 +00:00
|
|
|
backward(loss);
|
2018-05-17 18:03:08 +00:00
|
|
|
return at::Scalar(loss.data());
|
|
|
|
|
};
|
2018-05-01 01:36:35 +00:00
|
|
|
|
2018-05-17 18:03:08 +00:00
|
|
|
at::Scalar loss = optim->step(closure);
|
|
|
|
|
|
|
|
|
|
running_loss = running_loss * 0.99 + loss.toFloat() * 0.01;
|
2018-05-01 01:36:35 +00:00
|
|
|
if (epoch > 3000) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
epoch++;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_CASE("optim") {
|
2018-05-17 18:03:08 +00:00
|
|
|
std::srand(0);
|
|
|
|
|
setSeed(0);
|
2018-05-17 21:10:15 +00:00
|
|
|
auto model = std::make_shared<ContainerList>();
|
|
|
|
|
model->append(Linear(2, 8).build());
|
|
|
|
|
model->append(Linear(8, 1).build());
|
2018-05-01 01:36:35 +00:00
|
|
|
|
2018-05-17 18:03:08 +00:00
|
|
|
SECTION("lbfgs") {
|
|
|
|
|
auto optim = LBFGS(model, 5e-2).max_iter(5).make();
|
|
|
|
|
REQUIRE(test_optimizer_xor(optim, model));
|
|
|
|
|
}
|
|
|
|
|
|
2018-05-07 21:45:00 +00:00
|
|
|
SECTION("sgd") {
|
2018-05-01 01:36:35 +00:00
|
|
|
auto optim =
|
|
|
|
|
SGD(model, 1e-1).momentum(0.9).nesterov().weight_decay(1e-6).make();
|
|
|
|
|
REQUIRE(test_optimizer_xor(optim, model));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SECTION("adagrad") {
|
|
|
|
|
auto optim = Adagrad(model, 1.0).weight_decay(1e-6).lr_decay(1e-3).make();
|
|
|
|
|
REQUIRE(test_optimizer_xor(optim, model));
|
|
|
|
|
}
|
|
|
|
|
|
2018-05-07 21:45:00 +00:00
|
|
|
SECTION("rmsprop_simple") {
|
|
|
|
|
auto optim = RMSprop(model, 1e-1).centered().make();
|
|
|
|
|
REQUIRE(test_optimizer_xor(optim, model));
|
|
|
|
|
}
|
2018-05-01 01:36:35 +00:00
|
|
|
|
2018-05-07 21:45:00 +00:00
|
|
|
SECTION("rmsprop") {
|
|
|
|
|
auto optim = RMSprop(model, 1e-1).momentum(0.9).weight_decay(1e-6).make();
|
|
|
|
|
REQUIRE(test_optimizer_xor(optim, model));
|
2018-05-01 01:36:35 +00:00
|
|
|
}
|
|
|
|
|
|
2018-05-04 17:23:52 +00:00
|
|
|
/*
|
2018-05-15 01:24:58 +00:00
|
|
|
// This test appears to be flaky, see
|
|
|
|
|
https://github.com/pytorch/pytorch/issues/7288 SECTION("adam") { auto optim =
|
|
|
|
|
Adam(model, 1.0).weight_decay(1e-6).make(); REQUIRE(test_optimizer_xor(optim,
|
|
|
|
|
model));
|
2018-05-01 01:36:35 +00:00
|
|
|
}
|
2018-05-04 17:23:52 +00:00
|
|
|
*/
|
2018-05-01 01:36:35 +00:00
|
|
|
|
|
|
|
|
SECTION("amsgrad") {
|
|
|
|
|
auto optim = Adam(model, 0.1).weight_decay(1e-6).amsgrad().make();
|
|
|
|
|
REQUIRE(test_optimizer_xor(optim, model));
|
|
|
|
|
}
|
|
|
|
|
}
|