2018-05-24 19:46:51 +00:00
|
|
|
#include <catch.hpp>
|
|
|
|
|
|
2018-06-26 20:23:16 +00:00
|
|
|
#include <torch/nn/modules.h>
|
2018-05-24 19:46:51 +00:00
|
|
|
#include <torch/nn/modules/linear.h>
|
|
|
|
|
#include <torch/nn/modules/sequential.h>
|
2018-05-30 15:55:34 +00:00
|
|
|
#include <torch/tensor.h>
|
2018-06-28 03:00:53 +00:00
|
|
|
#include <torch/utils.h>
|
2018-05-24 19:46:51 +00:00
|
|
|
|
2018-06-25 02:03:39 +00:00
|
|
|
#include <memory>
|
2018-05-24 19:46:51 +00:00
|
|
|
#include <vector>
|
|
|
|
|
|
2018-07-17 04:43:40 +00:00
|
|
|
#include <test/cpp/api/util.h>
|
|
|
|
|
|
2018-05-24 19:46:51 +00:00
|
|
|
using namespace torch::nn;
|
2018-07-17 04:43:40 +00:00
|
|
|
using namespace torch::test;
|
2018-05-24 19:46:51 +00:00
|
|
|
|
|
|
|
|
using Catch::StartsWith;
|
|
|
|
|
|
|
|
|
|
TEST_CASE("sequential") {
|
2018-06-19 02:45:53 +00:00
|
|
|
SECTION("construction from shared pointer") {
|
2018-06-25 02:03:39 +00:00
|
|
|
struct M : torch::nn::Module {
|
2018-06-19 02:45:53 +00:00
|
|
|
explicit M(int value_) : value(value_) {}
|
|
|
|
|
int value;
|
|
|
|
|
int forward() {
|
|
|
|
|
return value;
|
|
|
|
|
}
|
|
|
|
|
};
|
2018-05-24 19:46:51 +00:00
|
|
|
Sequential sequential(
|
2018-06-19 02:45:53 +00:00
|
|
|
std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3));
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
REQUIRE(sequential->size() == 3);
|
2018-06-19 02:45:53 +00:00
|
|
|
}
|
|
|
|
|
SECTION("construction from concrete type") {
|
2018-06-25 02:03:39 +00:00
|
|
|
struct M : torch::nn::Module {
|
2018-06-19 02:45:53 +00:00
|
|
|
explicit M(int value_) : value(value_) {}
|
|
|
|
|
int value;
|
|
|
|
|
int forward() {
|
|
|
|
|
return value;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
Sequential sequential(M(1), M(2), M(3));
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
REQUIRE(sequential->size() == 3);
|
2018-06-19 02:45:53 +00:00
|
|
|
}
|
|
|
|
|
SECTION("construction from module holders") {
|
2018-06-25 02:03:39 +00:00
|
|
|
struct MImpl : torch::nn::Module {
|
2018-06-19 02:45:53 +00:00
|
|
|
explicit MImpl(int value_) : value(value_) {}
|
|
|
|
|
int forward() {
|
|
|
|
|
return value;
|
|
|
|
|
}
|
|
|
|
|
int value;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct M : torch::nn::ModuleHolder<MImpl> {
|
|
|
|
|
using torch::nn::ModuleHolder<MImpl>::ModuleHolder;
|
|
|
|
|
using torch::nn::ModuleHolder<MImpl>::get;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
Sequential sequential(M(1), M(2), M(3));
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
REQUIRE(sequential->size() == 3);
|
2018-05-24 19:46:51 +00:00
|
|
|
}
|
|
|
|
|
SECTION("push_back") {
|
2018-06-25 02:03:39 +00:00
|
|
|
struct M : torch::nn::Module {
|
2018-06-19 02:45:53 +00:00
|
|
|
explicit M(int value_) : value(value_) {}
|
|
|
|
|
int forward() {
|
|
|
|
|
return value;
|
|
|
|
|
}
|
|
|
|
|
int value;
|
|
|
|
|
};
|
2018-05-24 19:46:51 +00:00
|
|
|
Sequential sequential;
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
REQUIRE(sequential->size() == 0);
|
|
|
|
|
REQUIRE(sequential->is_empty());
|
|
|
|
|
sequential->push_back(Linear(3, 4));
|
|
|
|
|
REQUIRE(sequential->size() == 1);
|
|
|
|
|
sequential->push_back(std::make_shared<M>(1));
|
|
|
|
|
REQUIRE(sequential->size() == 2);
|
|
|
|
|
sequential->push_back(M(2));
|
|
|
|
|
REQUIRE(sequential->size() == 3);
|
2018-05-24 19:46:51 +00:00
|
|
|
}
|
|
|
|
|
SECTION("access") {
|
2018-06-25 02:03:39 +00:00
|
|
|
struct M : torch::nn::Module {
|
2018-06-19 02:45:53 +00:00
|
|
|
explicit M(int value_) : value(value_) {}
|
|
|
|
|
int forward() {
|
|
|
|
|
return value;
|
|
|
|
|
}
|
|
|
|
|
int value;
|
|
|
|
|
};
|
|
|
|
|
std::vector<std::shared_ptr<M>> modules = {
|
|
|
|
|
std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)};
|
2018-05-24 19:46:51 +00:00
|
|
|
|
|
|
|
|
Sequential sequential;
|
|
|
|
|
for (auto& module : modules) {
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
sequential->push_back(module);
|
2018-05-24 19:46:51 +00:00
|
|
|
}
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
REQUIRE(sequential->size() == 3);
|
2018-05-24 19:46:51 +00:00
|
|
|
|
|
|
|
|
SECTION("at()") {
|
|
|
|
|
SECTION("returns the correct module for a given index") {
|
|
|
|
|
for (size_t i = 0; i < modules.size(); ++i) {
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
REQUIRE(&sequential->at<M>(i) == modules[i].get());
|
2018-05-24 19:46:51 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
SECTION("throws for a bad index") {
|
|
|
|
|
REQUIRE_THROWS_WITH(
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
sequential->at<M>(modules.size() + 1),
|
2018-05-24 19:46:51 +00:00
|
|
|
StartsWith("Index out of range"));
|
|
|
|
|
REQUIRE_THROWS_WITH(
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
sequential->at<M>(modules.size() + 1000000),
|
2018-05-24 19:46:51 +00:00
|
|
|
StartsWith("Index out of range"));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SECTION("ptr()") {
|
|
|
|
|
SECTION("returns the correct module for a given index") {
|
|
|
|
|
for (size_t i = 0; i < modules.size(); ++i) {
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
REQUIRE(sequential->ptr(i).get() == modules[i].get());
|
2018-05-24 19:46:51 +00:00
|
|
|
REQUIRE(sequential[i].get() == modules[i].get());
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
REQUIRE(sequential->ptr<M>(i).get() == modules[i].get());
|
2018-05-24 19:46:51 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
SECTION("throws for a bad index") {
|
|
|
|
|
REQUIRE_THROWS_WITH(
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
sequential->ptr(modules.size() + 1),
|
2018-05-24 19:46:51 +00:00
|
|
|
StartsWith("Index out of range"));
|
|
|
|
|
REQUIRE_THROWS_WITH(
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
sequential->ptr(modules.size() + 1000000),
|
2018-05-24 19:46:51 +00:00
|
|
|
StartsWith("Index out of range"));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
SECTION("forward") {
|
|
|
|
|
SECTION("calling forward() on an empty sequential is disallowed") {
|
|
|
|
|
Sequential empty;
|
|
|
|
|
REQUIRE_THROWS_WITH(
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
empty->forward<int>(),
|
2018-05-24 19:46:51 +00:00
|
|
|
StartsWith("Cannot call forward() on an empty Sequential"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SECTION("calling forward() on a non-empty sequential chains correctly") {
|
2018-06-25 02:03:39 +00:00
|
|
|
struct MockModule : torch::nn::Module {
|
2018-05-24 19:46:51 +00:00
|
|
|
explicit MockModule(int value) : expected(value) {}
|
|
|
|
|
int expected;
|
|
|
|
|
int forward(int value) {
|
|
|
|
|
REQUIRE(value == expected);
|
|
|
|
|
return value + 1;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
Sequential sequential(MockModule{1}, MockModule{2}, MockModule{3});
|
|
|
|
|
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
REQUIRE(sequential->forward<int>(1) == 4);
|
2018-05-24 19:46:51 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SECTION("calling forward() with the wrong return type throws") {
|
2018-06-25 02:03:39 +00:00
|
|
|
struct M : public torch::nn::Module {
|
2018-05-24 19:46:51 +00:00
|
|
|
int forward() {
|
|
|
|
|
return 5;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
Sequential sequential(M{});
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
REQUIRE(sequential->forward<int>() == 5);
|
2018-05-24 19:46:51 +00:00
|
|
|
REQUIRE_THROWS_WITH(
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
sequential->forward<float>(),
|
2018-05-24 19:46:51 +00:00
|
|
|
StartsWith("The type of the return value "
|
|
|
|
|
"is int, but you asked for type float"));
|
|
|
|
|
}
|
|
|
|
|
|
2018-06-25 02:03:39 +00:00
|
|
|
SECTION("The return type of forward() defaults to Tensor") {
|
|
|
|
|
struct M : public torch::nn::Module {
|
|
|
|
|
torch::Tensor forward(torch::Tensor v) {
|
2018-05-24 19:46:51 +00:00
|
|
|
return v;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
Sequential sequential(M{});
|
2018-06-27 21:34:06 +00:00
|
|
|
auto variable = torch::ones({3, 3}, torch::requires_grad());
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
REQUIRE(sequential->forward(variable).equal(variable));
|
2018-05-24 19:46:51 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SECTION("returns the last value") {
|
2018-06-28 03:00:53 +00:00
|
|
|
torch::manual_seed(0);
|
2018-06-19 02:45:53 +00:00
|
|
|
Sequential sequential(Linear(10, 3), Linear(3, 5), Linear(5, 100));
|
2018-05-24 19:46:51 +00:00
|
|
|
|
2018-06-27 21:34:06 +00:00
|
|
|
auto x = torch::randn({1000, 10}, torch::requires_grad());
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
auto y = sequential->forward(x);
|
2018-05-24 19:46:51 +00:00
|
|
|
REQUIRE(y.ndimension() == 2);
|
|
|
|
|
REQUIRE(y.size(0) == 1000);
|
|
|
|
|
REQUIRE(y.size(1) == 100);
|
|
|
|
|
}
|
2018-06-26 20:23:16 +00:00
|
|
|
|
|
|
|
|
SECTION("can hold other important modules") {
|
|
|
|
|
Sequential sequential(
|
|
|
|
|
Linear(10, 3),
|
|
|
|
|
Conv2d(1, 2, 3),
|
|
|
|
|
Dropout(0.5),
|
|
|
|
|
BatchNorm(5),
|
|
|
|
|
Embedding(4, 10),
|
|
|
|
|
LSTM(4, 5));
|
|
|
|
|
}
|
2018-06-28 13:30:36 +00:00
|
|
|
|
|
|
|
|
SECTION("converts at::Tensor to torch::Tensor correctly") {
|
|
|
|
|
struct M : torch::nn::Module {
|
|
|
|
|
torch::Tensor forward(torch::Tensor input) {
|
|
|
|
|
return input;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
Sequential sequential(M{});
|
|
|
|
|
torch::Tensor variable = torch::ones(5);
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
REQUIRE(sequential->forward(variable).sum().toCFloat() == 5);
|
2018-06-28 13:30:36 +00:00
|
|
|
|
|
|
|
|
at::Tensor tensor_that_is_actually_a_variable = variable * 2;
|
|
|
|
|
REQUIRE(
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
sequential->forward(tensor_that_is_actually_a_variable)
|
2018-06-28 13:30:36 +00:00
|
|
|
.sum()
|
|
|
|
|
.toCFloat() == 10);
|
|
|
|
|
}
|
2018-07-03 02:34:45 +00:00
|
|
|
SECTION("extend() pushes modules from other Sequential") {
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
struct A : torch::nn::Module {
|
|
|
|
|
int forward(int x) {
|
|
|
|
|
return x;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
struct B : torch::nn::Module {
|
|
|
|
|
int forward(int x) {
|
|
|
|
|
return x;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
struct C : torch::nn::Module {
|
|
|
|
|
int forward(int x) {
|
|
|
|
|
return x;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
struct D : torch::nn::Module {
|
|
|
|
|
int forward(int x) {
|
|
|
|
|
return x;
|
|
|
|
|
}
|
|
|
|
|
};
|
2018-07-03 02:34:45 +00:00
|
|
|
Sequential a(A{}, B{});
|
|
|
|
|
Sequential b(C{}, D{});
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
a->extend(*b);
|
2018-07-03 02:34:45 +00:00
|
|
|
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
REQUIRE(a->size() == 4);
|
2018-07-06 17:55:18 +00:00
|
|
|
REQUIRE(a[0]->as<A>());
|
|
|
|
|
REQUIRE(a[1]->as<B>());
|
|
|
|
|
REQUIRE(a[2]->as<C>());
|
|
|
|
|
REQUIRE(a[3]->as<D>());
|
2018-07-03 02:34:45 +00:00
|
|
|
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
REQUIRE(b->size() == 2);
|
2018-07-06 17:55:18 +00:00
|
|
|
REQUIRE(b[0]->as<C>());
|
|
|
|
|
REQUIRE(b[1]->as<D>());
|
2018-07-03 02:34:45 +00:00
|
|
|
|
|
|
|
|
std::vector<std::shared_ptr<A>> c = {std::make_shared<A>(),
|
|
|
|
|
std::make_shared<A>()};
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
b->extend(c);
|
2018-07-03 02:34:45 +00:00
|
|
|
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
REQUIRE(b->size() == 4);
|
2018-07-06 17:55:18 +00:00
|
|
|
REQUIRE(b[0]->as<C>());
|
|
|
|
|
REQUIRE(b[1]->as<D>());
|
|
|
|
|
REQUIRE(b[2]->as<A>());
|
|
|
|
|
REQUIRE(b[3]->as<A>());
|
2018-07-03 02:34:45 +00:00
|
|
|
}
|
Make Sequential ref-counted (#9151)
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.
This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.
One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.
ezyang ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151
Reviewed By: ezyang
Differential Revision: D8809298
Pulled By: goldsborough
fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
|
|
|
SECTION("has reference semantics") {
|
|
|
|
|
Sequential first(Linear(2, 3), Linear(4, 4), Linear(4, 5));
|
|
|
|
|
Sequential second(first);
|
|
|
|
|
|
|
|
|
|
REQUIRE(first.get() == second.get());
|
|
|
|
|
REQUIRE(first->size() == second->size());
|
|
|
|
|
REQUIRE(std::equal(
|
|
|
|
|
first->begin(),
|
|
|
|
|
first->end(),
|
|
|
|
|
second->begin(),
|
|
|
|
|
[](const AnyModule& first, const AnyModule& second) {
|
|
|
|
|
return &first == &second;
|
|
|
|
|
}));
|
|
|
|
|
}
|
2018-07-17 04:43:40 +00:00
|
|
|
SECTION("Is cloneable") {
|
|
|
|
|
Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
|
|
|
|
|
Sequential clone =
|
2018-08-13 17:11:45 +00:00
|
|
|
std::dynamic_pointer_cast<SequentialImpl>(sequential->clone());
|
2018-07-17 04:43:40 +00:00
|
|
|
REQUIRE(sequential->size() == clone->size());
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < sequential->size(); ++i) {
|
|
|
|
|
// The modules should be the same kind (type).
|
|
|
|
|
REQUIRE(sequential[i]->name() == clone[i]->name());
|
|
|
|
|
// But not pointer-equal (distinct objects).
|
|
|
|
|
REQUIRE(sequential[i] != clone[i]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Verify that the clone is deep, i.e. parameters of modules are cloned too.
|
|
|
|
|
|
|
|
|
|
auto params1 = sequential->parameters();
|
|
|
|
|
auto params2 = clone->parameters();
|
|
|
|
|
REQUIRE(params1.size() == params2.size());
|
|
|
|
|
for (auto& param : params1) {
|
|
|
|
|
REQUIRE(!pointer_equal(param.value, params2[param.key]));
|
2018-07-23 21:49:18 +00:00
|
|
|
REQUIRE(param->device() == params2[param.key].device());
|
2018-07-17 04:43:40 +00:00
|
|
|
REQUIRE(param->allclose(params2[param.key]));
|
|
|
|
|
param->data().add_(2);
|
|
|
|
|
}
|
|
|
|
|
for (auto& param : params1) {
|
|
|
|
|
REQUIRE(!param->allclose(params2[param.key]));
|
|
|
|
|
}
|
|
|
|
|
}
|
2018-05-24 19:46:51 +00:00
|
|
|
}
|
2018-07-23 21:49:18 +00:00
|
|
|
|
|
|
|
|
TEST_CASE("sequential/clone-to-device", "[cuda]") {
|
|
|
|
|
Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
|
|
|
|
|
torch::Device device(torch::kCUDA, 0);
|
|
|
|
|
Sequential clone =
|
2018-08-13 17:11:45 +00:00
|
|
|
std::dynamic_pointer_cast<SequentialImpl>(sequential->clone(device));
|
2018-07-23 21:49:18 +00:00
|
|
|
for (const auto& p : clone->parameters()) {
|
|
|
|
|
REQUIRE(p->device() == device);
|
|
|
|
|
}
|
|
|
|
|
for (const auto& b : clone->buffers()) {
|
|
|
|
|
REQUIRE(b->device() == device);
|
|
|
|
|
}
|
|
|
|
|
}
|