From a37e2e33b4ce93a8a4124be932298dc338d97071 Mon Sep 17 00:00:00 2001 From: Chih-Hsuan Yen Date: Fri, 26 Jun 2020 11:34:08 +0800 Subject: [PATCH] Add compatibility with Protobuf 3.12 (#4291) In Protobuf 3.12, classes generated from protobuf files are declared as `final`, so use those classes as members rather than base classes. Ref: https://github.com/protocolbuffers/protobuf/releases/tag/v3.12.0 --- onnxruntime/test/framework/data_types_test.cc | 92 ++++++++++--------- .../test/providers/provider_test_utils.h | 40 ++++---- 2 files changed, 70 insertions(+), 62 deletions(-) diff --git a/onnxruntime/test/framework/data_types_test.cc b/onnxruntime/test/framework/data_types_test.cc index 3e0fddb857..74ba617469 100644 --- a/onnxruntime/test/framework/data_types_test.cc +++ b/onnxruntime/test/framework/data_types_test.cc @@ -149,44 +149,48 @@ struct DimSetter { }; template -struct TensorShapeTypeProto : public TensorShapeProto { +struct TensorShapeTypeProto { TensorShapeTypeProto() { - DimSetter::set(*this); + DimSetter::set(proto); } + TensorShapeProto proto; }; template <> -struct TensorShapeTypeProto<> : public TensorShapeProto {}; +struct TensorShapeTypeProto<> { TensorShapeProto proto; }; template -struct TensorTypeProto : public TypeProto { +struct TensorTypeProto { TensorTypeProto() { - mutable_tensor_type()->set_elem_type(T); + proto.mutable_tensor_type()->set_elem_type(T); } + TypeProto proto; }; template -struct SparseTensorTypeProto : public TypeProto { +struct SparseTensorTypeProto { SparseTensorTypeProto() { - mutable_sparse_tensor_type()->set_elem_type(T); + proto.mutable_sparse_tensor_type()->set_elem_type(T); } void SetShape(const TensorShapeProto& shape) { - mutable_sparse_tensor_type()->mutable_shape()->CopyFrom(shape); + proto.mutable_sparse_tensor_type()->mutable_shape()->CopyFrom(shape); } void SetShape(TensorShapeProto&& shape) { - *mutable_sparse_tensor_type()->mutable_shape() = std::move(shape); + *proto.mutable_sparse_tensor_type()->mutable_shape() = std::move(shape); } void ClearShape() { - mutable_sparse_tensor_type()->clear_shape(); + proto.mutable_sparse_tensor_type()->clear_shape(); } + TypeProto proto; }; template -struct MapTypeProto : public TypeProto { +struct MapTypeProto { MapTypeProto() { - mutable_map_type()->set_key_type(key); - mutable_map_type()->mutable_value_type()->mutable_tensor_type()->set_elem_type(value); + proto.mutable_map_type()->set_key_type(key); + proto.mutable_map_type()->mutable_value_type()->mutable_tensor_type()->set_elem_type(value); } + TypeProto proto; }; class DataTypeTest : public testing::Test { @@ -241,9 +245,9 @@ TEST_F(DataTypeTest, OpaqueRegistrationTest) { TEST_F(DataTypeTest, MapStringStringTest) { TensorTypeProto tensor_type; auto ml_str_str = DataTypeImpl::GetType(); - EXPECT_TRUE(DataTypeImpl::GetTensorType()->IsCompatible(tensor_type)); - EXPECT_FALSE(DataTypeImpl::GetTensorType()->IsCompatible(tensor_type)); - EXPECT_FALSE(ml_str_str->IsCompatible(tensor_type)); + EXPECT_TRUE(DataTypeImpl::GetTensorType()->IsCompatible(tensor_type.proto)); + EXPECT_FALSE(DataTypeImpl::GetTensorType()->IsCompatible(tensor_type.proto)); + EXPECT_FALSE(ml_str_str->IsCompatible(tensor_type.proto)); utils::ContainerChecker c_checker(ml_str_str); bool result = c_checker.IsMapOf(); EXPECT_TRUE(result); @@ -257,17 +261,17 @@ TEST_F(DataTypeTest, MapStringStringTest) { MapTypeProto maps2s_type; MapTypeProto maps2i_type; - EXPECT_TRUE(ml_str_str->IsCompatible(maps2s_type)); - EXPECT_FALSE(ml_str_str->IsCompatible(maps2i_type)); + EXPECT_TRUE(ml_str_str->IsCompatible(maps2s_type.proto)); + EXPECT_FALSE(ml_str_str->IsCompatible(maps2i_type.proto)); } TEST_F(DataTypeTest, MapStringInt64Test) { MapTypeProto maps2s_type; MapTypeProto maps2i_type; TensorTypeProto tensor_type; - EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(maps2s_type)); - EXPECT_TRUE(DataTypeImpl::GetType()->IsCompatible(maps2i_type)); - EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(tensor_type)); + EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(maps2s_type.proto)); + EXPECT_TRUE(DataTypeImpl::GetType()->IsCompatible(maps2i_type.proto)); + EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(tensor_type.proto)); utils::ContainerChecker c_checker(DataTypeImpl::GetType()); bool result = c_checker.IsMapOf(); @@ -278,36 +282,36 @@ TEST_F(DataTypeTest, MapStringFloatTest) { MapTypeProto maps2f_type; MapTypeProto maps2i_type; TensorTypeProto tensor_type; - EXPECT_TRUE(DataTypeImpl::GetType()->IsCompatible(maps2f_type)); - EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(maps2i_type)); - EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(tensor_type)); + EXPECT_TRUE(DataTypeImpl::GetType()->IsCompatible(maps2f_type.proto)); + EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(maps2i_type.proto)); + EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(tensor_type.proto)); } TEST_F(DataTypeTest, MapStringDoubleTest) { MapTypeProto maps2d_type; MapTypeProto maps2i_type; TensorTypeProto tensor_type; - EXPECT_TRUE(DataTypeImpl::GetType()->IsCompatible(maps2d_type)); - EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(maps2i_type)); - EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(tensor_type)); + EXPECT_TRUE(DataTypeImpl::GetType()->IsCompatible(maps2d_type.proto)); + EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(maps2i_type.proto)); + EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(tensor_type.proto)); } TEST_F(DataTypeTest, MapInt64StringTest) { MapTypeProto mapi2s_type; MapTypeProto mapi2i_type; TensorTypeProto tensor_type; - EXPECT_TRUE(DataTypeImpl::GetType()->IsCompatible(mapi2s_type)); - EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(mapi2i_type)); - EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(tensor_type)); + EXPECT_TRUE(DataTypeImpl::GetType()->IsCompatible(mapi2s_type.proto)); + EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(mapi2i_type.proto)); + EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(tensor_type.proto)); } TEST_F(DataTypeTest, MapInt64DoubleTest) { MapTypeProto mapi2d_type; MapTypeProto mapi2i_type; TensorTypeProto tensor_type; - EXPECT_TRUE(DataTypeImpl::GetType()->IsCompatible(mapi2d_type)); - EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(mapi2i_type)); - EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(tensor_type)); + EXPECT_TRUE(DataTypeImpl::GetType()->IsCompatible(mapi2d_type.proto)); + EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(mapi2i_type.proto)); + EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(tensor_type.proto)); } TEST_F(DataTypeTest, RecursiveMapTest) { @@ -366,9 +370,9 @@ TEST_F(DataTypeTest, VectorMapStringToFloatTest) { TensorTypeProto tensor_type; EXPECT_TRUE(DataTypeImpl::GetType()->IsCompatible(vector_map_string_to_float)); - EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(mapi2d_type)); - EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(mapi2i_type)); - EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(tensor_type)); + EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(mapi2d_type.proto)); + EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(mapi2i_type.proto)); + EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(tensor_type.proto)); utils::ContainerChecker c_check(DataTypeImpl::GetType()); bool result = c_check.IsSequenceOf(); EXPECT_TRUE(result); @@ -384,9 +388,9 @@ TEST_F(DataTypeTest, VectorMapInt64ToFloatTest) { TensorTypeProto tensor_type; EXPECT_TRUE(DataTypeImpl::GetType()->IsCompatible(type_proto)); - EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(mapi2d_type)); - EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(mapi2i_type)); - EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(tensor_type)); + EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(mapi2d_type.proto)); + EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(mapi2i_type.proto)); + EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(tensor_type.proto)); } TEST_F(DataTypeTest, BFloat16Test) { @@ -476,7 +480,7 @@ TEST_F(DataTypeTest, DataUtilsTest) { // We expect that the above string will be matched in both cases // where we have shape and where we don't SparseTensorTypeProto sparse_proto; - DataType ten_dt = DataTypeUtils::ToType(sparse_proto); + DataType ten_dt = DataTypeUtils::ToType(sparse_proto.proto); EXPECT_NE(ten_dt, nullptr); EXPECT_EQ(tensor_uint64, *ten_dt); DataType ten_from_str = DataTypeUtils::ToType(*ten_dt); @@ -485,8 +489,8 @@ TEST_F(DataTypeTest, DataUtilsTest) { // Now add empty shape, we expect the same string TensorShapeTypeProto<> shape_no_dims; - sparse_proto.SetShape(shape_no_dims); - ten_dt = DataTypeUtils::ToType(sparse_proto); + sparse_proto.SetShape(shape_no_dims.proto); + ten_dt = DataTypeUtils::ToType(sparse_proto.proto); EXPECT_NE(ten_dt, nullptr); EXPECT_EQ(tensor_uint64, *ten_dt); ten_from_str = DataTypeUtils::ToType(*ten_dt); @@ -496,8 +500,8 @@ TEST_F(DataTypeTest, DataUtilsTest) { // Now add shape with dimensions, we expect no difference sparse_proto.ClearShape(); TensorShapeTypeProto<10, 12> shape_with_dim; - sparse_proto.SetShape(shape_with_dim); - ten_dt = DataTypeUtils::ToType(sparse_proto); + sparse_proto.SetShape(shape_with_dim.proto); + ten_dt = DataTypeUtils::ToType(sparse_proto.proto); EXPECT_NE(ten_dt, nullptr); EXPECT_EQ(tensor_uint64, *ten_dt); ten_from_str = DataTypeUtils::ToType(*ten_dt); diff --git a/onnxruntime/test/providers/provider_test_utils.h b/onnxruntime/test/providers/provider_test_utils.h index 50befb3191..162556586f 100644 --- a/onnxruntime/test/providers/provider_test_utils.h +++ b/onnxruntime/test/providers/provider_test_utils.h @@ -120,12 +120,12 @@ constexpr ONNX_NAMESPACE::TensorProto_DataType TypeToDataType() { } template -struct TTypeProto : ONNX_NAMESPACE::TypeProto { +struct TTypeProto { TTypeProto(const std::vector* shape = nullptr) { - mutable_tensor_type()->set_elem_type(TypeToDataType()); + proto.mutable_tensor_type()->set_elem_type(TypeToDataType()); if (shape) { - auto mutable_shape = mutable_tensor_type()->mutable_shape(); + auto mutable_shape = proto.mutable_tensor_type()->mutable_shape(); for (auto i : *shape) { auto* mutable_dim = mutable_shape->add_dim(); if (i != -1) @@ -135,6 +135,7 @@ struct TTypeProto : ONNX_NAMESPACE::TypeProto { } } } + ONNX_NAMESPACE::TypeProto proto; }; // Variable template for ONNX_NAMESPACE::TensorProto_DataTypes, s_type_proto, etc.. @@ -148,12 +149,13 @@ const TTypeProto TTensorType::s_type_proto; // TypeProto for map template -struct MTypeProto : ONNX_NAMESPACE::TypeProto { +struct MTypeProto { MTypeProto() { - mutable_map_type()->set_key_type(TypeToDataType()); - mutable_map_type()->mutable_value_type()->mutable_tensor_type()->set_elem_type(TypeToDataType()); - mutable_map_type()->mutable_value_type()->mutable_tensor_type()->mutable_shape()->clear_dim(); + proto.mutable_map_type()->set_key_type(TypeToDataType()); + proto.mutable_map_type()->mutable_value_type()->mutable_tensor_type()->set_elem_type(TypeToDataType()); + proto.mutable_map_type()->mutable_value_type()->mutable_tensor_type()->mutable_shape()->clear_dim(); } + ONNX_NAMESPACE::TypeProto proto; }; template @@ -166,13 +168,14 @@ const MTypeProto MMapType::s_map_type_proto; // TypeProto for vector> template -struct VectorOfMapTypeProto : ONNX_NAMESPACE::TypeProto { +struct VectorOfMapTypeProto { VectorOfMapTypeProto() { - auto* map_type = mutable_sequence_type()->mutable_elem_type()->mutable_map_type(); + auto* map_type = proto.mutable_sequence_type()->mutable_elem_type()->mutable_map_type(); map_type->set_key_type(TypeToDataType()); map_type->mutable_value_type()->mutable_tensor_type()->set_elem_type(TypeToDataType()); map_type->mutable_value_type()->mutable_tensor_type()->mutable_shape()->clear_dim(); } + ONNX_NAMESPACE::TypeProto proto; }; template @@ -184,14 +187,15 @@ template const VectorOfMapTypeProto VectorOfMapType::s_vec_map_type_proto; template -struct SequenceTensorTypeProto : ONNX_NAMESPACE::TypeProto { +struct SequenceTensorTypeProto { SequenceTensorTypeProto() { MLDataType dt = DataTypeImpl::GetTensorType(); const auto* elem_proto = dt->GetTypeProto(); - mutable_sequence_type()->mutable_elem_type()->CopyFrom(*elem_proto); - auto* tensor_type = mutable_sequence_type()->mutable_elem_type()->mutable_tensor_type(); + proto.mutable_sequence_type()->mutable_elem_type()->CopyFrom(*elem_proto); + auto* tensor_type = proto.mutable_sequence_type()->mutable_elem_type()->mutable_tensor_type(); tensor_type->set_elem_type(TypeToDataType()); } + ONNX_NAMESPACE::TypeProto proto; }; template @@ -306,14 +310,14 @@ class OpTester { OrtValue value; value.Init(ptr.release(), DataTypeImpl::GetType>(), DataTypeImpl::GetType>()->GetDeleteFunc()); - input_data_.push_back(Data(NodeArg(name, &MMapType::s_map_type_proto), std::move(value), + input_data_.push_back(Data(NodeArg(name, &MMapType::s_map_type_proto.proto), std::move(value), optional(), optional())); } template void AddMissingOptionalInput() { std::string name; // empty == input doesn't exist - input_data_.push_back(Data(NodeArg(name, &TTensorType::s_type_proto), OrtValue(), optional(), + input_data_.push_back(Data(NodeArg(name, &TTensorType::s_type_proto.proto), OrtValue(), optional(), optional())); } @@ -338,7 +342,7 @@ class OpTester { template void AddMissingOptionalOutput() { std::string name; // empty == input doesn't exist - output_data_.push_back(Data(NodeArg(name, &TTensorType::s_type_proto), OrtValue(), optional(), + output_data_.push_back(Data(NodeArg(name, &TTensorType::s_type_proto.proto), OrtValue(), optional(), optional())); } @@ -374,7 +378,7 @@ class OpTester { OrtValue ml_value; ml_value.Init(ptr.release(), DataTypeImpl::GetType>>(), DataTypeImpl::GetType>>()->GetDeleteFunc()); - output_data_.push_back(Data(NodeArg(name, &VectorOfMapType::s_vec_map_type_proto), std::move(ml_value), + output_data_.push_back(Data(NodeArg(name, &VectorOfMapType::s_vec_map_type_proto.proto), std::move(ml_value), optional(), optional())); } @@ -530,7 +534,7 @@ class OpTester { OrtValue value; value.Init(p_tensor.release(), DataTypeImpl::GetType(), DataTypeImpl::GetType()->GetDeleteFunc()); - auto node_arg = NodeArg(name, &type_proto); + auto node_arg = NodeArg(name, &type_proto.proto); if (dim_params && !(dim_params->empty())) { // If dim_params presents, configure node_arg's dim value based on dim_params, which supports symbolic dim and dim broadcast. auto& dim_params_data = *dim_params; @@ -590,7 +594,7 @@ class OpTester { ptr->SetElements(std::move(tensors)); value.Init(ptr.get(), mltype, mltype->GetDeleteFunc()); ptr.release(); - data.push_back(Data(NodeArg(name, &SequenceTensorType::s_sequence_tensor_type_proto), std::move(value), + data.push_back(Data(NodeArg(name, &SequenceTensorType::s_sequence_tensor_type_proto.proto), std::move(value), optional(), optional())); }