pytorch/test/cpp/jit/test_ivalue.cpp
Michael Suo dfdb86a595 big cpp test reorg (#24801)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/24801

This is to fix the ODR-violations in fbcode static builds, which have been broken for several months.

This PR is unfortunately quite large, but the changes are only mechanical:
1. Tests defined in header files -> tests defined in cpp files
2. Remove the `torch::jit::testing` namespace -> `torch::jit`.
3. Single `test.h` file that aggregates all tests.
4. Separate out files for gtest and python versions of the tests instead of using a build flag
5. Add a readme for how to add a new test, and explaining a bit about why the cpp tests are the way they are.

Test Plan: Imported from OSS

Differential Revision: D16878605

Pulled By: suo

fbshipit-source-id: 27b5c077dadd990a5f74e25d01731f9c1f491603
2019-08-18 16:49:56 -07:00

58 lines
1.7 KiB
C++

#include <ATen/ATen.h>
#include "ATen/core/ivalue.h"
#include "test/cpp/jit/test_base.h"
#include "test/cpp/jit/test_utils.h"
namespace torch {
namespace jit {
using namespace torch::autograd;
void testIValue() {
c10::List<int64_t> foo({3, 4, 5});
ASSERT_EQ(foo.use_count(), 1);
IValue bar{foo};
ASSERT_EQ(foo.use_count(), 2);
auto baz = bar;
ASSERT_EQ(foo.use_count(), 3);
auto foo2 = std::move(bar);
ASSERT_EQ(foo.use_count(), 3);
ASSERT_TRUE(foo2.isIntList());
ASSERT_TRUE(bar.isNone());
foo2 = IValue(4.0);
ASSERT_TRUE(foo2.isDouble());
ASSERT_EQ(foo2.toDouble(), 4.0);
ASSERT_EQ(foo.use_count(), 2);
ASSERT_TRUE(baz.toIntListRef().equals({3, 4, 5}));
auto move_it = std::move(baz).toIntList();
ASSERT_EQ(foo.use_count(), 2);
ASSERT_TRUE(baz.isNone());
IValue i(4);
ASSERT_TRUE(i.isInt());
ASSERT_EQ(i.toInt(), 4);
IValue dlist(c10::List<double>({3.5}));
ASSERT_TRUE(dlist.isDoubleList());
ASSERT_TRUE(dlist.toDoubleListRef()
.equals({3.5}));
std::move(dlist).toDoubleList();
ASSERT_TRUE(dlist.isNone());
dlist = IValue(c10::List<double>({3.4}));
ASSERT_TRUE(dlist.toDoubleListRef().equals({3.4}));
IValue the_list(at::ivalue::Tuple::create({IValue(3.4), IValue(4), IValue(foo)}));
ASSERT_EQ(foo.use_count(), 3);
ASSERT_TRUE(the_list.isTuple());
auto first = the_list.toTuple()->elements()[1];
ASSERT_EQ(first.toInt(), 4);
at::Tensor tv = at::rand({3, 4});
IValue ten(tv);
ASSERT_EQ(tv.use_count(), 2);
auto ten2 = ten;
ASSERT_EQ(tv.use_count(), 3);
ASSERT_TRUE(ten2.toTensor().equal(ten.toTensor()));
std::move(ten2).toTensor();
ASSERT_EQ(tv.use_count(), 2);
}
} // namespace jit
} // namespace torch