Add sequence and map support in ORT mobile file format, add UT (#5066)

* Init change

* Update schema header

* Address review comments

* fix for DISABLE_ML_OPS build break

* Fix build break

Co-authored-by: gwang0000 <62914304+gwang0000@users.noreply.github.com>
This commit is contained in:
gwang-msft 2020-09-04 21:33:48 -07:00 committed by GitHub
parent de58720a97
commit d922cb1081
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 468 additions and 149 deletions

View file

@ -68,6 +68,15 @@ table TensorTypeAndShape{
shape:Shape;
}
table MapType{
key_type:TensorDataType;
value_type:onnxruntime.experimental.fbs.TypeInfo;
}
table SequenceType{
elem_type:onnxruntime.experimental.fbs.TypeInfo;
}
// Node
enum NodeType : int32 {
Primitive = 0,
@ -115,6 +124,8 @@ table ValueInfo {
// TODO add support of Sequence/Map/SparseTensor/Opaque
union TypeInfoValue {
tensor_type:TensorTypeAndShape,
sequence_type:SequenceType,
map_type:MapType,
}
table TypeInfo {

View file

@ -22,6 +22,12 @@ struct DimensionValueBuilder;
struct TensorTypeAndShape;
struct TensorTypeAndShapeBuilder;
struct MapType;
struct MapTypeBuilder;
struct SequenceType;
struct SequenceTypeBuilder;
struct EdgeEnd;
struct NodeEdge;
@ -267,29 +273,35 @@ inline const char *EnumNameNodeType(NodeType e) {
enum TypeInfoValue {
TypeInfoValue_NONE = 0,
TypeInfoValue_tensor_type = 1,
TypeInfoValue_sequence_type = 2,
TypeInfoValue_map_type = 3,
TypeInfoValue_MIN = TypeInfoValue_NONE,
TypeInfoValue_MAX = TypeInfoValue_tensor_type
TypeInfoValue_MAX = TypeInfoValue_map_type
};
inline const TypeInfoValue (&EnumValuesTypeInfoValue())[2] {
inline const TypeInfoValue (&EnumValuesTypeInfoValue())[4] {
static const TypeInfoValue values[] = {
TypeInfoValue_NONE,
TypeInfoValue_tensor_type
TypeInfoValue_tensor_type,
TypeInfoValue_sequence_type,
TypeInfoValue_map_type
};
return values;
}
inline const char * const *EnumNamesTypeInfoValue() {
static const char * const names[3] = {
static const char * const names[5] = {
"NONE",
"tensor_type",
"sequence_type",
"map_type",
nullptr
};
return names;
}
inline const char *EnumNameTypeInfoValue(TypeInfoValue e) {
if (flatbuffers::IsOutRange(e, TypeInfoValue_NONE, TypeInfoValue_tensor_type)) return "";
if (flatbuffers::IsOutRange(e, TypeInfoValue_NONE, TypeInfoValue_map_type)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesTypeInfoValue()[index];
}
@ -302,6 +314,14 @@ template<> struct TypeInfoValueTraits<onnxruntime::experimental::fbs::TensorType
static const TypeInfoValue enum_value = TypeInfoValue_tensor_type;
};
template<> struct TypeInfoValueTraits<onnxruntime::experimental::fbs::SequenceType> {
static const TypeInfoValue enum_value = TypeInfoValue_sequence_type;
};
template<> struct TypeInfoValueTraits<onnxruntime::experimental::fbs::MapType> {
static const TypeInfoValue enum_value = TypeInfoValue_map_type;
};
bool VerifyTypeInfoValue(flatbuffers::Verifier &verifier, const void *obj, TypeInfoValue type);
bool VerifyTypeInfoValueVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types);
@ -577,6 +597,100 @@ inline flatbuffers::Offset<TensorTypeAndShape> CreateTensorTypeAndShape(
return builder_.Finish();
}
struct MapType FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef MapTypeBuilder Builder;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
VT_KEY_TYPE = 4,
VT_VALUE_TYPE = 6
};
onnxruntime::experimental::fbs::TensorDataType key_type() const {
return static_cast<onnxruntime::experimental::fbs::TensorDataType>(GetField<int32_t>(VT_KEY_TYPE, 0));
}
const onnxruntime::experimental::fbs::TypeInfo *value_type() const {
return GetPointer<const onnxruntime::experimental::fbs::TypeInfo *>(VT_VALUE_TYPE);
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int32_t>(verifier, VT_KEY_TYPE) &&
VerifyOffset(verifier, VT_VALUE_TYPE) &&
verifier.VerifyTable(value_type()) &&
verifier.EndTable();
}
};
struct MapTypeBuilder {
typedef MapType Table;
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
void add_key_type(onnxruntime::experimental::fbs::TensorDataType key_type) {
fbb_.AddElement<int32_t>(MapType::VT_KEY_TYPE, static_cast<int32_t>(key_type), 0);
}
void add_value_type(flatbuffers::Offset<onnxruntime::experimental::fbs::TypeInfo> value_type) {
fbb_.AddOffset(MapType::VT_VALUE_TYPE, value_type);
}
explicit MapTypeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
flatbuffers::Offset<MapType> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = flatbuffers::Offset<MapType>(end);
return o;
}
};
inline flatbuffers::Offset<MapType> CreateMapType(
flatbuffers::FlatBufferBuilder &_fbb,
onnxruntime::experimental::fbs::TensorDataType key_type = onnxruntime::experimental::fbs::TensorDataType_UNDEFINED,
flatbuffers::Offset<onnxruntime::experimental::fbs::TypeInfo> value_type = 0) {
MapTypeBuilder builder_(_fbb);
builder_.add_value_type(value_type);
builder_.add_key_type(key_type);
return builder_.Finish();
}
struct SequenceType FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef SequenceTypeBuilder Builder;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
VT_ELEM_TYPE = 4
};
const onnxruntime::experimental::fbs::TypeInfo *elem_type() const {
return GetPointer<const onnxruntime::experimental::fbs::TypeInfo *>(VT_ELEM_TYPE);
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_ELEM_TYPE) &&
verifier.VerifyTable(elem_type()) &&
verifier.EndTable();
}
};
struct SequenceTypeBuilder {
typedef SequenceType Table;
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
void add_elem_type(flatbuffers::Offset<onnxruntime::experimental::fbs::TypeInfo> elem_type) {
fbb_.AddOffset(SequenceType::VT_ELEM_TYPE, elem_type);
}
explicit SequenceTypeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
flatbuffers::Offset<SequenceType> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = flatbuffers::Offset<SequenceType>(end);
return o;
}
};
inline flatbuffers::Offset<SequenceType> CreateSequenceType(
flatbuffers::FlatBufferBuilder &_fbb,
flatbuffers::Offset<onnxruntime::experimental::fbs::TypeInfo> elem_type = 0) {
SequenceTypeBuilder builder_(_fbb);
builder_.add_elem_type(elem_type);
return builder_.Finish();
}
struct NodeEdge FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef NodeEdgeBuilder Builder;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
@ -969,6 +1083,12 @@ struct TypeInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const onnxruntime::experimental::fbs::TensorTypeAndShape *value_as_tensor_type() const {
return value_type() == onnxruntime::experimental::fbs::TypeInfoValue_tensor_type ? static_cast<const onnxruntime::experimental::fbs::TensorTypeAndShape *>(value()) : nullptr;
}
const onnxruntime::experimental::fbs::SequenceType *value_as_sequence_type() const {
return value_type() == onnxruntime::experimental::fbs::TypeInfoValue_sequence_type ? static_cast<const onnxruntime::experimental::fbs::SequenceType *>(value()) : nullptr;
}
const onnxruntime::experimental::fbs::MapType *value_as_map_type() const {
return value_type() == onnxruntime::experimental::fbs::TypeInfoValue_map_type ? static_cast<const onnxruntime::experimental::fbs::MapType *>(value()) : nullptr;
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_DENOTATION) &&
@ -984,6 +1104,14 @@ template<> inline const onnxruntime::experimental::fbs::TensorTypeAndShape *Type
return value_as_tensor_type();
}
template<> inline const onnxruntime::experimental::fbs::SequenceType *TypeInfo::value_as<onnxruntime::experimental::fbs::SequenceType>() const {
return value_as_sequence_type();
}
template<> inline const onnxruntime::experimental::fbs::MapType *TypeInfo::value_as<onnxruntime::experimental::fbs::MapType>() const {
return value_as_map_type();
}
struct TypeInfoBuilder {
typedef TypeInfo Table;
flatbuffers::FlatBufferBuilder &fbb_;
@ -2016,6 +2144,14 @@ inline bool VerifyTypeInfoValue(flatbuffers::Verifier &verifier, const void *obj
auto ptr = reinterpret_cast<const onnxruntime::experimental::fbs::TensorTypeAndShape *>(obj);
return verifier.VerifyTable(ptr);
}
case TypeInfoValue_sequence_type: {
auto ptr = reinterpret_cast<const onnxruntime::experimental::fbs::SequenceType *>(obj);
return verifier.VerifyTable(ptr);
}
case TypeInfoValue_map_type: {
auto ptr = reinterpret_cast<const onnxruntime::experimental::fbs::MapType *>(obj);
return verifier.VerifyTable(ptr);
}
default: return true;
}
}

