mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-09 00:30:53 +00:00
Add sequence and map support in ORT mobile file format, add UT (#5066)
* Init change * Update schema header * Address review comments * fix for DISABLE_ML_OPS build break * Fix build break Co-authored-by: gwang0000 <62914304+gwang0000@users.noreply.github.com>
This commit is contained in:
parent
de58720a97
commit
d922cb1081
7 changed files with 468 additions and 149 deletions
|
|
@ -68,6 +68,15 @@ table TensorTypeAndShape{
|
|||
shape:Shape;
|
||||
}
|
||||
|
||||
table MapType{
|
||||
key_type:TensorDataType;
|
||||
value_type:onnxruntime.experimental.fbs.TypeInfo;
|
||||
}
|
||||
|
||||
table SequenceType{
|
||||
elem_type:onnxruntime.experimental.fbs.TypeInfo;
|
||||
}
|
||||
|
||||
// Node
|
||||
enum NodeType : int32 {
|
||||
Primitive = 0,
|
||||
|
|
@ -115,6 +124,8 @@ table ValueInfo {
|
|||
// TODO add support of Sequence/Map/SparseTensor/Opaque
|
||||
union TypeInfoValue {
|
||||
tensor_type:TensorTypeAndShape,
|
||||
sequence_type:SequenceType,
|
||||
map_type:MapType,
|
||||
}
|
||||
|
||||
table TypeInfo {
|
||||
|
|
|
|||
|
|
@ -22,6 +22,12 @@ struct DimensionValueBuilder;
|
|||
struct TensorTypeAndShape;
|
||||
struct TensorTypeAndShapeBuilder;
|
||||
|
||||
struct MapType;
|
||||
struct MapTypeBuilder;
|
||||
|
||||
struct SequenceType;
|
||||
struct SequenceTypeBuilder;
|
||||
|
||||
struct EdgeEnd;
|
||||
|
||||
struct NodeEdge;
|
||||
|
|
@ -267,29 +273,35 @@ inline const char *EnumNameNodeType(NodeType e) {
|
|||
enum TypeInfoValue {
|
||||
TypeInfoValue_NONE = 0,
|
||||
TypeInfoValue_tensor_type = 1,
|
||||
TypeInfoValue_sequence_type = 2,
|
||||
TypeInfoValue_map_type = 3,
|
||||
TypeInfoValue_MIN = TypeInfoValue_NONE,
|
||||
TypeInfoValue_MAX = TypeInfoValue_tensor_type
|
||||
TypeInfoValue_MAX = TypeInfoValue_map_type
|
||||
};
|
||||
|
||||
inline const TypeInfoValue (&EnumValuesTypeInfoValue())[2] {
|
||||
inline const TypeInfoValue (&EnumValuesTypeInfoValue())[4] {
|
||||
static const TypeInfoValue values[] = {
|
||||
TypeInfoValue_NONE,
|
||||
TypeInfoValue_tensor_type
|
||||
TypeInfoValue_tensor_type,
|
||||
TypeInfoValue_sequence_type,
|
||||
TypeInfoValue_map_type
|
||||
};
|
||||
return values;
|
||||
}
|
||||
|
||||
inline const char * const *EnumNamesTypeInfoValue() {
|
||||
static const char * const names[3] = {
|
||||
static const char * const names[5] = {
|
||||
"NONE",
|
||||
"tensor_type",
|
||||
"sequence_type",
|
||||
"map_type",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
}
|
||||
|
||||
inline const char *EnumNameTypeInfoValue(TypeInfoValue e) {
|
||||
if (flatbuffers::IsOutRange(e, TypeInfoValue_NONE, TypeInfoValue_tensor_type)) return "";
|
||||
if (flatbuffers::IsOutRange(e, TypeInfoValue_NONE, TypeInfoValue_map_type)) return "";
|
||||
const size_t index = static_cast<size_t>(e);
|
||||
return EnumNamesTypeInfoValue()[index];
|
||||
}
|
||||
|
|
@ -302,6 +314,14 @@ template<> struct TypeInfoValueTraits<onnxruntime::experimental::fbs::TensorType
|
|||
static const TypeInfoValue enum_value = TypeInfoValue_tensor_type;
|
||||
};
|
||||
|
||||
template<> struct TypeInfoValueTraits<onnxruntime::experimental::fbs::SequenceType> {
|
||||
static const TypeInfoValue enum_value = TypeInfoValue_sequence_type;
|
||||
};
|
||||
|
||||
template<> struct TypeInfoValueTraits<onnxruntime::experimental::fbs::MapType> {
|
||||
static const TypeInfoValue enum_value = TypeInfoValue_map_type;
|
||||
};
|
||||
|
||||
bool VerifyTypeInfoValue(flatbuffers::Verifier &verifier, const void *obj, TypeInfoValue type);
|
||||
bool VerifyTypeInfoValueVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types);
|
||||
|
||||
|
|
@ -577,6 +597,100 @@ inline flatbuffers::Offset<TensorTypeAndShape> CreateTensorTypeAndShape(
|
|||
return builder_.Finish();
|
||||
}
|
||||
|
||||
struct MapType FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
typedef MapTypeBuilder Builder;
|
||||
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
||||
VT_KEY_TYPE = 4,
|
||||
VT_VALUE_TYPE = 6
|
||||
};
|
||||
onnxruntime::experimental::fbs::TensorDataType key_type() const {
|
||||
return static_cast<onnxruntime::experimental::fbs::TensorDataType>(GetField<int32_t>(VT_KEY_TYPE, 0));
|
||||
}
|
||||
const onnxruntime::experimental::fbs::TypeInfo *value_type() const {
|
||||
return GetPointer<const onnxruntime::experimental::fbs::TypeInfo *>(VT_VALUE_TYPE);
|
||||
}
|
||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) &&
|
||||
VerifyField<int32_t>(verifier, VT_KEY_TYPE) &&
|
||||
VerifyOffset(verifier, VT_VALUE_TYPE) &&
|
||||
verifier.VerifyTable(value_type()) &&
|
||||
verifier.EndTable();
|
||||
}
|
||||
};
|
||||
|
||||
struct MapTypeBuilder {
|
||||
typedef MapType Table;
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
void add_key_type(onnxruntime::experimental::fbs::TensorDataType key_type) {
|
||||
fbb_.AddElement<int32_t>(MapType::VT_KEY_TYPE, static_cast<int32_t>(key_type), 0);
|
||||
}
|
||||
void add_value_type(flatbuffers::Offset<onnxruntime::experimental::fbs::TypeInfo> value_type) {
|
||||
fbb_.AddOffset(MapType::VT_VALUE_TYPE, value_type);
|
||||
}
|
||||
explicit MapTypeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||
: fbb_(_fbb) {
|
||||
start_ = fbb_.StartTable();
|
||||
}
|
||||
flatbuffers::Offset<MapType> Finish() {
|
||||
const auto end = fbb_.EndTable(start_);
|
||||
auto o = flatbuffers::Offset<MapType>(end);
|
||||
return o;
|
||||
}
|
||||
};
|
||||
|
||||
inline flatbuffers::Offset<MapType> CreateMapType(
|
||||
flatbuffers::FlatBufferBuilder &_fbb,
|
||||
onnxruntime::experimental::fbs::TensorDataType key_type = onnxruntime::experimental::fbs::TensorDataType_UNDEFINED,
|
||||
flatbuffers::Offset<onnxruntime::experimental::fbs::TypeInfo> value_type = 0) {
|
||||
MapTypeBuilder builder_(_fbb);
|
||||
builder_.add_value_type(value_type);
|
||||
builder_.add_key_type(key_type);
|
||||
return builder_.Finish();
|
||||
}
|
||||
|
||||
struct SequenceType FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
typedef SequenceTypeBuilder Builder;
|
||||
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
||||
VT_ELEM_TYPE = 4
|
||||
};
|
||||
const onnxruntime::experimental::fbs::TypeInfo *elem_type() const {
|
||||
return GetPointer<const onnxruntime::experimental::fbs::TypeInfo *>(VT_ELEM_TYPE);
|
||||
}
|
||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) &&
|
||||
VerifyOffset(verifier, VT_ELEM_TYPE) &&
|
||||
verifier.VerifyTable(elem_type()) &&
|
||||
verifier.EndTable();
|
||||
}
|
||||
};
|
||||
|
||||
struct SequenceTypeBuilder {
|
||||
typedef SequenceType Table;
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
void add_elem_type(flatbuffers::Offset<onnxruntime::experimental::fbs::TypeInfo> elem_type) {
|
||||
fbb_.AddOffset(SequenceType::VT_ELEM_TYPE, elem_type);
|
||||
}
|
||||
explicit SequenceTypeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||
: fbb_(_fbb) {
|
||||
start_ = fbb_.StartTable();
|
||||
}
|
||||
flatbuffers::Offset<SequenceType> Finish() {
|
||||
const auto end = fbb_.EndTable(start_);
|
||||
auto o = flatbuffers::Offset<SequenceType>(end);
|
||||
return o;
|
||||
}
|
||||
};
|
||||
|
||||
inline flatbuffers::Offset<SequenceType> CreateSequenceType(
|
||||
flatbuffers::FlatBufferBuilder &_fbb,
|
||||
flatbuffers::Offset<onnxruntime::experimental::fbs::TypeInfo> elem_type = 0) {
|
||||
SequenceTypeBuilder builder_(_fbb);
|
||||
builder_.add_elem_type(elem_type);
|
||||
return builder_.Finish();
|
||||
}
|
||||
|
||||
struct NodeEdge FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
typedef NodeEdgeBuilder Builder;
|
||||
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
||||
|
|
@ -969,6 +1083,12 @@ struct TypeInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||
const onnxruntime::experimental::fbs::TensorTypeAndShape *value_as_tensor_type() const {
|
||||
return value_type() == onnxruntime::experimental::fbs::TypeInfoValue_tensor_type ? static_cast<const onnxruntime::experimental::fbs::TensorTypeAndShape *>(value()) : nullptr;
|
||||
}
|
||||
const onnxruntime::experimental::fbs::SequenceType *value_as_sequence_type() const {
|
||||
return value_type() == onnxruntime::experimental::fbs::TypeInfoValue_sequence_type ? static_cast<const onnxruntime::experimental::fbs::SequenceType *>(value()) : nullptr;
|
||||
}
|
||||
const onnxruntime::experimental::fbs::MapType *value_as_map_type() const {
|
||||
return value_type() == onnxruntime::experimental::fbs::TypeInfoValue_map_type ? static_cast<const onnxruntime::experimental::fbs::MapType *>(value()) : nullptr;
|
||||
}
|
||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) &&
|
||||
VerifyOffset(verifier, VT_DENOTATION) &&
|
||||
|
|
@ -984,6 +1104,14 @@ template<> inline const onnxruntime::experimental::fbs::TensorTypeAndShape *Type
|
|||
return value_as_tensor_type();
|
||||
}
|
||||
|
||||
template<> inline const onnxruntime::experimental::fbs::SequenceType *TypeInfo::value_as<onnxruntime::experimental::fbs::SequenceType>() const {
|
||||
return value_as_sequence_type();
|
||||
}
|
||||
|
||||
template<> inline const onnxruntime::experimental::fbs::MapType *TypeInfo::value_as<onnxruntime::experimental::fbs::MapType>() const {
|
||||
return value_as_map_type();
|
||||
}
|
||||
|
||||
struct TypeInfoBuilder {
|
||||
typedef TypeInfo Table;
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
|
|
@ -2016,6 +2144,14 @@ inline bool VerifyTypeInfoValue(flatbuffers::Verifier &verifier, const void *obj
|
|||
auto ptr = reinterpret_cast<const onnxruntime::experimental::fbs::TensorTypeAndShape *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
case TypeInfoValue_sequence_type: {
|
||||
auto ptr = reinterpret_cast<const onnxruntime::experimental::fbs::SequenceType *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
case TypeInfoValue_map_type: {
|
||||
auto ptr = reinterpret_cast<const onnxruntime::experimental::fbs::MapType *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
default: return true;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,11 @@ namespace experimental {
|
|||
namespace utils {
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
static flatbuffers::Offset<fbs::Dimension> GetTensorDimensionOrtFormat(
|
||||
static Status SaveTypeInfoOrtFormat(flatbuffers::FlatBufferBuilder& builder,
|
||||
const TypeProto& type_proto,
|
||||
flatbuffers::Offset<fbs::TypeInfo>& fbs_type_info) ORT_MUST_USE_RESULT;
|
||||
|
||||
static flatbuffers::Offset<fbs::Dimension> SaveTensorDimensionOrtFormat(
|
||||
flatbuffers::FlatBufferBuilder& builder,
|
||||
const TensorShapeProto_Dimension& tensor_shape_dim) {
|
||||
auto denotation = builder.CreateString(tensor_shape_dim.denotation());
|
||||
|
|
@ -29,26 +33,45 @@ static flatbuffers::Offset<fbs::Dimension> GetTensorDimensionOrtFormat(
|
|||
return fbs::CreateDimension(builder, dim_val, denotation);
|
||||
}
|
||||
|
||||
static Status GetTensorShapeOrtFormat(flatbuffers::FlatBufferBuilder& builder,
|
||||
const TensorShapeProto& tensor_shape_proto,
|
||||
flatbuffers::Offset<fbs::Shape>& fbs_shape) {
|
||||
static Status SaveTensorShapeOrtFormat(flatbuffers::FlatBufferBuilder& builder,
|
||||
const TensorShapeProto& tensor_shape_proto,
|
||||
flatbuffers::Offset<fbs::Shape>& fbs_shape) {
|
||||
std::vector<flatbuffers::Offset<fbs::Dimension>> dim;
|
||||
dim.reserve(tensor_shape_proto.dim_size());
|
||||
for (const auto& d : tensor_shape_proto.dim()) {
|
||||
auto fbs_d = GetTensorDimensionOrtFormat(builder, d);
|
||||
auto fbs_d = SaveTensorDimensionOrtFormat(builder, d);
|
||||
dim.push_back(fbs_d);
|
||||
}
|
||||
fbs_shape = fbs::CreateShapeDirect(builder, &dim);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static Status GetTensorTypeAndShapeOrtFormat(flatbuffers::FlatBufferBuilder& builder,
|
||||
const TypeProto_Tensor& tensor_type_proto,
|
||||
flatbuffers::Offset<fbs::TensorTypeAndShape>& fbs_tensor_type) {
|
||||
static Status SaveSequenceTypeOrtFormat(flatbuffers::FlatBufferBuilder& builder,
|
||||
const TypeProto_Sequence& sequence_type_proto,
|
||||
flatbuffers::Offset<fbs::SequenceType>& fbs_sequence_type) {
|
||||
flatbuffers::Offset<fbs::TypeInfo> fbs_type_info;
|
||||
ORT_RETURN_IF_ERROR(SaveTypeInfoOrtFormat(builder, sequence_type_proto.elem_type(), fbs_type_info));
|
||||
fbs_sequence_type = fbs::CreateSequenceType(builder, fbs_type_info);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static Status SaveMapTypeOrtFormat(flatbuffers::FlatBufferBuilder& builder,
|
||||
const TypeProto_Map& map_type_proto,
|
||||
flatbuffers::Offset<fbs::MapType>& fbs_map_type) {
|
||||
flatbuffers::Offset<fbs::TypeInfo> fbs_type_info;
|
||||
ORT_RETURN_IF_ERROR(SaveTypeInfoOrtFormat(builder, map_type_proto.value_type(), fbs_type_info));
|
||||
fbs_map_type = fbs::CreateMapType(
|
||||
builder, static_cast<fbs::TensorDataType>(map_type_proto.key_type()), fbs_type_info);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static Status SaveTensorTypeAndShapeOrtFormat(flatbuffers::FlatBufferBuilder& builder,
|
||||
const TypeProto_Tensor& tensor_type_proto,
|
||||
flatbuffers::Offset<fbs::TensorTypeAndShape>& fbs_tensor_type) {
|
||||
// A flatbuffers::Offset of 0 means this shape is missing (was null when serializing)
|
||||
flatbuffers::Offset<fbs::Shape> shape = 0;
|
||||
if (tensor_type_proto.has_shape()) {
|
||||
ORT_RETURN_IF_ERROR(GetTensorShapeOrtFormat(builder, tensor_type_proto.shape(), shape));
|
||||
ORT_RETURN_IF_ERROR(SaveTensorShapeOrtFormat(builder, tensor_type_proto.shape(), shape));
|
||||
}
|
||||
|
||||
fbs_tensor_type = fbs::CreateTensorTypeAndShape(
|
||||
|
|
@ -57,19 +80,37 @@ static Status GetTensorTypeAndShapeOrtFormat(flatbuffers::FlatBufferBuilder& bui
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
static Status GetTypeInfoOrtFormat(flatbuffers::FlatBufferBuilder& builder,
|
||||
const TypeProto& type_proto,
|
||||
flatbuffers::Offset<fbs::TypeInfo>& fbs_type_info) {
|
||||
static Status SaveTypeInfoOrtFormat(flatbuffers::FlatBufferBuilder& builder,
|
||||
const TypeProto& type_proto,
|
||||
flatbuffers::Offset<fbs::TypeInfo>& fbs_type_info) {
|
||||
auto denotation = builder.CreateString(type_proto.denotation());
|
||||
auto value_type = fbs::TypeInfoValue_tensor_type;
|
||||
flatbuffers::Offset<void> value;
|
||||
if (type_proto.has_tensor_type()) {
|
||||
flatbuffers::Offset<fbs::TensorTypeAndShape> tensor_type;
|
||||
ORT_RETURN_IF_ERROR(
|
||||
GetTensorTypeAndShapeOrtFormat(builder, type_proto.tensor_type(), tensor_type));
|
||||
value = tensor_type.Union();
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "We only support tensor type for now");
|
||||
auto value_case = type_proto.value_case();
|
||||
switch (value_case) {
|
||||
case TypeProto::kTensorType: {
|
||||
flatbuffers::Offset<fbs::TensorTypeAndShape> fbs_tensor_type;
|
||||
ORT_RETURN_IF_ERROR(
|
||||
SaveTensorTypeAndShapeOrtFormat(builder, type_proto.tensor_type(), fbs_tensor_type));
|
||||
value = fbs_tensor_type.Union();
|
||||
} break;
|
||||
case TypeProto::kSequenceType: {
|
||||
value_type = fbs::TypeInfoValue_sequence_type;
|
||||
flatbuffers::Offset<fbs::SequenceType> fbs_sequence_type;
|
||||
ORT_RETURN_IF_ERROR(
|
||||
SaveSequenceTypeOrtFormat(builder, type_proto.sequence_type(), fbs_sequence_type));
|
||||
value = fbs_sequence_type.Union();
|
||||
} break;
|
||||
case TypeProto::kMapType: {
|
||||
value_type = fbs::TypeInfoValue_map_type;
|
||||
flatbuffers::Offset<fbs::MapType> fbs_map_type;
|
||||
ORT_RETURN_IF_ERROR(
|
||||
SaveMapTypeOrtFormat(builder, type_proto.map_type(), fbs_map_type));
|
||||
value = fbs_map_type.Union();
|
||||
} break;
|
||||
default: {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "We do not support type [", value_case, "] for now");
|
||||
} break;
|
||||
}
|
||||
|
||||
fbs::TypeInfoBuilder tb(builder);
|
||||
|
|
@ -88,7 +129,7 @@ Status SaveValueInfoOrtFormat(flatbuffers::FlatBufferBuilder& builder,
|
|||
flatbuffers::Offset<fbs::TypeInfo> type_info = 0; // 0 indicates null
|
||||
if (value_info_proto.has_type()) {
|
||||
ORT_RETURN_IF_ERROR(
|
||||
GetTypeInfoOrtFormat(builder, value_info_proto.type(), type_info));
|
||||
SaveTypeInfoOrtFormat(builder, value_info_proto.type(), type_info));
|
||||
} else {
|
||||
// we have a NodeArg for missing optional values (empty name, no type) so allow for that.
|
||||
// everything else should have type info
|
||||
|
|
@ -126,7 +167,7 @@ Status SaveInitializerOrtFormat(flatbuffers::FlatBufferBuilder& builder,
|
|||
string_data = builder.CreateVectorOfStrings(string_data_vec);
|
||||
} else {
|
||||
std::unique_ptr<uint8_t[]> unpacked_tensor;
|
||||
size_t tensor_byte_size;
|
||||
size_t tensor_byte_size = 0;
|
||||
ORT_RETURN_IF_ERROR(
|
||||
onnxruntime::utils::UnpackInitializerData(initializer, unpacked_tensor, tensor_byte_size));
|
||||
raw_data = builder.CreateVector(unpacked_tensor.get(), tensor_byte_size);
|
||||
|
|
@ -228,6 +269,9 @@ Status SaveAttributeOrtFormat(flatbuffers::FlatBufferBuilder& builder,
|
|||
#undef GET_FBS_ATTR
|
||||
#undef GET_DATA_VEC
|
||||
|
||||
static Status LoadTypeInfoOrtFormat(const fbs::TypeInfo& fbs_type_info,
|
||||
TypeProto& type_proto) ORT_MUST_USE_RESULT;
|
||||
|
||||
Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor,
|
||||
TensorProto& initializer) {
|
||||
initializer.Clear();
|
||||
|
|
@ -301,6 +345,23 @@ static Status LoadTensorTypeAndShapeOrtFormat(const fbs::TensorTypeAndShape& fbs
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
static Status LoadSequenceTypeOrtFormat(const fbs::SequenceType& fbs_sequence_type,
|
||||
TypeProto_Sequence& sequence_type_proto) {
|
||||
auto fbs_type_info = fbs_sequence_type.elem_type();
|
||||
ORT_RETURN_IF(nullptr == fbs_type_info, "Null value type info in fbs::SequenceType. Invalid ORT format model.");
|
||||
ORT_RETURN_IF_ERROR(LoadTypeInfoOrtFormat(*fbs_type_info, *sequence_type_proto.mutable_elem_type()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static Status LoadMapTypeOrtFormat(const fbs::MapType& fbs_map_type,
|
||||
TypeProto_Map& map_type_proto) {
|
||||
map_type_proto.set_key_type(fbs_map_type.key_type());
|
||||
auto fbs_type_info = fbs_map_type.value_type();
|
||||
ORT_RETURN_IF(nullptr == fbs_type_info, "Null value type info in fbs::MapType. Invalid ORT format model.");
|
||||
ORT_RETURN_IF_ERROR(LoadTypeInfoOrtFormat(*fbs_type_info, *map_type_proto.mutable_value_type()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static Status LoadTypeInfoOrtFormat(const fbs::TypeInfo& fbs_type_info,
|
||||
TypeProto& type_proto) {
|
||||
LoadStringFromOrtFormat(*type_proto.mutable_denotation(), fbs_type_info.denotation());
|
||||
|
|
@ -309,8 +370,16 @@ static Status LoadTypeInfoOrtFormat(const fbs::TypeInfo& fbs_type_info,
|
|||
auto fbs_tensor_type = fbs_type_info.value_as_tensor_type();
|
||||
ORT_RETURN_IF(nullptr == fbs_tensor_type, "Null tensor type info. Invalid ORT format model.");
|
||||
ORT_RETURN_IF_ERROR(LoadTensorTypeAndShapeOrtFormat(*fbs_tensor_type, *type_proto.mutable_tensor_type()));
|
||||
} else if (value_type == fbs::TypeInfoValue_sequence_type) {
|
||||
auto fbs_sequence_type = fbs_type_info.value_as_sequence_type();
|
||||
ORT_RETURN_IF(nullptr == fbs_sequence_type, "Null sequence type info. Invalid ORT format model.");
|
||||
ORT_RETURN_IF_ERROR(LoadSequenceTypeOrtFormat(*fbs_sequence_type, *type_proto.mutable_sequence_type()));
|
||||
} else if (value_type == fbs::TypeInfoValue_map_type) {
|
||||
auto fbs_map_type = fbs_type_info.value_as_map_type();
|
||||
ORT_RETURN_IF(nullptr == fbs_map_type, "Null map type info. Invalid ORT format model.");
|
||||
ORT_RETURN_IF_ERROR(LoadMapTypeOrtFormat(*fbs_map_type, *type_proto.mutable_map_type()));
|
||||
} else {
|
||||
// TODO: This may be required for traditional ML models.
|
||||
// We do not support SparseTensor and Opaque for now
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Type:", value_type, " is not supported currently");
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#include "core/framework/data_types.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
#include "core/session/inference_session.h"
|
||||
|
|
@ -23,17 +24,41 @@ class InferenceSessionGetGraphWrapper : public InferenceSession {
|
|||
const Environment& env) : InferenceSession(session_options, env) {
|
||||
}
|
||||
|
||||
const Graph& GetGraph() {
|
||||
const Graph& GetGraph() const {
|
||||
return model_->MainGraph();
|
||||
}
|
||||
|
||||
const SessionState& GetSessionState() {
|
||||
const SessionState& GetSessionState() const {
|
||||
return InferenceSession::GetSessionState();
|
||||
}
|
||||
};
|
||||
|
||||
namespace test {
|
||||
|
||||
struct OrtModelTestInfo {
|
||||
std::basic_string<ORTCHAR_T> model_filename;
|
||||
std::string logid;
|
||||
NameMLValMap inputs;
|
||||
std::vector<std::string> output_names;
|
||||
std::function<void(const std::vector<OrtValue>&)> output_verifier;
|
||||
std::vector<std::pair<std::string, std::string>> configs;
|
||||
};
|
||||
|
||||
void RunOrtModel(const OrtModelTestInfo& test_info) {
|
||||
SessionOptions so;
|
||||
so.session_logid = test_info.logid;
|
||||
for (const auto& config : test_info.configs)
|
||||
so.AddConfigEntry(config.first.c_str(), config.second.c_str());
|
||||
|
||||
InferenceSessionGetGraphWrapper session_object{so, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(session_object.Load(test_info.model_filename)); // infer type from filename
|
||||
ASSERT_STATUS_OK(session_object.Initialize());
|
||||
|
||||
std::vector<OrtValue> fetches;
|
||||
ASSERT_STATUS_OK(session_object.Run(test_info.inputs, test_info.output_names, &fetches));
|
||||
test_info.output_verifier(fetches);
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
// Same Tensor from ONNX and ORT format will have different binary representation, need to compare value by value
|
||||
static void CompareTensors(const OrtValue& left_value, const OrtValue& right_value) {
|
||||
|
|
@ -56,89 +81,60 @@ static void CompareTensors(const OrtValue& left_value, const OrtValue& right_val
|
|||
}
|
||||
}
|
||||
|
||||
static void CompareTypeProtos(const TypeProto& left_type_proto, const TypeProto& right_type_proto) {
|
||||
ASSERT_EQ(left_type_proto.denotation(), right_type_proto.denotation());
|
||||
|
||||
ASSERT_EQ(left_type_proto.has_tensor_type(), right_type_proto.has_tensor_type());
|
||||
ASSERT_EQ(left_type_proto.has_sequence_type(), right_type_proto.has_sequence_type());
|
||||
ASSERT_EQ(left_type_proto.has_map_type(), right_type_proto.has_map_type());
|
||||
|
||||
if (left_type_proto.has_tensor_type()) {
|
||||
const auto& left_tensor_type = left_type_proto.tensor_type();
|
||||
const auto& right_tensor_type = right_type_proto.tensor_type();
|
||||
|
||||
ASSERT_EQ(left_tensor_type.elem_type(), right_tensor_type.elem_type());
|
||||
|
||||
const auto& left_shape = left_tensor_type.shape();
|
||||
const auto& right_shape = right_tensor_type.shape();
|
||||
|
||||
ASSERT_EQ(left_shape.dim_size(), right_shape.dim_size());
|
||||
for (int i = 0; i < left_shape.dim_size(); i++) {
|
||||
const auto& left_dim = left_shape.dim(i);
|
||||
const auto& right_dim = right_shape.dim(i);
|
||||
ASSERT_EQ(left_dim.has_dim_value(), right_dim.has_dim_value());
|
||||
ASSERT_EQ(left_dim.dim_value(), right_dim.dim_value());
|
||||
ASSERT_EQ(left_dim.has_dim_param(), right_dim.has_dim_param());
|
||||
ASSERT_EQ(left_dim.dim_param(), right_dim.dim_param());
|
||||
}
|
||||
} else if (left_type_proto.has_sequence_type()) {
|
||||
CompareTypeProtos(left_type_proto.sequence_type().elem_type(), right_type_proto.sequence_type().elem_type());
|
||||
} else if (left_type_proto.has_map_type()) {
|
||||
const auto& left_map = left_type_proto.map_type();
|
||||
const auto& right_map = right_type_proto.map_type();
|
||||
ASSERT_EQ(left_map.key_type(), right_map.key_type());
|
||||
CompareTypeProtos(left_map.value_type(), right_map.value_type());
|
||||
} else {
|
||||
FAIL(); // We do not support SparseTensor and Opaque for now
|
||||
}
|
||||
}
|
||||
|
||||
static void CompareValueInfos(const ValueInfoProto& left, const ValueInfoProto& right) {
|
||||
ASSERT_EQ(left.name(), right.name());
|
||||
ASSERT_EQ(left.doc_string(), right.doc_string());
|
||||
|
||||
std::string left_data;
|
||||
std::string right_data;
|
||||
|
||||
const auto& left_type_proto = left.type();
|
||||
const auto& right_type_proto = right.type();
|
||||
|
||||
ASSERT_EQ(left_type_proto.denotation(), right_type_proto.denotation());
|
||||
ASSERT_TRUE(left_type_proto.has_tensor_type());
|
||||
ASSERT_TRUE(right_type_proto.has_tensor_type());
|
||||
|
||||
const auto& left_tensor_type = left_type_proto.tensor_type();
|
||||
const auto& right_tensor_type = right_type_proto.tensor_type();
|
||||
|
||||
ASSERT_EQ(left_tensor_type.elem_type(), right_tensor_type.elem_type());
|
||||
|
||||
const auto& left_shape = left_tensor_type.shape();
|
||||
const auto& right_shape = right_tensor_type.shape();
|
||||
|
||||
ASSERT_EQ(left_shape.dim_size(), right_shape.dim_size());
|
||||
for (int i = 0; i < left_shape.dim_size(); i++) {
|
||||
const auto& left_dim = left_shape.dim(i);
|
||||
const auto& right_dim = right_shape.dim(i);
|
||||
ASSERT_EQ(left_dim.has_dim_value(), right_dim.has_dim_value());
|
||||
ASSERT_EQ(left_dim.dim_value(), right_dim.dim_value());
|
||||
ASSERT_EQ(left_dim.has_dim_param(), right_dim.has_dim_param());
|
||||
ASSERT_EQ(left_dim.dim_param(), right_dim.dim_param());
|
||||
}
|
||||
CompareTypeProtos(left.type(), right.type());
|
||||
}
|
||||
|
||||
TEST(OrtModelOnlyTests, SerializeToOrtFormat) {
|
||||
const auto output_file = ORT_TSTR("ort_github_issue_4031.onnx.ort");
|
||||
SessionOptions so;
|
||||
so.session_logid = "SerializeToOrtFormat";
|
||||
so.optimized_model_filepath = output_file;
|
||||
// not strictly necessary - type should be inferred from the filename
|
||||
so.AddConfigEntry(ORT_SESSION_OPTIONS_CONFIG_SAVE_MODEL_FORMAT, "ORT");
|
||||
void CompareGraphAndSessionState(const InferenceSessionGetGraphWrapper& session_object_1,
|
||||
const InferenceSessionGetGraphWrapper& session_object_2) {
|
||||
const auto& graph_1 = session_object_1.GetGraph();
|
||||
const auto& graph_2 = session_object_2.GetGraph();
|
||||
|
||||
InferenceSessionGetGraphWrapper session_object{so, GetEnvironment()};
|
||||
const auto& session_state_1 = session_object_1.GetSessionState();
|
||||
const auto& session_state_2 = session_object_2.GetSessionState();
|
||||
|
||||
// create .ort file during Initialize due to values in SessionOptions
|
||||
ASSERT_STATUS_OK(session_object.Load(ORT_TSTR("testdata/ort_github_issue_4031.onnx")));
|
||||
ASSERT_STATUS_OK(session_object.Initialize());
|
||||
|
||||
// create inputs
|
||||
OrtValue ml_value;
|
||||
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1}, {123.f},
|
||||
&ml_value);
|
||||
NameMLValMap feeds;
|
||||
feeds.insert(std::make_pair("state_var_in", ml_value));
|
||||
|
||||
// prepare outputs
|
||||
std::vector<std::string> output_names{"state_var_out"};
|
||||
std::vector<OrtValue> fetches;
|
||||
|
||||
ASSERT_STATUS_OK(session_object.Run(feeds, output_names, &fetches));
|
||||
|
||||
SessionOptions so2;
|
||||
so.session_logid = "LoadOrtFormat";
|
||||
// not strictly necessary - type should be inferred from the filename, but to be sure we're testing what we
|
||||
// think we're testing set it.
|
||||
so.AddConfigEntry(ORT_SESSION_OPTIONS_CONFIG_LOAD_MODEL_FORMAT, "ORT");
|
||||
|
||||
// load serialized version
|
||||
InferenceSessionGetGraphWrapper session_object2{so2, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(session_object2.Load(output_file));
|
||||
ASSERT_STATUS_OK(session_object2.Initialize());
|
||||
|
||||
// compare contents on Graph instances
|
||||
const auto& graph = session_object.GetGraph();
|
||||
const auto& graph2 = session_object2.GetGraph();
|
||||
|
||||
const auto& session_state = session_object.GetSessionState();
|
||||
const auto& session_state2 = session_object2.GetSessionState();
|
||||
|
||||
const auto& name_idx_map = session_state.GetOrtValueNameIdxMap();
|
||||
const auto& name_idx_map2 = session_state2.GetOrtValueNameIdxMap();
|
||||
|
||||
const auto& i1 = session_state.GetInitializedTensors();
|
||||
const auto& i2 = session_state2.GetInitializedTensors();
|
||||
const auto& i1 = session_state_1.GetInitializedTensors();
|
||||
const auto& i2 = session_state_2.GetInitializedTensors();
|
||||
ASSERT_EQ(i1.size(), i2.size());
|
||||
|
||||
for (const auto& pair : i1) {
|
||||
|
|
@ -148,24 +144,12 @@ TEST(OrtModelOnlyTests, SerializeToOrtFormat) {
|
|||
const OrtValue& left = pair.second;
|
||||
const OrtValue& right = iter->second;
|
||||
CompareTensors(left, right);
|
||||
|
||||
// check NodeArgs for both initializers. need to get name from map as we store the initialized tensors against
|
||||
// their OrtValueIdx in SessionState.
|
||||
std::string name;
|
||||
std::string name2;
|
||||
ASSERT_STATUS_OK(name_idx_map.GetName(pair.first, name));
|
||||
ASSERT_STATUS_OK(name_idx_map2.GetName(pair.first, name2));
|
||||
ASSERT_EQ(name, name2);
|
||||
|
||||
const auto& left_nodearg = *graph.GetNodeArg(name);
|
||||
const auto& right_nodearg = *graph2.GetNodeArg(name2);
|
||||
CompareValueInfos(left_nodearg.ToProto(), right_nodearg.ToProto());
|
||||
}
|
||||
|
||||
// check all node args are fine
|
||||
for (const auto& input : graph.GetInputs()) {
|
||||
const auto& left = *graph.GetNodeArg(input->Name());
|
||||
const auto* right = graph2.GetNodeArg(input->Name());
|
||||
for (const auto& input : graph_1.GetInputsIncludingInitializers()) {
|
||||
const auto& left = *graph_1.GetNodeArg(input->Name());
|
||||
const auto* right = graph_2.GetNodeArg(input->Name());
|
||||
ASSERT_TRUE(right != nullptr);
|
||||
|
||||
const auto& left_proto = left.ToProto();
|
||||
|
|
@ -173,8 +157,8 @@ TEST(OrtModelOnlyTests, SerializeToOrtFormat) {
|
|||
CompareValueInfos(left_proto, right_proto);
|
||||
}
|
||||
|
||||
for (const auto& left : graph.Nodes()) {
|
||||
const auto* right = graph2.GetNode(left.Index());
|
||||
for (const auto& left : graph_1.Nodes()) {
|
||||
const auto* right = graph_2.GetNode(left.Index());
|
||||
ASSERT_TRUE(right != nullptr);
|
||||
const auto& left_outputs = left.OutputDefs();
|
||||
const auto& right_outputs = right->OutputDefs();
|
||||
|
|
@ -192,47 +176,165 @@ TEST(OrtModelOnlyTests, SerializeToOrtFormat) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check results match
|
||||
std::vector<OrtValue> fetches2;
|
||||
ASSERT_STATUS_OK(session_object2.Run(feeds, output_names, &fetches2));
|
||||
|
||||
const auto& output = fetches[0].Get<Tensor>();
|
||||
ASSERT_TRUE(output.Shape().Size() == 1);
|
||||
ASSERT_TRUE(output.Data<float>()[0] == 125.f);
|
||||
|
||||
const auto& output2 = fetches2[0].Get<Tensor>();
|
||||
ASSERT_TRUE(output2.Shape().Size() == 1);
|
||||
ASSERT_TRUE(output2.Data<float>()[0] == 125.f);
|
||||
}
|
||||
#endif
|
||||
|
||||
// test that we can deserialize and run a previously saved ORT format model
|
||||
TEST(OrtModelOnlyTests, LoadOrtFormatModel) {
|
||||
const auto model_filename = ORT_TSTR("testdata/ort_github_issue_4031.onnx.ort");
|
||||
void SaveAndCompareModels(const std::string& onnx_file, const std::basic_string<ORTCHAR_T>& ort_file) {
|
||||
SessionOptions so;
|
||||
so.session_logid = "LoadOrtFormatModel";
|
||||
|
||||
so.session_logid = "SerializeToOrtFormat";
|
||||
so.optimized_model_filepath = ort_file;
|
||||
// not strictly necessary - type should be inferred from the filename
|
||||
so.AddConfigEntry(ORT_SESSION_OPTIONS_CONFIG_SAVE_MODEL_FORMAT, "ORT");
|
||||
InferenceSessionGetGraphWrapper session_object{so, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(session_object.Load(model_filename)); // infer type from filename
|
||||
|
||||
// create .ort file during Initialize due to values in SessionOptions
|
||||
ASSERT_STATUS_OK(session_object.Load(onnx_file));
|
||||
ASSERT_STATUS_OK(session_object.Initialize());
|
||||
|
||||
SessionOptions so2;
|
||||
so2.session_logid = "LoadOrtFormat";
|
||||
// not strictly necessary - type should be inferred from the filename, but to be sure we're testing what we
|
||||
// think we're testing set it.
|
||||
so2.AddConfigEntry(ORT_SESSION_OPTIONS_CONFIG_LOAD_MODEL_FORMAT, "ORT");
|
||||
|
||||
// load serialized version
|
||||
InferenceSessionGetGraphWrapper session_object2{so2, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(session_object2.Load(ort_file));
|
||||
ASSERT_STATUS_OK(session_object2.Initialize());
|
||||
|
||||
CompareGraphAndSessionState(session_object, session_object2);
|
||||
}
|
||||
|
||||
TEST(OrtModelOnlyTests, SerializeToOrtFormat) {
|
||||
const std::basic_string<ORTCHAR_T> ort_file = ORT_TSTR("ort_github_issue_4031.onnx.ort");
|
||||
SaveAndCompareModels("testdata/ort_github_issue_4031.onnx", ort_file);
|
||||
|
||||
OrtModelTestInfo test_info;
|
||||
test_info.model_filename = ort_file;
|
||||
test_info.logid = "SerializeToOrtFormat";
|
||||
test_info.configs.push_back(std::make_pair(ORT_SESSION_OPTIONS_CONFIG_LOAD_MODEL_FORMAT, "ORT"));
|
||||
|
||||
OrtValue ml_value;
|
||||
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1}, {123.f},
|
||||
&ml_value);
|
||||
NameMLValMap feeds;
|
||||
feeds.insert(std::make_pair("state_var_in", ml_value));
|
||||
test_info.inputs.insert(std::make_pair("state_var_in", ml_value));
|
||||
|
||||
// prepare outputs
|
||||
std::vector<std::string> output_names{"state_var_out"};
|
||||
std::vector<OrtValue> fetches;
|
||||
test_info.output_names = {"state_var_out"};
|
||||
test_info.output_verifier = [](const std::vector<OrtValue>& fetches) {
|
||||
const auto& output = fetches[0].Get<Tensor>();
|
||||
ASSERT_TRUE(output.Shape().Size() == 1);
|
||||
ASSERT_TRUE(output.Data<float>()[0] == 125.f);
|
||||
};
|
||||
|
||||
ASSERT_STATUS_OK(session_object.Run(feeds, output_names, &fetches));
|
||||
|
||||
const auto& output = fetches[0].Get<Tensor>();
|
||||
ASSERT_TRUE(output.Shape().Size() == 1);
|
||||
ASSERT_TRUE(output.Data<float>()[0] == 125.f);
|
||||
RunOrtModel(test_info);
|
||||
}
|
||||
|
||||
#if !defined(DISABLE_ML_OPS)
|
||||
TEST(OrtModelOnlyTests, SerializeToOrtFormatMLOps) {
|
||||
const std::basic_string<ORTCHAR_T> ort_file = ORT_TSTR("sklearn_bin_voting_classifier_soft_converted.ort");
|
||||
SaveAndCompareModels("testdata/sklearn_bin_voting_classifier_soft.onnx", ort_file);
|
||||
|
||||
OrtModelTestInfo test_info;
|
||||
test_info.model_filename = ort_file;
|
||||
test_info.logid = "SerializeToOrtFormatMLOps";
|
||||
test_info.configs.push_back(std::make_pair(ORT_SESSION_OPTIONS_CONFIG_LOAD_MODEL_FORMAT, "ORT"));
|
||||
|
||||
OrtValue ml_value;
|
||||
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {3, 2},
|
||||
{0.f, 1.f, 1.f, 1.f, 2.f, 0.f}, &ml_value);
|
||||
test_info.inputs.insert(std::make_pair("input", ml_value));
|
||||
|
||||
// prepare outputs
|
||||
test_info.output_names = {"output_label", "output_probability"};
|
||||
test_info.output_verifier = [](const std::vector<OrtValue>& fetches) {
|
||||
const auto& output_0 = fetches[0].Get<Tensor>();
|
||||
int64_t tensor_size = 3;
|
||||
ASSERT_EQ(tensor_size, output_0.Shape().Size());
|
||||
const auto& output_0_data = output_0.Data<std::string>();
|
||||
for (int64_t i = 0; i < tensor_size; i++)
|
||||
ASSERT_TRUE(output_0_data[i] == "A");
|
||||
|
||||
VectorMapStringToFloat expected_output_1 = {{{"A", 0.572734f}, {"B", 0.427266f}},
|
||||
{{"A", 0.596016f}, {"B", 0.403984f}},
|
||||
{{"A", 0.656315f}, {"B", 0.343685f}}};
|
||||
const auto& actual_output_1 = fetches[1].Get<VectorMapStringToFloat>();
|
||||
ASSERT_EQ(actual_output_1.size(), 3);
|
||||
for (size_t i = 0; i < 3; i++) {
|
||||
const auto& expected = expected_output_1[i];
|
||||
const auto& actual = actual_output_1[i];
|
||||
ASSERT_EQ(actual.size(), 2);
|
||||
ASSERT_NEAR(expected.at("A"), actual.at("A"), 1e-6);
|
||||
ASSERT_NEAR(expected.at("B"), actual.at("B"), 1e-6);
|
||||
}
|
||||
};
|
||||
|
||||
RunOrtModel(test_info);
|
||||
}
|
||||
#endif // #if !defined(DISABLE_ML_OPS)
|
||||
#endif // #if !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
// test that we can deserialize and run a previously saved ORT format model
|
||||
TEST(OrtModelOnlyTests, LoadOrtFormatModel) {
|
||||
OrtModelTestInfo test_info;
|
||||
test_info.model_filename = ORT_TSTR("testdata/ort_github_issue_4031.onnx.ort");
|
||||
test_info.logid = "LoadOrtFormatModel";
|
||||
|
||||
OrtValue ml_value;
|
||||
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1}, {123.f},
|
||||
&ml_value);
|
||||
test_info.inputs.insert(std::make_pair("state_var_in", ml_value));
|
||||
|
||||
// prepare outputs
|
||||
test_info.output_names = {"state_var_out"};
|
||||
test_info.output_verifier = [](const std::vector<OrtValue>& fetches) {
|
||||
const auto& output = fetches[0].Get<Tensor>();
|
||||
ASSERT_TRUE(output.Shape().Size() == 1);
|
||||
ASSERT_TRUE(output.Data<float>()[0] == 125.f);
|
||||
};
|
||||
|
||||
RunOrtModel(test_info);
|
||||
}
|
||||
|
||||
#if !defined(DISABLE_ML_OPS)
|
||||
// test that we can deserialize and run a previously saved ORT format model
|
||||
// for a model with sequence and map outputs
|
||||
TEST(OrtModelOnlyTests, LoadOrtFormatModelMLOps) {
|
||||
OrtModelTestInfo test_info;
|
||||
test_info.model_filename = ORT_TSTR("testdata/sklearn_bin_voting_classifier_soft.ort");
|
||||
test_info.logid = "LoadOrtFormatModelMLOps";
|
||||
|
||||
OrtValue ml_value;
|
||||
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {3, 2},
|
||||
{0.f, 1.f, 1.f, 1.f, 2.f, 0.f}, &ml_value);
|
||||
test_info.inputs.insert(std::make_pair("input", ml_value));
|
||||
|
||||
// prepare outputs
|
||||
test_info.output_names = {"output_label", "output_probability"};
|
||||
test_info.output_verifier = [](const std::vector<OrtValue>& fetches) {
|
||||
const auto& output_0 = fetches[0].Get<Tensor>();
|
||||
int64_t tensor_size = 3;
|
||||
ASSERT_EQ(tensor_size, output_0.Shape().Size());
|
||||
const auto& output_0_data = output_0.Data<std::string>();
|
||||
for (int64_t i = 0; i < tensor_size; i++)
|
||||
ASSERT_TRUE(output_0_data[i] == "A");
|
||||
|
||||
VectorMapStringToFloat expected_output_1 = {{{"A", 0.572734f}, {"B", 0.427266f}},
|
||||
{{"A", 0.596016f}, {"B", 0.403984f}},
|
||||
{{"A", 0.656315f}, {"B", 0.343685f}}};
|
||||
const auto& actual_output_1 = fetches[1].Get<VectorMapStringToFloat>();
|
||||
ASSERT_EQ(actual_output_1.size(), 3);
|
||||
for (size_t i = 0; i < 3; i++) {
|
||||
const auto& expected = expected_output_1[i];
|
||||
const auto& actual = actual_output_1[i];
|
||||
ASSERT_EQ(actual.size(), 2);
|
||||
ASSERT_NEAR(expected.at("A"), actual.at("A"), 1e-6);
|
||||
ASSERT_NEAR(expected.at("B"), actual.at("B"), 1e-6);
|
||||
}
|
||||
};
|
||||
|
||||
RunOrtModel(test_info);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/sklearn_bin_voting_classifier_soft.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/sklearn_bin_voting_classifier_soft.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/sklearn_bin_voting_classifier_soft.ort
vendored
Normal file
BIN
onnxruntime/test/testdata/sklearn_bin_voting_classifier_soft.ort
vendored
Normal file
Binary file not shown.
1
onnxruntime/test/testdata/sklearn_bin_voting_classifier_soft.readme.txt
vendored
Normal file
1
onnxruntime/test/testdata/sklearn_bin_voting_classifier_soft.readme.txt
vendored
Normal file
|
|
@ -0,0 +1 @@
|
|||
The model sklearn_bin_voting_classifier_soft.onnx is from the scikit-learn onnx converter tests
|
||||
Loading…
Reference in a new issue