diff --git a/onnxruntime/core/flatbuffers/ort.fbs b/onnxruntime/core/flatbuffers/ort.fbs index 3f6f64edf4..f18620f3f4 100644 --- a/onnxruntime/core/flatbuffers/ort.fbs +++ b/onnxruntime/core/flatbuffers/ort.fbs @@ -68,6 +68,15 @@ table TensorTypeAndShape{ shape:Shape; } +table MapType{ + key_type:TensorDataType; + value_type:onnxruntime.experimental.fbs.TypeInfo; +} + +table SequenceType{ + elem_type:onnxruntime.experimental.fbs.TypeInfo; +} + // Node enum NodeType : int32 { Primitive = 0, @@ -115,6 +124,8 @@ table ValueInfo { // TODO add support of Sequence/Map/SparseTensor/Opaque union TypeInfoValue { tensor_type:TensorTypeAndShape, + sequence_type:SequenceType, + map_type:MapType, } table TypeInfo { diff --git a/onnxruntime/core/flatbuffers/ort.fbs.h b/onnxruntime/core/flatbuffers/ort.fbs.h index b41038cd4d..d64d2d3c6f 100644 --- a/onnxruntime/core/flatbuffers/ort.fbs.h +++ b/onnxruntime/core/flatbuffers/ort.fbs.h @@ -22,6 +22,12 @@ struct DimensionValueBuilder; struct TensorTypeAndShape; struct TensorTypeAndShapeBuilder; +struct MapType; +struct MapTypeBuilder; + +struct SequenceType; +struct SequenceTypeBuilder; + struct EdgeEnd; struct NodeEdge; @@ -267,29 +273,35 @@ inline const char *EnumNameNodeType(NodeType e) { enum TypeInfoValue { TypeInfoValue_NONE = 0, TypeInfoValue_tensor_type = 1, + TypeInfoValue_sequence_type = 2, + TypeInfoValue_map_type = 3, TypeInfoValue_MIN = TypeInfoValue_NONE, - TypeInfoValue_MAX = TypeInfoValue_tensor_type + TypeInfoValue_MAX = TypeInfoValue_map_type }; -inline const TypeInfoValue (&EnumValuesTypeInfoValue())[2] { +inline const TypeInfoValue (&EnumValuesTypeInfoValue())[4] { static const TypeInfoValue values[] = { TypeInfoValue_NONE, - TypeInfoValue_tensor_type + TypeInfoValue_tensor_type, + TypeInfoValue_sequence_type, + TypeInfoValue_map_type }; return values; } inline const char * const *EnumNamesTypeInfoValue() { - static const char * const names[3] = { + static const char * const names[5] = { "NONE", "tensor_type", + "sequence_type", + "map_type", nullptr }; return names; } inline const char *EnumNameTypeInfoValue(TypeInfoValue e) { - if (flatbuffers::IsOutRange(e, TypeInfoValue_NONE, TypeInfoValue_tensor_type)) return ""; + if (flatbuffers::IsOutRange(e, TypeInfoValue_NONE, TypeInfoValue_map_type)) return ""; const size_t index = static_cast(e); return EnumNamesTypeInfoValue()[index]; } @@ -302,6 +314,14 @@ template<> struct TypeInfoValueTraits struct TypeInfoValueTraits { + static const TypeInfoValue enum_value = TypeInfoValue_sequence_type; +}; + +template<> struct TypeInfoValueTraits { + static const TypeInfoValue enum_value = TypeInfoValue_map_type; +}; + bool VerifyTypeInfoValue(flatbuffers::Verifier &verifier, const void *obj, TypeInfoValue type); bool VerifyTypeInfoValueVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); @@ -577,6 +597,100 @@ inline flatbuffers::Offset CreateTensorTypeAndShape( return builder_.Finish(); } +struct MapType FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef MapTypeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_KEY_TYPE = 4, + VT_VALUE_TYPE = 6 + }; + onnxruntime::experimental::fbs::TensorDataType key_type() const { + return static_cast(GetField(VT_KEY_TYPE, 0)); + } + const onnxruntime::experimental::fbs::TypeInfo *value_type() const { + return GetPointer(VT_VALUE_TYPE); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_KEY_TYPE) && + VerifyOffset(verifier, VT_VALUE_TYPE) && + verifier.VerifyTable(value_type()) && + verifier.EndTable(); + } +}; + +struct MapTypeBuilder { + typedef MapType Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_key_type(onnxruntime::experimental::fbs::TensorDataType key_type) { + fbb_.AddElement(MapType::VT_KEY_TYPE, static_cast(key_type), 0); + } + void add_value_type(flatbuffers::Offset value_type) { + fbb_.AddOffset(MapType::VT_VALUE_TYPE, value_type); + } + explicit MapTypeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateMapType( + flatbuffers::FlatBufferBuilder &_fbb, + onnxruntime::experimental::fbs::TensorDataType key_type = onnxruntime::experimental::fbs::TensorDataType_UNDEFINED, + flatbuffers::Offset value_type = 0) { + MapTypeBuilder builder_(_fbb); + builder_.add_value_type(value_type); + builder_.add_key_type(key_type); + return builder_.Finish(); +} + +struct SequenceType FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SequenceTypeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_ELEM_TYPE = 4 + }; + const onnxruntime::experimental::fbs::TypeInfo *elem_type() const { + return GetPointer(VT_ELEM_TYPE); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_ELEM_TYPE) && + verifier.VerifyTable(elem_type()) && + verifier.EndTable(); + } +}; + +struct SequenceTypeBuilder { + typedef SequenceType Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_elem_type(flatbuffers::Offset elem_type) { + fbb_.AddOffset(SequenceType::VT_ELEM_TYPE, elem_type); + } + explicit SequenceTypeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateSequenceType( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset elem_type = 0) { + SequenceTypeBuilder builder_(_fbb); + builder_.add_elem_type(elem_type); + return builder_.Finish(); +} + struct NodeEdge FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef NodeEdgeBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { @@ -969,6 +1083,12 @@ struct TypeInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const onnxruntime::experimental::fbs::TensorTypeAndShape *value_as_tensor_type() const { return value_type() == onnxruntime::experimental::fbs::TypeInfoValue_tensor_type ? static_cast(value()) : nullptr; } + const onnxruntime::experimental::fbs::SequenceType *value_as_sequence_type() const { + return value_type() == onnxruntime::experimental::fbs::TypeInfoValue_sequence_type ? static_cast(value()) : nullptr; + } + const onnxruntime::experimental::fbs::MapType *value_as_map_type() const { + return value_type() == onnxruntime::experimental::fbs::TypeInfoValue_map_type ? static_cast(value()) : nullptr; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_DENOTATION) && @@ -984,6 +1104,14 @@ template<> inline const onnxruntime::experimental::fbs::TensorTypeAndShape *Type return value_as_tensor_type(); } +template<> inline const onnxruntime::experimental::fbs::SequenceType *TypeInfo::value_as() const { + return value_as_sequence_type(); +} + +template<> inline const onnxruntime::experimental::fbs::MapType *TypeInfo::value_as() const { + return value_as_map_type(); +} + struct TypeInfoBuilder { typedef TypeInfo Table; flatbuffers::FlatBufferBuilder &fbb_; @@ -2016,6 +2144,14 @@ inline bool VerifyTypeInfoValue(flatbuffers::Verifier &verifier, const void *obj auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case TypeInfoValue_sequence_type: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case TypeInfoValue_map_type: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return true; } } diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.cc b/onnxruntime/core/graph/graph_flatbuffers_utils.cc index 740d954859..d78d867e74 100644 --- a/onnxruntime/core/graph/graph_flatbuffers_utils.cc +++ b/onnxruntime/core/graph/graph_flatbuffers_utils.cc @@ -13,7 +13,11 @@ namespace experimental { namespace utils { #if !defined(ORT_MINIMAL_BUILD) -static flatbuffers::Offset GetTensorDimensionOrtFormat( +static Status SaveTypeInfoOrtFormat(flatbuffers::FlatBufferBuilder& builder, + const TypeProto& type_proto, + flatbuffers::Offset& fbs_type_info) ORT_MUST_USE_RESULT; + +static flatbuffers::Offset SaveTensorDimensionOrtFormat( flatbuffers::FlatBufferBuilder& builder, const TensorShapeProto_Dimension& tensor_shape_dim) { auto denotation = builder.CreateString(tensor_shape_dim.denotation()); @@ -29,26 +33,45 @@ static flatbuffers::Offset GetTensorDimensionOrtFormat( return fbs::CreateDimension(builder, dim_val, denotation); } -static Status GetTensorShapeOrtFormat(flatbuffers::FlatBufferBuilder& builder, - const TensorShapeProto& tensor_shape_proto, - flatbuffers::Offset& fbs_shape) { +static Status SaveTensorShapeOrtFormat(flatbuffers::FlatBufferBuilder& builder, + const TensorShapeProto& tensor_shape_proto, + flatbuffers::Offset& fbs_shape) { std::vector> dim; dim.reserve(tensor_shape_proto.dim_size()); for (const auto& d : tensor_shape_proto.dim()) { - auto fbs_d = GetTensorDimensionOrtFormat(builder, d); + auto fbs_d = SaveTensorDimensionOrtFormat(builder, d); dim.push_back(fbs_d); } fbs_shape = fbs::CreateShapeDirect(builder, &dim); return Status::OK(); } -static Status GetTensorTypeAndShapeOrtFormat(flatbuffers::FlatBufferBuilder& builder, - const TypeProto_Tensor& tensor_type_proto, - flatbuffers::Offset& fbs_tensor_type) { +static Status SaveSequenceTypeOrtFormat(flatbuffers::FlatBufferBuilder& builder, + const TypeProto_Sequence& sequence_type_proto, + flatbuffers::Offset& fbs_sequence_type) { + flatbuffers::Offset fbs_type_info; + ORT_RETURN_IF_ERROR(SaveTypeInfoOrtFormat(builder, sequence_type_proto.elem_type(), fbs_type_info)); + fbs_sequence_type = fbs::CreateSequenceType(builder, fbs_type_info); + return Status::OK(); +} + +static Status SaveMapTypeOrtFormat(flatbuffers::FlatBufferBuilder& builder, + const TypeProto_Map& map_type_proto, + flatbuffers::Offset& fbs_map_type) { + flatbuffers::Offset fbs_type_info; + ORT_RETURN_IF_ERROR(SaveTypeInfoOrtFormat(builder, map_type_proto.value_type(), fbs_type_info)); + fbs_map_type = fbs::CreateMapType( + builder, static_cast(map_type_proto.key_type()), fbs_type_info); + return Status::OK(); +} + +static Status SaveTensorTypeAndShapeOrtFormat(flatbuffers::FlatBufferBuilder& builder, + const TypeProto_Tensor& tensor_type_proto, + flatbuffers::Offset& fbs_tensor_type) { // A flatbuffers::Offset of 0 means this shape is missing (was null when serializing) flatbuffers::Offset shape = 0; if (tensor_type_proto.has_shape()) { - ORT_RETURN_IF_ERROR(GetTensorShapeOrtFormat(builder, tensor_type_proto.shape(), shape)); + ORT_RETURN_IF_ERROR(SaveTensorShapeOrtFormat(builder, tensor_type_proto.shape(), shape)); } fbs_tensor_type = fbs::CreateTensorTypeAndShape( @@ -57,19 +80,37 @@ static Status GetTensorTypeAndShapeOrtFormat(flatbuffers::FlatBufferBuilder& bui return Status::OK(); } -static Status GetTypeInfoOrtFormat(flatbuffers::FlatBufferBuilder& builder, - const TypeProto& type_proto, - flatbuffers::Offset& fbs_type_info) { +static Status SaveTypeInfoOrtFormat(flatbuffers::FlatBufferBuilder& builder, + const TypeProto& type_proto, + flatbuffers::Offset& fbs_type_info) { auto denotation = builder.CreateString(type_proto.denotation()); auto value_type = fbs::TypeInfoValue_tensor_type; flatbuffers::Offset value; - if (type_proto.has_tensor_type()) { - flatbuffers::Offset tensor_type; - ORT_RETURN_IF_ERROR( - GetTensorTypeAndShapeOrtFormat(builder, type_proto.tensor_type(), tensor_type)); - value = tensor_type.Union(); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "We only support tensor type for now"); + auto value_case = type_proto.value_case(); + switch (value_case) { + case TypeProto::kTensorType: { + flatbuffers::Offset fbs_tensor_type; + ORT_RETURN_IF_ERROR( + SaveTensorTypeAndShapeOrtFormat(builder, type_proto.tensor_type(), fbs_tensor_type)); + value = fbs_tensor_type.Union(); + } break; + case TypeProto::kSequenceType: { + value_type = fbs::TypeInfoValue_sequence_type; + flatbuffers::Offset fbs_sequence_type; + ORT_RETURN_IF_ERROR( + SaveSequenceTypeOrtFormat(builder, type_proto.sequence_type(), fbs_sequence_type)); + value = fbs_sequence_type.Union(); + } break; + case TypeProto::kMapType: { + value_type = fbs::TypeInfoValue_map_type; + flatbuffers::Offset fbs_map_type; + ORT_RETURN_IF_ERROR( + SaveMapTypeOrtFormat(builder, type_proto.map_type(), fbs_map_type)); + value = fbs_map_type.Union(); + } break; + default: { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "We do not support type [", value_case, "] for now"); + } break; } fbs::TypeInfoBuilder tb(builder); @@ -88,7 +129,7 @@ Status SaveValueInfoOrtFormat(flatbuffers::FlatBufferBuilder& builder, flatbuffers::Offset type_info = 0; // 0 indicates null if (value_info_proto.has_type()) { ORT_RETURN_IF_ERROR( - GetTypeInfoOrtFormat(builder, value_info_proto.type(), type_info)); + SaveTypeInfoOrtFormat(builder, value_info_proto.type(), type_info)); } else { // we have a NodeArg for missing optional values (empty name, no type) so allow for that. // everything else should have type info @@ -126,7 +167,7 @@ Status SaveInitializerOrtFormat(flatbuffers::FlatBufferBuilder& builder, string_data = builder.CreateVectorOfStrings(string_data_vec); } else { std::unique_ptr unpacked_tensor; - size_t tensor_byte_size; + size_t tensor_byte_size = 0; ORT_RETURN_IF_ERROR( onnxruntime::utils::UnpackInitializerData(initializer, unpacked_tensor, tensor_byte_size)); raw_data = builder.CreateVector(unpacked_tensor.get(), tensor_byte_size); @@ -228,6 +269,9 @@ Status SaveAttributeOrtFormat(flatbuffers::FlatBufferBuilder& builder, #undef GET_FBS_ATTR #undef GET_DATA_VEC +static Status LoadTypeInfoOrtFormat(const fbs::TypeInfo& fbs_type_info, + TypeProto& type_proto) ORT_MUST_USE_RESULT; + Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& initializer) { initializer.Clear(); @@ -301,6 +345,23 @@ static Status LoadTensorTypeAndShapeOrtFormat(const fbs::TensorTypeAndShape& fbs return Status::OK(); } +static Status LoadSequenceTypeOrtFormat(const fbs::SequenceType& fbs_sequence_type, + TypeProto_Sequence& sequence_type_proto) { + auto fbs_type_info = fbs_sequence_type.elem_type(); + ORT_RETURN_IF(nullptr == fbs_type_info, "Null value type info in fbs::SequenceType. Invalid ORT format model."); + ORT_RETURN_IF_ERROR(LoadTypeInfoOrtFormat(*fbs_type_info, *sequence_type_proto.mutable_elem_type())); + return Status::OK(); +} + +static Status LoadMapTypeOrtFormat(const fbs::MapType& fbs_map_type, + TypeProto_Map& map_type_proto) { + map_type_proto.set_key_type(fbs_map_type.key_type()); + auto fbs_type_info = fbs_map_type.value_type(); + ORT_RETURN_IF(nullptr == fbs_type_info, "Null value type info in fbs::MapType. Invalid ORT format model."); + ORT_RETURN_IF_ERROR(LoadTypeInfoOrtFormat(*fbs_type_info, *map_type_proto.mutable_value_type())); + return Status::OK(); +} + static Status LoadTypeInfoOrtFormat(const fbs::TypeInfo& fbs_type_info, TypeProto& type_proto) { LoadStringFromOrtFormat(*type_proto.mutable_denotation(), fbs_type_info.denotation()); @@ -309,8 +370,16 @@ static Status LoadTypeInfoOrtFormat(const fbs::TypeInfo& fbs_type_info, auto fbs_tensor_type = fbs_type_info.value_as_tensor_type(); ORT_RETURN_IF(nullptr == fbs_tensor_type, "Null tensor type info. Invalid ORT format model."); ORT_RETURN_IF_ERROR(LoadTensorTypeAndShapeOrtFormat(*fbs_tensor_type, *type_proto.mutable_tensor_type())); + } else if (value_type == fbs::TypeInfoValue_sequence_type) { + auto fbs_sequence_type = fbs_type_info.value_as_sequence_type(); + ORT_RETURN_IF(nullptr == fbs_sequence_type, "Null sequence type info. Invalid ORT format model."); + ORT_RETURN_IF_ERROR(LoadSequenceTypeOrtFormat(*fbs_sequence_type, *type_proto.mutable_sequence_type())); + } else if (value_type == fbs::TypeInfoValue_map_type) { + auto fbs_map_type = fbs_type_info.value_as_map_type(); + ORT_RETURN_IF(nullptr == fbs_map_type, "Null map type info. Invalid ORT format model."); + ORT_RETURN_IF_ERROR(LoadMapTypeOrtFormat(*fbs_map_type, *type_proto.mutable_map_type())); } else { - // TODO: This may be required for traditional ML models. + // We do not support SparseTensor and Opaque for now return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Type:", value_type, " is not supported currently"); } diff --git a/onnxruntime/test/framework/ort_model_only_test.cc b/onnxruntime/test/framework/ort_model_only_test.cc index 4c53808bba..437d0c2c07 100644 --- a/onnxruntime/test/framework/ort_model_only_test.cc +++ b/onnxruntime/test/framework/ort_model_only_test.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/framework/data_types.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/onnx_protobuf.h" #include "core/session/inference_session.h" @@ -23,17 +24,41 @@ class InferenceSessionGetGraphWrapper : public InferenceSession { const Environment& env) : InferenceSession(session_options, env) { } - const Graph& GetGraph() { + const Graph& GetGraph() const { return model_->MainGraph(); } - const SessionState& GetSessionState() { + const SessionState& GetSessionState() const { return InferenceSession::GetSessionState(); } }; namespace test { +struct OrtModelTestInfo { + std::basic_string model_filename; + std::string logid; + NameMLValMap inputs; + std::vector output_names; + std::function&)> output_verifier; + std::vector> configs; +}; + +void RunOrtModel(const OrtModelTestInfo& test_info) { + SessionOptions so; + so.session_logid = test_info.logid; + for (const auto& config : test_info.configs) + so.AddConfigEntry(config.first.c_str(), config.second.c_str()); + + InferenceSessionGetGraphWrapper session_object{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.Load(test_info.model_filename)); // infer type from filename + ASSERT_STATUS_OK(session_object.Initialize()); + + std::vector fetches; + ASSERT_STATUS_OK(session_object.Run(test_info.inputs, test_info.output_names, &fetches)); + test_info.output_verifier(fetches); +} + #if !defined(ORT_MINIMAL_BUILD) // Same Tensor from ONNX and ORT format will have different binary representation, need to compare value by value static void CompareTensors(const OrtValue& left_value, const OrtValue& right_value) { @@ -56,89 +81,60 @@ static void CompareTensors(const OrtValue& left_value, const OrtValue& right_val } } +static void CompareTypeProtos(const TypeProto& left_type_proto, const TypeProto& right_type_proto) { + ASSERT_EQ(left_type_proto.denotation(), right_type_proto.denotation()); + + ASSERT_EQ(left_type_proto.has_tensor_type(), right_type_proto.has_tensor_type()); + ASSERT_EQ(left_type_proto.has_sequence_type(), right_type_proto.has_sequence_type()); + ASSERT_EQ(left_type_proto.has_map_type(), right_type_proto.has_map_type()); + + if (left_type_proto.has_tensor_type()) { + const auto& left_tensor_type = left_type_proto.tensor_type(); + const auto& right_tensor_type = right_type_proto.tensor_type(); + + ASSERT_EQ(left_tensor_type.elem_type(), right_tensor_type.elem_type()); + + const auto& left_shape = left_tensor_type.shape(); + const auto& right_shape = right_tensor_type.shape(); + + ASSERT_EQ(left_shape.dim_size(), right_shape.dim_size()); + for (int i = 0; i < left_shape.dim_size(); i++) { + const auto& left_dim = left_shape.dim(i); + const auto& right_dim = right_shape.dim(i); + ASSERT_EQ(left_dim.has_dim_value(), right_dim.has_dim_value()); + ASSERT_EQ(left_dim.dim_value(), right_dim.dim_value()); + ASSERT_EQ(left_dim.has_dim_param(), right_dim.has_dim_param()); + ASSERT_EQ(left_dim.dim_param(), right_dim.dim_param()); + } + } else if (left_type_proto.has_sequence_type()) { + CompareTypeProtos(left_type_proto.sequence_type().elem_type(), right_type_proto.sequence_type().elem_type()); + } else if (left_type_proto.has_map_type()) { + const auto& left_map = left_type_proto.map_type(); + const auto& right_map = right_type_proto.map_type(); + ASSERT_EQ(left_map.key_type(), right_map.key_type()); + CompareTypeProtos(left_map.value_type(), right_map.value_type()); + } else { + FAIL(); // We do not support SparseTensor and Opaque for now + } +} + static void CompareValueInfos(const ValueInfoProto& left, const ValueInfoProto& right) { ASSERT_EQ(left.name(), right.name()); ASSERT_EQ(left.doc_string(), right.doc_string()); - std::string left_data; - std::string right_data; - - const auto& left_type_proto = left.type(); - const auto& right_type_proto = right.type(); - - ASSERT_EQ(left_type_proto.denotation(), right_type_proto.denotation()); - ASSERT_TRUE(left_type_proto.has_tensor_type()); - ASSERT_TRUE(right_type_proto.has_tensor_type()); - - const auto& left_tensor_type = left_type_proto.tensor_type(); - const auto& right_tensor_type = right_type_proto.tensor_type(); - - ASSERT_EQ(left_tensor_type.elem_type(), right_tensor_type.elem_type()); - - const auto& left_shape = left_tensor_type.shape(); - const auto& right_shape = right_tensor_type.shape(); - - ASSERT_EQ(left_shape.dim_size(), right_shape.dim_size()); - for (int i = 0; i < left_shape.dim_size(); i++) { - const auto& left_dim = left_shape.dim(i); - const auto& right_dim = right_shape.dim(i); - ASSERT_EQ(left_dim.has_dim_value(), right_dim.has_dim_value()); - ASSERT_EQ(left_dim.dim_value(), right_dim.dim_value()); - ASSERT_EQ(left_dim.has_dim_param(), right_dim.has_dim_param()); - ASSERT_EQ(left_dim.dim_param(), right_dim.dim_param()); - } + CompareTypeProtos(left.type(), right.type()); } -TEST(OrtModelOnlyTests, SerializeToOrtFormat) { - const auto output_file = ORT_TSTR("ort_github_issue_4031.onnx.ort"); - SessionOptions so; - so.session_logid = "SerializeToOrtFormat"; - so.optimized_model_filepath = output_file; - // not strictly necessary - type should be inferred from the filename - so.AddConfigEntry(ORT_SESSION_OPTIONS_CONFIG_SAVE_MODEL_FORMAT, "ORT"); +void CompareGraphAndSessionState(const InferenceSessionGetGraphWrapper& session_object_1, + const InferenceSessionGetGraphWrapper& session_object_2) { + const auto& graph_1 = session_object_1.GetGraph(); + const auto& graph_2 = session_object_2.GetGraph(); - InferenceSessionGetGraphWrapper session_object{so, GetEnvironment()}; + const auto& session_state_1 = session_object_1.GetSessionState(); + const auto& session_state_2 = session_object_2.GetSessionState(); - // create .ort file during Initialize due to values in SessionOptions - ASSERT_STATUS_OK(session_object.Load(ORT_TSTR("testdata/ort_github_issue_4031.onnx"))); - ASSERT_STATUS_OK(session_object.Initialize()); - - // create inputs - OrtValue ml_value; - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1}, {123.f}, - &ml_value); - NameMLValMap feeds; - feeds.insert(std::make_pair("state_var_in", ml_value)); - - // prepare outputs - std::vector output_names{"state_var_out"}; - std::vector fetches; - - ASSERT_STATUS_OK(session_object.Run(feeds, output_names, &fetches)); - - SessionOptions so2; - so.session_logid = "LoadOrtFormat"; - // not strictly necessary - type should be inferred from the filename, but to be sure we're testing what we - // think we're testing set it. - so.AddConfigEntry(ORT_SESSION_OPTIONS_CONFIG_LOAD_MODEL_FORMAT, "ORT"); - - // load serialized version - InferenceSessionGetGraphWrapper session_object2{so2, GetEnvironment()}; - ASSERT_STATUS_OK(session_object2.Load(output_file)); - ASSERT_STATUS_OK(session_object2.Initialize()); - - // compare contents on Graph instances - const auto& graph = session_object.GetGraph(); - const auto& graph2 = session_object2.GetGraph(); - - const auto& session_state = session_object.GetSessionState(); - const auto& session_state2 = session_object2.GetSessionState(); - - const auto& name_idx_map = session_state.GetOrtValueNameIdxMap(); - const auto& name_idx_map2 = session_state2.GetOrtValueNameIdxMap(); - - const auto& i1 = session_state.GetInitializedTensors(); - const auto& i2 = session_state2.GetInitializedTensors(); + const auto& i1 = session_state_1.GetInitializedTensors(); + const auto& i2 = session_state_2.GetInitializedTensors(); ASSERT_EQ(i1.size(), i2.size()); for (const auto& pair : i1) { @@ -148,24 +144,12 @@ TEST(OrtModelOnlyTests, SerializeToOrtFormat) { const OrtValue& left = pair.second; const OrtValue& right = iter->second; CompareTensors(left, right); - - // check NodeArgs for both initializers. need to get name from map as we store the initialized tensors against - // their OrtValueIdx in SessionState. - std::string name; - std::string name2; - ASSERT_STATUS_OK(name_idx_map.GetName(pair.first, name)); - ASSERT_STATUS_OK(name_idx_map2.GetName(pair.first, name2)); - ASSERT_EQ(name, name2); - - const auto& left_nodearg = *graph.GetNodeArg(name); - const auto& right_nodearg = *graph2.GetNodeArg(name2); - CompareValueInfos(left_nodearg.ToProto(), right_nodearg.ToProto()); } // check all node args are fine - for (const auto& input : graph.GetInputs()) { - const auto& left = *graph.GetNodeArg(input->Name()); - const auto* right = graph2.GetNodeArg(input->Name()); + for (const auto& input : graph_1.GetInputsIncludingInitializers()) { + const auto& left = *graph_1.GetNodeArg(input->Name()); + const auto* right = graph_2.GetNodeArg(input->Name()); ASSERT_TRUE(right != nullptr); const auto& left_proto = left.ToProto(); @@ -173,8 +157,8 @@ TEST(OrtModelOnlyTests, SerializeToOrtFormat) { CompareValueInfos(left_proto, right_proto); } - for (const auto& left : graph.Nodes()) { - const auto* right = graph2.GetNode(left.Index()); + for (const auto& left : graph_1.Nodes()) { + const auto* right = graph_2.GetNode(left.Index()); ASSERT_TRUE(right != nullptr); const auto& left_outputs = left.OutputDefs(); const auto& right_outputs = right->OutputDefs(); @@ -192,47 +176,165 @@ TEST(OrtModelOnlyTests, SerializeToOrtFormat) { } } } - - // check results match - std::vector fetches2; - ASSERT_STATUS_OK(session_object2.Run(feeds, output_names, &fetches2)); - - const auto& output = fetches[0].Get(); - ASSERT_TRUE(output.Shape().Size() == 1); - ASSERT_TRUE(output.Data()[0] == 125.f); - - const auto& output2 = fetches2[0].Get(); - ASSERT_TRUE(output2.Shape().Size() == 1); - ASSERT_TRUE(output2.Data()[0] == 125.f); } -#endif -// test that we can deserialize and run a previously saved ORT format model -TEST(OrtModelOnlyTests, LoadOrtFormatModel) { - const auto model_filename = ORT_TSTR("testdata/ort_github_issue_4031.onnx.ort"); +void SaveAndCompareModels(const std::string& onnx_file, const std::basic_string& ort_file) { SessionOptions so; - so.session_logid = "LoadOrtFormatModel"; - + so.session_logid = "SerializeToOrtFormat"; + so.optimized_model_filepath = ort_file; + // not strictly necessary - type should be inferred from the filename + so.AddConfigEntry(ORT_SESSION_OPTIONS_CONFIG_SAVE_MODEL_FORMAT, "ORT"); InferenceSessionGetGraphWrapper session_object{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object.Load(model_filename)); // infer type from filename + + // create .ort file during Initialize due to values in SessionOptions + ASSERT_STATUS_OK(session_object.Load(onnx_file)); ASSERT_STATUS_OK(session_object.Initialize()); + SessionOptions so2; + so2.session_logid = "LoadOrtFormat"; + // not strictly necessary - type should be inferred from the filename, but to be sure we're testing what we + // think we're testing set it. + so2.AddConfigEntry(ORT_SESSION_OPTIONS_CONFIG_LOAD_MODEL_FORMAT, "ORT"); + + // load serialized version + InferenceSessionGetGraphWrapper session_object2{so2, GetEnvironment()}; + ASSERT_STATUS_OK(session_object2.Load(ort_file)); + ASSERT_STATUS_OK(session_object2.Initialize()); + + CompareGraphAndSessionState(session_object, session_object2); +} + +TEST(OrtModelOnlyTests, SerializeToOrtFormat) { + const std::basic_string ort_file = ORT_TSTR("ort_github_issue_4031.onnx.ort"); + SaveAndCompareModels("testdata/ort_github_issue_4031.onnx", ort_file); + + OrtModelTestInfo test_info; + test_info.model_filename = ort_file; + test_info.logid = "SerializeToOrtFormat"; + test_info.configs.push_back(std::make_pair(ORT_SESSION_OPTIONS_CONFIG_LOAD_MODEL_FORMAT, "ORT")); + OrtValue ml_value; CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1}, {123.f}, &ml_value); - NameMLValMap feeds; - feeds.insert(std::make_pair("state_var_in", ml_value)); + test_info.inputs.insert(std::make_pair("state_var_in", ml_value)); // prepare outputs - std::vector output_names{"state_var_out"}; - std::vector fetches; + test_info.output_names = {"state_var_out"}; + test_info.output_verifier = [](const std::vector& fetches) { + const auto& output = fetches[0].Get(); + ASSERT_TRUE(output.Shape().Size() == 1); + ASSERT_TRUE(output.Data()[0] == 125.f); + }; - ASSERT_STATUS_OK(session_object.Run(feeds, output_names, &fetches)); - - const auto& output = fetches[0].Get(); - ASSERT_TRUE(output.Shape().Size() == 1); - ASSERT_TRUE(output.Data()[0] == 125.f); + RunOrtModel(test_info); } +#if !defined(DISABLE_ML_OPS) +TEST(OrtModelOnlyTests, SerializeToOrtFormatMLOps) { + const std::basic_string ort_file = ORT_TSTR("sklearn_bin_voting_classifier_soft_converted.ort"); + SaveAndCompareModels("testdata/sklearn_bin_voting_classifier_soft.onnx", ort_file); + + OrtModelTestInfo test_info; + test_info.model_filename = ort_file; + test_info.logid = "SerializeToOrtFormatMLOps"; + test_info.configs.push_back(std::make_pair(ORT_SESSION_OPTIONS_CONFIG_LOAD_MODEL_FORMAT, "ORT")); + + OrtValue ml_value; + CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {3, 2}, + {0.f, 1.f, 1.f, 1.f, 2.f, 0.f}, &ml_value); + test_info.inputs.insert(std::make_pair("input", ml_value)); + + // prepare outputs + test_info.output_names = {"output_label", "output_probability"}; + test_info.output_verifier = [](const std::vector& fetches) { + const auto& output_0 = fetches[0].Get(); + int64_t tensor_size = 3; + ASSERT_EQ(tensor_size, output_0.Shape().Size()); + const auto& output_0_data = output_0.Data(); + for (int64_t i = 0; i < tensor_size; i++) + ASSERT_TRUE(output_0_data[i] == "A"); + + VectorMapStringToFloat expected_output_1 = {{{"A", 0.572734f}, {"B", 0.427266f}}, + {{"A", 0.596016f}, {"B", 0.403984f}}, + {{"A", 0.656315f}, {"B", 0.343685f}}}; + const auto& actual_output_1 = fetches[1].Get(); + ASSERT_EQ(actual_output_1.size(), 3); + for (size_t i = 0; i < 3; i++) { + const auto& expected = expected_output_1[i]; + const auto& actual = actual_output_1[i]; + ASSERT_EQ(actual.size(), 2); + ASSERT_NEAR(expected.at("A"), actual.at("A"), 1e-6); + ASSERT_NEAR(expected.at("B"), actual.at("B"), 1e-6); + } + }; + + RunOrtModel(test_info); +} +#endif // #if !defined(DISABLE_ML_OPS) +#endif // #if !defined(ORT_MINIMAL_BUILD) + +// test that we can deserialize and run a previously saved ORT format model +TEST(OrtModelOnlyTests, LoadOrtFormatModel) { + OrtModelTestInfo test_info; + test_info.model_filename = ORT_TSTR("testdata/ort_github_issue_4031.onnx.ort"); + test_info.logid = "LoadOrtFormatModel"; + + OrtValue ml_value; + CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1}, {123.f}, + &ml_value); + test_info.inputs.insert(std::make_pair("state_var_in", ml_value)); + + // prepare outputs + test_info.output_names = {"state_var_out"}; + test_info.output_verifier = [](const std::vector& fetches) { + const auto& output = fetches[0].Get(); + ASSERT_TRUE(output.Shape().Size() == 1); + ASSERT_TRUE(output.Data()[0] == 125.f); + }; + + RunOrtModel(test_info); +} + +#if !defined(DISABLE_ML_OPS) +// test that we can deserialize and run a previously saved ORT format model +// for a model with sequence and map outputs +TEST(OrtModelOnlyTests, LoadOrtFormatModelMLOps) { + OrtModelTestInfo test_info; + test_info.model_filename = ORT_TSTR("testdata/sklearn_bin_voting_classifier_soft.ort"); + test_info.logid = "LoadOrtFormatModelMLOps"; + + OrtValue ml_value; + CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {3, 2}, + {0.f, 1.f, 1.f, 1.f, 2.f, 0.f}, &ml_value); + test_info.inputs.insert(std::make_pair("input", ml_value)); + + // prepare outputs + test_info.output_names = {"output_label", "output_probability"}; + test_info.output_verifier = [](const std::vector& fetches) { + const auto& output_0 = fetches[0].Get(); + int64_t tensor_size = 3; + ASSERT_EQ(tensor_size, output_0.Shape().Size()); + const auto& output_0_data = output_0.Data(); + for (int64_t i = 0; i < tensor_size; i++) + ASSERT_TRUE(output_0_data[i] == "A"); + + VectorMapStringToFloat expected_output_1 = {{{"A", 0.572734f}, {"B", 0.427266f}}, + {{"A", 0.596016f}, {"B", 0.403984f}}, + {{"A", 0.656315f}, {"B", 0.343685f}}}; + const auto& actual_output_1 = fetches[1].Get(); + ASSERT_EQ(actual_output_1.size(), 3); + for (size_t i = 0; i < 3; i++) { + const auto& expected = expected_output_1[i]; + const auto& actual = actual_output_1[i]; + ASSERT_EQ(actual.size(), 2); + ASSERT_NEAR(expected.at("A"), actual.at("A"), 1e-6); + ASSERT_NEAR(expected.at("B"), actual.at("B"), 1e-6); + } + }; + + RunOrtModel(test_info); +} +#endif + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/sklearn_bin_voting_classifier_soft.onnx b/onnxruntime/test/testdata/sklearn_bin_voting_classifier_soft.onnx new file mode 100644 index 0000000000..3754f87bc2 Binary files /dev/null and b/onnxruntime/test/testdata/sklearn_bin_voting_classifier_soft.onnx differ diff --git a/onnxruntime/test/testdata/sklearn_bin_voting_classifier_soft.ort b/onnxruntime/test/testdata/sklearn_bin_voting_classifier_soft.ort new file mode 100644 index 0000000000..7eae709f43 Binary files /dev/null and b/onnxruntime/test/testdata/sklearn_bin_voting_classifier_soft.ort differ diff --git a/onnxruntime/test/testdata/sklearn_bin_voting_classifier_soft.readme.txt b/onnxruntime/test/testdata/sklearn_bin_voting_classifier_soft.readme.txt new file mode 100644 index 0000000000..e0a59dd137 --- /dev/null +++ b/onnxruntime/test/testdata/sklearn_bin_voting_classifier_soft.readme.txt @@ -0,0 +1 @@ +The model sklearn_bin_voting_classifier_soft.onnx is from the scikit-learn onnx converter tests