pytorch/test/cpp/api/any.cpp

337 lines
9.5 KiB
C++
Raw Normal View History

#include <catch.hpp>
#include <torch/torch.h>
#include <torch/utils.h>
#include <torch/nn/modules/any.h>
#include <algorithm>
#include <string>
using namespace torch::nn;
using namespace torch::detail;
using Catch::Contains;
using Catch::StartsWith;
TEST_CASE("any-module") {
torch::manual_seed(0);
SECTION("int()") {
struct M : torch::nn::Module {
int forward() {
return 123;
}
};
AnyModule any(M{});
REQUIRE(any.forward().get<int>() == 123);
}
SECTION("int(int)") {
struct M : torch::nn::Module {
int forward(int x) {
return x;
}
};
AnyModule any(M{});
REQUIRE(any.forward(5).get<int>() == 5);
}
SECTION("const char*(const char*)") {
struct M : torch::nn::Module {
const char* forward(const char* x) {
return x;
}
};
AnyModule any(M{});
REQUIRE(any.forward("hello").get<const char*>() == std::string("hello"));
}
SECTION("string(int, const double)") {
struct M : torch::nn::Module {
std::string forward(int x, const double f) {
return std::to_string(static_cast<int>(x + f));
}
};
AnyModule any(M{});
int x = 4;
REQUIRE(any.forward(x, 3.14).get<std::string>() == std::string("7"));
}
SECTION("Tensor(string, const string&, string&&)") {
struct M : torch::nn::Module {
torch::Tensor forward(
std::string a,
const std::string& b,
std::string&& c) {
const auto s = a + b + c;
Create ATen tensors via TensorOptions (#7869) * Created TensorOptions Storing the type in TensorOptions to solve the Variable problem Created convenience creation functions for TensorOptions and added tests Converted zeros to TensorOptions Converted rand to TensorOptions Fix codegen for TensorOptions and multiple arguments Put TensorOptions convenience functions into torch namespace too All factory functions except *_like support TensorOptions Integrated with recent JIT changes Support *_like functions Fix in place modification Some cleanups and fixes Support sparse_coo_tensor Fix bug in Type.cpp Fix .empty calls in C++ API Fix bug in Type.cpp Trying to fix device placement Make AutoGPU CPU compatible Remove some auto_gpu.h uses Fixing some headers Fix some remaining CUDA/AutoGPU issues Fix some AutoGPU uses Fixes to dispatch_tensor_conversion Reset version of new variables to zero Implemented parsing device strings Random fixes to tests Self review cleanups flake8 Undo changes to variable.{h,cpp} because they fail on gcc7.2 Add [cuda] tag to tensor_options_cuda.cpp Move AutoGPU::set_index_from into .cpp file because Windows is stupid and sucks Fix linker error in AutoGPU.cpp Fix bad merge conflict in native_functions.yaml Fixed caffe2/contrib/aten Fix new window functions added to TensorFactories.cpp * Removed torch::TensorOptions Added code to generate wrapper functions for factory methods Add implicit constructor from Backend to TensorOptions Remove Var() from C++ API and use torch:: functions Use torch:: functions more subtly in C++ API Make AutoGPU::set_device more exception safe Check status directly in DynamicCUDAHooksInterface Rename AutoGPU to DeviceGuard Removed set_requires_grad from python_variables.h and warn appropriately in Variable::set_requires_grad remove python_default_init: self.type() Add back original factory functions, but with deprecation warnings Disable DeviceGuard for a couple functions in ATen Remove print statement Fix DeviceGuard construction from undefined tensor Fixing CUDA device compiler issues Moved as many methods as possible into header files Dont generate python functions for deprecated factories Remove merge conflict artefact Fix tensor_options_cuda.cpp Fix set_requires_grad not being checked Fix tensor_new.h TEMPORARILY put some methods in .cpp files to see if it solves issues on windows and mac Fix bug in DeviceGuard.h Missing includes TEMPORARILY moving a few more methods into .cpp to see if it fixes windows Fixing linker errors * Fix up SummaryOps to use new factories Undo device agnostic behavior of DeviceGuard Use -1 instead of optional for default device index Also move DeviceGuard methods into header Fixes around device index after optional -> int32_t switch Fix use of DeviceGuard in new_with_tensor_copy Fix tensor_options.cpp * Fix Type::copy( * Remove test_non_float_params from ONNX tests * Set requires_grad=False in ONNX tests that use ints * Put layout/dtype/device on Tensor * Post merge fixes * Change behavior of DeviceGuard to match AutoGPU * Fix C++ API integration tests * Fix flip functions
2018-06-16 07:40:35 +00:00
return torch::ones({static_cast<int64_t>(s.size())});
}
};
AnyModule any(M{});
REQUIRE(
any.forward(std::string("a"), std::string("ab"), std::string("abc"))
.get<torch::Tensor>()
.sum()
.toCInt() == 6);
}
SECTION("wrong argument type") {
struct M : torch::nn::Module {
int forward(float x) {
return x;
}
};
AnyModule any(M{});
REQUIRE_THROWS_WITH(
any.forward(5.0),
StartsWith("Expected argument #0 to be of type float, "
"but received value of type double"));
}
SECTION("wrong number of arguments") {
struct M : torch::nn::Module {
int forward(int a, int b) {
return a + b;
}
};
AnyModule any(M{});
REQUIRE_THROWS_WITH(
any.forward(),
Contains("M's forward() method expects 2 arguments, but received 0"));
REQUIRE_THROWS_WITH(
any.forward(5),
Contains("M's forward() method expects 2 arguments, but received 1"));
REQUIRE_THROWS_WITH(
any.forward(1, 2, 3),
Contains("M's forward() method expects 2 arguments, but received 3"));
}
SECTION("get()") {
struct M : torch::nn::Module {
explicit M(int value_) : torch::nn::Module("M"), value(value_) {}
int value;
int forward(float x) {
return x;
}
};
AnyModule any(M{5});
SECTION("good cast") {
REQUIRE(any.get<M>().value == 5);
}
SECTION("bad cast") {
struct N : torch::nn::Module {};
REQUIRE_THROWS_WITH(any.get<N>(), StartsWith("Attempted to cast module"));
}
}
SECTION("ptr()") {
struct M : torch::nn::Module {
explicit M(int value_) : torch::nn::Module("M"), value(value_) {}
int value;
int forward(float x) {
return x;
}
};
AnyModule any(M{5});
SECTION("base class cast") {
auto ptr = any.ptr();
REQUIRE(ptr != nullptr);
REQUIRE(ptr->name() == "M");
}
SECTION("good downcast") {
auto ptr = any.ptr<M>();
REQUIRE(ptr != nullptr);
REQUIRE(ptr->value == 5);
}
SECTION("bad downcast") {
struct N : torch::nn::Module {};
REQUIRE_THROWS_WITH(any.ptr<N>(), StartsWith("Attempted to cast module"));
}
}
SECTION("default state is empty") {
struct M : torch::nn::Module {
explicit M(int value_) : value(value_) {}
int value;
int forward(float x) {
return x;
}
};
AnyModule any;
REQUIRE(any.is_empty());
any = std::make_shared<M>(5);
REQUIRE(!any.is_empty());
REQUIRE(any.get<M>().value == 5);
}
SECTION("all methods throw for empty AnyModule") {
struct M : torch::nn::Module {
int forward(int x) {
return x;
}
};
AnyModule any;
REQUIRE(any.is_empty());
REQUIRE_THROWS_WITH(
any.get<M>(), StartsWith("Cannot call get() on an empty AnyModule"));
REQUIRE_THROWS_WITH(
any.ptr<M>(), StartsWith("Cannot call ptr() on an empty AnyModule"));
REQUIRE_THROWS_WITH(
any.ptr(), StartsWith("Cannot call ptr() on an empty AnyModule"));
REQUIRE_THROWS_WITH(
any.type_info(),
StartsWith("Cannot call type_info() on an empty AnyModule"));
REQUIRE_THROWS_WITH(
any.forward<int>(5),
StartsWith("Cannot call forward() on an empty AnyModule"));
}
SECTION("can move assign differentm modules") {
struct M : torch::nn::Module {
std::string forward(int x) {
return std::to_string(x);
}
};
struct N : torch::nn::Module {
int forward(float x) {
return 3 + x;
}
};
AnyModule any;
REQUIRE(any.is_empty());
any = std::make_shared<M>();
REQUIRE(!any.is_empty());
REQUIRE(any.forward(5).get<std::string>() == "5");
any = std::make_shared<N>();
REQUIRE(!any.is_empty());
REQUIRE(any.forward(5.0f).get<int>() == 8);
}
SECTION("has reference semantics") {
Sequential first(Linear(2, 3), Linear(4, 4), Linear(4, 5));
Sequential second(first);
REQUIRE(first.size() == second.size());
REQUIRE(std::equal(first.begin(), first.end(), second.begin()));
}
SECTION("constructs from ModuleHolder") {
struct MImpl : torch::nn::Module {
explicit MImpl(int value_) : torch::nn::Module("M"), value(value_) {}
int value;
int forward(float x) {
return x;
}
};
struct M : torch::nn::ModuleHolder<MImpl> {
using torch::nn::ModuleHolder<MImpl>::ModuleHolder;
using torch::nn::ModuleHolder<MImpl>::get;
};
AnyModule any(M{5});
REQUIRE(any.get<MImpl>().value == 5);
REQUIRE(any.get<M>()->value == 5);
}
}
namespace torch {
namespace nn {
struct TestValue {
template <typename T>
explicit TestValue(T&& value) : value_(std::forward<T>(value)) {}
AnyModule::Value operator()() {
return std::move(value_);
}
AnyModule::Value value_;
};
template <typename T>
AnyModule::Value make_value(T&& value) {
return TestValue(std::forward<T>(value))();
}
} // namespace nn
} // namespace torch
TEST_CASE("any-value") {
torch::manual_seed(0);
SECTION("gets the correct value for the right type") {
SECTION("int") {
auto value = make_value(5);
// const and non-const types have the same typeid()
REQUIRE(value.try_get<int>() != nullptr);
REQUIRE(value.try_get<const int>() != nullptr);
REQUIRE(value.get<int>() == 5);
}
SECTION("const int") {
auto value = make_value(5);
REQUIRE(value.try_get<const int>() != nullptr);
REQUIRE(value.try_get<int>() != nullptr);
REQUIRE(value.get<const int>() == 5);
}
SECTION("const char*") {
auto value = make_value("hello");
REQUIRE(value.try_get<const char*>() != nullptr);
REQUIRE(value.get<const char*>() == std::string("hello"));
}
SECTION("std::string") {
auto value = make_value(std::string("hello"));
REQUIRE(value.try_get<std::string>() != nullptr);
REQUIRE(value.get<std::string>() == "hello");
}
SECTION("pointers") {
std::string s("hello");
std::string* p = &s;
auto value = make_value(p);
REQUIRE(value.try_get<std::string*>() != nullptr);
REQUIRE(*value.get<std::string*>() == "hello");
}
SECTION("references") {
std::string s("hello");
const std::string& t = s;
auto value = make_value(t);
REQUIRE(value.try_get<std::string>() != nullptr);
REQUIRE(value.get<std::string>() == "hello");
}
}
SECTION("try_get returns nullptr for the wrong type") {
auto value = make_value(5);
REQUIRE(value.try_get<int>() != nullptr);
REQUIRE(value.try_get<float>() == nullptr);
REQUIRE(value.try_get<long>() == nullptr);
REQUIRE(value.try_get<std::string>() == nullptr);
}
SECTION("get throws for the wrong type") {
auto value = make_value(5);
REQUIRE(value.try_get<int>() != nullptr);
REQUIRE_THROWS_WITH(
value.get<float>(),
StartsWith("Attempted to cast Value to float, "
"but its actual type is int"));
REQUIRE_THROWS_WITH(
value.get<long>(),
StartsWith("Attempted to cast Value to long, "
"but its actual type is int"));
}
SECTION("move is allowed") {
auto value = make_value(5);
SECTION("construction") {
auto copy = make_value(std::move(value));
REQUIRE(copy.try_get<int>() != nullptr);
REQUIRE(copy.get<int>() == 5);
}
SECTION("assignment") {
auto copy = make_value(10);
copy = std::move(value);
REQUIRE(copy.try_get<int>() != nullptr);
REQUIRE(copy.get<int>() == 5);
}
}
SECTION("type_info is correct") {
SECTION("int") {
auto value = make_value(5);
REQUIRE(value.type_info().hash_code() == typeid(int).hash_code());
}
SECTION("const char") {
auto value = make_value("hello");
REQUIRE(value.type_info().hash_code() == typeid(const char*).hash_code());
}
SECTION("std::string") {
auto value = make_value(std::string("hello"));
REQUIRE(value.type_info().hash_code() == typeid(std::string).hash_code());
}
}
}