Add compatibility with Protobuf 3.12 (#4291)

In Protobuf 3.12, classes generated from protobuf files are declared as
`final`, so use those classes as members rather than base classes.

Ref: https://github.com/protocolbuffers/protobuf/releases/tag/v3.12.0
This commit is contained in:
Chih-Hsuan Yen 2020-06-26 11:34:08 +08:00 committed by GitHub
parent 5db67ec000
commit a37e2e33b4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 70 additions and 62 deletions

View file

@ -149,44 +149,48 @@ struct DimSetter<d, dims...> {
};
template <int... dims>
struct TensorShapeTypeProto : public TensorShapeProto {
struct TensorShapeTypeProto {
TensorShapeTypeProto() {
DimSetter<dims...>::set(*this);
DimSetter<dims...>::set(proto);
}
TensorShapeProto proto;
};
template <>
struct TensorShapeTypeProto<> : public TensorShapeProto {};
struct TensorShapeTypeProto<> { TensorShapeProto proto; };
template <TensorProto_DataType T>
struct TensorTypeProto : public TypeProto {
struct TensorTypeProto {
TensorTypeProto() {
mutable_tensor_type()->set_elem_type(T);
proto.mutable_tensor_type()->set_elem_type(T);
}
TypeProto proto;
};
template <TensorProto_DataType T>
struct SparseTensorTypeProto : public TypeProto {
struct SparseTensorTypeProto {
SparseTensorTypeProto() {
mutable_sparse_tensor_type()->set_elem_type(T);
proto.mutable_sparse_tensor_type()->set_elem_type(T);
}
void SetShape(const TensorShapeProto& shape) {
mutable_sparse_tensor_type()->mutable_shape()->CopyFrom(shape);
proto.mutable_sparse_tensor_type()->mutable_shape()->CopyFrom(shape);
}
void SetShape(TensorShapeProto&& shape) {
*mutable_sparse_tensor_type()->mutable_shape() = std::move(shape);
*proto.mutable_sparse_tensor_type()->mutable_shape() = std::move(shape);
}
void ClearShape() {
mutable_sparse_tensor_type()->clear_shape();
proto.mutable_sparse_tensor_type()->clear_shape();
}
TypeProto proto;
};
template <TensorProto_DataType key, TensorProto_DataType value>
struct MapTypeProto : public TypeProto {
struct MapTypeProto {
MapTypeProto() {
mutable_map_type()->set_key_type(key);
mutable_map_type()->mutable_value_type()->mutable_tensor_type()->set_elem_type(value);
proto.mutable_map_type()->set_key_type(key);
proto.mutable_map_type()->mutable_value_type()->mutable_tensor_type()->set_elem_type(value);
}
TypeProto proto;
};
class DataTypeTest : public testing::Test {
@ -241,9 +245,9 @@ TEST_F(DataTypeTest, OpaqueRegistrationTest) {
TEST_F(DataTypeTest, MapStringStringTest) {
TensorTypeProto<TensorProto_DataType_FLOAT> tensor_type;
auto ml_str_str = DataTypeImpl::GetType<MapStringToString>();
EXPECT_TRUE(DataTypeImpl::GetTensorType<float>()->IsCompatible(tensor_type));
EXPECT_FALSE(DataTypeImpl::GetTensorType<uint64_t>()->IsCompatible(tensor_type));
EXPECT_FALSE(ml_str_str->IsCompatible(tensor_type));
EXPECT_TRUE(DataTypeImpl::GetTensorType<float>()->IsCompatible(tensor_type.proto));
EXPECT_FALSE(DataTypeImpl::GetTensorType<uint64_t>()->IsCompatible(tensor_type.proto));
EXPECT_FALSE(ml_str_str->IsCompatible(tensor_type.proto));
utils::ContainerChecker c_checker(ml_str_str);
bool result = c_checker.IsMapOf<std::string, std::string>();
EXPECT_TRUE(result);
@ -257,17 +261,17 @@ TEST_F(DataTypeTest, MapStringStringTest) {
MapTypeProto<TensorProto_DataType_STRING, TensorProto_DataType_STRING> maps2s_type;
MapTypeProto<TensorProto_DataType_STRING, TensorProto_DataType_INT64> maps2i_type;
EXPECT_TRUE(ml_str_str->IsCompatible(maps2s_type));
EXPECT_FALSE(ml_str_str->IsCompatible(maps2i_type));
EXPECT_TRUE(ml_str_str->IsCompatible(maps2s_type.proto));
EXPECT_FALSE(ml_str_str->IsCompatible(maps2i_type.proto));
}
TEST_F(DataTypeTest, MapStringInt64Test) {
MapTypeProto<TensorProto_DataType_STRING, TensorProto_DataType_STRING> maps2s_type;
MapTypeProto<TensorProto_DataType_STRING, TensorProto_DataType_INT64> maps2i_type;
TensorTypeProto<TensorProto_DataType_FLOAT> tensor_type;
EXPECT_FALSE(DataTypeImpl::GetType<MapStringToInt64>()->IsCompatible(maps2s_type));
EXPECT_TRUE(DataTypeImpl::GetType<MapStringToInt64>()->IsCompatible(maps2i_type));
EXPECT_FALSE(DataTypeImpl::GetType<MapStringToInt64>()->IsCompatible(tensor_type));
EXPECT_FALSE(DataTypeImpl::GetType<MapStringToInt64>()->IsCompatible(maps2s_type.proto));
EXPECT_TRUE(DataTypeImpl::GetType<MapStringToInt64>()->IsCompatible(maps2i_type.proto));
EXPECT_FALSE(DataTypeImpl::GetType<MapStringToInt64>()->IsCompatible(tensor_type.proto));
utils::ContainerChecker c_checker(DataTypeImpl::GetType<MapStringToInt64>());
bool result = c_checker.IsMapOf<std::string, int64_t>();
@ -278,36 +282,36 @@ TEST_F(DataTypeTest, MapStringFloatTest) {
MapTypeProto<TensorProto_DataType_STRING, TensorProto_DataType_FLOAT> maps2f_type;
MapTypeProto<TensorProto_DataType_STRING, TensorProto_DataType_INT64> maps2i_type;
TensorTypeProto<TensorProto_DataType_FLOAT> tensor_type;
EXPECT_TRUE(DataTypeImpl::GetType<MapStringToFloat>()->IsCompatible(maps2f_type));
EXPECT_FALSE(DataTypeImpl::GetType<MapStringToFloat>()->IsCompatible(maps2i_type));
EXPECT_FALSE(DataTypeImpl::GetType<MapStringToFloat>()->IsCompatible(tensor_type));
EXPECT_TRUE(DataTypeImpl::GetType<MapStringToFloat>()->IsCompatible(maps2f_type.proto));
EXPECT_FALSE(DataTypeImpl::GetType<MapStringToFloat>()->IsCompatible(maps2i_type.proto));
EXPECT_FALSE(DataTypeImpl::GetType<MapStringToFloat>()->IsCompatible(tensor_type.proto));
}
TEST_F(DataTypeTest, MapStringDoubleTest) {
MapTypeProto<TensorProto_DataType_STRING, TensorProto_DataType_DOUBLE> maps2d_type;
MapTypeProto<TensorProto_DataType_STRING, TensorProto_DataType_INT64> maps2i_type;
TensorTypeProto<TensorProto_DataType_FLOAT> tensor_type;
EXPECT_TRUE(DataTypeImpl::GetType<MapStringToDouble>()->IsCompatible(maps2d_type));
EXPECT_FALSE(DataTypeImpl::GetType<MapStringToDouble>()->IsCompatible(maps2i_type));
EXPECT_FALSE(DataTypeImpl::GetType<MapStringToDouble>()->IsCompatible(tensor_type));
EXPECT_TRUE(DataTypeImpl::GetType<MapStringToDouble>()->IsCompatible(maps2d_type.proto));
EXPECT_FALSE(DataTypeImpl::GetType<MapStringToDouble>()->IsCompatible(maps2i_type.proto));
EXPECT_FALSE(DataTypeImpl::GetType<MapStringToDouble>()->IsCompatible(tensor_type.proto));
}
TEST_F(DataTypeTest, MapInt64StringTest) {
MapTypeProto<TensorProto_DataType_INT64, TensorProto_DataType_STRING> mapi2s_type;
MapTypeProto<TensorProto_DataType_INT64, TensorProto_DataType_INT64> mapi2i_type;
TensorTypeProto<TensorProto_DataType_FLOAT> tensor_type;
EXPECT_TRUE(DataTypeImpl::GetType<MapInt64ToString>()->IsCompatible(mapi2s_type));
EXPECT_FALSE(DataTypeImpl::GetType<MapInt64ToString>()->IsCompatible(mapi2i_type));
EXPECT_FALSE(DataTypeImpl::GetType<MapInt64ToString>()->IsCompatible(tensor_type));
EXPECT_TRUE(DataTypeImpl::GetType<MapInt64ToString>()->IsCompatible(mapi2s_type.proto));
EXPECT_FALSE(DataTypeImpl::GetType<MapInt64ToString>()->IsCompatible(mapi2i_type.proto));
EXPECT_FALSE(DataTypeImpl::GetType<MapInt64ToString>()->IsCompatible(tensor_type.proto));
}
TEST_F(DataTypeTest, MapInt64DoubleTest) {
MapTypeProto<TensorProto_DataType_INT64, TensorProto_DataType_DOUBLE> mapi2d_type;
MapTypeProto<TensorProto_DataType_INT64, TensorProto_DataType_INT64> mapi2i_type;
TensorTypeProto<TensorProto_DataType_FLOAT> tensor_type;
EXPECT_TRUE(DataTypeImpl::GetType<MapInt64ToDouble>()->IsCompatible(mapi2d_type));
EXPECT_FALSE(DataTypeImpl::GetType<MapInt64ToString>()->IsCompatible(mapi2i_type));
EXPECT_FALSE(DataTypeImpl::GetType<MapInt64ToString>()->IsCompatible(tensor_type));
EXPECT_TRUE(DataTypeImpl::GetType<MapInt64ToDouble>()->IsCompatible(mapi2d_type.proto));
EXPECT_FALSE(DataTypeImpl::GetType<MapInt64ToString>()->IsCompatible(mapi2i_type.proto));
EXPECT_FALSE(DataTypeImpl::GetType<MapInt64ToString>()->IsCompatible(tensor_type.proto));
}
TEST_F(DataTypeTest, RecursiveMapTest) {
@ -366,9 +370,9 @@ TEST_F(DataTypeTest, VectorMapStringToFloatTest) {
TensorTypeProto<TensorProto_DataType_FLOAT> tensor_type;
EXPECT_TRUE(DataTypeImpl::GetType<VectorMapStringToFloat>()->IsCompatible(vector_map_string_to_float));
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapStringToFloat>()->IsCompatible(mapi2d_type));
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapStringToFloat>()->IsCompatible(mapi2i_type));
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapStringToFloat>()->IsCompatible(tensor_type));
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapStringToFloat>()->IsCompatible(mapi2d_type.proto));
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapStringToFloat>()->IsCompatible(mapi2i_type.proto));
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapStringToFloat>()->IsCompatible(tensor_type.proto));
utils::ContainerChecker c_check(DataTypeImpl::GetType<VectorMapStringToFloat>());
bool result = c_check.IsSequenceOf<MapStringToFloat>();
EXPECT_TRUE(result);
@ -384,9 +388,9 @@ TEST_F(DataTypeTest, VectorMapInt64ToFloatTest) {
TensorTypeProto<TensorProto_DataType_FLOAT> tensor_type;
EXPECT_TRUE(DataTypeImpl::GetType<VectorMapInt64ToFloat>()->IsCompatible(type_proto));
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapInt64ToFloat>()->IsCompatible(mapi2d_type));
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapInt64ToFloat>()->IsCompatible(mapi2i_type));
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapInt64ToFloat>()->IsCompatible(tensor_type));
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapInt64ToFloat>()->IsCompatible(mapi2d_type.proto));
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapInt64ToFloat>()->IsCompatible(mapi2i_type.proto));
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapInt64ToFloat>()->IsCompatible(tensor_type.proto));
}
TEST_F(DataTypeTest, BFloat16Test) {
@ -476,7 +480,7 @@ TEST_F(DataTypeTest, DataUtilsTest) {
// We expect that the above string will be matched in both cases
// where we have shape and where we don't
SparseTensorTypeProto<TensorProto_DataType_UINT64> sparse_proto;
DataType ten_dt = DataTypeUtils::ToType(sparse_proto);
DataType ten_dt = DataTypeUtils::ToType(sparse_proto.proto);
EXPECT_NE(ten_dt, nullptr);
EXPECT_EQ(tensor_uint64, *ten_dt);
DataType ten_from_str = DataTypeUtils::ToType(*ten_dt);
@ -485,8 +489,8 @@ TEST_F(DataTypeTest, DataUtilsTest) {
// Now add empty shape, we expect the same string
TensorShapeTypeProto<> shape_no_dims;
sparse_proto.SetShape(shape_no_dims);
ten_dt = DataTypeUtils::ToType(sparse_proto);
sparse_proto.SetShape(shape_no_dims.proto);
ten_dt = DataTypeUtils::ToType(sparse_proto.proto);
EXPECT_NE(ten_dt, nullptr);
EXPECT_EQ(tensor_uint64, *ten_dt);
ten_from_str = DataTypeUtils::ToType(*ten_dt);
@ -496,8 +500,8 @@ TEST_F(DataTypeTest, DataUtilsTest) {
// Now add shape with dimensions, we expect no difference
sparse_proto.ClearShape();
TensorShapeTypeProto<10, 12> shape_with_dim;
sparse_proto.SetShape(shape_with_dim);
ten_dt = DataTypeUtils::ToType(sparse_proto);
sparse_proto.SetShape(shape_with_dim.proto);
ten_dt = DataTypeUtils::ToType(sparse_proto.proto);
EXPECT_NE(ten_dt, nullptr);
EXPECT_EQ(tensor_uint64, *ten_dt);
ten_from_str = DataTypeUtils::ToType(*ten_dt);

View file

@ -120,12 +120,12 @@ constexpr ONNX_NAMESPACE::TensorProto_DataType TypeToDataType<BFloat16>() {
}
template <typename T>
struct TTypeProto : ONNX_NAMESPACE::TypeProto {
struct TTypeProto {
TTypeProto(const std::vector<int64_t>* shape = nullptr) {
mutable_tensor_type()->set_elem_type(TypeToDataType<T>());
proto.mutable_tensor_type()->set_elem_type(TypeToDataType<T>());
if (shape) {
auto mutable_shape = mutable_tensor_type()->mutable_shape();
auto mutable_shape = proto.mutable_tensor_type()->mutable_shape();
for (auto i : *shape) {
auto* mutable_dim = mutable_shape->add_dim();
if (i != -1)
@ -135,6 +135,7 @@ struct TTypeProto : ONNX_NAMESPACE::TypeProto {
}
}
}
ONNX_NAMESPACE::TypeProto proto;
};
// Variable template for ONNX_NAMESPACE::TensorProto_DataTypes, s_type_proto<float>, etc..
@ -148,12 +149,13 @@ const TTypeProto<T> TTensorType<T>::s_type_proto;
// TypeProto for map<TKey, TVal>
template <typename TKey, typename TVal>
struct MTypeProto : ONNX_NAMESPACE::TypeProto {
struct MTypeProto {
MTypeProto() {
mutable_map_type()->set_key_type(TypeToDataType<TKey>());
mutable_map_type()->mutable_value_type()->mutable_tensor_type()->set_elem_type(TypeToDataType<TVal>());
mutable_map_type()->mutable_value_type()->mutable_tensor_type()->mutable_shape()->clear_dim();
proto.mutable_map_type()->set_key_type(TypeToDataType<TKey>());
proto.mutable_map_type()->mutable_value_type()->mutable_tensor_type()->set_elem_type(TypeToDataType<TVal>());
proto.mutable_map_type()->mutable_value_type()->mutable_tensor_type()->mutable_shape()->clear_dim();
}
ONNX_NAMESPACE::TypeProto proto;
};
template <typename TKey, typename TVal>
@ -166,13 +168,14 @@ const MTypeProto<TKey, TVal> MMapType<TKey, TVal>::s_map_type_proto;
// TypeProto for vector<map<TKey, TVal>>
template <typename TKey, typename TVal>
struct VectorOfMapTypeProto : ONNX_NAMESPACE::TypeProto {
struct VectorOfMapTypeProto {
VectorOfMapTypeProto() {
auto* map_type = mutable_sequence_type()->mutable_elem_type()->mutable_map_type();
auto* map_type = proto.mutable_sequence_type()->mutable_elem_type()->mutable_map_type();
map_type->set_key_type(TypeToDataType<TKey>());
map_type->mutable_value_type()->mutable_tensor_type()->set_elem_type(TypeToDataType<TVal>());
map_type->mutable_value_type()->mutable_tensor_type()->mutable_shape()->clear_dim();
}
ONNX_NAMESPACE::TypeProto proto;
};
template <typename TKey, typename TVal>
@ -184,14 +187,15 @@ template <typename TKey, typename TVal>
const VectorOfMapTypeProto<TKey, TVal> VectorOfMapType<TKey, TVal>::s_vec_map_type_proto;
template <typename ElemType>
struct SequenceTensorTypeProto : ONNX_NAMESPACE::TypeProto {
struct SequenceTensorTypeProto {
SequenceTensorTypeProto() {
MLDataType dt = DataTypeImpl::GetTensorType<ElemType>();
const auto* elem_proto = dt->GetTypeProto();
mutable_sequence_type()->mutable_elem_type()->CopyFrom(*elem_proto);
auto* tensor_type = mutable_sequence_type()->mutable_elem_type()->mutable_tensor_type();
proto.mutable_sequence_type()->mutable_elem_type()->CopyFrom(*elem_proto);
auto* tensor_type = proto.mutable_sequence_type()->mutable_elem_type()->mutable_tensor_type();
tensor_type->set_elem_type(TypeToDataType<ElemType>());
}
ONNX_NAMESPACE::TypeProto proto;
};
template <typename ElemType>
@ -306,14 +310,14 @@ class OpTester {
OrtValue value;
value.Init(ptr.release(), DataTypeImpl::GetType<std::map<TKey, TVal>>(),
DataTypeImpl::GetType<std::map<TKey, TVal>>()->GetDeleteFunc());
input_data_.push_back(Data(NodeArg(name, &MMapType<TKey, TVal>::s_map_type_proto), std::move(value),
input_data_.push_back(Data(NodeArg(name, &MMapType<TKey, TVal>::s_map_type_proto.proto), std::move(value),
optional<float>(), optional<float>()));
}
template <typename T>
void AddMissingOptionalInput() {
std::string name; // empty == input doesn't exist
input_data_.push_back(Data(NodeArg(name, &TTensorType<T>::s_type_proto), OrtValue(), optional<float>(),
input_data_.push_back(Data(NodeArg(name, &TTensorType<T>::s_type_proto.proto), OrtValue(), optional<float>(),
optional<float>()));
}
@ -338,7 +342,7 @@ class OpTester {
template <typename T>
void AddMissingOptionalOutput() {
std::string name; // empty == input doesn't exist
output_data_.push_back(Data(NodeArg(name, &TTensorType<T>::s_type_proto), OrtValue(), optional<float>(),
output_data_.push_back(Data(NodeArg(name, &TTensorType<T>::s_type_proto.proto), OrtValue(), optional<float>(),
optional<float>()));
}
@ -374,7 +378,7 @@ class OpTester {
OrtValue ml_value;
ml_value.Init(ptr.release(), DataTypeImpl::GetType<std::vector<std::map<TKey, TVal>>>(),
DataTypeImpl::GetType<std::vector<std::map<TKey, TVal>>>()->GetDeleteFunc());
output_data_.push_back(Data(NodeArg(name, &VectorOfMapType<TKey, TVal>::s_vec_map_type_proto), std::move(ml_value),
output_data_.push_back(Data(NodeArg(name, &VectorOfMapType<TKey, TVal>::s_vec_map_type_proto.proto), std::move(ml_value),
optional<float>(), optional<float>()));
}
@ -530,7 +534,7 @@ class OpTester {
OrtValue value;
value.Init(p_tensor.release(), DataTypeImpl::GetType<Tensor>(),
DataTypeImpl::GetType<Tensor>()->GetDeleteFunc());
auto node_arg = NodeArg(name, &type_proto);
auto node_arg = NodeArg(name, &type_proto.proto);
if (dim_params && !(dim_params->empty())) {
// If dim_params presents, configure node_arg's dim value based on dim_params, which supports symbolic dim and dim broadcast.
auto& dim_params_data = *dim_params;
@ -590,7 +594,7 @@ class OpTester {
ptr->SetElements(std::move(tensors));
value.Init(ptr.get(), mltype, mltype->GetDeleteFunc());
ptr.release();
data.push_back(Data(NodeArg(name, &SequenceTensorType<T>::s_sequence_tensor_type_proto), std::move(value),
data.push_back(Data(NodeArg(name, &SequenceTensorType<T>::s_sequence_tensor_type_proto.proto), std::move(value),
optional<float>(), optional<float>()));
}