diff --git a/c10/test/util/ArrayRef_test.cpp b/c10/test/util/ArrayRef_test.cpp new file mode 100644 index 00000000000..00e5eeab695 --- /dev/null +++ b/c10/test/util/ArrayRef_test.cpp @@ -0,0 +1,45 @@ +#include + +#include +#include + +#include +#include + +namespace { + +template +class ctor_from_container_test_span_ { + T* data_; + std::size_t sz_; + + public: + template >> + constexpr explicit ctor_from_container_test_span_( + std::conditional_t, const V, V>& vec) noexcept + : data_(vec.data()), sz_(vec.size()) {} + + [[nodiscard]] constexpr auto data() const noexcept { + return data_; + } + + [[nodiscard]] constexpr auto size() const noexcept { + return sz_; + } +}; + +TEST(ArrayRefTest, ctor_from_container_test) { + using value_type = int; + std::vector test_vec{1, 6, 32, 4, 68, 3, 7}; + const ctor_from_container_test_span_ test_mspan{test_vec}; + const ctor_from_container_test_span_ test_cspan{ + std::as_const(test_vec)}; + + const auto test_ref_mspan = c10::ArrayRef(test_mspan); + const auto test_ref_cspan = c10::ArrayRef(test_cspan); + + EXPECT_EQ(std::as_const(test_vec), test_ref_mspan); + EXPECT_EQ(std::as_const(test_vec), test_ref_cspan); +} + +} // namespace diff --git a/c10/util/ArrayRef.h b/c10/util/ArrayRef.h index bd1405c1fc6..10c83998c42 100644 --- a/c10/util/ArrayRef.h +++ b/c10/util/ArrayRef.h @@ -98,9 +98,9 @@ class ArrayRef final { template < typename Container, - typename = std::enable_if_t().data())>, - T*>>> + typename U = decltype(std::declval().data()), + typename = std::enable_if_t< + (std::is_same_v || std::is_same_v)>> /* implicit */ ArrayRef(const Container& container) : Data(container.data()), Length(container.size()) { debugCheckNullptrInvariant();