2019-11-21 20:58:47 +00:00
|
|
|
#include <test/cpp/jit/test_base.h>
|
|
|
|
|
#include <test/cpp/jit/test_utils.h>
|
|
|
|
|
#include <torch/torch.h>
|
|
|
|
|
|
|
|
|
|
namespace torch {
|
|
|
|
|
namespace jit {
|
|
|
|
|
|
|
|
|
|
using namespace torch::jit::script;
|
|
|
|
|
|
|
|
|
|
void testModuleCloneInstance() {
|
|
|
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
|
|
|
auto cls = ClassType::create("foo.bar", cu, true);
|
|
|
|
|
auto attr_name = "attr";
|
|
|
|
|
cls->addAttribute(attr_name, IntType::get());
|
|
|
|
|
Module m(cu, cls);
|
|
|
|
|
auto v = IValue(2);
|
|
|
|
|
m.register_attribute(attr_name,
|
|
|
|
|
IntType::get(),
|
|
|
|
|
v,
|
|
|
|
|
false);
|
|
|
|
|
|
|
|
|
|
Module m2 = m.clone();
|
|
|
|
|
Module m3 = m.clone_instance();
|
|
|
|
|
// Make sure copy works
|
|
|
|
|
ASSERT_EQ(m2.attr(attr_name).toInt(), 2);
|
|
|
|
|
ASSERT_EQ(m3.attr(attr_name).toInt(), 2);
|
|
|
|
|
|
|
|
|
|
// clone will copy both type and data, therefore we'll have a
|
|
|
|
|
// different type
|
|
|
|
|
ASSERT_NE(m.type(), m2.type());
|
|
|
|
|
// clone_instance only copies data, type is shared
|
|
|
|
|
ASSERT_EQ(m.type(), m3.type());
|
|
|
|
|
|
|
|
|
|
// change value of copied instance
|
|
|
|
|
m3.register_attribute(attr_name,
|
|
|
|
|
IntType::get(),
|
|
|
|
|
IValue(3),
|
|
|
|
|
false);
|
|
|
|
|
// Verify value of original instance doesn't change
|
|
|
|
|
ASSERT_EQ(m2.attr(attr_name).toInt(), 2);
|
|
|
|
|
ASSERT_EQ(m3.attr(attr_name).toInt(), 3);
|
|
|
|
|
}
|
|
|
|
|
|
2020-01-04 19:07:37 +00:00
|
|
|
void testModuleConstant() {
|
|
|
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
|
|
|
auto cls = ClassType::create("foo.bar", cu, true);
|
|
|
|
|
auto attr_name = "attr";
|
|
|
|
|
auto const_name = "const";
|
|
|
|
|
cls->addAttribute(attr_name, IntType::get());
|
|
|
|
|
cls->addConstant(const_name, IValue(3));
|
|
|
|
|
Module m(cu, cls);
|
|
|
|
|
auto v = IValue(2);
|
|
|
|
|
m.register_attribute(attr_name,
|
|
|
|
|
IntType::get(),
|
|
|
|
|
v,
|
|
|
|
|
false);
|
|
|
|
|
ASSERT_TRUE(m.hasattr(attr_name));
|
|
|
|
|
ASSERT_TRUE(m.hasattr(const_name));
|
|
|
|
|
ASSERT_EQ(m.attr(attr_name).toInt(), 2);
|
|
|
|
|
ASSERT_EQ(m.attr(const_name).toInt(), 3);
|
|
|
|
|
}
|
|
|
|
|
|
2019-11-21 20:58:47 +00:00
|
|
|
} // namespace jit
|
|
|
|
|
} // namespace torch
|