#include #include #include #include #include #include #include #include #include namespace { using at::Tensor; using c10::IValue; struct MyString : public c10::intrusive_ptr_target, public std::string { using std::string::string; }; template class MaybeOwnedTest : public ::testing::Test { public: T borrowFrom; T ownCopy; T ownCopy2; c10::MaybeOwned borrowed; c10::MaybeOwned owned; c10::MaybeOwned owned2; protected: void SetUp() override; // defined below helpers void TearDown() override { // Release everything to try to trigger ASAN violations in the // test that broke things. borrowFrom = T(); ownCopy = T(); ownCopy2 = T(); borrowed = c10::MaybeOwned(); owned = c10::MaybeOwned(); owned2 = c10::MaybeOwned(); } }; //////////////////// Helpers that differ per tested type. //////////////////// template T getSampleValue(); template T getSampleValue2(); template void assertBorrow(const c10::MaybeOwned&, const T&); template void assertOwn(const c10::MaybeOwned&, const T&, size_t useCount = 2); ////////////////// Helper implementations for intrusive_ptr. ////////////////// template<> c10::intrusive_ptr getSampleValue() { return c10::make_intrusive("hello"); } template<> c10::intrusive_ptr getSampleValue2() { return c10::make_intrusive("goodbye"); } bool are_equal(const c10::intrusive_ptr& lhs, const c10::intrusive_ptr& rhs) { if (!lhs || !rhs) { return !lhs && !rhs; } return *lhs == *rhs; } template <> void assertBorrow( const c10::MaybeOwned>& mo, const c10::intrusive_ptr& borrowedFrom) { EXPECT_EQ(*mo, borrowedFrom); EXPECT_EQ(mo->get(), borrowedFrom.get()); EXPECT_EQ(borrowedFrom.use_count(), 1); } template <> void assertOwn( const c10::MaybeOwned>& mo, const c10::intrusive_ptr& original, size_t useCount) { EXPECT_EQ(*mo, original); EXPECT_EQ(mo->get(), original.get()); EXPECT_NE(&*mo, &original); EXPECT_EQ(original.use_count(), useCount); } //////////////////// Helper implementations for Tensor. //////////////////// template<> Tensor getSampleValue() { return at::zeros({2, 2}).to(at::kCPU); } template<> Tensor getSampleValue2() { return at::native::ones({2, 2}).to(at::kCPU); } bool are_equal(const Tensor& lhs, const Tensor& rhs) { if (!lhs.defined() || !rhs.defined()) { return !lhs.defined() && !rhs.defined(); } return at::native::cpu_equal(lhs, rhs); } template <> void assertBorrow( const c10::MaybeOwned& mo, const Tensor& borrowedFrom) { EXPECT_TRUE(mo->is_same(borrowedFrom)); EXPECT_EQ(borrowedFrom.use_count(), 1); } template <> void assertOwn( const c10::MaybeOwned& mo, const Tensor& original, size_t useCount) { EXPECT_TRUE(mo->is_same(original)); EXPECT_EQ(original.use_count(), useCount); } //////////////////// Helper implementations for IValue. //////////////////// template<> IValue getSampleValue() { return IValue(getSampleValue()); } template<> IValue getSampleValue2() { return IValue("hello"); } bool are_equal(const IValue& lhs, const IValue& rhs) { if (lhs.isTensor() != rhs.isTensor()) { return false; } if (lhs.isTensor() && rhs.isTensor()) { return lhs.toTensor().equal(rhs.toTensor()); } return lhs == rhs; } template <> void assertBorrow( const c10::MaybeOwned& mo, const IValue& borrowedFrom) { if (!borrowedFrom.isPtrType()) { EXPECT_EQ(*mo, borrowedFrom); } else { EXPECT_EQ(mo->internalToPointer(), borrowedFrom.internalToPointer()); EXPECT_EQ(borrowedFrom.use_count(), 1); } } template <> void assertOwn( const c10::MaybeOwned& mo, const IValue& original, size_t useCount) { if (!original.isPtrType()) { EXPECT_EQ(*mo, original); } else { EXPECT_EQ(mo->internalToPointer(), original.internalToPointer()); EXPECT_EQ(original.use_count(), useCount); } } template void MaybeOwnedTest::SetUp() { borrowFrom = getSampleValue(); ownCopy = getSampleValue(); ownCopy2 = getSampleValue(); borrowed = c10::MaybeOwned::borrowed(borrowFrom); owned = c10::MaybeOwned::owned(c10::in_place, ownCopy); owned2 = c10::MaybeOwned::owned(T(ownCopy2)); } using MaybeOwnedTypes = ::testing::Types< c10::intrusive_ptr, at::Tensor, c10::IValue >; TYPED_TEST_SUITE(MaybeOwnedTest, MaybeOwnedTypes); TYPED_TEST(MaybeOwnedTest, SimpleDereferencingString) { assertBorrow(this->borrowed, this->borrowFrom); assertOwn(this->owned, this->ownCopy); assertOwn(this->owned2, this->ownCopy2); } TYPED_TEST(MaybeOwnedTest, DefaultCtor) { c10::MaybeOwned borrowed, owned; // Don't leave the fixture versions around messing up reference counts. this->borrowed = c10::MaybeOwned(); this->owned = c10::MaybeOwned(); borrowed = c10::MaybeOwned::borrowed(this->borrowFrom); owned = c10::MaybeOwned::owned(c10::in_place, this->ownCopy); assertBorrow(borrowed, this->borrowFrom); assertOwn(owned, this->ownCopy); } TYPED_TEST(MaybeOwnedTest, CopyConstructor) { auto copiedBorrowed(this->borrowed); auto copiedOwned(this->owned); auto copiedOwned2(this->owned2); assertBorrow(this->borrowed, this->borrowFrom); assertBorrow(copiedBorrowed, this->borrowFrom); assertOwn(this->owned, this->ownCopy, 3); assertOwn(copiedOwned, this->ownCopy, 3); assertOwn(this->owned2, this->ownCopy2, 3); assertOwn(copiedOwned2, this->ownCopy2, 3); } TYPED_TEST(MaybeOwnedTest, MoveDereferencing) { // Need a different value. this->owned = c10::MaybeOwned::owned(c10::in_place, getSampleValue2()); EXPECT_TRUE(are_equal(*std::move(this->borrowed), getSampleValue())); EXPECT_TRUE(are_equal(*std::move(this->owned), getSampleValue2())); // Borrowed is unaffected. assertBorrow(this->borrowed, this->borrowFrom); // Owned is a null c10::intrusive_ptr / empty Tensor. EXPECT_TRUE(are_equal(*this->owned, TypeParam())); } TYPED_TEST(MaybeOwnedTest, MoveConstructor) { auto movedBorrowed(std::move(this->borrowed)); auto movedOwned(std::move(this->owned)); auto movedOwned2(std::move(this->owned2)); assertBorrow(movedBorrowed, this->borrowFrom); assertOwn(movedOwned, this->ownCopy); assertOwn(movedOwned2, this->ownCopy2); } TYPED_TEST(MaybeOwnedTest, CopyAssignmentIntoOwned) { auto copiedBorrowed = c10::MaybeOwned::owned(c10::in_place); auto copiedOwned = c10::MaybeOwned::owned(c10::in_place); auto copiedOwned2 = c10::MaybeOwned::owned(c10::in_place); copiedBorrowed = this->borrowed; copiedOwned = this->owned; copiedOwned2 = this->owned2; assertBorrow(this->borrowed, this->borrowFrom); assertBorrow(copiedBorrowed, this->borrowFrom); assertOwn(this->owned, this->ownCopy, 3); assertOwn(copiedOwned, this->ownCopy, 3); assertOwn(this->owned2, this->ownCopy2, 3); assertOwn(copiedOwned2, this->ownCopy2, 3); } TYPED_TEST(MaybeOwnedTest, CopyAssignmentIntoBorrowed) { auto otherBorrowFrom = getSampleValue2(); auto otherOwnCopy = getSampleValue2(); auto copiedBorrowed = c10::MaybeOwned::borrowed(otherBorrowFrom); auto copiedOwned = c10::MaybeOwned::borrowed(otherOwnCopy); auto copiedOwned2 = c10::MaybeOwned::borrowed(otherOwnCopy); copiedBorrowed = this->borrowed; copiedOwned = this->owned; copiedOwned2 = this->owned2; assertBorrow(this->borrowed, this->borrowFrom); assertBorrow(copiedBorrowed, this->borrowFrom); assertOwn(this->owned, this->ownCopy, 3); assertOwn(this->owned2, this->ownCopy2, 3); assertOwn(copiedOwned, this->ownCopy, 3); assertOwn(copiedOwned2, this->ownCopy2, 3); } TYPED_TEST(MaybeOwnedTest, MoveAssignmentIntoOwned) { auto movedBorrowed = c10::MaybeOwned::owned(c10::in_place); auto movedOwned = c10::MaybeOwned::owned(c10::in_place); auto movedOwned2 = c10::MaybeOwned::owned(c10::in_place); movedBorrowed = std::move(this->borrowed); movedOwned = std::move(this->owned); movedOwned2 = std::move(this->owned2); assertBorrow(movedBorrowed, this->borrowFrom); assertOwn(movedOwned, this->ownCopy); assertOwn(movedOwned2, this->ownCopy2); } TYPED_TEST(MaybeOwnedTest, MoveAssignmentIntoBorrowed) { auto y = getSampleValue2(); auto movedBorrowed = c10::MaybeOwned::borrowed(y); auto movedOwned = c10::MaybeOwned::borrowed(y); auto movedOwned2 = c10::MaybeOwned::borrowed(y); movedBorrowed = std::move(this->borrowed); movedOwned = std::move(this->owned); movedOwned2 = std::move(this->owned2); assertBorrow(movedBorrowed, this->borrowFrom); assertOwn(movedOwned, this->ownCopy); assertOwn(movedOwned2, this->ownCopy2); } TYPED_TEST(MaybeOwnedTest, SelfAssignment) { this->borrowed = this->borrowed; this->owned = this->owned; this->owned2 = this->owned2; assertBorrow(this->borrowed, this->borrowFrom); assertOwn(this->owned, this->ownCopy); assertOwn(this->owned2, this->ownCopy2); } } // namespace