From 3afb83ac3c66840b99c5854d0a697ebeb8f4ff94 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Mon, 10 Feb 2020 16:18:42 -0800 Subject: [PATCH] Support a few new C/C++ APIs (#2794) * Initial commit * More changes * More changes * Changes * More changes * More changes * More changes * More changes * Updates * Fix break * PR feedback * Nit * Resolve conflicts * More changes --- .../core/session/onnxruntime_c_api.h | 39 ++++- .../core/session/onnxruntime_cxx_api.h | 21 ++- .../core/session/onnxruntime_cxx_inline.h | 48 ++++++ onnxruntime/core/session/inference_session.h | 7 + onnxruntime/core/session/onnxruntime_c_api.cc | 110 +++++++++++++- onnxruntime/core/session/ort_apis.h | 20 +++ onnxruntime/test/shared_lib/test_inference.cc | 137 +++++++++++++----- .../model_with_valid_ort_config_json.onnx | Bin 271 -> 322 bytes 8 files changed, 336 insertions(+), 46 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index cca39a0862..341fc9d778 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -158,6 +158,7 @@ ORT_RUNTIME_CLASS(SessionOptions); ORT_RUNTIME_CLASS(CustomOpDomain); ORT_RUNTIME_CLASS(MapTypeInfo); ORT_RUNTIME_CLASS(SequenceTypeInfo); +ORT_RUNTIME_CLASS(ModelMetadata); // When passing in an allocator to any ORT function, be sure that the allocator object // is not destroyed until the last allocated object using it is freed. @@ -381,7 +382,7 @@ struct OrtApi { OrtStatus*(ORT_API_CALL* SessionGetOverridableInitializerTypeInfo)(_In_ const OrtSession* sess, size_t index, _Outptr_ OrtTypeInfo** type_info)NO_EXCEPTION; /** - * \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible in freeing it. + * \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it. */ OrtStatus*(ORT_API_CALL* SessionGetInputName)(_In_ const OrtSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** value)NO_EXCEPTION; @@ -710,6 +711,42 @@ struct OrtApi { ORT_CLASS_RELEASE(MapTypeInfo); ORT_CLASS_RELEASE(SequenceTypeInfo); + + /** + * \param out is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it. + * Profiling is turned ON automatically if enabled for the particular session by invoking EnableProfiling() + * on the SessionOptions instance used to create the session. + */ + OrtStatus*(ORT_API_CALL* SessionEndProfiling)(_In_ OrtSession* sess, _Inout_ OrtAllocator* allocator, + _Outptr_ char** out)NO_EXCEPTION; + + /** + * \param out is a pointer to the newly created object. The pointer should be freed by calling ReleaseModelMetadata after use. + */ + OrtStatus*(ORT_API_CALL* SessionGetModelMetadata)(_In_ const OrtSession* sess, + _Outptr_ OrtModelMetadata** out)NO_EXCEPTION; + + /** + * \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it. + */ + OrtStatus*(ORT_API_CALL* ModelMetadataGetProducerName)(_In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value)NO_EXCEPTION; + OrtStatus*(ORT_API_CALL* ModelMetadataGetGraphName)(_In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value)NO_EXCEPTION; + OrtStatus*(ORT_API_CALL* ModelMetadataGetDomain)(_In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value)NO_EXCEPTION; + OrtStatus*(ORT_API_CALL* ModelMetadataGetDescription)(_In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value)NO_EXCEPTION; + /** + * \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it. + * 'value' will be a nullptr if the given key is not found in the custom metadata map. + */ + OrtStatus*(ORT_API_CALL* ModelMetadataLookupCustomMetadataMap)(_In_ const OrtModelMetadata* model_metadata, _Inout_ OrtAllocator* allocator, + _In_ const char* key, _Outptr_ char** value)NO_EXCEPTION; + + OrtStatus*(ORT_API_CALL* ModelMetadataGetVersion)(_In_ const OrtModelMetadata* model_metadata, _Out_ int64_t* value)NO_EXCEPTION; + + ORT_CLASS_RELEASE(ModelMetadata); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index a97a5d413f..5e1c91b916 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -53,7 +53,6 @@ template const OrtApi& Global::api_ = *OrtGetApiBase()->GetApi(ORT_API_VERSION); #endif - // This returns a reference to the OrtApi interface in use, in case someone wants to use the C API functions inline const OrtApi& GetApi() { return Global::api_; } @@ -71,6 +70,7 @@ ORT_DEFINE_RELEASE(SessionOptions); ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo); ORT_DEFINE_RELEASE(TypeInfo); ORT_DEFINE_RELEASE(Value); +ORT_DEFINE_RELEASE(ModelMetadata); // This is used internally by the C++ API. This is the common base class used by the wrapper objects. template @@ -82,7 +82,7 @@ struct Base { ~Base() { OrtRelease(p_); } operator T*() { return p_; } - operator const T*() const { return p_; } + operator const T *() const { return p_; } T* release() { T* p = p_; @@ -118,6 +118,7 @@ struct MemoryInfo; struct Env; struct TypeInfo; struct Value; +struct ModelMetadata; struct Env : Base { Env(std::nullptr_t) {} @@ -186,6 +187,18 @@ struct SessionOptions : Base { SessionOptions& Add(OrtCustomOpDomain* custom_op_domain); }; +struct ModelMetadata : Base { + explicit ModelMetadata(std::nullptr_t) {} + explicit ModelMetadata(OrtModelMetadata* p) : Base{p} {} + + char* GetProducerName(OrtAllocator* allocator) const; + char* GetGraphName(OrtAllocator* allocator) const; + char* GetDomain(OrtAllocator* allocator) const; + char* GetDescription(OrtAllocator* allocator) const; + char* LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const; + int64_t GetVersion() const; +}; + struct Session : Base { explicit Session(std::nullptr_t) {} Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); @@ -205,6 +218,8 @@ struct Session : Base { char* GetInputName(size_t index, OrtAllocator* allocator) const; char* GetOutputName(size_t index, OrtAllocator* allocator) const; char* GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const; + char* EndProfiling(OrtAllocator* allocator) const; + ModelMetadata GetModelMetadata() const; TypeInfo GetInputTypeInfo(size_t index) const; TypeInfo GetOutputTypeInfo(size_t index) const; @@ -274,7 +289,7 @@ struct AllocatorWithDefaultOptions { AllocatorWithDefaultOptions(); operator OrtAllocator*() { return p_; } - operator const OrtAllocator*() const { return p_; } + operator const OrtAllocator *() const { return p_; } void* Alloc(size_t size); void Free(void* p); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index f6fb350171..b10f0cebe5 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -278,6 +278,54 @@ inline char* Session::GetOverridableInitializerName(size_t index, OrtAllocator* return out; } +inline char* Session::EndProfiling(OrtAllocator* allocator) const { + char* out; + ThrowOnError(Global::api_.SessionEndProfiling(p_, allocator, &out)); + return out; +} + +inline ModelMetadata Session::GetModelMetadata() const { + OrtModelMetadata* out; + ThrowOnError(Global::api_.SessionGetModelMetadata(p_, &out)); + return ModelMetadata{out}; +} + +inline char* ModelMetadata::GetProducerName(OrtAllocator* allocator) const { + char* out; + ThrowOnError(Global::api_.ModelMetadataGetProducerName(p_, allocator, &out)); + return out; +} + +inline char* ModelMetadata::GetGraphName(OrtAllocator* allocator) const { + char* out; + ThrowOnError(Global::api_.ModelMetadataGetGraphName(p_, allocator, &out)); + return out; +} + +inline char* ModelMetadata::GetDomain(OrtAllocator* allocator) const { + char* out; + ThrowOnError(Global::api_.ModelMetadataGetDomain(p_, allocator, &out)); + return out; +} + +inline char* ModelMetadata::GetDescription(OrtAllocator* allocator) const { + char* out; + ThrowOnError(Global::api_.ModelMetadataGetDescription(p_, allocator, &out)); + return out; +} + +inline char* ModelMetadata::LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const { + char* out; + ThrowOnError(Global::api_.ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out)); + return out; +} + +inline int64_t ModelMetadata::GetVersion() const { + int64_t out; + ThrowOnError(Global::api_.ModelMetadataGetVersion(p_, &out)); + return out; +} + inline TypeInfo Session::GetInputTypeInfo(size_t index) const { OrtTypeInfo* out; ThrowOnError(Global::api_.SessionGetInputTypeInfo(p_, index, &out)); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 66e90075e5..b4977a9de5 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -56,6 +56,13 @@ class LoggingManager; * Pre-defined and custom metadata about the model. */ struct ModelMetadata { + ModelMetadata() = default; + ModelMetadata(const ModelMetadata& other) + : producer_name(other.producer_name), graph_name(other.graph_name), domain(other.domain), description(other.description), version(other.version), custom_metadata_map(other.custom_metadata_map) { + } + ~ModelMetadata() = default; + ModelMetadata& operator=(const ModelMetadata&) = delete; + std::string producer_name; std::string graph_name; std::string domain; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index b9e1f58c95..974b3bcce5 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -49,9 +49,6 @@ using namespace onnxruntime; if (_status) return _status; \ } while (0) - - - #define TENSOR_READ_API_BEGIN \ API_IMPL_BEGIN \ auto v = reinterpret_cast(value); \ @@ -664,6 +661,99 @@ static OrtStatus* GetNodeDefNameImpl(_In_ const OrtSession* sess, size_t index, return nullptr; } +ORT_API_STATUS_IMPL(OrtApis::SessionEndProfiling, _In_ OrtSession* sess, _Inout_ OrtAllocator* allocator, + _Out_ char** out) { + API_IMPL_BEGIN + auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess); + auto profile_file_name = session->EndProfiling(); + *out = StrDup(profile_file_name, allocator); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::SessionGetModelMetadata, _In_ const OrtSession* sess, + _Outptr_ OrtModelMetadata** out) { + API_IMPL_BEGIN + auto session = reinterpret_cast(sess); + auto p = session->GetModelMetadata(); + if (!p.first.IsOK()) + return ToOrtStatus(p.first); + *out = reinterpret_cast(new ModelMetadata(*p.second)); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetProducerName, + _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value) { + API_IMPL_BEGIN + auto producer_name = reinterpret_cast(model_metadata)->producer_name; + *value = StrDup(producer_name, allocator); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetGraphName, + _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value) { + API_IMPL_BEGIN + auto graph_name = reinterpret_cast(model_metadata)->graph_name; + *value = StrDup(graph_name, allocator); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetDomain, + _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value) { + API_IMPL_BEGIN + auto domain = reinterpret_cast(model_metadata)->domain; + *value = StrDup(domain, allocator); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetDescription, + _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value) { + API_IMPL_BEGIN + auto description = reinterpret_cast(model_metadata)->description; + *value = StrDup(description, allocator); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ModelMetadataLookupCustomMetadataMap, + _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, + _In_ const char* key, _Outptr_ char** value) { + API_IMPL_BEGIN + auto custom_metadata_map = + reinterpret_cast(model_metadata)->custom_metadata_map; + + std::string temp(key); + + auto iter = custom_metadata_map.find(temp); + + if (iter == custom_metadata_map.end()) { + *value = nullptr; + } else { + *value = StrDup(iter->second, allocator); + } + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetVersion, + _In_ const OrtModelMetadata* model_metadata, + _Out_ int64_t* value) { + API_IMPL_BEGIN + *value = reinterpret_cast(model_metadata)->version; + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::SessionGetInputName, _In_ const OrtSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** output) { API_IMPL_BEGIN @@ -1400,7 +1490,16 @@ static constexpr OrtApi ort_api_1_to_2 = { &OrtApis::GetMapValueType, &OrtApis::GetSequenceElementType, &OrtApis::ReleaseMapTypeInfo, - &OrtApis::ReleaseSequenceTypeInfo + &OrtApis::ReleaseSequenceTypeInfo, + &OrtApis::SessionEndProfiling, + &OrtApis::SessionGetModelMetadata, + &OrtApis::ModelMetadataGetProducerName, + &OrtApis::ModelMetadataGetGraphName, + &OrtApis::ModelMetadataGetDomain, + &OrtApis::ModelMetadataGetDescription, + &OrtApis::ModelMetadataLookupCustomMetadataMap, + &OrtApis::ModelMetadataGetVersion, + &OrtApis::ReleaseModelMetadata, }; // Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other) @@ -1428,4 +1527,5 @@ ORT_API(void, OrtApis::ReleaseEnv, _Frees_ptr_opt_ OrtEnv* value) { DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Value, OrtValue) DEFINE_RELEASE_ORT_OBJECT_FUNCTION(RunOptions, OrtRunOptions) -DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Session, ::onnxruntime::InferenceSession) \ No newline at end of file +DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Session, ::onnxruntime::InferenceSession) +DEFINE_RELEASE_ORT_OBJECT_FUNCTION(ModelMetadata, ::onnxruntime::ModelMetadata) diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 4e3bf2274a..8ebed4a71b 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -18,6 +18,7 @@ ORT_API(void, ReleaseSessionOptions, OrtSessionOptions*); ORT_API(void, ReleaseCustomOpDomain, OrtCustomOpDomain*); ORT_API(void, ReleaseMapTypeInfo, OrtMapTypeInfo*); ORT_API(void, ReleaseSequenceTypeInfo, OrtSequenceTypeInfo*); +ORT_API(void, ReleaseModelMetadata, OrtModelMetadata*); ORT_API_STATUS_IMPL(CreateStatus, OrtErrorCode code, _In_ const char* msg); OrtErrorCode ORT_API_CALL GetErrorCode(_In_ const OrtStatus* status) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; @@ -73,6 +74,25 @@ ORT_API_STATUS_IMPL(SessionGetInputName, _In_ const OrtSession* sess, size_t ind ORT_API_STATUS_IMPL(SessionGetOutputName, _In_ const OrtSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** value); ORT_API_STATUS_IMPL(SessionGetOverridableInitializerName, _In_ const OrtSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** value); +ORT_API_STATUS_IMPL(SessionEndProfiling, _In_ OrtSession* sess, _Inout_ OrtAllocator* allocator, + _Outptr_ char** out); +ORT_API_STATUS_IMPL(SessionGetModelMetadata, _In_ const OrtSession* sess, + _Outptr_ OrtModelMetadata** out); + +ORT_API_STATUS_IMPL(ModelMetadataGetProducerName, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value); +ORT_API_STATUS_IMPL(ModelMetadataGetGraphName, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value); +ORT_API_STATUS_IMPL(ModelMetadataGetDomain, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value); +ORT_API_STATUS_IMPL(ModelMetadataGetDescription, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value); +ORT_API_STATUS_IMPL(ModelMetadataLookupCustomMetadataMap, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, + _In_ const char* key, _Outptr_ char** value); + +ORT_API_STATUS_IMPL(ModelMetadataGetVersion, _In_ const OrtModelMetadata* model_metadata, + _Out_ int64_t* value); ORT_API_STATUS_IMPL(CreateRunOptions, _Outptr_ OrtRunOptions** out); diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 4e446ae7fa..08726f8bf0 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -57,14 +57,13 @@ void RunSession(OrtAllocator* allocator, Ort::Session& session_object, } } - template void TestInference(Ort::Env& env, T model_uri, const std::vector& inputs, const char* output_name, const std::vector& expected_dims_y, const std::vector& expected_values_y, - int provider_type, + int provider_type, OrtCustomOpDomain* custom_op_domain_ptr, const char* custom_op_library_filename) { Ort::SessionOptions session_options; @@ -97,8 +96,8 @@ void TestInference(Ort::Env& env, T model_uri, session_options.Add(custom_op_domain_ptr); } - if (custom_op_library_filename){ - void* library_handle = nullptr; // leak this, no harm. + if (custom_op_library_filename) { + void* library_handle = nullptr; // leak this, no harm. Ort::GetApi().RegisterCustomOpsLibrary((OrtSessionOptions*)session_options, custom_op_library_filename, &library_handle); } @@ -107,24 +106,24 @@ void TestInference(Ort::Env& env, T model_uri, // Now run //without preallocated output tensor RunSession(default_allocator.get(), - session, - inputs, - output_name, - expected_dims_y, - expected_values_y, - nullptr); + session, + inputs, + output_name, + expected_dims_y, + expected_values_y, + nullptr); //with preallocated output tensor Ort::Value value_y = Ort::Value::CreateTensor(default_allocator.get(), expected_dims_y.data(), expected_dims_y.size()); //test it twice for (int i = 0; i != 2; ++i) RunSession(default_allocator.get(), - session, - inputs, - output_name, - expected_dims_y, - expected_values_y, - &value_y); + session, + inputs, + output_name, + expected_dims_y, + expected_values_y, + &value_y); } static constexpr PATH_TYPE MODEL_URI = TSTR("testdata/mul_1.onnx"); @@ -132,6 +131,7 @@ static constexpr PATH_TYPE CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_1.onnx"); static constexpr PATH_TYPE CUSTOM_OP_LIBRARY_TEST_MODEL_URI = TSTR("testdata/custom_op_library/custom_op_test.onnx"); static constexpr PATH_TYPE OVERRIDABLE_INITIALIZER_MODEL_URI = TSTR("testdata/overridable_initializer.onnx"); static constexpr PATH_TYPE NAMED_AND_ANON_DIM_PARAM_URI = TSTR("testdata/capi_symbolic_dims.onnx"); +static constexpr PATH_TYPE MODEL_WITH_CUSTOM_MODEL_METADATA = TSTR("testdata/model_with_valid_ort_config_json.onnx"); #ifdef ENABLE_LANGUAGE_INTEROP_OPS static constexpr PATH_TYPE PYOP_FLOAT_MODEL_URI = TSTR("testdata/pyop_1.onnx"); @@ -267,36 +267,34 @@ TEST(CApiTest, DISABLED_test_custom_op_library) { std::vector inputs(2); inputs[0].name = "input_1"; inputs[0].dims = {3, 5}; - inputs[0].values = {1.1f, 2.2f, 3.3f, 4.4f, 5.5f, - 6.6f, 7.7f, 8.8f, 9.9f, 10.0f, - 11.1f, 12.2f, 13.3f, 14.4f, 15.5f}; + inputs[0].values = {1.1f, 2.2f, 3.3f, 4.4f, 5.5f, + 6.6f, 7.7f, 8.8f, 9.9f, 10.0f, + 11.1f, 12.2f, 13.3f, 14.4f, 15.5f}; inputs[1].name = "input_2"; inputs[1].dims = {3, 5}; - inputs[1].values = {15.5f, 14.4f, 13.3f, 12.2f, 11.1f, - 10.0f, 9.9f, 8.8f, 7.7f, 6.6f, - 5.5f, 4.4f, 3.3f, 2.2f, 1.1f}; + inputs[1].values = {15.5f, 14.4f, 13.3f, 12.2f, 11.1f, + 10.0f, 9.9f, 8.8f, 7.7f, 6.6f, + 5.5f, 4.4f, 3.3f, 2.2f, 1.1f}; // prepare expected inputs and outputs std::vector expected_dims_y = {3, 5}; - std::vector expected_values_y = - {17, 17, 17, 17, 17, - 17, 18, 18, 18, 17, - 17, 17, 17, 17, 17}; + std::vector expected_values_y = + {17, 17, 17, 17, 17, + 17, 18, 18, 18, 17, + 17, 17, 17, 17, 17}; std::string lib_name; - #if defined(_WIN32) - lib_name = "custom_op_library.dll"; - #elif defined(__APPLE__) - lib_name = "libcustom_op_library.dylib"; - #else - lib_name = "libcustom_op_library.so"; - #endif +#if defined(_WIN32) + lib_name = "custom_op_library.dll"; +#elif defined(__APPLE__) + lib_name = "libcustom_op_library.dylib"; +#else + lib_name = "libcustom_op_library.so"; +#endif - TestInference(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y, expected_values_y, 0, nullptr, lib_name.c_str()); + TestInference(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y, expected_values_y, 0, nullptr, lib_name.c_str()); } - - #if defined(ENABLE_LANGUAGE_INTEROP_OPS) && !defined(_WIN32) // on windows, PYTHONHOME must be set explicitly TEST(CApiTest, DISABLED_test_pyop) { std::cout << "Test model with pyop" << std::endl; @@ -422,4 +420,69 @@ TEST(CApiTest, override_initializer) { ASSERT_EQ(type_info.GetElementCount(), 1U); float* output_data = ort_outputs[2].GetTensorMutableData(); ASSERT_EQ(*output_data, f11_input_data[0]); -} \ No newline at end of file +} + +TEST(CApiTest, end_profiling) { + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + auto allocator = onnxruntime::make_unique(); + + // Create session with profiling enabled (profiling is automatically turned on) + Ort::SessionOptions session_options_1; +#ifdef _WIN32 + session_options_1.EnableProfiling(L"profile_prefix"); +#else + session_options_1.EnableProfiling("profile_prefix"); +#endif + Ort::Session session_1(*ort_env, MODEL_WITH_CUSTOM_MODEL_METADATA, session_options_1); + char* profile_file = session_1.EndProfiling(allocator.get()); + + ASSERT_TRUE(std::string(profile_file).find("profile_prefix") != std::string::npos); + + // Create session with profiling disabled + Ort::SessionOptions session_options_2; +#ifdef _WIN32 + session_options_2.DisableProfiling(); +#else + session_options_2.DisableProfiling(); +#endif + Ort::Session session_2(*ort_env, MODEL_WITH_CUSTOM_MODEL_METADATA, session_options_2); + profile_file = session_2.EndProfiling(allocator.get()); + + ASSERT_TRUE(std::string(profile_file) == std::string()); +} + +TEST(CApiTest, model_metadata) { + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + auto allocator = onnxruntime::make_unique(); + + // Create session + Ort::SessionOptions session_options; + Ort::Session session(*ort_env, MODEL_WITH_CUSTOM_MODEL_METADATA, session_options); + + // Fetch model metadata + // The following all tap into the c++ APIs which internally wrap over C APIs + auto model_metadata = session.GetModelMetadata(); + + char* producer_name = model_metadata.GetProducerName(allocator.get()); + ASSERT_TRUE(strcmp("Hari", producer_name) == 0); + + char* graph_name = model_metadata.GetGraphName(allocator.get()); + ASSERT_TRUE(strcmp("matmul test", graph_name) == 0); + + char* domain = model_metadata.GetDomain(allocator.get()); + ASSERT_TRUE(strcmp("", domain) == 0); + + char* description = model_metadata.GetDescription(allocator.get()); + ASSERT_TRUE(strcmp("This is a test model with a valid ORT config Json", description) == 0); + + int64_t version = model_metadata.GetVersion(); + ASSERT_TRUE(version == 1); + + char* lookup_value = model_metadata.LookupCustomMetadataMap("ort_config", allocator.get()); + ASSERT_TRUE(strcmp(lookup_value, + "{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, \"graph_optimization_level\": 99, \"enable_profiling\": 1}}") == 0); + + // key doesn't exist in custom metadata map + lookup_value = model_metadata.LookupCustomMetadataMap("key_doesnt_exist", allocator.get()); + ASSERT_TRUE(lookup_value == nullptr); +} diff --git a/onnxruntime/test/testdata/model_with_valid_ort_config_json.onnx b/onnxruntime/test/testdata/model_with_valid_ort_config_json.onnx index 160bbbd2d4a50d7b62e664e237078cc6e065a3bb..a57b83f71aa30c39b7d27dbf4773effcd194f06c 100644 GIT binary patch delta 70 zcmeBYI>cnl!7Rk$kywC{Jf1^tc(ENi4~my delta 19 acmX@a)X&7l!7Rj