View file

@ -13,7 +13,11 @@ namespace experimental {
namespace utils {
#if !defined(ORT_MINIMAL_BUILD)
static flatbuffers::Offset<fbs::Dimension> GetTensorDimensionOrtFormat(
static Status SaveTypeInfoOrtFormat(flatbuffers::FlatBufferBuilder& builder,
const TypeProto& type_proto,
flatbuffers::Offset<fbs::TypeInfo>& fbs_type_info) ORT_MUST_USE_RESULT;
static flatbuffers::Offset<fbs::Dimension> SaveTensorDimensionOrtFormat(
flatbuffers::FlatBufferBuilder& builder,
const TensorShapeProto_Dimension& tensor_shape_dim) {
auto denotation = builder.CreateString(tensor_shape_dim.denotation());
@ -29,26 +33,45 @@ static flatbuffers::Offset<fbs::Dimension> GetTensorDimensionOrtFormat(
return fbs::CreateDimension(builder, dim_val, denotation);
}
static Status GetTensorShapeOrtFormat(flatbuffers::FlatBufferBuilder& builder,
const TensorShapeProto& tensor_shape_proto,
flatbuffers::Offset<fbs::Shape>& fbs_shape) {
static Status SaveTensorShapeOrtFormat(flatbuffers::FlatBufferBuilder& builder,
const TensorShapeProto& tensor_shape_proto,
flatbuffers::Offset<fbs::Shape>& fbs_shape) {
std::vector<flatbuffers::Offset<fbs::Dimension>> dim;
dim.reserve(tensor_shape_proto.dim_size());
for (const auto& d : tensor_shape_proto.dim()) {
auto fbs_d = GetTensorDimensionOrtFormat(builder, d);
auto fbs_d = SaveTensorDimensionOrtFormat(builder, d);
dim.push_back(fbs_d);
}
fbs_shape = fbs::CreateShapeDirect(builder, &dim);
return Status::OK();
}
static Status GetTensorTypeAndShapeOrtFormat(flatbuffers::FlatBufferBuilder& builder,
const TypeProto_Tensor& tensor_type_proto,
flatbuffers::Offset<fbs::TensorTypeAndShape>& fbs_tensor_type) {
static Status SaveSequenceTypeOrtFormat(flatbuffers::FlatBufferBuilder& builder,
const TypeProto_Sequence& sequence_type_proto,
flatbuffers::Offset<fbs::SequenceType>& fbs_sequence_type) {
flatbuffers::Offset<fbs::TypeInfo> fbs_type_info;
ORT_RETURN_IF_ERROR(SaveTypeInfoOrtFormat(builder, sequence_type_proto.elem_type(), fbs_type_info));
fbs_sequence_type = fbs::CreateSequenceType(builder, fbs_type_info);
return Status::OK();
}
static Status SaveMapTypeOrtFormat(flatbuffers::FlatBufferBuilder& builder,
const TypeProto_Map& map_type_proto,
flatbuffers::Offset<fbs::MapType>& fbs_map_type) {
flatbuffers::Offset<fbs::TypeInfo> fbs_type_info;
ORT_RETURN_IF_ERROR(SaveTypeInfoOrtFormat(builder, map_type_proto.value_type(), fbs_type_info));
fbs_map_type = fbs::CreateMapType(
builder, static_cast<fbs::TensorDataType>(map_type_proto.key_type()), fbs_type_info);
return Status::OK();
}
static Status SaveTensorTypeAndShapeOrtFormat(flatbuffers::FlatBufferBuilder& builder,
const TypeProto_Tensor& tensor_type_proto,
flatbuffers::Offset<fbs::TensorTypeAndShape>& fbs_tensor_type) {
// A flatbuffers::Offset of 0 means this shape is missing (was null when serializing)
flatbuffers::Offset<fbs::Shape> shape = 0;
if (tensor_type_proto.has_shape()) {
ORT_RETURN_IF_ERROR(GetTensorShapeOrtFormat(builder, tensor_type_proto.shape(), shape));
ORT_RETURN_IF_ERROR(SaveTensorShapeOrtFormat(builder, tensor_type_proto.shape(), shape));
}
fbs_tensor_type = fbs::CreateTensorTypeAndShape(
@ -57,19 +80,37 @@ static Status GetTensorTypeAndShapeOrtFormat(flatbuffers::FlatBufferBuilder& bui
return Status::OK();
}
static Status GetTypeInfoOrtFormat(flatbuffers::FlatBufferBuilder& builder,
const TypeProto& type_proto,
flatbuffers::Offset<fbs::TypeInfo>& fbs_type_info) {
static Status SaveTypeInfoOrtFormat(flatbuffers::FlatBufferBuilder& builder,
const TypeProto& type_proto,
flatbuffers::Offset<fbs::TypeInfo>& fbs_type_info) {
auto denotation = builder.CreateString(type_proto.denotation());
auto value_type = fbs::TypeInfoValue_tensor_type;
flatbuffers::Offset<void> value;
if (type_proto.has_tensor_type()) {
flatbuffers::Offset<fbs::TensorTypeAndShape> tensor_type;
ORT_RETURN_IF_ERROR(
GetTensorTypeAndShapeOrtFormat(builder, type_proto.tensor_type(), tensor_type));
value = tensor_type.Union();
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "We only support tensor type for now");
auto value_case = type_proto.value_case();
switch (value_case) {
case TypeProto::kTensorType: {
flatbuffers::Offset<fbs::TensorTypeAndShape> fbs_tensor_type;
ORT_RETURN_IF_ERROR(
SaveTensorTypeAndShapeOrtFormat(builder, type_proto.tensor_type(), fbs_tensor_type));
value = fbs_tensor_type.Union();
} break;
case TypeProto::kSequenceType: {
value_type = fbs::TypeInfoValue_sequence_type;
flatbuffers::Offset<fbs::SequenceType> fbs_sequence_type;
ORT_RETURN_IF_ERROR(
SaveSequenceTypeOrtFormat(builder, type_proto.sequence_type(), fbs_sequence_type));
value = fbs_sequence_type.Union();
} break;
case TypeProto::kMapType: {
value_type = fbs::TypeInfoValue_map_type;
flatbuffers::Offset<fbs::MapType> fbs_map_type;
ORT_RETURN_IF_ERROR(
SaveMapTypeOrtFormat(builder, type_proto.map_type(), fbs_map_type));
value = fbs_map_type.Union();
} break;
default: {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "We do not support type [", value_case, "] for now");
} break;
}
fbs::TypeInfoBuilder tb(builder);
@ -88,7 +129,7 @@ Status SaveValueInfoOrtFormat(flatbuffers::FlatBufferBuilder& builder,
flatbuffers::Offset<fbs::TypeInfo> type_info = 0; // 0 indicates null
if (value_info_proto.has_type()) {
ORT_RETURN_IF_ERROR(
GetTypeInfoOrtFormat(builder, value_info_proto.type(), type_info));
SaveTypeInfoOrtFormat(builder, value_info_proto.type(), type_info));
} else {
// we have a NodeArg for missing optional values (empty name, no type) so allow for that.
// everything else should have type info
@ -126,7 +167,7 @@ Status SaveInitializerOrtFormat(flatbuffers::FlatBufferBuilder& builder,
string_data = builder.CreateVectorOfStrings(string_data_vec);
} else {
std::unique_ptr<uint8_t[]> unpacked_tensor;
size_t tensor_byte_size;
size_t tensor_byte_size = 0;
ORT_RETURN_IF_ERROR(
onnxruntime::utils::UnpackInitializerData(initializer, unpacked_tensor, tensor_byte_size));
raw_data = builder.CreateVector(unpacked_tensor.get(), tensor_byte_size);
@ -228,6 +269,9 @@ Status SaveAttributeOrtFormat(flatbuffers::FlatBufferBuilder& builder,
#undef GET_FBS_ATTR
#undef GET_DATA_VEC
static Status LoadTypeInfoOrtFormat(const fbs::TypeInfo& fbs_type_info,
TypeProto& type_proto) ORT_MUST_USE_RESULT;
Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor,
TensorProto& initializer) {
initializer.Clear();
@ -301,6 +345,23 @@ static Status LoadTensorTypeAndShapeOrtFormat(const fbs::TensorTypeAndShape& fbs
return Status::OK();
}
static Status LoadSequenceTypeOrtFormat(const fbs::SequenceType& fbs_sequence_type,
TypeProto_Sequence& sequence_type_proto) {
auto fbs_type_info = fbs_sequence_type.elem_type();
ORT_RETURN_IF(nullptr == fbs_type_info, "Null value type info in fbs::SequenceType. Invalid ORT format model.");
ORT_RETURN_IF_ERROR(LoadTypeInfoOrtFormat(*fbs_type_info, *sequence_type_proto.mutable_elem_type()));
return Status::OK();
}
static Status LoadMapTypeOrtFormat(const fbs::MapType& fbs_map_type,
TypeProto_Map& map_type_proto) {
map_type_proto.set_key_type(fbs_map_type.key_type());
auto fbs_type_info = fbs_map_type.value_type();
ORT_RETURN_IF(nullptr == fbs_type_info, "Null value type info in fbs::MapType. Invalid ORT format model.");
ORT_RETURN_IF_ERROR(LoadTypeInfoOrtFormat(*fbs_type_info, *map_type_proto.mutable_value_type()));
return Status::OK();
}
static Status LoadTypeInfoOrtFormat(const fbs::TypeInfo& fbs_type_info,
TypeProto& type_proto) {
LoadStringFromOrtFormat(*type_proto.mutable_denotation(), fbs_type_info.denotation());
@ -309,8 +370,16 @@ static Status LoadTypeInfoOrtFormat(const fbs::TypeInfo& fbs_type_info,
auto fbs_tensor_type = fbs_type_info.value_as_tensor_type();
ORT_RETURN_IF(nullptr == fbs_tensor_type, "Null tensor type info. Invalid ORT format model.");
ORT_RETURN_IF_ERROR(LoadTensorTypeAndShapeOrtFormat(*fbs_tensor_type, *type_proto.mutable_tensor_type()));
} else if (value_type == fbs::TypeInfoValue_sequence_type) {
auto fbs_sequence_type = fbs_type_info.value_as_sequence_type();
ORT_RETURN_IF(nullptr == fbs_sequence_type, "Null sequence type info. Invalid ORT format model.");
ORT_RETURN_IF_ERROR(LoadSequenceTypeOrtFormat(*fbs_sequence_type, *type_proto.mutable_sequence_type()));
} else if (value_type == fbs::TypeInfoValue_map_type) {
auto fbs_map_type = fbs_type_info.value_as_map_type();
ORT_RETURN_IF(nullptr == fbs_map_type, "Null map type info. Invalid ORT format model.");
ORT_RETURN_IF_ERROR(LoadMapTypeOrtFormat(*fbs_map_type, *type_proto.mutable_map_type()));
} else {
// TODO: This may be required for traditional ML models.
// We do not support SparseTensor and Opaque for now
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Type:", value_type, " is not supported currently");
}

View file

@ -1,5 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/data_types.h"
#include "core/framework/tensorprotoutils.h"
#include "core/graph/onnx_protobuf.h"
#include "core/session/inference_session.h"
@ -23,17 +24,41 @@ class InferenceSessionGetGraphWrapper : public InferenceSession {
const Environment& env) : InferenceSession(session_options, env) {
}
const Graph& GetGraph() {
const Graph& GetGraph() const {
return model_->MainGraph();
}
const SessionState& GetSessionState() {
const SessionState& GetSessionState() const {
return InferenceSession::GetSessionState();
}
};
namespace test {
struct OrtModelTestInfo {
std::basic_string<ORTCHAR_T> model_filename;
std::string logid;
NameMLValMap inputs;
std::vector<std::string> output_names;
std::function<void(const std::vector<OrtValue>&)> output_verifier;
std::vector<std::pair<std::string, std::string>> configs;
};
void RunOrtModel(const OrtModelTestInfo& test_info) {
SessionOptions so;
so.session_logid = test_info.logid;
for (const auto& config : test_info.configs)
so.AddConfigEntry(config.first.c_str(), config.second.c_str());
InferenceSessionGetGraphWrapper session_object{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object.Load(test_info.model_filename)); // infer type from filename
ASSERT_STATUS_OK(session_object.Initialize());
std::vector<OrtValue> fetches;
ASSERT_STATUS_OK(session_object.Run(test_info.inputs, test_info.output_names, &fetches));
test_info.output_verifier(fetches);
}
#if !defined(ORT_MINIMAL_BUILD)
// Same Tensor from ONNX and ORT format will have different binary representation, need to compare value by value
static void CompareTensors(const OrtValue& left_value, const OrtValue& right_value) {
@ -56,89 +81,60 @@ static void CompareTensors(const OrtValue& left_value, const OrtValue& right_val
}
}
static void CompareTypeProtos(const TypeProto& left_type_proto, const TypeProto& right_type_proto) {
ASSERT_EQ(left_type_proto.denotation(), right_type_proto.denotation());
ASSERT_EQ(left_type_proto.has_tensor_type(), right_type_proto.has_tensor_type());
ASSERT_EQ(left_type_proto.has_sequence_type(), right_type_proto.has_sequence_type());
ASSERT_EQ(left_type_proto.has_map_type(), right_type_proto.has_map_type());
if (left_type_proto.has_tensor_type()) {
const auto& left_tensor_type = left_type_proto.tensor_type();
const auto& right_tensor_type = right_type_proto.tensor_type();
ASSERT_EQ(left_tensor_type.elem_type(), right_tensor_type.elem_type());
const auto& left_shape = left_tensor_type.shape();
const auto& right_shape = right_tensor_type.shape();
ASSERT_EQ(left_shape.dim_size(), right_shape.dim_size());
for (int i = 0; i < left_shape.dim_size(); i++) {
const auto& left_dim = left_shape.dim(i);
const auto& right_dim = right_shape.dim(i);
ASSERT_EQ(left_dim.has_dim_value(), right_dim.has_dim_value());
ASSERT_EQ(left_dim.dim_value(), right_dim.dim_value());
ASSERT_EQ(left_dim.has_dim_param(), right_dim.has_dim_param());
ASSERT_EQ(left_dim.dim_param(), right_dim.dim_param());
}
} else if (left_type_proto.has_sequence_type()) {
CompareTypeProtos(left_type_proto.sequence_type().elem_type(), right_type_proto.sequence_type().elem_type());
} else if (left_type_proto.has_map_type()) {
const auto& left_map = left_type_proto.map_type();
const auto& right_map = right_type_proto.map_type();
ASSERT_EQ(left_map.key_type(), right_map.key_type());
CompareTypeProtos(left_map.value_type(), right_map.value_type());
} else {
FAIL(); // We do not support SparseTensor and Opaque for now
}
}
static void CompareValueInfos(const ValueInfoProto& left, const ValueInfoProto& right) {
ASSERT_EQ(left.name(), right.name());
ASSERT_EQ(left.doc_string(), right.doc_string());
std::string left_data;
std::string right_data;
const auto& left_type_proto = left.type();
const auto& right_type_proto = right.type();
ASSERT_EQ(left_type_proto.denotation(), right_type_proto.denotation());
ASSERT_TRUE(left_type_proto.has_tensor_type());
ASSERT_TRUE(right_type_proto.has_tensor_type());
const auto& left_tensor_type = left_type_proto.tensor_type();
const auto& right_tensor_type = right_type_proto.tensor_type();
ASSERT_EQ(left_tensor_type.elem_type(), right_tensor_type.elem_type());
const auto& left_shape = left_tensor_type.shape();
const auto& right_shape = right_tensor_type.shape();
ASSERT_EQ(left_shape.dim_size(), right_shape.dim_size());
for (int i = 0; i < left_shape.dim_size(); i++) {
const auto& left_dim = left_shape.dim(i);
const auto& right_dim = right_shape.dim(i);
ASSERT_EQ(left_dim.has_dim_value(), right_dim.has_dim_value());
ASSERT_EQ(left_dim.dim_value(), right_dim.dim_value());
ASSERT_EQ(left_dim.has_dim_param(), right_dim.has_dim_param());
ASSERT_EQ(left_dim.dim_param(), right_dim.dim_param());
}
CompareTypeProtos(left.type(), right.type());
}
TEST(OrtModelOnlyTests, SerializeToOrtFormat) {
const auto output_file = ORT_TSTR("ort_github_issue_4031.onnx.ort");
SessionOptions so;
so.session_logid = "SerializeToOrtFormat";
so.optimized_model_filepath = output_file;
// not strictly necessary - type should be inferred from the filename
so.AddConfigEntry(ORT_SESSION_OPTIONS_CONFIG_SAVE_MODEL_FORMAT, "ORT");
void CompareGraphAndSessionState(const InferenceSessionGetGraphWrapper& session_object_1,
const InferenceSessionGetGraphWrapper& session_object_2) {
const auto& graph_1 = session_object_1.GetGraph();
const auto& graph_2 = session_object_2.GetGraph();
InferenceSessionGetGraphWrapper session_object{so, GetEnvironment()};
const auto& session_state_1 = session_object_1.GetSessionState();
const auto& session_state_2 = session_object_2.GetSessionState();
// create .ort file during Initialize due to values in SessionOptions
ASSERT_STATUS_OK(session_object.Load(ORT_TSTR("testdata/ort_github_issue_4031.onnx")));
ASSERT_STATUS_OK(session_object.Initialize());
// create inputs
OrtValue ml_value;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1}, {123.f},
&ml_value);
NameMLValMap feeds;
feeds.insert(std::make_pair("state_var_in", ml_value));
// prepare outputs
std::vector<std::string> output_names{"state_var_out"};
std::vector<OrtValue> fetches;
ASSERT_STATUS_OK(session_object.Run(feeds, output_names, &fetches));
SessionOptions so2;
so.session_logid = "LoadOrtFormat";
// not strictly necessary - type should be inferred from the filename, but to be sure we're testing what we
// think we're testing set it.
so.AddConfigEntry(ORT_SESSION_OPTIONS_CONFIG_LOAD_MODEL_FORMAT, "ORT");
// load serialized version
InferenceSessionGetGraphWrapper session_object2{so2, GetEnvironment()};
ASSERT_STATUS_OK(session_object2.Load(output_file));
ASSERT_STATUS_OK(session_object2.Initialize());
// compare contents on Graph instances
const auto& graph = session_object.GetGraph();
const auto& graph2 = session_object2.GetGraph();
const auto& session_state = session_object.GetSessionState();
const auto& session_state2 = session_object2.GetSessionState();
const auto& name_idx_map = session_state.GetOrtValueNameIdxMap();
const auto& name_idx_map2 = session_state2.GetOrtValueNameIdxMap();
const auto& i1 = session_state.GetInitializedTensors();
const auto& i2 = session_state2.GetInitializedTensors();
const auto& i1 = session_state_1.GetInitializedTensors();
const auto& i2 = session_state_2.GetInitializedTensors();
ASSERT_EQ(i1.size(), i2.size());
for (const auto& pair : i1) {
@ -148,24 +144,12 @@ TEST(OrtModelOnlyTests, SerializeToOrtFormat) {
const OrtValue& left = pair.second;
const OrtValue& right = iter->second;
CompareTensors(left, right);
// check NodeArgs for both initializers. need to get name from map as we store the initialized tensors against
// their OrtValueIdx in SessionState.
std::string name;
std::string name2;
ASSERT_STATUS_OK(name_idx_map.GetName(pair.first, name));
ASSERT_STATUS_OK(name_idx_map2.GetName(pair.first, name2));
ASSERT_EQ(name, name2);
const auto& left_nodearg = *graph.GetNodeArg(name);
const auto& right_nodearg = *graph2.GetNodeArg(name2);
CompareValueInfos(left_nodearg.ToProto(), right_nodearg.ToProto());
}
// check all node args are fine
for (const auto& input : graph.GetInputs()) {
const auto& left = *graph.GetNodeArg(input->Name());
const auto* right = graph2.GetNodeArg(input->Name());
for (const auto& input : graph_1.GetInputsIncludingInitializers()) {
const auto& left = *graph_1.GetNodeArg(input->Name());
const auto* right = graph_2.GetNodeArg(input->Name());
ASSERT_TRUE(right != nullptr);
const auto& left_proto = left.ToProto();
@ -173,8 +157,8 @@ TEST(OrtModelOnlyTests, SerializeToOrtFormat) {
CompareValueInfos(left_proto, right_proto);
}
for (const auto& left : graph.Nodes()) {
const auto* right = graph2.GetNode(left.Index());
for (const auto& left : graph_1.Nodes()) {
const auto* right = graph_2.GetNode(left.Index());
ASSERT_TRUE(right != nullptr);
const auto& left_outputs = left.OutputDefs();
const auto& right_outputs = right->OutputDefs();
@ -192,47 +176,165 @@ TEST(OrtModelOnlyTests, SerializeToOrtFormat) {
}
}
}
// check results match
std::vector<OrtValue> fetches2;
ASSERT_STATUS_OK(session_object2.Run(feeds, output_names, &fetches2));
const auto& output = fetches[0].Get<Tensor>();
ASSERT_TRUE(output.Shape().Size() == 1);
ASSERT_TRUE(output.Data<float>()[0] == 125.f);
const auto& output2 = fetches2[0].Get<Tensor>();
ASSERT_TRUE(output2.Shape().Size() == 1);
ASSERT_TRUE(output2.Data<float>()[0] == 125.f);
}
#endif
// test that we can deserialize and run a previously saved ORT format model
TEST(OrtModelOnlyTests, LoadOrtFormatModel) {
const auto model_filename = ORT_TSTR("testdata/ort_github_issue_4031.onnx.ort");
void SaveAndCompareModels(const std::string& onnx_file, const std::basic_string<ORTCHAR_T>& ort_file) {
SessionOptions so;
so.session_logid = "LoadOrtFormatModel";
so.session_logid = "SerializeToOrtFormat";
so.optimized_model_filepath = ort_file;
// not strictly necessary - type should be inferred from the filename
so.AddConfigEntry(ORT_SESSION_OPTIONS_CONFIG_SAVE_MODEL_FORMAT, "ORT");
InferenceSessionGetGraphWrapper session_object{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object.Load(model_filename)); // infer type from filename
// create .ort file during Initialize due to values in SessionOptions
ASSERT_STATUS_OK(session_object.Load(onnx_file));
ASSERT_STATUS_OK(session_object.Initialize());
SessionOptions so2;
so2.session_logid = "LoadOrtFormat";
// not strictly necessary - type should be inferred from the filename, but to be sure we're testing what we
// think we're testing set it.
so2.AddConfigEntry(ORT_SESSION_OPTIONS_CONFIG_LOAD_MODEL_FORMAT, "ORT");
// load serialized version
InferenceSessionGetGraphWrapper session_object2{so2, GetEnvironment()};
ASSERT_STATUS_OK(session_object2.Load(ort_file));
ASSERT_STATUS_OK(session_object2.Initialize());
CompareGraphAndSessionState(session_object, session_object2);
}
TEST(OrtModelOnlyTests, SerializeToOrtFormat) {
const std::basic_string<ORTCHAR_T> ort_file = ORT_TSTR("ort_github_issue_4031.onnx.ort");
SaveAndCompareModels("testdata/ort_github_issue_4031.onnx", ort_file);
OrtModelTestInfo test_info;
test_info.model_filename = ort_file;
test_info.logid = "SerializeToOrtFormat";
test_info.configs.push_back(std::make_pair(ORT_SESSION_OPTIONS_CONFIG_LOAD_MODEL_FORMAT, "ORT"));
OrtValue ml_value;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1}, {123.f},
&ml_value);
NameMLValMap feeds;
feeds.insert(std::make_pair("state_var_in", ml_value));
test_info.inputs.insert(std::make_pair("state_var_in", ml_value));
// prepare outputs
std::vector<std::string> output_names{"state_var_out"};
std::vector<OrtValue> fetches;
test_info.output_names = {"state_var_out"};
test_info.output_verifier = [](const std::vector<OrtValue>& fetches) {
const auto& output = fetches[0].Get<Tensor>();
ASSERT_TRUE(output.Shape().Size() == 1);
ASSERT_TRUE(output.Data<float>()[0] == 125.f);
};
ASSERT_STATUS_OK(session_object.Run(feeds, output_names, &fetches));
const auto& output = fetches[0].Get<Tensor>();
ASSERT_TRUE(output.Shape().Size() == 1);
ASSERT_TRUE(output.Data<float>()[0] == 125.f);
RunOrtModel(test_info);
}
#if !defined(DISABLE_ML_OPS)
TEST(OrtModelOnlyTests, SerializeToOrtFormatMLOps) {
const std::basic_string<ORTCHAR_T> ort_file = ORT_TSTR("sklearn_bin_voting_classifier_soft_converted.ort");
SaveAndCompareModels("testdata/sklearn_bin_voting_classifier_soft.onnx", ort_file);
OrtModelTestInfo test_info;
test_info.model_filename = ort_file;
test_info.logid = "SerializeToOrtFormatMLOps";
test_info.configs.push_back(std::make_pair(ORT_SESSION_OPTIONS_CONFIG_LOAD_MODEL_FORMAT, "ORT"));
OrtValue ml_value;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {3, 2},
{0.f, 1.f, 1.f, 1.f, 2.f, 0.f}, &ml_value);
test_info.inputs.insert(std::make_pair("input", ml_value));
// prepare outputs
test_info.output_names = {"output_label", "output_probability"};
test_info.output_verifier = [](const std::vector<OrtValue>& fetches) {
const auto& output_0 = fetches[0].Get<Tensor>();
int64_t tensor_size = 3;
ASSERT_EQ(tensor_size, output_0.Shape().Size());
const auto& output_0_data = output_0.Data<std::string>();
for (int64_t i = 0; i < tensor_size; i++)
ASSERT_TRUE(output_0_data[i] == "A");
VectorMapStringToFloat expected_output_1 = {{{"A", 0.572734f}, {"B", 0.427266f}},
{{"A", 0.596016f}, {"B", 0.403984f}},
{{"A", 0.656315f}, {"B", 0.343685f}}};
const auto& actual_output_1 = fetches[1].Get<VectorMapStringToFloat>();
ASSERT_EQ(actual_output_1.size(), 3);
for (size_t i = 0; i < 3; i++) {
const auto& expected = expected_output_1[i];
const auto& actual = actual_output_1[i];
ASSERT_EQ(actual.size(), 2);
ASSERT_NEAR(expected.at("A"), actual.at("A"), 1e-6);
ASSERT_NEAR(expected.at("B"), actual.at("B"), 1e-6);
}
};
RunOrtModel(test_info);
}
#endif // #if !defined(DISABLE_ML_OPS)
#endif // #if !defined(ORT_MINIMAL_BUILD)
// test that we can deserialize and run a previously saved ORT format model
TEST(OrtModelOnlyTests, LoadOrtFormatModel) {
OrtModelTestInfo test_info;
test_info.model_filename = ORT_TSTR("testdata/ort_github_issue_4031.onnx.ort");
test_info.logid = "LoadOrtFormatModel";
OrtValue ml_value;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1}, {123.f},
&ml_value);
test_info.inputs.insert(std::make_pair("state_var_in", ml_value));
// prepare outputs
test_info.output_names = {"state_var_out"};
test_info.output_verifier = [](const std::vector<OrtValue>& fetches) {
const auto& output = fetches[0].Get<Tensor>();
ASSERT_TRUE(output.Shape().Size() == 1);
ASSERT_TRUE(output.Data<float>()[0] == 125.f);
};
RunOrtModel(test_info);
}
#if !defined(DISABLE_ML_OPS)
// test that we can deserialize and run a previously saved ORT format model
// for a model with sequence and map outputs
TEST(OrtModelOnlyTests, LoadOrtFormatModelMLOps) {
OrtModelTestInfo test_info;
test_info.model_filename = ORT_TSTR("testdata/sklearn_bin_voting_classifier_soft.ort");
test_info.logid = "LoadOrtFormatModelMLOps";
OrtValue ml_value;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {3, 2},
{0.f, 1.f, 1.f, 1.f, 2.f, 0.f}, &ml_value);
test_info.inputs.insert(std::make_pair("input", ml_value));
// prepare outputs
test_info.output_names = {"output_label", "output_probability"};
test_info.output_verifier = [](const std::vector<OrtValue>& fetches) {
const auto& output_0 = fetches[0].Get<Tensor>();
int64_t tensor_size = 3;
ASSERT_EQ(tensor_size, output_0.Shape().Size());
const auto& output_0_data = output_0.Data<std::string>();
for (int64_t i = 0; i < tensor_size; i++)
ASSERT_TRUE(output_0_data[i] == "A");
VectorMapStringToFloat expected_output_1 = {{{"A", 0.572734f}, {"B", 0.427266f}},
{{"A", 0.596016f}, {"B", 0.403984f}},
{{"A", 0.656315f}, {"B", 0.343685f}}};
const auto& actual_output_1 = fetches[1].Get<VectorMapStringToFloat>();
ASSERT_EQ(actual_output_1.size(), 3);
for (size_t i = 0; i < 3; i++) {
const auto& expected = expected_output_1[i];
const auto& actual = actual_output_1[i];
ASSERT_EQ(actual.size(), 2);
ASSERT_NEAR(expected.at("A"), actual.at("A"), 1e-6);
ASSERT_NEAR(expected.at("B"), actual.at("B"), 1e-6);
}
};
RunOrtModel(test_info);
}
#endif
} // namespace test
} // namespace onnxruntime

Binary file not shown.

Binary file not shown.

View file

@ -0,0 +1 @@
The model sklearn_bin_voting_classifier_soft.onnx is from the scikit-learn onnx converter tests