mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Reviewed By: ezyang Differential Revision: D19273220 fbshipit-source-id: 3dfc3388914e60611c84472e3fc529f5b5e40534
96 lines
2.8 KiB
C++
96 lines
2.8 KiB
C++
#include <ATen/ATen.h>
|
|
#include "ATen/core/ivalue.h"
|
|
#include <test/cpp/jit/test_base.h>
|
|
#include <test/cpp/jit/test_utils.h>
|
|
#include <torch/torch.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
using namespace torch::jit::script;
|
|
|
|
void testIValue() {
|
|
c10::List<int64_t> foo({3, 4, 5});
|
|
ASSERT_EQ(foo.use_count(), 1);
|
|
IValue bar{foo};
|
|
ASSERT_EQ(foo.use_count(), 2);
|
|
auto baz = bar;
|
|
ASSERT_EQ(foo.use_count(), 3);
|
|
auto foo2 = std::move(bar);
|
|
ASSERT_EQ(foo.use_count(), 3);
|
|
ASSERT_TRUE(foo2.isIntList());
|
|
ASSERT_TRUE(bar.isNone());
|
|
foo2 = IValue(4.0);
|
|
ASSERT_TRUE(foo2.isDouble());
|
|
ASSERT_EQ(foo2.toDouble(), 4.0);
|
|
ASSERT_EQ(foo.use_count(), 2);
|
|
ASSERT_TRUE(baz.toIntVector() == std::vector<int64_t>({3, 4, 5}));
|
|
|
|
auto move_it = std::move(baz).toIntList();
|
|
ASSERT_EQ(foo.use_count(), 2);
|
|
ASSERT_TRUE(baz.isNone());
|
|
IValue i(4);
|
|
ASSERT_TRUE(i.isInt());
|
|
ASSERT_EQ(i.toInt(), 4);
|
|
IValue dlist(c10::List<double>({3.5}));
|
|
ASSERT_TRUE(dlist.isDoubleList());
|
|
ASSERT_TRUE(dlist.toDoubleVector() == std::vector<double>({3.5}));
|
|
std::move(dlist).toDoubleList();
|
|
ASSERT_TRUE(dlist.isNone());
|
|
dlist = IValue(c10::List<double>({3.4}));
|
|
ASSERT_TRUE(dlist.toDoubleVector() == std::vector<double>({3.4}));
|
|
IValue the_list(
|
|
at::ivalue::Tuple::create({IValue(3.4), IValue(4), IValue(foo)}));
|
|
ASSERT_EQ(foo.use_count(), 3);
|
|
ASSERT_TRUE(the_list.isTuple());
|
|
auto first = the_list.toTuple()->elements()[1];
|
|
ASSERT_EQ(first.toInt(), 4);
|
|
at::Tensor tv = at::rand({3, 4});
|
|
IValue ten(tv);
|
|
ASSERT_EQ(tv.use_count(), 2);
|
|
auto ten2 = ten;
|
|
ASSERT_EQ(tv.use_count(), 3);
|
|
ASSERT_TRUE(ten2.toTensor().equal(ten.toTensor()));
|
|
std::move(ten2).toTensor();
|
|
ASSERT_EQ(tv.use_count(), 2);
|
|
|
|
{
|
|
std::tuple<int64_t, at::Tensor> t = std::make_tuple(123, at::randn({1}));
|
|
auto iv = IValue(t);
|
|
auto t_ = iv.to<std::tuple<int64_t, at::Tensor>>();
|
|
ASSERT_EQ(std::get<0>(t_), 123);
|
|
ASSERT_EQ(
|
|
std::get<1>(t_).item().to<float>(), std::get<1>(t).item().to<float>());
|
|
}
|
|
|
|
// unsafeRemoveAttr in ivalue::Object
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
auto cls = ClassType::create("foo.bar", cu);
|
|
cls->addAttribute("attr1", TensorType::get());
|
|
cls->addAttribute("attr2", TensorType::get());
|
|
auto obj = c10::ivalue::Object::create(
|
|
c10::StrongTypePtr(cu, cls), cls->numAttributes());
|
|
obj->unsafeRemoveAttr("attr1");
|
|
// attr1 is not removed in the type
|
|
ASSERT_TRUE(cls->hasAttribute("attr1"));
|
|
ASSERT_TRUE(cls->hasAttribute("attr2"));
|
|
ASSERT_TRUE(obj->slots().size() == 1);
|
|
|
|
// Test tuple print
|
|
{
|
|
IValue tp = std::make_tuple(3);
|
|
std::stringstream ss;
|
|
ss << tp;
|
|
ASSERT_EQ(ss.str(), "(3,)");
|
|
}
|
|
|
|
{
|
|
IValue tp = std::make_tuple(3, 3);
|
|
std::stringstream ss;
|
|
ss << tp;
|
|
ASSERT_EQ(ss.str(), "(3, 3)");
|
|
}
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|