2018-09-22 04:12:37 +00:00
|
|
|
#include <gtest/gtest.h>
|
2018-05-09 21:01:19 +00:00
|
|
|
|
2018-06-25 02:03:39 +00:00
|
|
|
#include <torch/nn/module.h>
|
|
|
|
|
#include <torch/nn/modules/linear.h>
|
|
|
|
|
#include <torch/nn/modules/rnn.h>
|
|
|
|
|
#include <torch/tensor.h>
|
2018-06-28 03:00:53 +00:00
|
|
|
#include <torch/utils.h>
|
2018-05-09 21:01:19 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
#include <test/cpp/api/support.h>
|
2018-07-17 04:43:40 +00:00
|
|
|
|
2018-05-09 21:01:19 +00:00
|
|
|
using namespace torch::nn;
|
2018-07-17 04:43:40 +00:00
|
|
|
using namespace torch::test;
|
2018-05-24 19:46:51 +00:00
|
|
|
|
2018-06-25 02:03:39 +00:00
|
|
|
struct AGIUnit : torch::nn::Module {};
|
2018-05-10 15:52:38 +00:00
|
|
|
|
|
|
|
|
namespace test {
|
2018-06-25 02:03:39 +00:00
|
|
|
struct AGIUnit : torch::nn::Module {};
|
|
|
|
|
struct AGIUnit2 : torch::nn::Module {
|
|
|
|
|
AGIUnit2() : torch::nn::Module("Foo") {}
|
2018-05-10 15:52:38 +00:00
|
|
|
};
|
|
|
|
|
} // namespace test
|
|
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
struct ModuleTest : torch::test::SeedingFixture {};
|
|
|
|
|
|
|
|
|
|
TEST_F(ModuleTest, CanEnableAndDisableTrainingMode) {
|
2018-06-19 02:45:53 +00:00
|
|
|
Linear module(3, 4);
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_TRUE(module->is_training());
|
|
|
|
|
|
|
|
|
|
module->eval();
|
|
|
|
|
ASSERT_FALSE(module->is_training());
|
|
|
|
|
|
|
|
|
|
module->train();
|
|
|
|
|
ASSERT_TRUE(module->is_training());
|
2018-05-09 21:01:19 +00:00
|
|
|
}
|
|
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST_F(ModuleTest, ZeroGrad) {
|
2018-06-19 02:45:53 +00:00
|
|
|
Linear module(3, 4);
|
2018-06-27 21:34:06 +00:00
|
|
|
auto weight = torch::ones({8, 3}, torch::requires_grad());
|
2018-06-26 20:23:16 +00:00
|
|
|
auto loss = module->forward(weight).sum();
|
2018-05-25 00:31:41 +00:00
|
|
|
loss.backward();
|
2018-05-24 05:11:32 +00:00
|
|
|
for (auto& parameter : module->parameters()) {
|
2018-06-25 02:03:39 +00:00
|
|
|
auto grad = parameter->grad();
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_TRUE(grad.defined());
|
Remove caffe2::Tensor::capacity_nbytes, at::Tensor::to##name##Data, (#11876)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11876
Modern C++ api instead of macros, item() is aligned with Python frontend. caffe2::Tensor::capacity_nbytes is effecitvely unused and confusing w.r.t. caffe2::Tensor::nbytes().
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCComplexDouble "item<std::complex<double>>"
codemod -d tc --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>"
Reviewed By: ezyang
Differential Revision: D9948572
fbshipit-source-id: 70c9f5390d92b82c85fdd5f8a5aebca338ab413c
2018-09-24 17:39:10 +00:00
|
|
|
ASSERT_NE(grad.sum().item<float>(), 0);
|
2018-05-09 21:01:19 +00:00
|
|
|
}
|
2018-05-24 05:11:32 +00:00
|
|
|
module->zero_grad();
|
|
|
|
|
for (auto& parameter : module->parameters()) {
|
2018-06-25 02:03:39 +00:00
|
|
|
auto grad = parameter->grad();
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_TRUE(grad.defined());
|
Remove caffe2::Tensor::capacity_nbytes, at::Tensor::to##name##Data, (#11876)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11876
Modern C++ api instead of macros, item() is aligned with Python frontend. caffe2::Tensor::capacity_nbytes is effecitvely unused and confusing w.r.t. caffe2::Tensor::nbytes().
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCComplexDouble "item<std::complex<double>>"
codemod -d tc --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>"
Reviewed By: ezyang
Differential Revision: D9948572
fbshipit-source-id: 70c9f5390d92b82c85fdd5f8a5aebca338ab413c
2018-09-24 17:39:10 +00:00
|
|
|
ASSERT_EQ(grad.sum().item<float>(), 0);
|
2018-05-09 21:01:19 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST_F(ModuleTest, ZeroGradWithUndefined) {
|
2018-06-28 17:20:58 +00:00
|
|
|
struct TestModule : torch::nn::Module {
|
|
|
|
|
TestModule() {
|
|
|
|
|
x = register_parameter("x", torch::ones(5, at::requires_grad()));
|
|
|
|
|
y = register_parameter("y", torch::ones(5, at::requires_grad()));
|
|
|
|
|
}
|
|
|
|
|
torch::Tensor x, y;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
TestModule module;
|
|
|
|
|
auto z = module.x * 2;
|
|
|
|
|
z.sum().backward();
|
|
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_TRUE(module.x.grad().defined());
|
|
|
|
|
ASSERT_FALSE(module.y.grad().defined());
|
2018-06-28 17:20:58 +00:00
|
|
|
|
|
|
|
|
module.zero_grad();
|
|
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_TRUE(module.x.grad().defined());
|
|
|
|
|
ASSERT_FALSE(module.y.grad().defined());
|
2018-06-28 17:20:58 +00:00
|
|
|
|
Remove caffe2::Tensor::capacity_nbytes, at::Tensor::to##name##Data, (#11876)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11876
Modern C++ api instead of macros, item() is aligned with Python frontend. caffe2::Tensor::capacity_nbytes is effecitvely unused and confusing w.r.t. caffe2::Tensor::nbytes().
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>"
codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>"
codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCComplexDouble "item<std::complex<double>>"
codemod -d tc --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>"
Reviewed By: ezyang
Differential Revision: D9948572
fbshipit-source-id: 70c9f5390d92b82c85fdd5f8a5aebca338ab413c
2018-09-24 17:39:10 +00:00
|
|
|
ASSERT_EQ(module.x.grad().sum().item<float>(), 0);
|
2018-06-28 17:20:58 +00:00
|
|
|
}
|
|
|
|
|
|
2018-10-25 20:50:06 +00:00
|
|
|
TEST_F(ModuleTest, RegisterModuleThrowsForEmptyOrDottedName) {
|
|
|
|
|
struct TestModel : public torch::nn::Module {
|
|
|
|
|
using torch::nn::Module::register_module;
|
|
|
|
|
};
|
|
|
|
|
ASSERT_THROWS_WITH(
|
|
|
|
|
TestModel{}.register_module("name.with.dot", torch::nn::Linear(3, 4)),
|
|
|
|
|
"Submodule name must not contain a dot (got 'name.with.dot')");
|
|
|
|
|
ASSERT_THROWS_WITH(
|
|
|
|
|
TestModel{}.register_module("", torch::nn::Linear(3, 4)),
|
|
|
|
|
"Submodule name must not be empty");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(ModuleTest, RegisterModuleThrowsForDuplicateModuleName) {
|
|
|
|
|
struct TestModel : public torch::nn::Module {
|
|
|
|
|
using torch::nn::Module::register_module;
|
|
|
|
|
};
|
|
|
|
|
TestModel model;
|
|
|
|
|
model.register_module("linear", torch::nn::Linear(3, 4));
|
|
|
|
|
ASSERT_THROWS_WITH(
|
|
|
|
|
model.register_module("linear", torch::nn::Linear(3, 4)),
|
|
|
|
|
"Submodule 'linear' already defined");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(ModuleTest, RegisterParameterThrowsForEmptyOrDottedName) {
|
|
|
|
|
struct TestModel : public torch::nn::Module {
|
|
|
|
|
using torch::nn::Module::register_parameter;
|
|
|
|
|
};
|
|
|
|
|
ASSERT_THROWS_WITH(
|
|
|
|
|
TestModel{}.register_parameter("name.with.dot", torch::ones(5)),
|
|
|
|
|
"Parameter name must not contain a dot (got 'name.with.dot')");
|
|
|
|
|
ASSERT_THROWS_WITH(
|
|
|
|
|
TestModel{}.register_parameter("", torch::ones(5)),
|
|
|
|
|
"Parameter name must not be empty");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(ModuleTest, RegisterParameterThrowsForDuplicateModuleName) {
|
|
|
|
|
struct TestModel : public torch::nn::Module {
|
|
|
|
|
using torch::nn::Module::register_parameter;
|
|
|
|
|
};
|
|
|
|
|
TestModel model;
|
|
|
|
|
model.register_parameter("p", torch::ones(5));
|
|
|
|
|
ASSERT_THROWS_WITH(
|
|
|
|
|
model.register_parameter("p", torch::ones(5)),
|
|
|
|
|
"Parameter 'p' already defined");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(ModuleTest, RegisterBufferThrowsForEmptyOrDottedName) {
|
|
|
|
|
struct TestModel : public torch::nn::Module {
|
|
|
|
|
using torch::nn::Module::register_buffer;
|
|
|
|
|
};
|
|
|
|
|
ASSERT_THROWS_WITH(
|
|
|
|
|
TestModel{}.register_buffer("name.with.dot", torch::ones(5)),
|
|
|
|
|
"Buffer name must not contain a dot (got 'name.with.dot')");
|
|
|
|
|
ASSERT_THROWS_WITH(
|
|
|
|
|
TestModel{}.register_buffer("", torch::ones(5)),
|
|
|
|
|
"Buffer name must not be empty");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(ModuleTest, RegisterBufferThrowsForDuplicateModuleName) {
|
|
|
|
|
struct TestModel : public torch::nn::Module {
|
|
|
|
|
using torch::nn::Module::register_buffer;
|
|
|
|
|
};
|
|
|
|
|
TestModel model;
|
|
|
|
|
model.register_buffer("p", torch::ones(5));
|
|
|
|
|
ASSERT_THROWS_WITH(
|
|
|
|
|
model.register_buffer("p", torch::ones(5)), "Buffer 'p' already defined");
|
|
|
|
|
}
|
|
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST_F(ModuleTest, CanGetName) {
|
2018-05-12 01:56:53 +00:00
|
|
|
// CHECK instead of REQUIRE because demangling may fail.
|
2018-05-10 15:52:38 +00:00
|
|
|
AGIUnit agi;
|
|
|
|
|
// Call it twice just to make sure there are no bugs in the lazy
|
|
|
|
|
// initialization semantics.
|
2018-09-22 04:12:37 +00:00
|
|
|
EXPECT_TRUE(agi.name() == "AGIUnit");
|
|
|
|
|
EXPECT_TRUE(agi.name() == "AGIUnit");
|
|
|
|
|
EXPECT_TRUE(test::AGIUnit().name() == "test::AGIUnit");
|
|
|
|
|
EXPECT_TRUE(test::AGIUnit2().name() == "Foo");
|
2018-05-10 15:52:38 +00:00
|
|
|
}
|
|
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST_F(ModuleTest, TestAsCastsModulesCorrectly) {
|
2018-06-28 22:56:42 +00:00
|
|
|
Linear module(3, 4);
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_EQ(module->as<Linear>(), module.get());
|
|
|
|
|
ASSERT_EQ(module->as<LinearImpl>(), module.get());
|
|
|
|
|
ASSERT_EQ(module->as<Module>(), module.get());
|
|
|
|
|
ASSERT_EQ(module->as<AGIUnit>(), nullptr);
|
2018-07-06 17:55:18 +00:00
|
|
|
|
|
|
|
|
std::shared_ptr<Module> raw = module.ptr();
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_EQ(raw->as<Linear>(), module.get());
|
|
|
|
|
ASSERT_EQ(raw->as<LinearImpl>(), module.get());
|
|
|
|
|
ASSERT_EQ(raw->as<Module>(), module.get());
|
|
|
|
|
ASSERT_EQ(raw->as<AGIUnit>(), nullptr);
|
2018-07-06 17:55:18 +00:00
|
|
|
|
|
|
|
|
Module& raw_ref = *raw.get();
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_EQ(raw_ref.as<Linear>(), module.get());
|
|
|
|
|
ASSERT_EQ(raw_ref.as<LinearImpl>(), module.get());
|
|
|
|
|
ASSERT_EQ(raw_ref.as<Module>(), module.get());
|
|
|
|
|
ASSERT_EQ(raw_ref.as<AGIUnit>(), nullptr);
|
2018-07-06 17:55:18 +00:00
|
|
|
if (auto* linear = raw_ref.as<Linear>()) {
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_EQ(linear->weight.ndimension(), 2);
|
2018-07-06 17:55:18 +00:00
|
|
|
}
|
2018-06-28 22:56:42 +00:00
|
|
|
|
|
|
|
|
AGIUnit unit;
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_EQ(unit.as<Linear>(), nullptr);
|
|
|
|
|
ASSERT_EQ(unit.as<LinearImpl>(), nullptr);
|
|
|
|
|
ASSERT_EQ(unit.as<AGIUnit>(), &unit);
|
2018-06-28 22:56:42 +00:00
|
|
|
}
|
|
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST_F(ModuleTest, Conversion_MultiCUDA) {
|
2018-06-26 04:11:49 +00:00
|
|
|
Linear module(128, 64);
|
2018-09-22 04:12:37 +00:00
|
|
|
for (auto& parameter : module->parameters()) {
|
|
|
|
|
ASSERT_EQ(parameter->device(), torch::Device(torch::kCPU));
|
|
|
|
|
ASSERT_EQ(parameter->dtype(), torch::kFloat32);
|
2018-05-09 21:01:19 +00:00
|
|
|
}
|
2018-09-22 04:12:37 +00:00
|
|
|
{
|
2018-06-27 21:34:06 +00:00
|
|
|
module->to({torch::kCUDA, 0});
|
2018-05-24 05:11:32 +00:00
|
|
|
for (auto& parameter : module->parameters()) {
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_EQ(parameter->device().type(), torch::Device::Type::CUDA);
|
|
|
|
|
ASSERT_EQ(parameter->device().index(), 0);
|
2018-06-26 04:11:49 +00:00
|
|
|
}
|
2018-06-30 00:13:34 +00:00
|
|
|
module->to({at::kCUDA, 1});
|
2018-06-26 04:11:49 +00:00
|
|
|
for (auto& parameter : module->parameters()) {
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_EQ(parameter->device().type(), torch::Device::Type::CUDA);
|
|
|
|
|
ASSERT_EQ(parameter->device().index(), 1);
|
2018-05-09 21:01:19 +00:00
|
|
|
}
|
|
|
|
|
}
|
2018-09-22 04:12:37 +00:00
|
|
|
{
|
2018-06-27 21:34:06 +00:00
|
|
|
module->to(torch::Device(torch::kCPU));
|
2018-05-24 05:11:32 +00:00
|
|
|
for (auto& parameter : module->parameters()) {
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_EQ(parameter->device().type(), torch::Device::Type::CPU);
|
2018-05-09 21:01:19 +00:00
|
|
|
}
|
|
|
|
|
}
|
2018-09-22 04:12:37 +00:00
|
|
|
{
|
2018-06-19 19:40:58 +00:00
|
|
|
module->to(torch::kInt32);
|
2018-05-24 05:11:32 +00:00
|
|
|
for (auto& parameter : module->parameters()) {
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_EQ(parameter->dtype(), torch::kInt32);
|
2018-05-09 21:01:19 +00:00
|
|
|
}
|
|
|
|
|
}
|
2018-09-22 04:12:37 +00:00
|
|
|
{
|
2018-06-19 19:40:58 +00:00
|
|
|
module->to(torch::kFloat64);
|
2018-05-24 05:11:32 +00:00
|
|
|
for (auto& parameter : module->parameters()) {
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_EQ(parameter->dtype(), torch::kFloat64);
|
2018-05-09 21:01:19 +00:00
|
|
|
}
|
|
|
|
|
}
|
2018-09-22 04:12:37 +00:00
|
|
|
{
|
2018-06-27 21:34:06 +00:00
|
|
|
module->to(torch::Device(torch::kCUDA, 1), torch::kUInt8);
|
2018-06-26 04:11:49 +00:00
|
|
|
for (auto& parameter : module->parameters()) {
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_EQ(parameter->device().type(), torch::Device::Type::CUDA);
|
|
|
|
|
ASSERT_EQ(parameter->device().index(), 1);
|
2018-06-26 04:11:49 +00:00
|
|
|
}
|
2018-05-24 05:11:32 +00:00
|
|
|
for (auto& parameter : module->parameters()) {
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_EQ(parameter->dtype(), torch::kUInt8);
|
2018-05-09 21:01:19 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
2018-05-10 07:49:29 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST_F(ModuleTest, CallingCloneOnModuleThatDoesNotOverrideCloneThrows) {
|
|
|
|
|
struct UnCloneable : Module {};
|
|
|
|
|
UnCloneable module;
|
|
|
|
|
ASSERT_THROWS_WITH(module.clone(), "clone() has not been implemented");
|
|
|
|
|
}
|
2018-05-10 07:49:29 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST_F(ModuleTest, CallingCloneOnModuleThatDoesOverrideCloneDoesNotThrow) {
|
|
|
|
|
struct Cloneable : Module {
|
|
|
|
|
std::shared_ptr<Module> clone(
|
2018-10-26 07:06:24 +00:00
|
|
|
torch::optional<torch::Device> device = torch::nullopt) const override {
|
2018-09-22 04:12:37 +00:00
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
Cloneable module;
|
|
|
|
|
ASSERT_NO_THROW({ module.clone(); });
|
|
|
|
|
}
|
2018-05-15 01:24:58 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST_F(ModuleTest, CloneCreatesDistinctParameters) {
|
|
|
|
|
struct TestModule : public Cloneable<TestModule> {
|
|
|
|
|
TestModule() {
|
|
|
|
|
reset();
|
|
|
|
|
}
|
|
|
|
|
void reset() override {
|
|
|
|
|
l1 = register_module("l1", Linear(10, 3));
|
|
|
|
|
l2 = register_module("l2", Linear(3, 5));
|
|
|
|
|
l3 = register_module("l3", Linear(5, 100));
|
|
|
|
|
buffer = register_buffer("buf", torch::ones({2, 2}));
|
|
|
|
|
}
|
2018-05-15 01:24:58 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
|
|
|
|
|
torch::Tensor buffer;
|
|
|
|
|
};
|
2018-05-15 01:24:58 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
auto module = std::make_shared<TestModule>();
|
2018-05-15 01:24:58 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
torch::NoGradGuard no_grad;
|
2018-08-16 04:09:33 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
auto module2 = module->clone();
|
|
|
|
|
auto params1 = module->parameters();
|
|
|
|
|
auto params2 = module2->parameters();
|
|
|
|
|
ASSERT_EQ(params1.size(), 6);
|
|
|
|
|
ASSERT_EQ(params2.size(), 6);
|
|
|
|
|
for (auto& param : params1) {
|
|
|
|
|
ASSERT_FALSE(pointer_equal(param.value, params2[param.key]));
|
|
|
|
|
ASSERT_TRUE(param->allclose(params2[param.key]));
|
|
|
|
|
param->add_(2);
|
|
|
|
|
}
|
|
|
|
|
for (auto& param : params1) {
|
|
|
|
|
ASSERT_FALSE(param->allclose(params2[param.key]));
|
|
|
|
|
}
|
2018-06-26 04:11:49 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
auto buffers1 = module->buffers();
|
|
|
|
|
auto buffers2 = module2->buffers();
|
|
|
|
|
ASSERT_EQ(buffers1.size(), 1);
|
|
|
|
|
ASSERT_EQ(buffers2.size(), 1);
|
|
|
|
|
for (auto& buffer : buffers1) {
|
|
|
|
|
ASSERT_FALSE(pointer_equal(buffer.value, buffers2[buffer.key]));
|
|
|
|
|
ASSERT_TRUE(buffer->allclose(buffers2[buffer.key]));
|
|
|
|
|
buffer->add_(2);
|
|
|
|
|
}
|
|
|
|
|
for (auto& buffer : buffers1) {
|
|
|
|
|
ASSERT_FALSE(buffer->allclose(buffers2[buffer.key]));
|
2018-05-15 01:24:58 +00:00
|
|
|
}
|
2018-09-22 04:12:37 +00:00
|
|
|
}
|
2018-05-17 21:10:15 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST_F(ModuleTest, ClonePreservesExternalReferences) {
|
|
|
|
|
struct TestModule : public Cloneable<TestModule> {
|
|
|
|
|
TestModule() {
|
|
|
|
|
reset();
|
2018-08-16 04:09:33 +00:00
|
|
|
}
|
2018-09-22 04:12:37 +00:00
|
|
|
void reset() override {
|
|
|
|
|
weight = register_parameter("weight", torch::ones({4, 4}));
|
|
|
|
|
}
|
|
|
|
|
torch::Tensor weight;
|
|
|
|
|
};
|
|
|
|
|
auto module = std::make_shared<TestModule>();
|
|
|
|
|
{
|
|
|
|
|
torch::NoGradGuard no_grad;
|
|
|
|
|
module->weight += 1;
|
2018-05-24 05:11:32 +00:00
|
|
|
}
|
2018-09-22 04:12:37 +00:00
|
|
|
ASSERT_TRUE(pointer_equal(module->weight, module->parameters()["weight"]));
|
|
|
|
|
ASSERT_TRUE(module->weight.allclose(module->parameters()["weight"]));
|
|
|
|
|
|
|
|
|
|
auto module2 = std::dynamic_pointer_cast<TestModule>(
|
|
|
|
|
std::shared_ptr<Module>(module->clone()));
|
|
|
|
|
ASSERT_FALSE(pointer_equal(module2->weight, module->weight));
|
|
|
|
|
ASSERT_TRUE(pointer_equal(module2->weight, module2->parameters()["weight"]));
|
|
|
|
|
ASSERT_TRUE(module2->weight.allclose(module2->parameters()["weight"]));
|
|
|
|
|
ASSERT_TRUE(module2->weight.allclose(module->weight));
|
|
|
|
|
ASSERT_FALSE(pointer_equal(module2->weight, module->parameters()["weight"]));
|
|
|
|
|
}
|
2018-05-17 21:10:15 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST_F(ModuleTest, CloneCopiesTheValuesOfVariablesOfSubmodules) {
|
|
|
|
|
struct TestModule : public Cloneable<TestModule> {
|
|
|
|
|
TestModule() {
|
|
|
|
|
reset();
|
|
|
|
|
}
|
|
|
|
|
void reset() override {
|
|
|
|
|
weight = register_parameter("weight", torch::ones({4, 4}));
|
2018-08-16 04:09:33 +00:00
|
|
|
}
|
2018-05-24 05:11:32 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
torch::Tensor weight;
|
|
|
|
|
int value = 0;
|
|
|
|
|
};
|
|
|
|
|
struct NestedModule : public Cloneable<NestedModule> {
|
|
|
|
|
NestedModule() {
|
|
|
|
|
reset();
|
|
|
|
|
}
|
|
|
|
|
void reset() override {
|
|
|
|
|
module = register_module("module", std::make_shared<TestModule>());
|
|
|
|
|
}
|
|
|
|
|
std::shared_ptr<TestModule> module;
|
|
|
|
|
};
|
2018-05-17 21:10:15 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
auto a = std::make_shared<NestedModule>();
|
|
|
|
|
{
|
|
|
|
|
torch::NoGradGuard no_grad;
|
|
|
|
|
a->module->weight += 1;
|
|
|
|
|
a->module->value = 123;
|
2018-05-22 00:59:21 +00:00
|
|
|
}
|
2018-09-22 04:12:37 +00:00
|
|
|
|
|
|
|
|
auto b = std::dynamic_pointer_cast<NestedModule>(a->clone());
|
|
|
|
|
|
|
|
|
|
ASSERT_FALSE(pointer_equal(b->module->weight, a->module->weight));
|
|
|
|
|
ASSERT_TRUE(
|
|
|
|
|
pointer_equal(b->module->weight, b->module->parameters()["weight"]));
|
|
|
|
|
ASSERT_TRUE(b->module->parameters()["weight"].allclose(a->module->weight));
|
|
|
|
|
ASSERT_TRUE(b->module->weight.allclose(a->module->weight));
|
|
|
|
|
ASSERT_EQ(b->module->value, a->module->value);
|
2018-05-22 00:59:21 +00:00
|
|
|
}
|
|
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST_F(ModuleTest, CloneToDevicePreservesTheDeviceOfParameters_CUDA) {
|
2018-07-23 21:49:18 +00:00
|
|
|
struct TestModule : public Cloneable<TestModule> {
|
|
|
|
|
TestModule() {
|
|
|
|
|
reset();
|
|
|
|
|
}
|
|
|
|
|
void reset() override {
|
|
|
|
|
l1 = register_module("l1", Linear(10, 3));
|
|
|
|
|
l2 = register_module("l2", Linear(3, 5));
|
|
|
|
|
l3 = register_module("l3", Linear(5, 100));
|
|
|
|
|
buffer = register_buffer("buf", torch::ones({2, 2}));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
|
|
|
|
|
torch::Tensor buffer;
|
|
|
|
|
};
|
|
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
TestModule m;
|
|
|
|
|
torch::Device device(torch::kCUDA, 0);
|
2018-07-23 21:49:18 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
m.to(device);
|
2018-07-23 21:49:18 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
auto clone = m.clone();
|
|
|
|
|
for (const auto& parameter : clone->parameters()) {
|
|
|
|
|
ASSERT_EQ(parameter->device().type(), device.type());
|
|
|
|
|
ASSERT_EQ(parameter->device().index(), device.index());
|
2018-07-23 21:49:18 +00:00
|
|
|
}
|
2018-09-22 04:12:37 +00:00
|
|
|
for (const auto& buffer : clone->buffers()) {
|
|
|
|
|
ASSERT_EQ(buffer->device().type(), device.type());
|
|
|
|
|
ASSERT_EQ(buffer->device().index(), device.index());
|
2018-07-23 21:49:18 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST_F(ModuleTest, CloningToAParticularDevicePlacesAllParametersThere_CUDA) {
|
|
|
|
|
struct TestModule : public Cloneable<TestModule> {
|
2018-05-22 00:59:21 +00:00
|
|
|
TestModule() {
|
2018-09-22 04:12:37 +00:00
|
|
|
reset();
|
|
|
|
|
}
|
|
|
|
|
void reset() override {
|
|
|
|
|
l1 = register_module("l1", Linear(10, 3));
|
|
|
|
|
l2 = register_module("l2", Linear(3, 5));
|
|
|
|
|
l3 = register_module("l3", Linear(5, 100));
|
|
|
|
|
buffer = register_buffer("buf", torch::ones({2, 2}));
|
2018-05-22 00:59:21 +00:00
|
|
|
}
|
|
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
|
|
|
|
|
torch::Tensor buffer;
|
2018-05-22 00:59:21 +00:00
|
|
|
};
|
|
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
TestModule m;
|
|
|
|
|
torch::Device device(torch::kCUDA, 1);
|
|
|
|
|
// everything is on CPU here
|
|
|
|
|
auto clone = m.clone(device);
|
|
|
|
|
for (const auto& parameter : clone->parameters()) {
|
|
|
|
|
ASSERT_EQ(parameter->device().type(), device.type());
|
|
|
|
|
ASSERT_EQ(parameter->device().index(), device.index());
|
2018-05-22 00:59:21 +00:00
|
|
|
}
|
2018-09-22 04:12:37 +00:00
|
|
|
for (const auto& buffer : clone->buffers()) {
|
|
|
|
|
ASSERT_EQ(buffer->device().type(), device.type());
|
|
|
|
|
ASSERT_EQ(buffer->device().index(), device.index());
|
2018-05-17 21:10:15 +00:00
|
|
|
}
|
2018-05-10 07:49:29 +00:00
|
|
|
}
|
2018-06-26 04:11:49 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
struct ParameterTestModule : Module {
|
|
|
|
|
ParameterTestModule() {
|
|
|
|
|
a = register_parameter("a", torch::zeros({2, 2}));
|
|
|
|
|
b = register_parameter("b", torch::ones({2, 2}));
|
|
|
|
|
c = register_parameter("c", torch::ones({2, 2}) * 2);
|
|
|
|
|
}
|
2018-06-26 04:11:49 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
torch::Tensor a, b, c;
|
|
|
|
|
};
|
2018-06-26 04:11:49 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST_F(ModuleTest, HasCorrectNumberOfParameters) {
|
|
|
|
|
ParameterTestModule module;
|
|
|
|
|
ASSERT_EQ(module.parameters().size(), 3);
|
|
|
|
|
}
|
2018-06-26 04:11:49 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST_F(ModuleTest, ContainsParametersWithTheCorrectName) {
|
|
|
|
|
ParameterTestModule module;
|
|
|
|
|
auto parameters = module.parameters();
|
|
|
|
|
ASSERT_TRUE(parameters.contains("a"));
|
|
|
|
|
ASSERT_TRUE(parameters.contains("b"));
|
|
|
|
|
ASSERT_TRUE(parameters.contains("c"));
|
|
|
|
|
}
|
2018-06-26 04:11:49 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
struct BufferTestModule : Module {
|
|
|
|
|
BufferTestModule() {
|
|
|
|
|
a = register_buffer("a", torch::zeros({2, 2}));
|
|
|
|
|
b = register_buffer("b", torch::ones({2, 2}));
|
|
|
|
|
c = register_buffer("c", torch::ones({2, 2}) * 2);
|
2018-06-26 04:11:49 +00:00
|
|
|
}
|
2018-09-22 04:12:37 +00:00
|
|
|
|
|
|
|
|
torch::Tensor a, b, c;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
TEST_F(ModuleTest, HasCorrectNumberOfBuffers) {
|
|
|
|
|
BufferTestModule module;
|
|
|
|
|
ASSERT_EQ(module.buffers().size(), 3);
|
2018-06-26 04:11:49 +00:00
|
|
|
}
|
2018-07-19 22:55:51 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
TEST_F(ModuleTest, ContainsBuffersWithTheCorrectName) {
|
|
|
|
|
BufferTestModule module;
|
|
|
|
|
auto buffers = module.buffers();
|
|
|
|
|
ASSERT_TRUE(buffers.contains("a"));
|
|
|
|
|
ASSERT_TRUE(buffers.contains("b"));
|
|
|
|
|
ASSERT_TRUE(buffers.contains("c"));
|
|
|
|
|
}
|
2018-07-19 22:55:51 +00:00
|
|
|
|
2018-09-22 04:12:37 +00:00
|
|
|
struct AImpl : torch::nn::Module {
|
|
|
|
|
AImpl() : x_(123) {}
|
|
|
|
|
AImpl(int x) : x_(x) {}
|
|
|
|
|
int x_;
|
|
|
|
|
};
|
|
|
|
|
TORCH_MODULE(A);
|
|
|
|
|
|
|
|
|
|
TEST_F(
|
|
|
|
|
ModuleTest,
|
|
|
|
|
DefaultConstructorOfModuleHolderCallsDefaultConstructorOfImpl) {
|
|
|
|
|
A a;
|
|
|
|
|
ASSERT_TRUE(a);
|
|
|
|
|
ASSERT_FALSE(a.is_empty());
|
|
|
|
|
ASSERT_EQ(a->x_, 123);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(
|
|
|
|
|
ModuleTest,
|
|
|
|
|
ValueConstructorOfModuleHolderCallsCorrectConstructorInImpl) {
|
|
|
|
|
A a(5);
|
|
|
|
|
ASSERT_TRUE(a);
|
|
|
|
|
ASSERT_FALSE(a.is_empty());
|
|
|
|
|
ASSERT_EQ(a->x_, 5);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(ModuleTest, NullptrConstructorLeavesTheModuleHolderInEmptyState) {
|
|
|
|
|
A a = nullptr;
|
|
|
|
|
ASSERT_FALSE(a);
|
|
|
|
|
ASSERT_TRUE(a.is_empty());
|
|
|
|
|
ASSERT_THROWS_WITH(a->x_, "Accessing empty ModuleHolder");
|
2018-07-19 22:55:51 +00:00
|
|
|
}
|