Exclude the Map types from the build if ML ops are disabled. (#4908)

* Exclude the Map types from the build if ML ops are disabled. They're the only ops that use Map.
This commit is contained in:
Scott McKay 2020-08-27 17:48:12 +10:00 committed by GitHub
parent 792ed44537
commit 08eb15068c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 163 additions and 26 deletions

View file

@ -24,7 +24,8 @@ class TypeProto;
namespace onnxruntime {
/// Predefined registered types
//maps
#if !defined(DISABLE_ML_OPS)
//maps (only used by ML ops)
using MapStringToString = std::map<std::string, std::string>;
using MapStringToInt64 = std::map<std::string, int64_t>;
using MapStringToFloat = std::map<std::string, float>;
@ -33,10 +34,13 @@ using MapInt64ToString = std::map<int64_t, std::string>;
using MapInt64ToInt64 = std::map<int64_t, int64_t>;
using MapInt64ToFloat = std::map<int64_t, float>;
using MapInt64ToDouble = std::map<int64_t, double>;
#endif
//vectors/sequences
#if !defined(DISABLE_ML_OPS)
using VectorMapStringToFloat = std::vector<MapStringToFloat>;
using VectorMapInt64ToFloat = std::vector<MapInt64ToFloat>;
#endif
using VectorString = std::vector<std::string>;
using VectorInt64 = std::vector<int64_t>;
@ -422,6 +426,7 @@ struct GetMLDataType<T, false> {
}
};
#if !defined(DISABLE_ML_OPS)
/// MapTypes helper API
/// K should always be one of the primitive data types
/// V can be either a primitive type (in which case it is a tensor)
@ -445,6 +450,7 @@ struct SetMapTypes {
CopyMutableMapValue(*value_proto, proto);
}
};
#endif
/// Sequence helpers
///
@ -713,6 +719,7 @@ class NonTensorType : public NonTensorTypeBase {
NonTensorType() = default;
};
#if !defined(DISABLE_ML_OPS)
/**
* \brief MapType. Use this type to register
* mapping types.
@ -741,6 +748,7 @@ class MapType : public NonTensorType<CPPType> {
SetMapTypes<typename CPPType::key_type, typename CPPType::mapped_type>::Set(this->mutable_type_proto());
}
};
#endif
/**
* \brief SequenceType. Use to register sequence for non-tensor types.
@ -968,6 +976,7 @@ class PrimitiveDataType : public PrimitiveDataTypeBase {
return SparseTensorType<ELEM_TYPE>::Type(); \
}
#if !defined(DISABLE_ML_OPS)
#define ORT_REGISTER_MAP(TYPE) \
template <> \
MLDataType MapType<TYPE>::Type() { \
@ -978,6 +987,7 @@ class PrimitiveDataType : public PrimitiveDataTypeBase {
MLDataType DataTypeImpl::GetType<TYPE>() { \
return MapType<TYPE>::Type(); \
}
#endif
#define ORT_REGISTER_SEQ(TYPE) \
template <> \

View file

@ -60,9 +60,13 @@ struct TensorElementTypeSetter<T> {
static void SetSparseTensorElementType(ONNX_NAMESPACE::TypeProto& proto) {
proto.mutable_sparse_tensor_type()->set_elem_type(utils::ToTensorProtoElementType<T>());
}
#if !defined(DISABLE_ML_OPS)
static void SetMapKeyType(ONNX_NAMESPACE::TypeProto& proto) {
proto.mutable_map_type()->set_key_type(utils::ToTensorProtoElementType<T>());
}
#endif
constexpr static int32_t GetElementType() {
return utils::ToTensorProtoElementType<T>();
}
@ -98,10 +102,12 @@ template struct
template struct
TensorElementTypeSetter<BFloat16>;
#if !defined(DISABLE_ML_OPS)
void CopyMutableMapValue(const ONNX_NAMESPACE::TypeProto& value_proto,
ONNX_NAMESPACE::TypeProto& map_proto) {
map_proto.mutable_map_type()->mutable_value_type()->CopyFrom(value_proto);
}
#endif
void CopyMutableSeqElement(const ONNX_NAMESPACE::TypeProto& elem_proto,
ONNX_NAMESPACE::TypeProto& proto) {
@ -121,8 +127,10 @@ bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Tensor& tensor_proto,
bool IsCompatible(const ONNX_NAMESPACE::TypeProto_SparseTensor& tensor_proto,
const ONNX_NAMESPACE::TypeProto_SparseTensor& type_proto);
#if !defined(DISABLE_ML_OPS)
bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Map& map_proto,
const ONNX_NAMESPACE::TypeProto_Map& type_proto);
#endif
bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Sequence& sequence_proto,
const ONNX_NAMESPACE::TypeProto_Sequence& type_proto);
@ -139,6 +147,7 @@ bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Tensor& tensor_proto,
*/
}
#if !defined(DISABLE_ML_OPS)
bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Map& map_proto,
const ONNX_NAMESPACE::TypeProto_Map& type_proto) {
const auto& lhs = map_proto;
@ -171,6 +180,7 @@ bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Map& map_proto,
}
return result;
}
#endif
bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Sequence& sequence_proto,
const ONNX_NAMESPACE::TypeProto_Sequence& type_proto) {
@ -185,9 +195,11 @@ bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Sequence& sequence_proto,
case TypeProto::ValueCase::kSequenceType:
result = IsCompatible(lhs.elem_type().sequence_type(), rhs.elem_type().sequence_type());
break;
#if !defined(DISABLE_ML_OPS)
case TypeProto::ValueCase::kMapType:
result = IsCompatible(lhs.elem_type().map_type(), rhs.elem_type().map_type());
break;
#endif
case TypeProto::ValueCase::kOpaqueType:
result = IsCompatible(lhs.elem_type().opaque_type(), rhs.elem_type().opaque_type());
break;
@ -203,7 +215,8 @@ bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Sequence& sequence_proto,
}
return result;
}
bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Opaque& opaque_proto, const ONNX_NAMESPACE::TypeProto_Opaque& type_proto) {
bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Opaque& opaque_proto,
const ONNX_NAMESPACE::TypeProto_Opaque& type_proto) {
const auto& lhs = opaque_proto;
const auto& rhs = type_proto;
bool lhs_domain = utils::HasDomain(lhs);
@ -454,6 +467,7 @@ const ONNX_NAMESPACE::TypeProto* NonTensorTypeBase::GetTypeProto() const {
return impl_->GetProto();
}
#if !defined(DISABLE_ML_OPS)
bool NonTensorTypeBase::IsMapCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const {
const auto* thisProto = impl_->GetProto();
if (&type_proto == thisProto) {
@ -467,6 +481,7 @@ bool NonTensorTypeBase::IsMapCompatible(const ONNX_NAMESPACE::TypeProto& type_pr
ORT_ENFORCE(utils::HasKeyType(thisProto->map_type()));
return data_types_internal::IsCompatible(thisProto->map_type(), type_proto.map_type());
}
#endif
bool NonTensorTypeBase::IsSequenceCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const {
const auto* thisProto = impl_->GetProto();
@ -532,6 +547,7 @@ ORT_REGISTER_SPARSE_TENSOR_TYPE(uint64_t);
ORT_REGISTER_SPARSE_TENSOR_TYPE(MLFloat16);
ORT_REGISTER_SPARSE_TENSOR_TYPE(BFloat16);
#if !defined(DISABLE_ML_OPS)
ORT_REGISTER_MAP(MapStringToString);
ORT_REGISTER_MAP(MapStringToInt64);
ORT_REGISTER_MAP(MapStringToFloat);
@ -540,6 +556,7 @@ ORT_REGISTER_MAP(MapInt64ToString);
ORT_REGISTER_MAP(MapInt64ToInt64);
ORT_REGISTER_MAP(MapInt64ToFloat);
ORT_REGISTER_MAP(MapInt64ToDouble);
#endif
ORT_REGISTER_SEQ_TENSOR_TYPE(float);
ORT_REGISTER_SEQ_TENSOR_TYPE(double);
@ -556,8 +573,10 @@ ORT_REGISTER_SEQ_TENSOR_TYPE(std::string);
ORT_REGISTER_SEQ_TENSOR_TYPE(MLFloat16);
ORT_REGISTER_SEQ_TENSOR_TYPE(BFloat16);
#if !defined(DISABLE_ML_OPS)
ORT_REGISTER_SEQ(VectorMapStringToFloat);
ORT_REGISTER_SEQ(VectorMapInt64ToFloat);
#endif
// Used for Tensor Proto registrations
#define REGISTER_TENSOR_PROTO(TYPE, reg_fn) \
@ -617,6 +636,7 @@ void RegisterAllProtos(const std::function<void(MLDataType)>& reg_fn) {
REGISTER_SPARSE_TENSOR_PROTO(MLFloat16, reg_fn);
REGISTER_SPARSE_TENSOR_PROTO(BFloat16, reg_fn);
#if !defined(DISABLE_ML_OPS)
REGISTER_ONNX_PROTO(MapStringToString, reg_fn);
REGISTER_ONNX_PROTO(MapStringToInt64, reg_fn);
REGISTER_ONNX_PROTO(MapStringToFloat, reg_fn);
@ -625,6 +645,7 @@ void RegisterAllProtos(const std::function<void(MLDataType)>& reg_fn) {
REGISTER_ONNX_PROTO(MapInt64ToInt64, reg_fn);
REGISTER_ONNX_PROTO(MapInt64ToFloat, reg_fn);
REGISTER_ONNX_PROTO(MapInt64ToDouble, reg_fn);
#endif
REGISTER_SEQ_TENSOR_PROTO(int32_t, reg_fn);
REGISTER_SEQ_TENSOR_PROTO(float, reg_fn);
@ -641,8 +662,10 @@ void RegisterAllProtos(const std::function<void(MLDataType)>& reg_fn) {
REGISTER_SEQ_TENSOR_PROTO(MLFloat16, reg_fn);
REGISTER_SEQ_TENSOR_PROTO(BFloat16, reg_fn);
#if !defined(DISABLE_ML_OPS)
REGISTER_ONNX_PROTO(VectorMapStringToFloat, reg_fn);
REGISTER_ONNX_PROTO(VectorMapInt64ToFloat, reg_fn);
#endif
}
} // namespace data_types_internal
@ -982,18 +1005,18 @@ ContainerChecker::ContainerChecker(MLDataType ml_type) {
types_.emplace_back(ContainerType::kTensor, type_proto->tensor_type().elem_type());
type_proto = nullptr;
break;
case TypeProto::ValueCase::kMapType:
{
#if !defined(DISABLE_ML_OPS)
case TypeProto::ValueCase::kMapType: {
const auto& map_type = type_proto->map_type();
types_.emplace_back(ContainerType::kMap, map_type.key_type());
// Move on handling the value
type_proto = &map_type.value_type();
}
break;
} break;
#endif
case TypeProto::ValueCase::kSequenceType:
types_.emplace_back(ContainerType::kSequence, TensorProto_DataType_UNDEFINED);
type_proto = &type_proto->sequence_type().elem_type();
break;
types_.emplace_back(ContainerType::kSequence, TensorProto_DataType_UNDEFINED);
type_proto = &type_proto->sequence_type().elem_type();
break;
case TypeProto::ValueCase::kOpaqueType:
// We do not handle this and terminate here
types_.emplace_back(ContainerType::kOpaque,

View file

@ -1084,8 +1084,6 @@ ORT_API_STATUS_IMPL(OrtApis::AllocatorGetInfo, _In_ const OrtAllocator* ptr, _Ou
API_IMPL_END
}
static const int NUM_MAP_INDICES = 2;
template <typename T>
ORT_STATUS_PTR OrtGetNumSequenceElements(const OrtValue* p_ml_value, size_t* out) {
auto& data = p_ml_value->Get<T>();
@ -1100,13 +1098,21 @@ ORT_STATUS_PTR OrtGetNumSequenceElements<TensorSeq>(const OrtValue* p_ml_value,
return nullptr;
}
#if !defined(DISABLE_ML_OPS)
static const int NUM_MAP_INDICES = 2;
#endif
static ORT_STATUS_PTR OrtGetValueCountImpl(const OrtValue* value, size_t* out) {
ONNXType value_type;
if (auto status = OrtApis::GetValueType(value, &value_type))
return status;
if (value_type == ONNX_TYPE_MAP) {
#if !defined(DISABLE_ML_OPS)
*out = NUM_MAP_INDICES;
return nullptr;
#else
return OrtApis::CreateStatus(ORT_FAIL, "Map type is not supported in this build.");
#endif
}
if (value_type == ONNX_TYPE_SEQUENCE) {
auto v = reinterpret_cast<const OrtValue*>(value);
@ -1115,6 +1121,7 @@ static ORT_STATUS_PTR OrtGetValueCountImpl(const OrtValue* value, size_t* out) {
if (type->IsTensorSequenceType()) {
return OrtGetNumSequenceElements<TensorSeq>(v, out);
} else {
#if !defined(DISABLE_ML_OPS)
utils::ContainerChecker c_checker(type);
if (c_checker.IsSequenceOf<std::map<std::string, float>>()) {
return OrtGetNumSequenceElements<VectorMapStringToFloat>(v, out);
@ -1123,6 +1130,9 @@ static ORT_STATUS_PTR OrtGetValueCountImpl(const OrtValue* value, size_t* out) {
} else {
return OrtApis::CreateStatus(ORT_FAIL, "Input is not of one of the supported sequence types.");
}
#else
return OrtApis::CreateStatus(ORT_FAIL, "Map type is not supported in this build.");
#endif
}
} else {
return OrtApis::CreateStatus(ORT_FAIL, "Input is not of type sequence or map.");
@ -1135,6 +1145,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetValueCount, _In_ const OrtValue* value, _Out_ si
API_IMPL_END
}
#if !defined(DISABLE_ML_OPS)
///////////////////
// OrtGetValueImplSeqOfMap
template <typename T>
@ -1153,6 +1164,7 @@ static ORT_STATUS_PTR OrtGetValueImplSeqOfMap(const OrtValue* p_ml_value, int in
*out = value.release();
return nullptr;
}
#endif
ORT_STATUS_PTR PopulateTensorWithData(_Inout_ OrtValue* oval, _In_ const void* data_elem, size_t num_elems,
size_t elem_size) {
@ -1231,6 +1243,7 @@ static ORT_STATUS_PTR OrtGetValueImplSeq(_In_ const OrtValue* value, int index,
if (type->IsTensorSequenceType()) {
return OrtGetValueImplSeqOfTensors(p_ml_value, index, allocator, out);
} else {
#if !defined(DISABLE_ML_OPS)
utils::ContainerChecker c_checker(type);
if (c_checker.IsSequenceOf<std::map<std::string, float>>()) {
return OrtGetValueImplSeqOfMap<VectorMapStringToFloat>(p_ml_value, index, out);
@ -1239,9 +1252,13 @@ static ORT_STATUS_PTR OrtGetValueImplSeq(_In_ const OrtValue* value, int index,
} else {
return OrtApis::CreateStatus(ORT_FAIL, "Input is not of one of the supported sequence types.");
}
#else
return OrtApis::CreateStatus(ORT_FAIL, "Map type is not supported in this build.");
#endif
}
}
#if !defined(DISABLE_ML_OPS)
template <typename T>
static ORT_STATUS_PTR OrtGetValueImplMapHelper(_In_ const OrtValue* p_ml_value, int index,
_Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out) {
@ -1308,6 +1325,7 @@ static ORT_STATUS_PTR OrtGetValueImplMap(_In_ const OrtValue* value, int index,
}
return OrtApis::CreateStatus(ORT_FAIL, "Input is not of one of the supported map types.");
}
#endif
static ORT_STATUS_PTR OrtGetValueImpl(_In_ const OrtValue* value, int index, _Inout_ OrtAllocator* allocator,
_Outptr_ OrtValue** out) {
@ -1315,7 +1333,11 @@ static ORT_STATUS_PTR OrtGetValueImpl(_In_ const OrtValue* value, int index, _In
if (auto status = OrtApis::GetValueType(value, &value_type))
return status;
if (value_type == ONNX_TYPE_MAP) {
#if !defined(DISABLE_ML_OPS)
return OrtGetValueImplMap(value, index, allocator, out);
#else
return OrtApis::CreateStatus(ORT_FAIL, "Map type is not supported in this build.");
#endif
}
if (value_type == ONNX_TYPE_SEQUENCE) {
return OrtGetValueImplSeq(value, index, allocator, out);
@ -1333,6 +1355,8 @@ ORT_API_STATUS_IMPL(OrtApis::GetValue, _In_ const OrtValue* value, int index, _I
///////////////////
// OrtCreateValue
#if !defined(DISABLE_ML_OPS)
template <typename T>
static OrtStatus* OrtCreateValueImplSeqHelperMap(const OrtValue* const* in, size_t num_values,
_Outptr_ OrtValue** out) {
@ -1352,6 +1376,7 @@ static OrtStatus* OrtCreateValueImplSeqHelperMap(const OrtValue* const* in, size
*out = value.release();
return nullptr;
}
#endif
template <typename TensorElemType>
static OrtStatus* OrtCreateValueImplSeqHelperTensor(const Tensor& tensor,
@ -1461,6 +1486,7 @@ static ORT_STATUS_PTR OrtCreateValueImplSeq(_In_reads_(num_values) const OrtValu
if (first_value_type == ONNX_TYPE_TENSOR) {
return OrtCreateValueImplSeqHelper(in, num_values, out);
} else if (first_value_type == ONNX_TYPE_MAP) {
#if !defined(DISABLE_ML_OPS)
auto map_type = first_mlvalue->Type();
utils::ContainerChecker c_checker(map_type);
if (c_checker.IsMapOf<std::string, float>()) {
@ -1471,11 +1497,17 @@ static ORT_STATUS_PTR OrtCreateValueImplSeq(_In_reads_(num_values) const OrtValu
} else {
return OrtApis::CreateStatus(ORT_FAIL, "Input is not of one of the supported map types.");
}
#else
ORT_UNUSED_PARAMETER(first_mlvalue);
return OrtApis::CreateStatus(ORT_FAIL, "Map type is not supported in this build.");
#endif
} else {
return OrtApis::CreateStatus(ORT_FAIL, "Unsupported input type");
}
}
#if !defined(DISABLE_ML_OPS)
template <typename KeyType, typename ValueType>
static OrtStatus* OrtCreateMapMLValue(const Tensor& key_tensor, const Tensor& value_tensor, _Outptr_ OrtValue** out) {
using MapType = std::map<KeyType, ValueType>;
@ -1559,6 +1591,7 @@ static ORT_STATUS_PTR OrtCreateValueImplMap(const OrtValue* const* in, size_t nu
}
return OrtApis::CreateStatus(ORT_FAIL, "Key type is not supported yet.");
}
#endif
static ORT_STATUS_PTR OrtCreateValueImpl(_In_reads_(num_values) const OrtValue* const* in, size_t num_values,
enum ONNXType value_type, _Outptr_ OrtValue** out) {
@ -1566,7 +1599,11 @@ static ORT_STATUS_PTR OrtCreateValueImpl(_In_reads_(num_values) const OrtValue*
return OrtApis::CreateStatus(ORT_FAIL, "Number of values should be at least 1.");
}
if (value_type == ONNX_TYPE_MAP) {
#if !defined(DISABLE_ML_OPS)
return OrtCreateValueImplMap(in, num_values, out);
#else
return OrtApis::CreateStatus(ORT_FAIL, "Map type is not supported in this build.");
#endif
}
if (value_type == ONNX_TYPE_SEQUENCE) {
return OrtCreateValueImplSeq(in, num_values, out);

View file

@ -387,6 +387,7 @@ std::string _get_type_name(std::string&) {
return std::string("string");
}
#if !defined(DISABLE_ML_OPS)
template <typename KeyType, typename ValueType, typename KeyGetterType, typename ValueGetterType>
void CreateMapMLValue_LoopIntoMap(Py_ssize_t& pos, PyObject*& key, const std::string& name_input, PyObject*& value,
PyObject* item, std::map<KeyType, ValueType>& current,
@ -549,6 +550,7 @@ void CreateMapMLValue_AgnosticVectorMap(PyObject* iterator, PyObject* item, Allo
throw std::runtime_error("Size of dictionary is empty, unable to run the prediction.");
}
}
#endif
void CreateGenericIterableMLValue(PyObject* iterator, AllocatorPtr alloc, const std::string& name_input,
OrtValue* p_mlvalue) {
@ -573,7 +575,13 @@ void CreateGenericIterableMLValue(PyObject* iterator, AllocatorPtr alloc, const
throw std::runtime_error("Input must be a list of dictionaries or a single numpy array for input '" +
name_input + std::string("'."));
}
#if !defined(DISABLE_ML_OPS)
CreateMapMLValue_AgnosticVectorMap(iterator, item, alloc, name_input, p_mlvalue);
#else
ORT_UNUSED_PARAMETER(alloc);
ORT_UNUSED_PARAMETER(p_mlvalue);
throw std::runtime_error("Map type is not supported in this build.");
#endif
}
}
@ -610,7 +618,13 @@ void CreateGenericMLValue(const onnxruntime::InputDefList* input_def_list, const
auto* seq_tensors = reinterpret_cast<PyObject*>(value.ptr());
CreateSequenceOfTensors(alloc, name_input, input_def_list, seq_tensors, p_mlvalue);
} else if (PyDict_Check(value.ptr())) {
#if !defined(DISABLE_ML_OPS)
CreateMapMLValue_AgnosticVectorMap((PyObject*)NULL, value.ptr(), alloc, name_input, p_mlvalue);
#else
ORT_UNUSED_PARAMETER(p_mlvalue);
throw std::runtime_error("Map type is not supported in this build.");
#endif
} else {
auto iterator = PyObject_GetIter(value.ptr());
if (iterator == NULL) {

View file

@ -264,6 +264,7 @@ void AddNonTensorAsPyObj(const OrtValue& val, std::vector<py::object>& pyobjs, c
if (val_type->IsTensorSequenceType()) {
AddNonTensor<TensorSeq>(val, pyobjs, data_transfer_manager);
} else {
#if !defined(DISABLE_ML_OPS)
utils::ContainerChecker c_checker(val_type);
if (c_checker.IsMap()) {
if (c_checker.IsMapOf<std::string, std::string>()) {
@ -283,6 +284,7 @@ void AddNonTensorAsPyObj(const OrtValue& val, std::vector<py::object>& pyobjs, c
} else if (c_checker.IsMapOf<int64_t, double>()) {
AddNonTensor<MapInt64ToDouble>(val, pyobjs, data_transfer_manager);
}
} else {
if (c_checker.IsSequenceOf<std::map<std::string, float>>()) {
AddNonTensor<VectorMapStringToFloat>(val, pyobjs, data_transfer_manager);
@ -292,6 +294,9 @@ void AddNonTensorAsPyObj(const OrtValue& val, std::vector<py::object>& pyobjs, c
throw std::runtime_error("Output is a non-tensor type which is not supported.");
}
}
#else
throw std::runtime_error("Map type is not supported in this build.");
#endif
}
}

View file

@ -28,9 +28,11 @@ struct TestMap {
};
// Try recursive type registration and compatibility tests
using TestMapToMapInt64ToFloat = TestMap<int64_t, MapInt64ToFloat>;
using VectorInt64 = std::vector<int64_t>;
#if !defined(DISABLE_ML_OPS)
using TestMapToMapInt64ToFloat = TestMap<int64_t, MapInt64ToFloat>;
using TestMapStringToVectorInt64 = TestMap<std::string, VectorInt64>;
#endif
// Trial to see if we resolve the setter properly
// a map with a key that has not been registered in data_types.cc
@ -66,19 +68,23 @@ struct TestOpaqueNoNames {};
// use the same cpp runtime types but due to Opaque type domain, name
// and optional parameters we produce separate MLDataTypes that are NOT
// compatible with each other.
#if !defined(DISABLE_ML_OPS)
using MyOpaqueMapCpp_1 = std::map<int64_t, TestOpaqueType_1>;
using MyOpaqueMapCpp_2 = std::map<int64_t, TestOpaqueType_2>;
#endif
// Register Sequence as containing an Opaque type
using MyOpaqueSeqCpp_1 = std::vector<TestOpaqueType_1>;
using MyOpaqueSeqCpp_2 = std::vector<TestOpaqueType_2>;
#if !defined(DISABLE_ML_OPS)
ORT_REGISTER_MAP(MyOpaqueMapCpp_1);
ORT_REGISTER_MAP(MyOpaqueMapCpp_2);
ORT_REGISTER_MAP(TestMapToMapInt64ToFloat);
ORT_REGISTER_MAP(TestMapStringToVectorInt64);
ORT_REGISTER_MAP(TestMapMLFloat16ToFloat);
#endif
ORT_REGISTER_SEQ(MyOpaqueSeqCpp_1);
ORT_REGISTER_SEQ(MyOpaqueSeqCpp_2);
@ -100,12 +106,14 @@ ORT_REGISTER_OPAQUE_TYPE(TestOpaqueNoNames, TestOpaqueEmpty, TestOpaqueEmpty);
}
void RegisterTestTypes() {
#if !defined(DISABLE_ML_OPS)
REGISTER_ONNX_PROTO(MyOpaqueMapCpp_1);
REGISTER_ONNX_PROTO(MyOpaqueMapCpp_2);
REGISTER_ONNX_PROTO(TestMapToMapInt64ToFloat);
REGISTER_ONNX_PROTO(TestMapStringToVectorInt64);
REGISTER_ONNX_PROTO(TestMapMLFloat16ToFloat);
#endif
REGISTER_ONNX_PROTO(MyOpaqueSeqCpp_1);
REGISTER_ONNX_PROTO(MyOpaqueSeqCpp_2);
@ -236,12 +244,15 @@ TEST_F(DataTypeTest, OpaqueRegistrationTest) {
EXPECT_FALSE(utils::IsOpaqueType(op_ml2, TestOpaqueDomain_1, TestOpaqueName_2));
EXPECT_FALSE(utils::IsOpaqueType(DataTypeImpl::GetTensorType<float>(), TestOpaqueDomain_1, TestOpaqueName_1));
#if !defined(DISABLE_ML_OPS)
utils::ContainerChecker c_checker(DataTypeImpl::GetType<MyOpaqueMapCpp_1>());
EXPECT_TRUE(c_checker.IsMap());
bool result = c_checker.IsMapOf<int64_t, TestOpaqueType_1>();
EXPECT_TRUE(result);
#endif
}
#if !defined(DISABLE_ML_OPS)
TEST_F(DataTypeTest, MapStringStringTest) {
TensorTypeProto<TensorProto_DataType_FLOAT> tensor_type;
auto ml_str_str = DataTypeImpl::GetType<MapStringToString>();
@ -251,14 +262,13 @@ TEST_F(DataTypeTest, MapStringStringTest) {
utils::ContainerChecker c_checker(ml_str_str);
bool result = c_checker.IsMapOf<std::string, std::string>();
EXPECT_TRUE(result);
result = c_checker.IsMapOf<std::string, int64_t>();
result = c_checker.IsMapOf<std::string, int64_t>();
EXPECT_FALSE(result);
utils::ContainerChecker c_checker1(DataTypeImpl::GetTensorType<float>());
result = c_checker1.IsMapOf<std::string, int64_t>();
result = c_checker1.IsMapOf<std::string, int64_t>();
EXPECT_FALSE(result);
MapTypeProto<TensorProto_DataType_STRING, TensorProto_DataType_STRING> maps2s_type;
MapTypeProto<TensorProto_DataType_STRING, TensorProto_DataType_INT64> maps2i_type;
EXPECT_TRUE(ml_str_str->IsCompatible(maps2s_type.proto));
@ -349,6 +359,7 @@ TEST_F(DataTypeTest, RecursiveMapTest) {
mut_map->mutable_value_type()->CopyFrom(*op2_proto->GetTypeProto());
EXPECT_TRUE(DataTypeImpl::GetType<MyOpaqueMapCpp_2>()->IsCompatible(unod_map_int64_to_op2));
}
#endif // !defined(DISABLE_ML_OPS)
TEST_F(DataTypeTest, RecursiveVectorTest) {
TypeProto seq_of_seq_string;
@ -357,9 +368,12 @@ TEST_F(DataTypeTest, RecursiveVectorTest) {
mut_seq->mutable_elem_type()->mutable_tensor_type()->set_elem_type(TensorProto_DataType_STRING);
EXPECT_TRUE(DataTypeImpl::GetType<TestSequenceOfSequence>()->IsCompatible(seq_of_seq_string));
#if !defined(DISABLE_ML_OPS)
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapStringToFloat>()->IsCompatible(seq_of_seq_string));
#endif
}
#if !defined(DISABLE_ML_OPS)
TEST_F(DataTypeTest, VectorMapStringToFloatTest) {
TypeProto vector_map_string_to_float;
vector_map_string_to_float.mutable_sequence_type()->mutable_elem_type()->mutable_map_type()->set_key_type(TensorProto_DataType_STRING);
@ -374,7 +388,7 @@ TEST_F(DataTypeTest, VectorMapStringToFloatTest) {
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapStringToFloat>()->IsCompatible(mapi2i_type.proto));
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapStringToFloat>()->IsCompatible(tensor_type.proto));
utils::ContainerChecker c_check(DataTypeImpl::GetType<VectorMapStringToFloat>());
bool result = c_check.IsSequenceOf<MapStringToFloat>();
bool result = c_check.IsSequenceOf<MapStringToFloat>();
EXPECT_TRUE(result);
}
@ -392,6 +406,7 @@ TEST_F(DataTypeTest, VectorMapInt64ToFloatTest) {
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapInt64ToFloat>()->IsCompatible(mapi2i_type.proto));
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapInt64ToFloat>()->IsCompatible(tensor_type.proto));
}
#endif // !defined(DISABLE_ML_OPS)
TEST_F(DataTypeTest, BFloat16Test) {
// Test data type
@ -508,6 +523,8 @@ TEST_F(DataTypeTest, DataUtilsTest) {
// Expect internalized strings
EXPECT_EQ(ten_dt, ten_from_str);
}
#if !defined(DISABLE_ML_OPS)
// Test Simple map
{
const std::string map_string_string("map(string,tensor(string))");
@ -522,6 +539,7 @@ TEST_F(DataTypeTest, DataUtilsTest) {
const auto& from_dt_proto = DataTypeUtils::ToTypeProto(map_dt);
EXPECT_TRUE(DataTypeImpl::GetType<MapStringToString>()->IsCompatible(from_dt_proto));
}
// Test map with recursive value
{
const std::string map_int_map_int_float("map(int64,map(int64,tensor(float)))");
@ -536,6 +554,7 @@ TEST_F(DataTypeTest, DataUtilsTest) {
const auto& from_dt_proto = DataTypeUtils::ToTypeProto(map_dt);
EXPECT_TRUE(DataTypeImpl::GetType<TestMapToMapInt64ToFloat>()->IsCompatible(from_dt_proto));
}
{
const std::string opaque_map_2("map(int64,opaque(test_domain_2,test_name_2))");
const auto* map_proto = DataTypeImpl::GetType<MyOpaqueMapCpp_2>()->GetTypeProto();
@ -549,6 +568,7 @@ TEST_F(DataTypeTest, DataUtilsTest) {
const auto& from_dt_proto = DataTypeUtils::ToTypeProto(map_dt);
EXPECT_TRUE(DataTypeImpl::GetType<MyOpaqueMapCpp_2>()->IsCompatible(from_dt_proto));
}
// Test Sequence with recursion
{
const std::string seq_map_str_float("seq(map(string,tensor(float)))");
@ -563,6 +583,8 @@ TEST_F(DataTypeTest, DataUtilsTest) {
const auto& from_dt_proto = DataTypeUtils::ToTypeProto(seq_dt);
EXPECT_TRUE(DataTypeImpl::GetType<VectorMapStringToFloat>()->IsCompatible(from_dt_proto));
}
#endif
// Test Sequence with opaque_2
{
const std::string seq_opaque_2("seq(opaque(test_domain_2,test_name_2))");

View file

@ -131,4 +131,4 @@ TEST(SchemaRegistryManager, OpsetRegTest) {
// TODO - Consider making the registration algorithm robust to this invalid usage in general
ASSERT_TRUE(manager.GetSchema("Op5", 5, "Domain1")->since_version() == 1);
ASSERT_TRUE(manager.GetSchema("Op5", 1, "Domain1")->since_version() == 1);
}
}

View file

@ -357,7 +357,11 @@ void CheckDispatch(MLDataType type, const OpTester::Data& expected_data,
void Check(const OpTester::Data& expected_data, OrtValue& ort_value,
const std::string& provider_type) {
CheckDispatch<VectorMapStringToFloat, VectorMapInt64ToFloat, TensorSeq>(
CheckDispatch<
#if !defined(DISABLE_ML_OPS)
VectorMapStringToFloat, VectorMapInt64ToFloat,
#endif
TensorSeq>(
expected_data.data_.Type(), expected_data, ort_value, provider_type);
}

View file

@ -1,13 +1,16 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <core/common/make_unique.h>
#include "core/session/onnxruntime_cxx_api.h"
#include <functional>
#include <set>
#include "test_allocator.h"
#include <gtest/gtest.h>
#include <iostream>
#include <set>
#include "core/common/make_unique.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "test_allocator.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
template <typename T>
struct RelAllocations {
@ -25,6 +28,7 @@ struct RelAllocations {
}
};
#if !defined(DISABLE_ML_OPS)
TEST(CApiTest, CreateGetVectorOfMapsInt64Float) { // support zipmap output type seq(map(int64, float))
// Creation
auto default_allocator = onnxruntime::make_unique<MockedOrtAllocator>();
@ -149,6 +153,7 @@ TEST(CApiTest, CreateGetVectorOfMapsStringFloat) { // support zipmap output typ
std::set<float>(std::begin(values), std::end(values)));
}
}
#endif // !defined(DISABLE_ML_OPS)
TEST(CApiTest, TypeInfoMap) {
// Creation
@ -166,6 +171,7 @@ TEST(CApiTest, TypeInfoMap) {
Ort::Value values_tensor = Ort::Value::CreateTensor(info, values.data(), values.size() * sizeof(float),
dims.data(), dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
#if !defined(DISABLE_ML_OPS)
Ort::Value map_ort = Ort::Value::CreateMap(keys_tensor, values_tensor);
Ort::TypeInfo type_info = map_ort.GetTypeInfo();
Ort::MapTypeInfo map_type_info = type_info.GetMapTypeInfo();
@ -183,6 +189,16 @@ TEST(CApiTest, TypeInfoMap) {
map_value_type_info.release();
map_type_info.release();
#else
// until https://github.com/google/googletest/pull/2904/ makes it into a release,
// check an exception is thrown with the expected message the ugly way
try {
Ort::Value map_ort = Ort::Value::CreateMap(keys_tensor, values_tensor);
ASSERT_TRUE(false) << "CreateMap should have throw in this build";
} catch (const Ort::Exception& ex) {
ASSERT_THAT(ex.what(), testing::HasSubstr("Map type is not supported in this build"));
}
#endif
}
TEST(CApiTest, CreateGetSeqTensors) {
@ -223,7 +239,8 @@ TEST(CApiTest, CreateGetSeqStringTensors) {
for (int i = 0; i < N; ++i) {
// create tensor
std::vector<int64_t> shape{2};
auto value = Ort::Value::CreateTensor(Ort::AllocatorWithDefaultOptions(), shape.data(), shape.size(), ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
auto value = Ort::Value::CreateTensor(Ort::AllocatorWithDefaultOptions(), shape.data(), shape.size(),
ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
Ort::ThrowOnError(Ort::GetApi().FillStringTensor(value, string_input_data, 2));
in.push_back(std::move(value));
@ -274,7 +291,8 @@ TEST(CApiTest, TypeInfoSequence) {
ASSERT_EQ(seq_type_info.GetSequenceElementType().GetONNXType(), ONNX_TYPE_TENSOR);
// No shape present, as sequence allows different shapes for each element
// ASSERT_EQ(seq_type_info.GetSequenceElementType().GetTensorTypeAndShapeInfo().GetShape(), dims);
ASSERT_EQ(seq_type_info.GetSequenceElementType().GetTensorTypeAndShapeInfo().GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64);
ASSERT_EQ(seq_type_info.GetSequenceElementType().GetTensorTypeAndShapeInfo().GetElementType(),
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64);
seq_type_info.release();
}
}

View file

@ -305,6 +305,7 @@ std::pair<COMPARE_RESULT, std::string> CompareOrtValue(const OrtValue& o, const
return std::make_pair(COMPARE_RESULT::TYPE_MISMATCH, "");
}
if (!o.IsTensor()) {
#if !defined(DISABLE_ML_OPS)
if (o.Type() == DataTypeImpl::GetType<VectorMapInt64ToFloat>()) {
return CompareSeqOfMapToFloat(o.Get<VectorMapInt64ToFloat>(), expected_mlvalue.Get<VectorMapInt64ToFloat>(),
per_sample_tolerance, relative_per_sample_tolerance, post_processing);
@ -314,6 +315,9 @@ std::pair<COMPARE_RESULT, std::string> CompareOrtValue(const OrtValue& o, const
per_sample_tolerance, relative_per_sample_tolerance, post_processing);
}
return std::make_pair(COMPARE_RESULT::NOT_SUPPORT, "");
#else
return std::make_pair(COMPARE_RESULT::NOT_SUPPORT, "Map type is not supported in this build.");
#endif
}
const Tensor& outvalue = o.Get<Tensor>();
const Tensor& expected_tensor = expected_mlvalue.Get<Tensor>();