pytorch/test/cpp/api/sequential.cpp

423 lines
12 KiB
C++
Raw Normal View History

#include <gtest/gtest.h>
Re-organize C++ API `torch::nn` folder structure (#26262) Summary: This PR aims to re-organize C++ API `torch::nn` folder structure in the following way: - Every module in `torch/csrc/api/include/torch/nn/modules/` (except `any.h`, `named_any.h`, `modulelist.h`, `sequential.h`, `embedding.h`) has a strictly equivalent Python file in `torch/nn/modules/`. For example: `torch/csrc/api/include/torch/nn/modules/pooling.h` -> `torch/nn/modules/pooling.py` `torch/csrc/api/include/torch/nn/modules/conv.h` -> `torch/nn/modules/conv.py` `torch/csrc/api/include/torch/nn/modules/batchnorm.h` -> `torch/nn/modules/batchnorm.py` `torch/csrc/api/include/torch/nn/modules/sparse.h` -> `torch/nn/modules/sparse.py` - Containers such as `any.h`, `named_any.h`, `modulelist.h`, `sequential.h` are moved into `torch/csrc/api/include/torch/nn/modules/container/`, because their implementations are too long to be combined into one file (like `torch/nn/modules/container.py` in Python API) - `embedding.h` is not renamed to `sparse.h` yet, because we have another work stream that works on API parity for Embedding and EmbeddingBag, and renaming the file would cause conflict. After the embedding API parity work is done, we will rename `embedding.h` to `sparse.h` to match the Python file name, and move the embedding options out to options/ folder. - `torch/csrc/api/include/torch/nn/functional/` is added, and the folder structure mirrors that of `torch/csrc/api/include/torch/nn/modules/`. For example, `torch/csrc/api/include/torch/nn/functional/pooling.h` contains the functions for pooling, which are then used by the pooling modules in `torch/csrc/api/include/torch/nn/modules/pooling.h`. - `torch/csrc/api/include/torch/nn/options/` is added, and the folder structure mirrors that of `torch/csrc/api/include/torch/nn/modules/`. For example, `torch/csrc/api/include/torch/nn/options/pooling.h` contains MaxPoolOptions, which is used by both MaxPool modules in `torch/csrc/api/include/torch/nn/modules/pooling.h`, and max_pool functions in `torch/csrc/api/include/torch/nn/functional/pooling.h`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/26262 Differential Revision: D17422426 Pulled By: yf225 fbshipit-source-id: c413d2a374ba716dac81db31516619bbd879db7f
2019-09-17 17:05:11 +00:00
#include <torch/torch.h>
#include <algorithm>
#include <memory>
#include <vector>
#include <test/cpp/api/support.h>
using namespace torch::nn;
using namespace torch::test;
struct SequentialTest : torch::test::SeedingFixture {};
TEST_F(SequentialTest, CanContainThings) {
Sequential sequential(Linear(3, 4), ReLU(), BatchNorm(3));
}
TEST_F(SequentialTest, ConstructsFromSharedPointer) {
struct M : torch::nn::Module {
explicit M(int value_) : value(value_) {}
int value;
int forward() {
return value;
}
};
Sequential sequential(
std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3));
ASSERT_EQ(sequential->size(), 3);
Sequential sequential_named(modules_ordered_dict({
{"m1", std::make_shared<M>(1)},
{std::string("m2"), std::make_shared<M>(2)},
{"m3", std::make_shared<M>(3)}
}));
ASSERT_EQ(sequential->size(), 3);
}
TEST_F(SequentialTest, ConstructsFromConcreteType) {
static int copy_count;
struct M : torch::nn::Module {
explicit M(int value_) : value(value_) {}
M(const M& other) : torch::nn::Module(other) {
copy_count++;
}
int value;
int forward() {
return value;
}
};
copy_count = 0;
Sequential sequential(M(1), M(2), M(3));
ASSERT_EQ(sequential->size(), 3);
// NOTE: The current implementation expects each module to be copied exactly once,
// which happens when the module is passed into `std::make_shared<T>()`.
// TODO: Find a way to avoid copying, and then delete the copy constructor of `M`.
ASSERT_EQ(copy_count, 3);
copy_count = 0;
Sequential sequential_named(modules_ordered_dict({
{"m1", M(1)},
{std::string("m2"), M(2)},
{"m3", M(3)}
}));
ASSERT_EQ(sequential->size(), 3);
ASSERT_EQ(copy_count, 3);
}
TEST_F(SequentialTest, ConstructsFromModuleHolder) {
struct MImpl : torch::nn::Module {
explicit MImpl(int value_) : value(value_) {}
int forward() {
return value;
}
int value;
};
struct M : torch::nn::ModuleHolder<MImpl> {
using torch::nn::ModuleHolder<MImpl>::ModuleHolder;
using torch::nn::ModuleHolder<MImpl>::get;
};
Sequential sequential(M(1), M(2), M(3));
ASSERT_EQ(sequential->size(), 3);
Sequential sequential_named(modules_ordered_dict({
{"m1", M(1)},
{std::string("m2"), M(2)},
{"m3", M(3)}
}));
ASSERT_EQ(sequential->size(), 3);
}
TEST_F(SequentialTest, PushBackAddsAnElement) {
struct M : torch::nn::Module {
explicit M(int value_) : value(value_) {}
int forward() {
return value;
}
int value;
};
// Test unnamed submodules
Sequential sequential;
ASSERT_EQ(sequential->size(), 0);
ASSERT_TRUE(sequential->is_empty());
sequential->push_back(Linear(3, 4));
ASSERT_EQ(sequential->size(), 1);
sequential->push_back(std::make_shared<M>(1));
ASSERT_EQ(sequential->size(), 2);
sequential->push_back(M(2));
ASSERT_EQ(sequential->size(), 3);
// Mix named and unnamed submodules
Sequential sequential_named;
ASSERT_EQ(sequential_named->size(), 0);
ASSERT_TRUE(sequential_named->is_empty());
sequential_named->push_back(Linear(3, 4));
ASSERT_EQ(sequential_named->size(), 1);
ASSERT_EQ(sequential_named->named_children()[0].key(), "0");
sequential_named->push_back(std::string("linear2"), Linear(3, 4));
ASSERT_EQ(sequential_named->size(), 2);
ASSERT_EQ(sequential_named->named_children()[1].key(), "linear2");
sequential_named->push_back("shared_m1", std::make_shared<M>(1));
ASSERT_EQ(sequential_named->size(), 3);
ASSERT_EQ(sequential_named->named_children()[2].key(), "shared_m1");
sequential_named->push_back(std::make_shared<M>(1));
ASSERT_EQ(sequential_named->size(), 4);
ASSERT_EQ(sequential_named->named_children()[3].key(), "3");
sequential_named->push_back(M(1));
ASSERT_EQ(sequential_named->size(), 5);
ASSERT_EQ(sequential_named->named_children()[4].key(), "4");
sequential_named->push_back(std::string("m2"), M(1));
ASSERT_EQ(sequential_named->size(), 6);
ASSERT_EQ(sequential_named->named_children()[5].key(), "m2");
}
TEST_F(SequentialTest, AccessWithAt) {
struct M : torch::nn::Module {
explicit M(int value_) : value(value_) {}
int forward() {
return value;
}
int value;
};
std::vector<std::shared_ptr<M>> modules = {
std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)};
Sequential sequential;
for (auto& module : modules) {
sequential->push_back(module);
}
ASSERT_EQ(sequential->size(), 3);
// returns the correct module for a given index
for (size_t i = 0; i < modules.size(); ++i) {
ASSERT_EQ(&sequential->at<M>(i), modules[i].get());
}
// throws for a bad index
ASSERT_THROWS_WITH(
sequential->at<M>(modules.size() + 1), "Index out of range");
ASSERT_THROWS_WITH(
sequential->at<M>(modules.size() + 1000000), "Index out of range");
}
TEST_F(SequentialTest, AccessWithPtr) {
struct M : torch::nn::Module {
explicit M(int value_) : value(value_) {}
int forward() {
return value;
}
int value;
};
std::vector<std::shared_ptr<M>> modules = {
std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)};
Sequential sequential;
for (auto& module : modules) {
sequential->push_back(module);
}
ASSERT_EQ(sequential->size(), 3);
// returns the correct module for a given index
for (size_t i = 0; i < modules.size(); ++i) {
ASSERT_EQ(sequential->ptr(i).get(), modules[i].get());
ASSERT_EQ(sequential[i].get(), modules[i].get());
ASSERT_EQ(sequential->ptr<M>(i).get(), modules[i].get());
}
// throws for a bad index
ASSERT_THROWS_WITH(sequential->ptr(modules.size() + 1), "Index out of range");
ASSERT_THROWS_WITH(
sequential->ptr(modules.size() + 1000000), "Index out of range");
}
TEST_F(SequentialTest, CallingForwardOnEmptySequentialIsDisallowed) {
Sequential empty;
ASSERT_THROWS_WITH(
empty->forward<int>(), "Cannot call forward() on an empty Sequential");
}
TEST_F(SequentialTest, CallingForwardChainsCorrectly) {
struct MockModule : torch::nn::Module {
explicit MockModule(int value) : expected(value) {}
int expected;
int forward(int value) {
assert(value == expected);
return value + 1;
}
};
Sequential sequential(MockModule{1}, MockModule{2}, MockModule{3});
ASSERT_EQ(sequential->forward<int>(1), 4);
}
TEST_F(SequentialTest, CallingForwardWithTheWrongReturnTypeThrows) {
struct M : public torch::nn::Module {
int forward() {
return 5;
}
};
Sequential sequential(M{});
ASSERT_EQ(sequential->forward<int>(), 5);
ASSERT_THROWS_WITH(
sequential->forward<float>(),
"The type of the return value is int, but you asked for type float");
}
TEST_F(SequentialTest, TheReturnTypeOfForwardDefaultsToTensor) {
struct M : public torch::nn::Module {
torch::Tensor forward(torch::Tensor v) {
return v;
}
};
Sequential sequential(M{});
auto variable = torch::ones({3, 3}, torch::requires_grad());
ASSERT_TRUE(sequential->forward(variable).equal(variable));
}
TEST_F(SequentialTest, ForwardReturnsTheLastValue) {
torch::manual_seed(0);
Sequential sequential(Linear(10, 3), Linear(3, 5), Linear(5, 100));
auto x = torch::randn({1000, 10}, torch::requires_grad());
auto y = sequential->forward(x);
ASSERT_EQ(y.ndimension(), 2);
ASSERT_EQ(y.size(0), 1000);
ASSERT_EQ(y.size(1), 100);
}
TEST_F(SequentialTest, SanityCheckForHoldingStandardModules) {
Sequential sequential(
Linear(10, 3),
Conv2d(1, 2, 3),
Dropout(0.5),
BatchNorm(5),
Embedding(4, 10),
LSTM(4, 5));
}
TEST_F(SequentialTest, ExtendPushesModulesFromOtherSequential) {
struct A : torch::nn::Module {
int forward(int x) {
return x;
}
};
struct B : torch::nn::Module {
int forward(int x) {
return x;
}
};
struct C : torch::nn::Module {
int forward(int x) {
return x;
}
};
struct D : torch::nn::Module {
int forward(int x) {
return x;
}
};
Sequential a(A{}, B{});
Sequential b(C{}, D{});
a->extend(*b);
ASSERT_EQ(a->size(), 4);
ASSERT_TRUE(a[0]->as<A>());
ASSERT_TRUE(a[1]->as<B>());
ASSERT_TRUE(a[2]->as<C>());
ASSERT_TRUE(a[3]->as<D>());
ASSERT_EQ(b->size(), 2);
ASSERT_TRUE(b[0]->as<C>());
ASSERT_TRUE(b[1]->as<D>());
std::vector<std::shared_ptr<A>> c = {std::make_shared<A>(),
std::make_shared<A>()};
b->extend(c);
ASSERT_EQ(b->size(), 4);
ASSERT_TRUE(b[0]->as<C>());
ASSERT_TRUE(b[1]->as<D>());
ASSERT_TRUE(b[2]->as<A>());
ASSERT_TRUE(b[3]->as<A>());
}
TEST_F(SequentialTest, HasReferenceSemantics) {
Sequential first(Linear(2, 3), Linear(4, 4), Linear(4, 5));
Sequential second(first);
ASSERT_EQ(first.get(), second.get());
ASSERT_EQ(first->size(), second->size());
ASSERT_TRUE(std::equal(
first->begin(),
first->end(),
second->begin(),
[](const AnyModule& first, const AnyModule& second) {
return &first == &second;
}));
}
TEST_F(SequentialTest, IsCloneable) {
Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
Sequential clone =
std::dynamic_pointer_cast<SequentialImpl>(sequential->clone());
ASSERT_EQ(sequential->size(), clone->size());
for (size_t i = 0; i < sequential->size(); ++i) {
// The modules should be the same kind (type).
ASSERT_EQ(sequential[i]->name(), clone[i]->name());
// But not pointer-equal (distinct objects).
ASSERT_NE(sequential[i], clone[i]);
}
// Verify that the clone is deep, i.e. parameters of modules are cloned too.
torch::NoGradGuard no_grad;
Replace cursors with OrderedDict (#13427) Summary: This is a pre-cursor diff to Python <-> C++ frontend integration -- I have a follow-up PR coming for that. This PR changes the C++ frontend module interface to replace the custom "cursor"s I introduced some time ago with `OrderedDict`. I introduced cursors at the time as a convenient way of applying functions and query operations on a modules' parameters, buffers and modules, allowing things like `module.parameters().map(my_func)`. However, I noticed that (1) this functionality is easily implement-able on top of a regular data structure and (2) more importantly, using OrderedDicts is much, much easier for Python integration. This is especially true given that ScriptModule today also uses OrderedDict. Since C++ frontend modules and ScriptModules will soon too share as many implementation details as possible, it is overall the best move to ditch the custom cursor datastructure and pervasively use OrderedDict everywhere. For this I did: 1. Changed the C++ frontend module interface to more closely match the Python one by providing `parameters()`, `named_parameters()` and other methods Python provides. This is very important for the following diff which binds these into Python for inter-op with Python modules. 2. In lieu of the `Cursor::apply()` method I added `nn::Module::apply`. This again is one more unifying step between Python and C++, since Python modules have an apply function too. 3. Deleted all uses of Cursor. 4. Tidied and beefed up the `OrderedDict` class. In particular, I made `OrderedDict::Item` store an `std::pair` under the hood, because that is trivial to bind into Python and saved me a lot of headaches. `key` and `value` become methods instead of fields, which they should have been from the very start anyway because it allows exactly these kinds of changes, as per usual good software engineering principle of encapsulation. 5. Added many tests for the OrderedDict use in `nn::Module`. ebetica ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/13427 Differential Revision: D12894092 Pulled By: goldsborough fbshipit-source-id: 715770c95a9643753a1db26d7f9da9a78619a15d
2018-11-07 18:53:07 +00:00
auto params1 = sequential->named_parameters();
auto params2 = clone->named_parameters();
ASSERT_EQ(params1.size(), params2.size());
for (auto& param : params1) {
Replace cursors with OrderedDict (#13427) Summary: This is a pre-cursor diff to Python <-> C++ frontend integration -- I have a follow-up PR coming for that. This PR changes the C++ frontend module interface to replace the custom "cursor"s I introduced some time ago with `OrderedDict`. I introduced cursors at the time as a convenient way of applying functions and query operations on a modules' parameters, buffers and modules, allowing things like `module.parameters().map(my_func)`. However, I noticed that (1) this functionality is easily implement-able on top of a regular data structure and (2) more importantly, using OrderedDicts is much, much easier for Python integration. This is especially true given that ScriptModule today also uses OrderedDict. Since C++ frontend modules and ScriptModules will soon too share as many implementation details as possible, it is overall the best move to ditch the custom cursor datastructure and pervasively use OrderedDict everywhere. For this I did: 1. Changed the C++ frontend module interface to more closely match the Python one by providing `parameters()`, `named_parameters()` and other methods Python provides. This is very important for the following diff which binds these into Python for inter-op with Python modules. 2. In lieu of the `Cursor::apply()` method I added `nn::Module::apply`. This again is one more unifying step between Python and C++, since Python modules have an apply function too. 3. Deleted all uses of Cursor. 4. Tidied and beefed up the `OrderedDict` class. In particular, I made `OrderedDict::Item` store an `std::pair` under the hood, because that is trivial to bind into Python and saved me a lot of headaches. `key` and `value` become methods instead of fields, which they should have been from the very start anyway because it allows exactly these kinds of changes, as per usual good software engineering principle of encapsulation. 5. Added many tests for the OrderedDict use in `nn::Module`. ebetica ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/13427 Differential Revision: D12894092 Pulled By: goldsborough fbshipit-source-id: 715770c95a9643753a1db26d7f9da9a78619a15d
2018-11-07 18:53:07 +00:00
ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
ASSERT_EQ(param->device(), params2[param.key()].device());
ASSERT_TRUE(param->allclose(params2[param.key()]));
param->add_(2);
}
for (auto& param : params1) {
Replace cursors with OrderedDict (#13427) Summary: This is a pre-cursor diff to Python <-> C++ frontend integration -- I have a follow-up PR coming for that. This PR changes the C++ frontend module interface to replace the custom "cursor"s I introduced some time ago with `OrderedDict`. I introduced cursors at the time as a convenient way of applying functions and query operations on a modules' parameters, buffers and modules, allowing things like `module.parameters().map(my_func)`. However, I noticed that (1) this functionality is easily implement-able on top of a regular data structure and (2) more importantly, using OrderedDicts is much, much easier for Python integration. This is especially true given that ScriptModule today also uses OrderedDict. Since C++ frontend modules and ScriptModules will soon too share as many implementation details as possible, it is overall the best move to ditch the custom cursor datastructure and pervasively use OrderedDict everywhere. For this I did: 1. Changed the C++ frontend module interface to more closely match the Python one by providing `parameters()`, `named_parameters()` and other methods Python provides. This is very important for the following diff which binds these into Python for inter-op with Python modules. 2. In lieu of the `Cursor::apply()` method I added `nn::Module::apply`. This again is one more unifying step between Python and C++, since Python modules have an apply function too. 3. Deleted all uses of Cursor. 4. Tidied and beefed up the `OrderedDict` class. In particular, I made `OrderedDict::Item` store an `std::pair` under the hood, because that is trivial to bind into Python and saved me a lot of headaches. `key` and `value` become methods instead of fields, which they should have been from the very start anyway because it allows exactly these kinds of changes, as per usual good software engineering principle of encapsulation. 5. Added many tests for the OrderedDict use in `nn::Module`. ebetica ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/13427 Differential Revision: D12894092 Pulled By: goldsborough fbshipit-source-id: 715770c95a9643753a1db26d7f9da9a78619a15d
2018-11-07 18:53:07 +00:00
ASSERT_FALSE(param->allclose(params2[param.key()]));
}
}
TEST_F(SequentialTest, RegistersElementsAsSubmodules) {
Sequential sequential(Linear(10, 3), Conv2d(1, 2, 3), FeatureDropout(0.5));
Replace cursors with OrderedDict (#13427) Summary: This is a pre-cursor diff to Python <-> C++ frontend integration -- I have a follow-up PR coming for that. This PR changes the C++ frontend module interface to replace the custom "cursor"s I introduced some time ago with `OrderedDict`. I introduced cursors at the time as a convenient way of applying functions and query operations on a modules' parameters, buffers and modules, allowing things like `module.parameters().map(my_func)`. However, I noticed that (1) this functionality is easily implement-able on top of a regular data structure and (2) more importantly, using OrderedDicts is much, much easier for Python integration. This is especially true given that ScriptModule today also uses OrderedDict. Since C++ frontend modules and ScriptModules will soon too share as many implementation details as possible, it is overall the best move to ditch the custom cursor datastructure and pervasively use OrderedDict everywhere. For this I did: 1. Changed the C++ frontend module interface to more closely match the Python one by providing `parameters()`, `named_parameters()` and other methods Python provides. This is very important for the following diff which binds these into Python for inter-op with Python modules. 2. In lieu of the `Cursor::apply()` method I added `nn::Module::apply`. This again is one more unifying step between Python and C++, since Python modules have an apply function too. 3. Deleted all uses of Cursor. 4. Tidied and beefed up the `OrderedDict` class. In particular, I made `OrderedDict::Item` store an `std::pair` under the hood, because that is trivial to bind into Python and saved me a lot of headaches. `key` and `value` become methods instead of fields, which they should have been from the very start anyway because it allows exactly these kinds of changes, as per usual good software engineering principle of encapsulation. 5. Added many tests for the OrderedDict use in `nn::Module`. ebetica ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/13427 Differential Revision: D12894092 Pulled By: goldsborough fbshipit-source-id: 715770c95a9643753a1db26d7f9da9a78619a15d
2018-11-07 18:53:07 +00:00
auto modules = sequential->children();
ASSERT_TRUE(modules[0]->as<Linear>());
ASSERT_TRUE(modules[1]->as<Conv2d>());
ASSERT_TRUE(modules[2]->as<FeatureDropout>());
}
TEST_F(SequentialTest, CloneToDevice_CUDA) {
Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
torch::Device device(torch::kCUDA, 0);
Sequential clone =
std::dynamic_pointer_cast<SequentialImpl>(sequential->clone(device));
for (const auto& p : clone->parameters()) {
Replace cursors with OrderedDict (#13427) Summary: This is a pre-cursor diff to Python <-> C++ frontend integration -- I have a follow-up PR coming for that. This PR changes the C++ frontend module interface to replace the custom "cursor"s I introduced some time ago with `OrderedDict`. I introduced cursors at the time as a convenient way of applying functions and query operations on a modules' parameters, buffers and modules, allowing things like `module.parameters().map(my_func)`. However, I noticed that (1) this functionality is easily implement-able on top of a regular data structure and (2) more importantly, using OrderedDicts is much, much easier for Python integration. This is especially true given that ScriptModule today also uses OrderedDict. Since C++ frontend modules and ScriptModules will soon too share as many implementation details as possible, it is overall the best move to ditch the custom cursor datastructure and pervasively use OrderedDict everywhere. For this I did: 1. Changed the C++ frontend module interface to more closely match the Python one by providing `parameters()`, `named_parameters()` and other methods Python provides. This is very important for the following diff which binds these into Python for inter-op with Python modules. 2. In lieu of the `Cursor::apply()` method I added `nn::Module::apply`. This again is one more unifying step between Python and C++, since Python modules have an apply function too. 3. Deleted all uses of Cursor. 4. Tidied and beefed up the `OrderedDict` class. In particular, I made `OrderedDict::Item` store an `std::pair` under the hood, because that is trivial to bind into Python and saved me a lot of headaches. `key` and `value` become methods instead of fields, which they should have been from the very start anyway because it allows exactly these kinds of changes, as per usual good software engineering principle of encapsulation. 5. Added many tests for the OrderedDict use in `nn::Module`. ebetica ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/13427 Differential Revision: D12894092 Pulled By: goldsborough fbshipit-source-id: 715770c95a9643753a1db26d7f9da9a78619a15d
2018-11-07 18:53:07 +00:00
ASSERT_EQ(p.device(), device);
}
for (const auto& b : clone->buffers()) {
Replace cursors with OrderedDict (#13427) Summary: This is a pre-cursor diff to Python <-> C++ frontend integration -- I have a follow-up PR coming for that. This PR changes the C++ frontend module interface to replace the custom "cursor"s I introduced some time ago with `OrderedDict`. I introduced cursors at the time as a convenient way of applying functions and query operations on a modules' parameters, buffers and modules, allowing things like `module.parameters().map(my_func)`. However, I noticed that (1) this functionality is easily implement-able on top of a regular data structure and (2) more importantly, using OrderedDicts is much, much easier for Python integration. This is especially true given that ScriptModule today also uses OrderedDict. Since C++ frontend modules and ScriptModules will soon too share as many implementation details as possible, it is overall the best move to ditch the custom cursor datastructure and pervasively use OrderedDict everywhere. For this I did: 1. Changed the C++ frontend module interface to more closely match the Python one by providing `parameters()`, `named_parameters()` and other methods Python provides. This is very important for the following diff which binds these into Python for inter-op with Python modules. 2. In lieu of the `Cursor::apply()` method I added `nn::Module::apply`. This again is one more unifying step between Python and C++, since Python modules have an apply function too. 3. Deleted all uses of Cursor. 4. Tidied and beefed up the `OrderedDict` class. In particular, I made `OrderedDict::Item` store an `std::pair` under the hood, because that is trivial to bind into Python and saved me a lot of headaches. `key` and `value` become methods instead of fields, which they should have been from the very start anyway because it allows exactly these kinds of changes, as per usual good software engineering principle of encapsulation. 5. Added many tests for the OrderedDict use in `nn::Module`. ebetica ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/13427 Differential Revision: D12894092 Pulled By: goldsborough fbshipit-source-id: 715770c95a9643753a1db26d7f9da9a78619a15d
2018-11-07 18:53:07 +00:00
ASSERT_EQ(b.device(), device);
}
}
TEST_F(SequentialTest, PrettyPrintSequential) {
Sequential sequential(
Linear(10, 3),
Conv2d(1, 2, 3),
Dropout(0.5),
BatchNorm(5),
Embedding(4, 10),
LSTM(4, 5));
ASSERT_EQ(
c10::str(sequential),
"torch::nn::Sequential(\n"
" (0): torch::nn::Linear(in=10, out=3, with_bias=true)\n"
" (1): torch::nn::Conv2d(input_channels=1, output_channels=2, kernel_size=[3, 3], stride=[1, 1])\n"
" (2): torch::nn::Dropout(rate=0.5)\n"
" (3): torch::nn::BatchNorm(features=5, eps=1e-05, momentum=0.1, affine=true, stateful=true)\n"
" (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
" (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
")");
Sequential sequential_named(modules_ordered_dict({
{"linear", Linear(10, 3)},
{"conv2d", Conv2d(1, 2, 3)},
{"dropout", Dropout(0.5)},
{"batchnorm", BatchNorm(5)},
{"embedding", Embedding(4, 10)},
{"lstm", LSTM(4, 5)}
}));
ASSERT_EQ(
c10::str(sequential_named),
"torch::nn::Sequential(\n"
" (linear): torch::nn::Linear(in=10, out=3, with_bias=true)\n"
" (conv2d): torch::nn::Conv2d(input_channels=1, output_channels=2, kernel_size=[3, 3], stride=[1, 1])\n"
" (dropout): torch::nn::Dropout(rate=0.5)\n"
" (batchnorm): torch::nn::BatchNorm(features=5, eps=1e-05, momentum=0.1, affine=true, stateful=true)\n"
" (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
" (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
")");
}