From 08eb15068cf5298643097e144591f88845974b9d Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Thu, 27 Aug 2020 17:48:12 +1000 Subject: [PATCH] Exclude the Map types from the build if ML ops are disabled. (#4908) * Exclude the Map types from the build if ML ops are disabled. They're the only ops that use Map. --- .../onnxruntime/core/framework/data_types.h | 12 +++++- onnxruntime/core/framework/data_types.cc | 39 ++++++++++++++---- onnxruntime/core/session/onnxruntime_c_api.cc | 41 ++++++++++++++++++- .../python/onnxruntime_pybind_mlvalue.cc | 14 +++++++ .../python/onnxruntime_pybind_state.cc | 5 +++ onnxruntime/test/framework/data_types_test.cc | 32 ++++++++++++--- .../test/ir/schema_registry_manager_test.cc | 2 +- .../test/providers/provider_test_utils.cc | 6 ++- .../test/shared_lib/test_nontensor_types.cc | 34 +++++++++++---- onnxruntime/test/util/compare_ortvalue.cc | 4 ++ 10 files changed, 163 insertions(+), 26 deletions(-) diff --git a/include/onnxruntime/core/framework/data_types.h b/include/onnxruntime/core/framework/data_types.h index 87aa7012b7..5f539cfc43 100644 --- a/include/onnxruntime/core/framework/data_types.h +++ b/include/onnxruntime/core/framework/data_types.h @@ -24,7 +24,8 @@ class TypeProto; namespace onnxruntime { /// Predefined registered types -//maps +#if !defined(DISABLE_ML_OPS) +//maps (only used by ML ops) using MapStringToString = std::map; using MapStringToInt64 = std::map; using MapStringToFloat = std::map; @@ -33,10 +34,13 @@ using MapInt64ToString = std::map; using MapInt64ToInt64 = std::map; using MapInt64ToFloat = std::map; using MapInt64ToDouble = std::map; +#endif //vectors/sequences +#if !defined(DISABLE_ML_OPS) using VectorMapStringToFloat = std::vector; using VectorMapInt64ToFloat = std::vector; +#endif using VectorString = std::vector; using VectorInt64 = std::vector; @@ -422,6 +426,7 @@ struct GetMLDataType { } }; +#if !defined(DISABLE_ML_OPS) /// MapTypes helper API /// K should always be one of the primitive data types /// V can be either a primitive type (in which case it is a tensor) @@ -445,6 +450,7 @@ struct SetMapTypes { CopyMutableMapValue(*value_proto, proto); } }; +#endif /// Sequence helpers /// @@ -713,6 +719,7 @@ class NonTensorType : public NonTensorTypeBase { NonTensorType() = default; }; +#if !defined(DISABLE_ML_OPS) /** * \brief MapType. Use this type to register * mapping types. @@ -741,6 +748,7 @@ class MapType : public NonTensorType { SetMapTypes::Set(this->mutable_type_proto()); } }; +#endif /** * \brief SequenceType. Use to register sequence for non-tensor types. @@ -968,6 +976,7 @@ class PrimitiveDataType : public PrimitiveDataTypeBase { return SparseTensorType::Type(); \ } +#if !defined(DISABLE_ML_OPS) #define ORT_REGISTER_MAP(TYPE) \ template <> \ MLDataType MapType::Type() { \ @@ -978,6 +987,7 @@ class PrimitiveDataType : public PrimitiveDataTypeBase { MLDataType DataTypeImpl::GetType() { \ return MapType::Type(); \ } +#endif #define ORT_REGISTER_SEQ(TYPE) \ template <> \ diff --git a/onnxruntime/core/framework/data_types.cc b/onnxruntime/core/framework/data_types.cc index 51d8078f10..fe445c0de7 100644 --- a/onnxruntime/core/framework/data_types.cc +++ b/onnxruntime/core/framework/data_types.cc @@ -60,9 +60,13 @@ struct TensorElementTypeSetter { static void SetSparseTensorElementType(ONNX_NAMESPACE::TypeProto& proto) { proto.mutable_sparse_tensor_type()->set_elem_type(utils::ToTensorProtoElementType()); } + +#if !defined(DISABLE_ML_OPS) static void SetMapKeyType(ONNX_NAMESPACE::TypeProto& proto) { proto.mutable_map_type()->set_key_type(utils::ToTensorProtoElementType()); } +#endif + constexpr static int32_t GetElementType() { return utils::ToTensorProtoElementType(); } @@ -98,10 +102,12 @@ template struct template struct TensorElementTypeSetter; +#if !defined(DISABLE_ML_OPS) void CopyMutableMapValue(const ONNX_NAMESPACE::TypeProto& value_proto, ONNX_NAMESPACE::TypeProto& map_proto) { map_proto.mutable_map_type()->mutable_value_type()->CopyFrom(value_proto); } +#endif void CopyMutableSeqElement(const ONNX_NAMESPACE::TypeProto& elem_proto, ONNX_NAMESPACE::TypeProto& proto) { @@ -121,8 +127,10 @@ bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Tensor& tensor_proto, bool IsCompatible(const ONNX_NAMESPACE::TypeProto_SparseTensor& tensor_proto, const ONNX_NAMESPACE::TypeProto_SparseTensor& type_proto); +#if !defined(DISABLE_ML_OPS) bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Map& map_proto, const ONNX_NAMESPACE::TypeProto_Map& type_proto); +#endif bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Sequence& sequence_proto, const ONNX_NAMESPACE::TypeProto_Sequence& type_proto); @@ -139,6 +147,7 @@ bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Tensor& tensor_proto, */ } +#if !defined(DISABLE_ML_OPS) bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Map& map_proto, const ONNX_NAMESPACE::TypeProto_Map& type_proto) { const auto& lhs = map_proto; @@ -171,6 +180,7 @@ bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Map& map_proto, } return result; } +#endif bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Sequence& sequence_proto, const ONNX_NAMESPACE::TypeProto_Sequence& type_proto) { @@ -185,9 +195,11 @@ bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Sequence& sequence_proto, case TypeProto::ValueCase::kSequenceType: result = IsCompatible(lhs.elem_type().sequence_type(), rhs.elem_type().sequence_type()); break; +#if !defined(DISABLE_ML_OPS) case TypeProto::ValueCase::kMapType: result = IsCompatible(lhs.elem_type().map_type(), rhs.elem_type().map_type()); break; +#endif case TypeProto::ValueCase::kOpaqueType: result = IsCompatible(lhs.elem_type().opaque_type(), rhs.elem_type().opaque_type()); break; @@ -203,7 +215,8 @@ bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Sequence& sequence_proto, } return result; } -bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Opaque& opaque_proto, const ONNX_NAMESPACE::TypeProto_Opaque& type_proto) { +bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Opaque& opaque_proto, + const ONNX_NAMESPACE::TypeProto_Opaque& type_proto) { const auto& lhs = opaque_proto; const auto& rhs = type_proto; bool lhs_domain = utils::HasDomain(lhs); @@ -454,6 +467,7 @@ const ONNX_NAMESPACE::TypeProto* NonTensorTypeBase::GetTypeProto() const { return impl_->GetProto(); } +#if !defined(DISABLE_ML_OPS) bool NonTensorTypeBase::IsMapCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const { const auto* thisProto = impl_->GetProto(); if (&type_proto == thisProto) { @@ -467,6 +481,7 @@ bool NonTensorTypeBase::IsMapCompatible(const ONNX_NAMESPACE::TypeProto& type_pr ORT_ENFORCE(utils::HasKeyType(thisProto->map_type())); return data_types_internal::IsCompatible(thisProto->map_type(), type_proto.map_type()); } +#endif bool NonTensorTypeBase::IsSequenceCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const { const auto* thisProto = impl_->GetProto(); @@ -532,6 +547,7 @@ ORT_REGISTER_SPARSE_TENSOR_TYPE(uint64_t); ORT_REGISTER_SPARSE_TENSOR_TYPE(MLFloat16); ORT_REGISTER_SPARSE_TENSOR_TYPE(BFloat16); +#if !defined(DISABLE_ML_OPS) ORT_REGISTER_MAP(MapStringToString); ORT_REGISTER_MAP(MapStringToInt64); ORT_REGISTER_MAP(MapStringToFloat); @@ -540,6 +556,7 @@ ORT_REGISTER_MAP(MapInt64ToString); ORT_REGISTER_MAP(MapInt64ToInt64); ORT_REGISTER_MAP(MapInt64ToFloat); ORT_REGISTER_MAP(MapInt64ToDouble); +#endif ORT_REGISTER_SEQ_TENSOR_TYPE(float); ORT_REGISTER_SEQ_TENSOR_TYPE(double); @@ -556,8 +573,10 @@ ORT_REGISTER_SEQ_TENSOR_TYPE(std::string); ORT_REGISTER_SEQ_TENSOR_TYPE(MLFloat16); ORT_REGISTER_SEQ_TENSOR_TYPE(BFloat16); +#if !defined(DISABLE_ML_OPS) ORT_REGISTER_SEQ(VectorMapStringToFloat); ORT_REGISTER_SEQ(VectorMapInt64ToFloat); +#endif // Used for Tensor Proto registrations #define REGISTER_TENSOR_PROTO(TYPE, reg_fn) \ @@ -617,6 +636,7 @@ void RegisterAllProtos(const std::function& reg_fn) { REGISTER_SPARSE_TENSOR_PROTO(MLFloat16, reg_fn); REGISTER_SPARSE_TENSOR_PROTO(BFloat16, reg_fn); +#if !defined(DISABLE_ML_OPS) REGISTER_ONNX_PROTO(MapStringToString, reg_fn); REGISTER_ONNX_PROTO(MapStringToInt64, reg_fn); REGISTER_ONNX_PROTO(MapStringToFloat, reg_fn); @@ -625,6 +645,7 @@ void RegisterAllProtos(const std::function& reg_fn) { REGISTER_ONNX_PROTO(MapInt64ToInt64, reg_fn); REGISTER_ONNX_PROTO(MapInt64ToFloat, reg_fn); REGISTER_ONNX_PROTO(MapInt64ToDouble, reg_fn); +#endif REGISTER_SEQ_TENSOR_PROTO(int32_t, reg_fn); REGISTER_SEQ_TENSOR_PROTO(float, reg_fn); @@ -641,8 +662,10 @@ void RegisterAllProtos(const std::function& reg_fn) { REGISTER_SEQ_TENSOR_PROTO(MLFloat16, reg_fn); REGISTER_SEQ_TENSOR_PROTO(BFloat16, reg_fn); +#if !defined(DISABLE_ML_OPS) REGISTER_ONNX_PROTO(VectorMapStringToFloat, reg_fn); REGISTER_ONNX_PROTO(VectorMapInt64ToFloat, reg_fn); +#endif } } // namespace data_types_internal @@ -982,18 +1005,18 @@ ContainerChecker::ContainerChecker(MLDataType ml_type) { types_.emplace_back(ContainerType::kTensor, type_proto->tensor_type().elem_type()); type_proto = nullptr; break; - case TypeProto::ValueCase::kMapType: - { +#if !defined(DISABLE_ML_OPS) + case TypeProto::ValueCase::kMapType: { const auto& map_type = type_proto->map_type(); types_.emplace_back(ContainerType::kMap, map_type.key_type()); // Move on handling the value type_proto = &map_type.value_type(); - } - break; + } break; +#endif case TypeProto::ValueCase::kSequenceType: - types_.emplace_back(ContainerType::kSequence, TensorProto_DataType_UNDEFINED); - type_proto = &type_proto->sequence_type().elem_type(); - break; + types_.emplace_back(ContainerType::kSequence, TensorProto_DataType_UNDEFINED); + type_proto = &type_proto->sequence_type().elem_type(); + break; case TypeProto::ValueCase::kOpaqueType: // We do not handle this and terminate here types_.emplace_back(ContainerType::kOpaque, diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 6b04022a95..adf20b5159 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -1084,8 +1084,6 @@ ORT_API_STATUS_IMPL(OrtApis::AllocatorGetInfo, _In_ const OrtAllocator* ptr, _Ou API_IMPL_END } -static const int NUM_MAP_INDICES = 2; - template ORT_STATUS_PTR OrtGetNumSequenceElements(const OrtValue* p_ml_value, size_t* out) { auto& data = p_ml_value->Get(); @@ -1100,13 +1098,21 @@ ORT_STATUS_PTR OrtGetNumSequenceElements(const OrtValue* p_ml_value, return nullptr; } +#if !defined(DISABLE_ML_OPS) +static const int NUM_MAP_INDICES = 2; +#endif + static ORT_STATUS_PTR OrtGetValueCountImpl(const OrtValue* value, size_t* out) { ONNXType value_type; if (auto status = OrtApis::GetValueType(value, &value_type)) return status; if (value_type == ONNX_TYPE_MAP) { +#if !defined(DISABLE_ML_OPS) *out = NUM_MAP_INDICES; return nullptr; +#else + return OrtApis::CreateStatus(ORT_FAIL, "Map type is not supported in this build."); +#endif } if (value_type == ONNX_TYPE_SEQUENCE) { auto v = reinterpret_cast(value); @@ -1115,6 +1121,7 @@ static ORT_STATUS_PTR OrtGetValueCountImpl(const OrtValue* value, size_t* out) { if (type->IsTensorSequenceType()) { return OrtGetNumSequenceElements(v, out); } else { +#if !defined(DISABLE_ML_OPS) utils::ContainerChecker c_checker(type); if (c_checker.IsSequenceOf>()) { return OrtGetNumSequenceElements(v, out); @@ -1123,6 +1130,9 @@ static ORT_STATUS_PTR OrtGetValueCountImpl(const OrtValue* value, size_t* out) { } else { return OrtApis::CreateStatus(ORT_FAIL, "Input is not of one of the supported sequence types."); } +#else + return OrtApis::CreateStatus(ORT_FAIL, "Map type is not supported in this build."); +#endif } } else { return OrtApis::CreateStatus(ORT_FAIL, "Input is not of type sequence or map."); @@ -1135,6 +1145,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetValueCount, _In_ const OrtValue* value, _Out_ si API_IMPL_END } +#if !defined(DISABLE_ML_OPS) /////////////////// // OrtGetValueImplSeqOfMap template @@ -1153,6 +1164,7 @@ static ORT_STATUS_PTR OrtGetValueImplSeqOfMap(const OrtValue* p_ml_value, int in *out = value.release(); return nullptr; } +#endif ORT_STATUS_PTR PopulateTensorWithData(_Inout_ OrtValue* oval, _In_ const void* data_elem, size_t num_elems, size_t elem_size) { @@ -1231,6 +1243,7 @@ static ORT_STATUS_PTR OrtGetValueImplSeq(_In_ const OrtValue* value, int index, if (type->IsTensorSequenceType()) { return OrtGetValueImplSeqOfTensors(p_ml_value, index, allocator, out); } else { +#if !defined(DISABLE_ML_OPS) utils::ContainerChecker c_checker(type); if (c_checker.IsSequenceOf>()) { return OrtGetValueImplSeqOfMap(p_ml_value, index, out); @@ -1239,9 +1252,13 @@ static ORT_STATUS_PTR OrtGetValueImplSeq(_In_ const OrtValue* value, int index, } else { return OrtApis::CreateStatus(ORT_FAIL, "Input is not of one of the supported sequence types."); } +#else + return OrtApis::CreateStatus(ORT_FAIL, "Map type is not supported in this build."); +#endif } } +#if !defined(DISABLE_ML_OPS) template static ORT_STATUS_PTR OrtGetValueImplMapHelper(_In_ const OrtValue* p_ml_value, int index, _Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out) { @@ -1308,6 +1325,7 @@ static ORT_STATUS_PTR OrtGetValueImplMap(_In_ const OrtValue* value, int index, } return OrtApis::CreateStatus(ORT_FAIL, "Input is not of one of the supported map types."); } +#endif static ORT_STATUS_PTR OrtGetValueImpl(_In_ const OrtValue* value, int index, _Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out) { @@ -1315,7 +1333,11 @@ static ORT_STATUS_PTR OrtGetValueImpl(_In_ const OrtValue* value, int index, _In if (auto status = OrtApis::GetValueType(value, &value_type)) return status; if (value_type == ONNX_TYPE_MAP) { +#if !defined(DISABLE_ML_OPS) return OrtGetValueImplMap(value, index, allocator, out); +#else + return OrtApis::CreateStatus(ORT_FAIL, "Map type is not supported in this build."); +#endif } if (value_type == ONNX_TYPE_SEQUENCE) { return OrtGetValueImplSeq(value, index, allocator, out); @@ -1333,6 +1355,8 @@ ORT_API_STATUS_IMPL(OrtApis::GetValue, _In_ const OrtValue* value, int index, _I /////////////////// // OrtCreateValue + +#if !defined(DISABLE_ML_OPS) template static OrtStatus* OrtCreateValueImplSeqHelperMap(const OrtValue* const* in, size_t num_values, _Outptr_ OrtValue** out) { @@ -1352,6 +1376,7 @@ static OrtStatus* OrtCreateValueImplSeqHelperMap(const OrtValue* const* in, size *out = value.release(); return nullptr; } +#endif template static OrtStatus* OrtCreateValueImplSeqHelperTensor(const Tensor& tensor, @@ -1461,6 +1486,7 @@ static ORT_STATUS_PTR OrtCreateValueImplSeq(_In_reads_(num_values) const OrtValu if (first_value_type == ONNX_TYPE_TENSOR) { return OrtCreateValueImplSeqHelper(in, num_values, out); } else if (first_value_type == ONNX_TYPE_MAP) { +#if !defined(DISABLE_ML_OPS) auto map_type = first_mlvalue->Type(); utils::ContainerChecker c_checker(map_type); if (c_checker.IsMapOf()) { @@ -1471,11 +1497,17 @@ static ORT_STATUS_PTR OrtCreateValueImplSeq(_In_reads_(num_values) const OrtValu } else { return OrtApis::CreateStatus(ORT_FAIL, "Input is not of one of the supported map types."); } +#else + ORT_UNUSED_PARAMETER(first_mlvalue); + return OrtApis::CreateStatus(ORT_FAIL, "Map type is not supported in this build."); +#endif + } else { return OrtApis::CreateStatus(ORT_FAIL, "Unsupported input type"); } } +#if !defined(DISABLE_ML_OPS) template static OrtStatus* OrtCreateMapMLValue(const Tensor& key_tensor, const Tensor& value_tensor, _Outptr_ OrtValue** out) { using MapType = std::map; @@ -1559,6 +1591,7 @@ static ORT_STATUS_PTR OrtCreateValueImplMap(const OrtValue* const* in, size_t nu } return OrtApis::CreateStatus(ORT_FAIL, "Key type is not supported yet."); } +#endif static ORT_STATUS_PTR OrtCreateValueImpl(_In_reads_(num_values) const OrtValue* const* in, size_t num_values, enum ONNXType value_type, _Outptr_ OrtValue** out) { @@ -1566,7 +1599,11 @@ static ORT_STATUS_PTR OrtCreateValueImpl(_In_reads_(num_values) const OrtValue* return OrtApis::CreateStatus(ORT_FAIL, "Number of values should be at least 1."); } if (value_type == ONNX_TYPE_MAP) { +#if !defined(DISABLE_ML_OPS) return OrtCreateValueImplMap(in, num_values, out); +#else + return OrtApis::CreateStatus(ORT_FAIL, "Map type is not supported in this build."); +#endif } if (value_type == ONNX_TYPE_SEQUENCE) { return OrtCreateValueImplSeq(in, num_values, out); diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index c8055cf854..293b594b0f 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -387,6 +387,7 @@ std::string _get_type_name(std::string&) { return std::string("string"); } +#if !defined(DISABLE_ML_OPS) template void CreateMapMLValue_LoopIntoMap(Py_ssize_t& pos, PyObject*& key, const std::string& name_input, PyObject*& value, PyObject* item, std::map& current, @@ -549,6 +550,7 @@ void CreateMapMLValue_AgnosticVectorMap(PyObject* iterator, PyObject* item, Allo throw std::runtime_error("Size of dictionary is empty, unable to run the prediction."); } } +#endif void CreateGenericIterableMLValue(PyObject* iterator, AllocatorPtr alloc, const std::string& name_input, OrtValue* p_mlvalue) { @@ -573,7 +575,13 @@ void CreateGenericIterableMLValue(PyObject* iterator, AllocatorPtr alloc, const throw std::runtime_error("Input must be a list of dictionaries or a single numpy array for input '" + name_input + std::string("'.")); } +#if !defined(DISABLE_ML_OPS) CreateMapMLValue_AgnosticVectorMap(iterator, item, alloc, name_input, p_mlvalue); +#else + ORT_UNUSED_PARAMETER(alloc); + ORT_UNUSED_PARAMETER(p_mlvalue); + throw std::runtime_error("Map type is not supported in this build."); +#endif } } @@ -610,7 +618,13 @@ void CreateGenericMLValue(const onnxruntime::InputDefList* input_def_list, const auto* seq_tensors = reinterpret_cast(value.ptr()); CreateSequenceOfTensors(alloc, name_input, input_def_list, seq_tensors, p_mlvalue); } else if (PyDict_Check(value.ptr())) { +#if !defined(DISABLE_ML_OPS) CreateMapMLValue_AgnosticVectorMap((PyObject*)NULL, value.ptr(), alloc, name_input, p_mlvalue); +#else + ORT_UNUSED_PARAMETER(p_mlvalue); + throw std::runtime_error("Map type is not supported in this build."); +#endif + } else { auto iterator = PyObject_GetIter(value.ptr()); if (iterator == NULL) { diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 84471ab68a..9d3838cbb5 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -264,6 +264,7 @@ void AddNonTensorAsPyObj(const OrtValue& val, std::vector& pyobjs, c if (val_type->IsTensorSequenceType()) { AddNonTensor(val, pyobjs, data_transfer_manager); } else { +#if !defined(DISABLE_ML_OPS) utils::ContainerChecker c_checker(val_type); if (c_checker.IsMap()) { if (c_checker.IsMapOf()) { @@ -283,6 +284,7 @@ void AddNonTensorAsPyObj(const OrtValue& val, std::vector& pyobjs, c } else if (c_checker.IsMapOf()) { AddNonTensor(val, pyobjs, data_transfer_manager); } + } else { if (c_checker.IsSequenceOf>()) { AddNonTensor(val, pyobjs, data_transfer_manager); @@ -292,6 +294,9 @@ void AddNonTensorAsPyObj(const OrtValue& val, std::vector& pyobjs, c throw std::runtime_error("Output is a non-tensor type which is not supported."); } } +#else + throw std::runtime_error("Map type is not supported in this build."); +#endif } } diff --git a/onnxruntime/test/framework/data_types_test.cc b/onnxruntime/test/framework/data_types_test.cc index 74ba617469..a6822bc89f 100644 --- a/onnxruntime/test/framework/data_types_test.cc +++ b/onnxruntime/test/framework/data_types_test.cc @@ -28,9 +28,11 @@ struct TestMap { }; // Try recursive type registration and compatibility tests -using TestMapToMapInt64ToFloat = TestMap; using VectorInt64 = std::vector; +#if !defined(DISABLE_ML_OPS) +using TestMapToMapInt64ToFloat = TestMap; using TestMapStringToVectorInt64 = TestMap; +#endif // Trial to see if we resolve the setter properly // a map with a key that has not been registered in data_types.cc @@ -66,19 +68,23 @@ struct TestOpaqueNoNames {}; // use the same cpp runtime types but due to Opaque type domain, name // and optional parameters we produce separate MLDataTypes that are NOT // compatible with each other. +#if !defined(DISABLE_ML_OPS) using MyOpaqueMapCpp_1 = std::map; using MyOpaqueMapCpp_2 = std::map; +#endif // Register Sequence as containing an Opaque type using MyOpaqueSeqCpp_1 = std::vector; using MyOpaqueSeqCpp_2 = std::vector; +#if !defined(DISABLE_ML_OPS) ORT_REGISTER_MAP(MyOpaqueMapCpp_1); ORT_REGISTER_MAP(MyOpaqueMapCpp_2); ORT_REGISTER_MAP(TestMapToMapInt64ToFloat); ORT_REGISTER_MAP(TestMapStringToVectorInt64); ORT_REGISTER_MAP(TestMapMLFloat16ToFloat); +#endif ORT_REGISTER_SEQ(MyOpaqueSeqCpp_1); ORT_REGISTER_SEQ(MyOpaqueSeqCpp_2); @@ -100,12 +106,14 @@ ORT_REGISTER_OPAQUE_TYPE(TestOpaqueNoNames, TestOpaqueEmpty, TestOpaqueEmpty); } void RegisterTestTypes() { +#if !defined(DISABLE_ML_OPS) REGISTER_ONNX_PROTO(MyOpaqueMapCpp_1); REGISTER_ONNX_PROTO(MyOpaqueMapCpp_2); REGISTER_ONNX_PROTO(TestMapToMapInt64ToFloat); REGISTER_ONNX_PROTO(TestMapStringToVectorInt64); REGISTER_ONNX_PROTO(TestMapMLFloat16ToFloat); +#endif REGISTER_ONNX_PROTO(MyOpaqueSeqCpp_1); REGISTER_ONNX_PROTO(MyOpaqueSeqCpp_2); @@ -236,12 +244,15 @@ TEST_F(DataTypeTest, OpaqueRegistrationTest) { EXPECT_FALSE(utils::IsOpaqueType(op_ml2, TestOpaqueDomain_1, TestOpaqueName_2)); EXPECT_FALSE(utils::IsOpaqueType(DataTypeImpl::GetTensorType(), TestOpaqueDomain_1, TestOpaqueName_1)); +#if !defined(DISABLE_ML_OPS) utils::ContainerChecker c_checker(DataTypeImpl::GetType()); EXPECT_TRUE(c_checker.IsMap()); bool result = c_checker.IsMapOf(); EXPECT_TRUE(result); +#endif } +#if !defined(DISABLE_ML_OPS) TEST_F(DataTypeTest, MapStringStringTest) { TensorTypeProto tensor_type; auto ml_str_str = DataTypeImpl::GetType(); @@ -251,14 +262,13 @@ TEST_F(DataTypeTest, MapStringStringTest) { utils::ContainerChecker c_checker(ml_str_str); bool result = c_checker.IsMapOf(); EXPECT_TRUE(result); - result = c_checker.IsMapOf(); + result = c_checker.IsMapOf(); EXPECT_FALSE(result); utils::ContainerChecker c_checker1(DataTypeImpl::GetTensorType()); - result = c_checker1.IsMapOf(); + result = c_checker1.IsMapOf(); EXPECT_FALSE(result); - MapTypeProto maps2s_type; MapTypeProto maps2i_type; EXPECT_TRUE(ml_str_str->IsCompatible(maps2s_type.proto)); @@ -349,6 +359,7 @@ TEST_F(DataTypeTest, RecursiveMapTest) { mut_map->mutable_value_type()->CopyFrom(*op2_proto->GetTypeProto()); EXPECT_TRUE(DataTypeImpl::GetType()->IsCompatible(unod_map_int64_to_op2)); } +#endif // !defined(DISABLE_ML_OPS) TEST_F(DataTypeTest, RecursiveVectorTest) { TypeProto seq_of_seq_string; @@ -357,9 +368,12 @@ TEST_F(DataTypeTest, RecursiveVectorTest) { mut_seq->mutable_elem_type()->mutable_tensor_type()->set_elem_type(TensorProto_DataType_STRING); EXPECT_TRUE(DataTypeImpl::GetType()->IsCompatible(seq_of_seq_string)); +#if !defined(DISABLE_ML_OPS) EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(seq_of_seq_string)); +#endif } +#if !defined(DISABLE_ML_OPS) TEST_F(DataTypeTest, VectorMapStringToFloatTest) { TypeProto vector_map_string_to_float; vector_map_string_to_float.mutable_sequence_type()->mutable_elem_type()->mutable_map_type()->set_key_type(TensorProto_DataType_STRING); @@ -374,7 +388,7 @@ TEST_F(DataTypeTest, VectorMapStringToFloatTest) { EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(mapi2i_type.proto)); EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(tensor_type.proto)); utils::ContainerChecker c_check(DataTypeImpl::GetType()); - bool result = c_check.IsSequenceOf(); + bool result = c_check.IsSequenceOf(); EXPECT_TRUE(result); } @@ -392,6 +406,7 @@ TEST_F(DataTypeTest, VectorMapInt64ToFloatTest) { EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(mapi2i_type.proto)); EXPECT_FALSE(DataTypeImpl::GetType()->IsCompatible(tensor_type.proto)); } +#endif // !defined(DISABLE_ML_OPS) TEST_F(DataTypeTest, BFloat16Test) { // Test data type @@ -508,6 +523,8 @@ TEST_F(DataTypeTest, DataUtilsTest) { // Expect internalized strings EXPECT_EQ(ten_dt, ten_from_str); } + +#if !defined(DISABLE_ML_OPS) // Test Simple map { const std::string map_string_string("map(string,tensor(string))"); @@ -522,6 +539,7 @@ TEST_F(DataTypeTest, DataUtilsTest) { const auto& from_dt_proto = DataTypeUtils::ToTypeProto(map_dt); EXPECT_TRUE(DataTypeImpl::GetType()->IsCompatible(from_dt_proto)); } + // Test map with recursive value { const std::string map_int_map_int_float("map(int64,map(int64,tensor(float)))"); @@ -536,6 +554,7 @@ TEST_F(DataTypeTest, DataUtilsTest) { const auto& from_dt_proto = DataTypeUtils::ToTypeProto(map_dt); EXPECT_TRUE(DataTypeImpl::GetType()->IsCompatible(from_dt_proto)); } + { const std::string opaque_map_2("map(int64,opaque(test_domain_2,test_name_2))"); const auto* map_proto = DataTypeImpl::GetType()->GetTypeProto(); @@ -549,6 +568,7 @@ TEST_F(DataTypeTest, DataUtilsTest) { const auto& from_dt_proto = DataTypeUtils::ToTypeProto(map_dt); EXPECT_TRUE(DataTypeImpl::GetType()->IsCompatible(from_dt_proto)); } + // Test Sequence with recursion { const std::string seq_map_str_float("seq(map(string,tensor(float)))"); @@ -563,6 +583,8 @@ TEST_F(DataTypeTest, DataUtilsTest) { const auto& from_dt_proto = DataTypeUtils::ToTypeProto(seq_dt); EXPECT_TRUE(DataTypeImpl::GetType()->IsCompatible(from_dt_proto)); } +#endif + // Test Sequence with opaque_2 { const std::string seq_opaque_2("seq(opaque(test_domain_2,test_name_2))"); diff --git a/onnxruntime/test/ir/schema_registry_manager_test.cc b/onnxruntime/test/ir/schema_registry_manager_test.cc index 3c5e783689..7e5b51415d 100644 --- a/onnxruntime/test/ir/schema_registry_manager_test.cc +++ b/onnxruntime/test/ir/schema_registry_manager_test.cc @@ -131,4 +131,4 @@ TEST(SchemaRegistryManager, OpsetRegTest) { // TODO - Consider making the registration algorithm robust to this invalid usage in general ASSERT_TRUE(manager.GetSchema("Op5", 5, "Domain1")->since_version() == 1); ASSERT_TRUE(manager.GetSchema("Op5", 1, "Domain1")->since_version() == 1); -} \ No newline at end of file +} diff --git a/onnxruntime/test/providers/provider_test_utils.cc b/onnxruntime/test/providers/provider_test_utils.cc index 687a63f339..ebe72559db 100644 --- a/onnxruntime/test/providers/provider_test_utils.cc +++ b/onnxruntime/test/providers/provider_test_utils.cc @@ -357,7 +357,11 @@ void CheckDispatch(MLDataType type, const OpTester::Data& expected_data, void Check(const OpTester::Data& expected_data, OrtValue& ort_value, const std::string& provider_type) { - CheckDispatch( + CheckDispatch< +#if !defined(DISABLE_ML_OPS) + VectorMapStringToFloat, VectorMapInt64ToFloat, +#endif + TensorSeq>( expected_data.data_.Type(), expected_data, ort_value, provider_type); } diff --git a/onnxruntime/test/shared_lib/test_nontensor_types.cc b/onnxruntime/test/shared_lib/test_nontensor_types.cc index 5f68fa59ba..d0f412b7cd 100644 --- a/onnxruntime/test/shared_lib/test_nontensor_types.cc +++ b/onnxruntime/test/shared_lib/test_nontensor_types.cc @@ -1,13 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include "core/session/onnxruntime_cxx_api.h" #include -#include -#include "test_allocator.h" -#include #include +#include + +#include "core/common/make_unique.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "test_allocator.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" template struct RelAllocations { @@ -25,6 +28,7 @@ struct RelAllocations { } }; +#if !defined(DISABLE_ML_OPS) TEST(CApiTest, CreateGetVectorOfMapsInt64Float) { // support zipmap output type seq(map(int64, float)) // Creation auto default_allocator = onnxruntime::make_unique(); @@ -149,6 +153,7 @@ TEST(CApiTest, CreateGetVectorOfMapsStringFloat) { // support zipmap output typ std::set(std::begin(values), std::end(values))); } } +#endif // !defined(DISABLE_ML_OPS) TEST(CApiTest, TypeInfoMap) { // Creation @@ -166,6 +171,7 @@ TEST(CApiTest, TypeInfoMap) { Ort::Value values_tensor = Ort::Value::CreateTensor(info, values.data(), values.size() * sizeof(float), dims.data(), dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); +#if !defined(DISABLE_ML_OPS) Ort::Value map_ort = Ort::Value::CreateMap(keys_tensor, values_tensor); Ort::TypeInfo type_info = map_ort.GetTypeInfo(); Ort::MapTypeInfo map_type_info = type_info.GetMapTypeInfo(); @@ -183,6 +189,16 @@ TEST(CApiTest, TypeInfoMap) { map_value_type_info.release(); map_type_info.release(); +#else + // until https://github.com/google/googletest/pull/2904/ makes it into a release, + // check an exception is thrown with the expected message the ugly way + try { + Ort::Value map_ort = Ort::Value::CreateMap(keys_tensor, values_tensor); + ASSERT_TRUE(false) << "CreateMap should have throw in this build"; + } catch (const Ort::Exception& ex) { + ASSERT_THAT(ex.what(), testing::HasSubstr("Map type is not supported in this build")); + } +#endif } TEST(CApiTest, CreateGetSeqTensors) { @@ -223,7 +239,8 @@ TEST(CApiTest, CreateGetSeqStringTensors) { for (int i = 0; i < N; ++i) { // create tensor std::vector shape{2}; - auto value = Ort::Value::CreateTensor(Ort::AllocatorWithDefaultOptions(), shape.data(), shape.size(), ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING); + auto value = Ort::Value::CreateTensor(Ort::AllocatorWithDefaultOptions(), shape.data(), shape.size(), + ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING); Ort::ThrowOnError(Ort::GetApi().FillStringTensor(value, string_input_data, 2)); in.push_back(std::move(value)); @@ -274,7 +291,8 @@ TEST(CApiTest, TypeInfoSequence) { ASSERT_EQ(seq_type_info.GetSequenceElementType().GetONNXType(), ONNX_TYPE_TENSOR); // No shape present, as sequence allows different shapes for each element // ASSERT_EQ(seq_type_info.GetSequenceElementType().GetTensorTypeAndShapeInfo().GetShape(), dims); - ASSERT_EQ(seq_type_info.GetSequenceElementType().GetTensorTypeAndShapeInfo().GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64); + ASSERT_EQ(seq_type_info.GetSequenceElementType().GetTensorTypeAndShapeInfo().GetElementType(), + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64); seq_type_info.release(); -} \ No newline at end of file +} diff --git a/onnxruntime/test/util/compare_ortvalue.cc b/onnxruntime/test/util/compare_ortvalue.cc index fc3d229d68..0ca96739fa 100644 --- a/onnxruntime/test/util/compare_ortvalue.cc +++ b/onnxruntime/test/util/compare_ortvalue.cc @@ -305,6 +305,7 @@ std::pair CompareOrtValue(const OrtValue& o, const return std::make_pair(COMPARE_RESULT::TYPE_MISMATCH, ""); } if (!o.IsTensor()) { +#if !defined(DISABLE_ML_OPS) if (o.Type() == DataTypeImpl::GetType()) { return CompareSeqOfMapToFloat(o.Get(), expected_mlvalue.Get(), per_sample_tolerance, relative_per_sample_tolerance, post_processing); @@ -314,6 +315,9 @@ std::pair CompareOrtValue(const OrtValue& o, const per_sample_tolerance, relative_per_sample_tolerance, post_processing); } return std::make_pair(COMPARE_RESULT::NOT_SUPPORT, ""); +#else + return std::make_pair(COMPARE_RESULT::NOT_SUPPORT, "Map type is not supported in this build."); +#endif } const Tensor& outvalue = o.Get(); const Tensor& expected_tensor = expected_mlvalue.Get();