2018-05-01 01:36:35 +00:00
|
|
|
#include <catch.hpp>
|
|
|
|
|
|
2018-06-26 20:23:16 +00:00
|
|
|
#include <torch/nn/modules/functional.h>
|
2018-05-24 19:46:51 +00:00
|
|
|
#include <torch/nn/modules/linear.h>
|
|
|
|
|
#include <torch/nn/modules/sequential.h>
|
2018-06-26 17:13:14 +00:00
|
|
|
#include <torch/optim/optimizer.h>
|
|
|
|
|
#include <torch/optim/sgd.h>
|
2018-05-24 19:46:51 +00:00
|
|
|
#include <torch/serialization.h>
|
2018-05-30 15:55:34 +00:00
|
|
|
#include <torch/tensor.h>
|
2018-06-28 03:00:53 +00:00
|
|
|
#include <torch/utils.h>
|
2018-05-01 01:36:35 +00:00
|
|
|
|
2018-05-24 19:46:51 +00:00
|
|
|
#include <test/cpp/api/util.h>
|
|
|
|
|
|
|
|
|
|
#include <cereal/archives/portable_binary.hpp>
|
|
|
|
|
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <sstream>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
2018-05-01 01:36:35 +00:00
|
|
|
|
2018-05-07 21:45:00 +00:00
|
|
|
using namespace torch::nn;
|
2018-05-01 01:36:35 +00:00
|
|
|
|
2018-05-24 19:46:51 +00:00
|
|
|
namespace {
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
Sequential xor_model() {
|
|
|
|
|
return Sequential(
|
2018-06-26 20:23:16 +00:00
|
|
|
Linear(2, 8),
|
|
|
|
|
Functional(at::sigmoid),
|
|
|
|
|
Linear(8, 1),
|
|
|
|
|
Functional(at::sigmoid));
|
2018-05-24 19:46:51 +00:00
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2018-05-01 01:36:35 +00:00
|
|
|
TEST_CASE("serialization") {
|
2018-06-28 03:00:53 +00:00
|
|
|
torch::manual_seed(0);
|
2018-05-01 01:36:35 +00:00
|
|
|
SECTION("undefined") {
|
2018-06-25 02:03:39 +00:00
|
|
|
auto x = torch::Tensor();
|
2018-05-01 01:36:35 +00:00
|
|
|
|
|
|
|
|
REQUIRE(!x.defined());
|
|
|
|
|
|
2018-06-25 02:03:39 +00:00
|
|
|
auto y = torch::randn({5});
|
2018-05-01 01:36:35 +00:00
|
|
|
|
|
|
|
|
std::stringstream ss;
|
2018-06-25 02:03:39 +00:00
|
|
|
torch::save(ss, &x);
|
|
|
|
|
torch::load(ss, &y);
|
2018-05-01 01:36:35 +00:00
|
|
|
|
|
|
|
|
REQUIRE(!y.defined());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SECTION("cputypes") {
|
2018-06-27 21:34:06 +00:00
|
|
|
for (int i = 0; i < static_cast<int>(torch::Dtype::NumOptions); i++) {
|
|
|
|
|
if (i == static_cast<int>(torch::Dtype::Half)) {
|
2018-05-01 01:36:35 +00:00
|
|
|
// XXX can't serialize half tensors at the moment since contiguous() is
|
|
|
|
|
// not implemented for this type;
|
|
|
|
|
continue;
|
2018-09-04 02:00:47 +00:00
|
|
|
} else if (at::isComplexType(static_cast<torch::Dtype>(i))) {
|
|
|
|
|
// Not supported yet
|
|
|
|
|
continue;
|
2018-06-27 21:34:06 +00:00
|
|
|
} else if (i == static_cast<int>(torch::Dtype::Undefined)) {
|
2018-05-01 01:36:35 +00:00
|
|
|
// We can't construct a tensor for this type. This is tested in
|
|
|
|
|
// serialization/undefined anyway.
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
2018-06-25 02:03:39 +00:00
|
|
|
auto x = torch::ones(
|
2018-08-31 03:03:10 +00:00
|
|
|
{5, 5}, static_cast<torch::Dtype>(i));
|
2018-06-25 02:03:39 +00:00
|
|
|
auto y = torch::empty({});
|
2018-05-01 01:36:35 +00:00
|
|
|
|
|
|
|
|
std::stringstream ss;
|
2018-06-25 02:03:39 +00:00
|
|
|
torch::save(ss, &x);
|
|
|
|
|
torch::load(ss, &y);
|
2018-05-01 01:36:35 +00:00
|
|
|
|
|
|
|
|
REQUIRE(y.defined());
|
|
|
|
|
REQUIRE(x.sizes().vec() == y.sizes().vec());
|
2018-06-27 21:34:06 +00:00
|
|
|
if (torch::isIntegralType(static_cast<torch::Dtype>(i))) {
|
2018-05-01 01:36:35 +00:00
|
|
|
REQUIRE(x.equal(y));
|
|
|
|
|
} else {
|
|
|
|
|
REQUIRE(x.allclose(y));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SECTION("binary") {
|
2018-06-25 02:03:39 +00:00
|
|
|
auto x = torch::randn({5, 5});
|
|
|
|
|
auto y = torch::Tensor();
|
2018-05-01 01:36:35 +00:00
|
|
|
|
|
|
|
|
std::stringstream ss;
|
|
|
|
|
{
|
|
|
|
|
cereal::BinaryOutputArchive archive(ss);
|
|
|
|
|
archive(x);
|
|
|
|
|
}
|
|
|
|
|
{
|
|
|
|
|
cereal::BinaryInputArchive archive(ss);
|
|
|
|
|
archive(y);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
REQUIRE(y.defined());
|
|
|
|
|
REQUIRE(x.sizes().vec() == y.sizes().vec());
|
|
|
|
|
REQUIRE(x.allclose(y));
|
|
|
|
|
}
|
|
|
|
|
SECTION("portable_binary") {
|
2018-06-25 02:03:39 +00:00
|
|
|
auto x = torch::randn({5, 5});
|
|
|
|
|
auto y = torch::Tensor();
|
2018-05-01 01:36:35 +00:00
|
|
|
|
|
|
|
|
std::stringstream ss;
|
|
|
|
|
{
|
|
|
|
|
cereal::PortableBinaryOutputArchive archive(ss);
|
|
|
|
|
archive(x);
|
|
|
|
|
}
|
|
|
|
|
{
|
|
|
|
|
cereal::PortableBinaryInputArchive archive(ss);
|
|
|
|
|
archive(y);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
REQUIRE(y.defined());
|
|
|
|
|
REQUIRE(x.sizes().vec() == y.sizes().vec());
|
|
|
|
|
REQUIRE(x.allclose(y));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SECTION("resized") {
|
2018-06-25 02:03:39 +00:00
|
|
|
auto x = torch::randn({11, 5});
|
2018-05-01 01:36:35 +00:00
|
|
|
x.resize_({5, 5});
|
2018-06-25 02:03:39 +00:00
|
|
|
auto y = torch::Tensor();
|
2018-05-01 01:36:35 +00:00
|
|
|
|
|
|
|
|
std::stringstream ss;
|
|
|
|
|
{
|
|
|
|
|
cereal::BinaryOutputArchive archive(ss);
|
|
|
|
|
archive(x);
|
|
|
|
|
}
|
|
|
|
|
{
|
|
|
|
|
cereal::BinaryInputArchive archive(ss);
|
|
|
|
|
archive(y);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
REQUIRE(y.defined());
|
|
|
|
|
REQUIRE(x.sizes().vec() == y.sizes().vec());
|
|
|
|
|
REQUIRE(x.allclose(y));
|
|
|
|
|
}
|
|
|
|
|
SECTION("sliced") {
|
2018-06-25 02:03:39 +00:00
|
|
|
auto x = torch::randn({11, 5});
|
2018-05-01 01:36:35 +00:00
|
|
|
x = x.slice(0, 1, 3);
|
2018-06-25 02:03:39 +00:00
|
|
|
auto y = torch::Tensor();
|
2018-05-01 01:36:35 +00:00
|
|
|
|
|
|
|
|
std::stringstream ss;
|
|
|
|
|
{
|
|
|
|
|
cereal::BinaryOutputArchive archive(ss);
|
|
|
|
|
archive(x);
|
|
|
|
|
}
|
|
|
|
|
{
|
|
|
|
|
cereal::BinaryInputArchive archive(ss);
|
|
|
|
|
archive(y);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
REQUIRE(y.defined());
|
|
|
|
|
REQUIRE(x.sizes().vec() == y.sizes().vec());
|
|
|
|
|
REQUIRE(x.allclose(y));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SECTION("noncontig") {
|
2018-06-25 02:03:39 +00:00
|
|
|
auto x = torch::randn({11, 5});
|
2018-05-01 01:36:35 +00:00
|
|
|
x = x.slice(1, 1, 4);
|
2018-06-25 02:03:39 +00:00
|
|
|
auto y = torch::Tensor();
|
2018-05-01 01:36:35 +00:00
|
|
|
|
|
|
|
|
std::stringstream ss;
|
|
|
|
|
{
|
|
|
|
|
cereal::BinaryOutputArchive archive(ss);
|
|
|
|
|
archive(x);
|
|
|
|
|
}
|
|
|
|
|
{
|
|
|
|
|
cereal::BinaryInputArchive archive(ss);
|
|
|
|
|
archive(y);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
REQUIRE(y.defined());
|
|
|
|
|
REQUIRE(x.sizes().vec() == y.sizes().vec());
|
|
|
|
|
REQUIRE(x.allclose(y));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SECTION("xor") {
|
|
|
|
|
// We better be able to save and load a XOR model!
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
auto getLoss = [](Sequential model, uint32_t batch_size) {
|
2018-06-28 03:00:53 +00:00
|
|
|
auto inputs = torch::empty({batch_size, 2});
|
|
|
|
|
auto labels = torch::empty({batch_size});
|
|
|
|
|
for (size_t i = 0; i < batch_size; i++) {
|
|
|
|
|
inputs[i] = torch::randint(2, {2}, torch::kInt64);
|
|
|
|
|
labels[i] = inputs[i][0].toCLong() ^ inputs[i][1].toCLong();
|
2018-05-01 01:36:35 +00:00
|
|
|
}
|
2018-06-28 03:00:53 +00:00
|
|
|
auto x = model->forward<torch::Tensor>(inputs);
|
|
|
|
|
return torch::binary_cross_entropy(x, labels);
|
2018-05-01 01:36:35 +00:00
|
|
|
};
|
|
|
|
|
|
2018-05-24 19:46:51 +00:00
|
|
|
auto model = xor_model();
|
|
|
|
|
auto model2 = xor_model();
|
|
|
|
|
auto model3 = xor_model();
|
2018-06-26 17:13:14 +00:00
|
|
|
auto optimizer = torch::optim::SGD(
|
|
|
|
|
model->parameters(),
|
|
|
|
|
torch::optim::SGDOptions(1e-1)
|
|
|
|
|
.momentum(0.9)
|
|
|
|
|
.nesterov(true)
|
|
|
|
|
.weight_decay(1e-6));
|
2018-05-01 01:36:35 +00:00
|
|
|
|
|
|
|
|
float running_loss = 1;
|
|
|
|
|
int epoch = 0;
|
|
|
|
|
while (running_loss > 0.1) {
|
2018-06-25 02:03:39 +00:00
|
|
|
torch::Tensor loss = getLoss(model, 4);
|
2018-06-26 17:13:14 +00:00
|
|
|
optimizer.zero_grad();
|
2018-05-25 00:31:41 +00:00
|
|
|
loss.backward();
|
2018-06-26 17:13:14 +00:00
|
|
|
optimizer.step();
|
2018-05-01 01:36:35 +00:00
|
|
|
|
2018-08-16 04:09:33 +00:00
|
|
|
running_loss = running_loss * 0.99 + loss.sum().toCFloat() * 0.01;
|
2018-05-01 01:36:35 +00:00
|
|
|
REQUIRE(epoch < 3000);
|
|
|
|
|
epoch++;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::stringstream ss;
|
2018-06-25 02:03:39 +00:00
|
|
|
torch::save(ss, model);
|
|
|
|
|
torch::load(ss, model2);
|
2018-05-01 01:36:35 +00:00
|
|
|
|
|
|
|
|
auto loss = getLoss(model2, 100);
|
|
|
|
|
REQUIRE(loss.toCFloat() < 0.1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SECTION("optim") {
|
2018-06-19 02:45:53 +00:00
|
|
|
auto model1 = Linear(5, 2);
|
|
|
|
|
auto model2 = Linear(5, 2);
|
|
|
|
|
auto model3 = Linear(5, 2);
|
2018-05-01 01:36:35 +00:00
|
|
|
|
|
|
|
|
// Models 1, 2, 3 will have the same params
|
|
|
|
|
std::stringstream ss;
|
2018-06-25 02:03:39 +00:00
|
|
|
torch::save(ss, model1.get());
|
|
|
|
|
torch::load(ss, model2.get());
|
2018-05-01 01:36:35 +00:00
|
|
|
ss.seekg(0, std::ios::beg);
|
2018-06-25 02:03:39 +00:00
|
|
|
torch::load(ss, model3.get());
|
2018-05-01 01:36:35 +00:00
|
|
|
|
2018-08-13 17:11:45 +00:00
|
|
|
auto param1 = model1->parameters();
|
|
|
|
|
auto param2 = model2->parameters();
|
|
|
|
|
auto param3 = model3->parameters();
|
|
|
|
|
for (const auto& p : param1) {
|
|
|
|
|
REQUIRE(param1[p.key].allclose(param2[p.key]));
|
|
|
|
|
REQUIRE(param2[p.key].allclose(param3[p.key]));
|
|
|
|
|
}
|
|
|
|
|
|
2018-05-01 01:36:35 +00:00
|
|
|
// Make some optimizers with momentum (and thus state)
|
2018-06-26 17:13:14 +00:00
|
|
|
auto optim1 = torch::optim::SGD(
|
|
|
|
|
model1->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
|
|
|
|
|
auto optim2 = torch::optim::SGD(
|
|
|
|
|
model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
|
|
|
|
|
auto optim2_2 = torch::optim::SGD(
|
|
|
|
|
model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
|
|
|
|
|
auto optim3 = torch::optim::SGD(
|
|
|
|
|
model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
|
|
|
|
|
auto optim3_2 = torch::optim::SGD(
|
|
|
|
|
model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
|
2018-05-01 01:36:35 +00:00
|
|
|
|
2018-08-13 17:11:45 +00:00
|
|
|
auto x = torch::ones({10, 5});
|
2018-05-01 01:36:35 +00:00
|
|
|
|
2018-08-13 17:11:45 +00:00
|
|
|
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
|
2018-06-26 17:13:14 +00:00
|
|
|
optimizer.zero_grad();
|
2018-06-26 20:23:16 +00:00
|
|
|
auto y = model->forward(x).sum();
|
2018-05-25 00:31:41 +00:00
|
|
|
y.backward();
|
2018-06-26 17:13:14 +00:00
|
|
|
optimizer.step();
|
2018-05-01 01:36:35 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Do 2 steps of model1
|
|
|
|
|
step(optim1, model1);
|
|
|
|
|
step(optim1, model1);
|
|
|
|
|
|
|
|
|
|
// Do 2 steps of model 2 without saving the optimizer
|
|
|
|
|
step(optim2, model2);
|
|
|
|
|
step(optim2_2, model2);
|
|
|
|
|
|
|
|
|
|
// Do 2 steps of model 3 while saving the optimizer
|
|
|
|
|
step(optim3, model3);
|
|
|
|
|
ss.clear();
|
2018-06-26 17:13:14 +00:00
|
|
|
torch::save(ss, &optim3);
|
|
|
|
|
torch::load(ss, &optim3_2);
|
2018-05-01 01:36:35 +00:00
|
|
|
step(optim3_2, model3);
|
|
|
|
|
|
2018-08-13 17:11:45 +00:00
|
|
|
param1 = model1->parameters();
|
|
|
|
|
param2 = model2->parameters();
|
|
|
|
|
param3 = model3->parameters();
|
|
|
|
|
for (const auto& p : param1) {
|
|
|
|
|
const auto& name = p.key;
|
2018-05-01 01:36:35 +00:00
|
|
|
// Model 1 and 3 should be the same
|
|
|
|
|
REQUIRE(param1[name].norm().toCFloat() == param3[name].norm().toCFloat());
|
|
|
|
|
REQUIRE(param1[name].norm().toCFloat() != param2[name].norm().toCFloat());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_CASE("serialization_cuda", "[cuda]") {
|
2018-06-28 03:00:53 +00:00
|
|
|
torch::manual_seed(0);
|
2018-06-25 02:03:39 +00:00
|
|
|
// We better be able to save and load a XOR model!
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
auto getLoss = [](Sequential model, uint32_t batch_size) {
|
2018-06-28 03:00:53 +00:00
|
|
|
auto inputs = torch::empty({batch_size, 2});
|
|
|
|
|
auto labels = torch::empty({batch_size});
|
|
|
|
|
for (size_t i = 0; i < batch_size; i++) {
|
|
|
|
|
inputs[i] = torch::randint(2, {2}, torch::kInt64);
|
|
|
|
|
labels[i] = inputs[i][0].toCLong() ^ inputs[i][1].toCLong();
|
2018-05-01 01:36:35 +00:00
|
|
|
}
|
2018-06-28 03:00:53 +00:00
|
|
|
auto x = model->forward<torch::Tensor>(inputs);
|
|
|
|
|
return torch::binary_cross_entropy(x, labels);
|
2018-06-25 02:03:39 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto model = xor_model();
|
|
|
|
|
auto model2 = xor_model();
|
|
|
|
|
auto model3 = xor_model();
|
2018-06-26 17:13:14 +00:00
|
|
|
auto optimizer = torch::optim::SGD(
|
|
|
|
|
model->parameters(),
|
|
|
|
|
torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
|
|
|
|
|
1e-6));
|
2018-06-25 02:03:39 +00:00
|
|
|
|
|
|
|
|
float running_loss = 1;
|
|
|
|
|
int epoch = 0;
|
|
|
|
|
while (running_loss > 0.1) {
|
|
|
|
|
torch::Tensor loss = getLoss(model, 4);
|
2018-06-26 17:13:14 +00:00
|
|
|
optimizer.zero_grad();
|
2018-06-25 02:03:39 +00:00
|
|
|
loss.backward();
|
2018-06-26 17:13:14 +00:00
|
|
|
optimizer.step();
|
2018-06-25 02:03:39 +00:00
|
|
|
|
2018-08-16 04:09:33 +00:00
|
|
|
running_loss = running_loss * 0.99 + loss.sum().toCFloat() * 0.01;
|
2018-06-25 02:03:39 +00:00
|
|
|
REQUIRE(epoch < 3000);
|
|
|
|
|
epoch++;
|
|
|
|
|
}
|
2018-05-01 01:36:35 +00:00
|
|
|
|
2018-06-25 02:03:39 +00:00
|
|
|
std::stringstream ss;
|
|
|
|
|
torch::save(ss, model);
|
|
|
|
|
torch::load(ss, model2);
|
2018-05-01 01:36:35 +00:00
|
|
|
|
2018-06-25 02:03:39 +00:00
|
|
|
auto loss = getLoss(model2, 100);
|
|
|
|
|
REQUIRE(loss.toCFloat() < 0.1);
|
2018-05-01 01:36:35 +00:00
|
|
|
|
2018-06-30 00:13:34 +00:00
|
|
|
model2->to(torch::kCUDA);
|
2018-06-25 02:03:39 +00:00
|
|
|
ss.clear();
|
|
|
|
|
torch::save(ss, model2);
|
|
|
|
|
torch::load(ss, model3);
|
|
|
|
|
|
|
|
|
|
loss = getLoss(model3, 100);
|
|
|
|
|
REQUIRE(loss.toCFloat() < 0.1);
|
2018-05-01 01:36:35 +00:00
|
|
|
}
|