Support a few new C/C++ APIs (#2794)

* Initial commit

* More changes

* More changes

* Changes

* More changes

* More changes

* More changes

* More changes

* Updates

* Fix break

* PR feedback

* Nit

* Resolve conflicts

* More changes
This commit is contained in:
Hariharan Seshadri 2020-02-10 16:18:42 -08:00 committed by GitHub
parent 7437928f47
commit 3afb83ac3c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 336 additions and 46 deletions

View file

@ -158,6 +158,7 @@ ORT_RUNTIME_CLASS(SessionOptions);
ORT_RUNTIME_CLASS(CustomOpDomain);
ORT_RUNTIME_CLASS(MapTypeInfo);
ORT_RUNTIME_CLASS(SequenceTypeInfo);
ORT_RUNTIME_CLASS(ModelMetadata);
// When passing in an allocator to any ORT function, be sure that the allocator object
// is not destroyed until the last allocated object using it is freed.
@ -381,7 +382,7 @@ struct OrtApi {
OrtStatus*(ORT_API_CALL* SessionGetOverridableInitializerTypeInfo)(_In_ const OrtSession* sess, size_t index, _Outptr_ OrtTypeInfo** type_info)NO_EXCEPTION;
/**
* \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible in freeing it.
* \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it.
*/
OrtStatus*(ORT_API_CALL* SessionGetInputName)(_In_ const OrtSession* sess, size_t index,
_Inout_ OrtAllocator* allocator, _Outptr_ char** value)NO_EXCEPTION;
@ -710,6 +711,42 @@ struct OrtApi {
ORT_CLASS_RELEASE(MapTypeInfo);
ORT_CLASS_RELEASE(SequenceTypeInfo);
/**
* \param out is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it.
* Profiling is turned ON automatically if enabled for the particular session by invoking EnableProfiling()
* on the SessionOptions instance used to create the session.
*/
OrtStatus*(ORT_API_CALL* SessionEndProfiling)(_In_ OrtSession* sess, _Inout_ OrtAllocator* allocator,
_Outptr_ char** out)NO_EXCEPTION;
/**
* \param out is a pointer to the newly created object. The pointer should be freed by calling ReleaseModelMetadata after use.
*/
OrtStatus*(ORT_API_CALL* SessionGetModelMetadata)(_In_ const OrtSession* sess,
_Outptr_ OrtModelMetadata** out)NO_EXCEPTION;
/**
* \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it.
*/
OrtStatus*(ORT_API_CALL* ModelMetadataGetProducerName)(_In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _Outptr_ char** value)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* ModelMetadataGetGraphName)(_In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _Outptr_ char** value)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* ModelMetadataGetDomain)(_In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _Outptr_ char** value)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* ModelMetadataGetDescription)(_In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _Outptr_ char** value)NO_EXCEPTION;
/**
* \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it.
* 'value' will be a nullptr if the given key is not found in the custom metadata map.
*/
OrtStatus*(ORT_API_CALL* ModelMetadataLookupCustomMetadataMap)(_In_ const OrtModelMetadata* model_metadata, _Inout_ OrtAllocator* allocator,
_In_ const char* key, _Outptr_ char** value)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* ModelMetadataGetVersion)(_In_ const OrtModelMetadata* model_metadata, _Out_ int64_t* value)NO_EXCEPTION;
ORT_CLASS_RELEASE(ModelMetadata);
};
/*

View file

@ -53,7 +53,6 @@ template <typename T>
const OrtApi& Global<T>::api_ = *OrtGetApiBase()->GetApi(ORT_API_VERSION);
#endif
// This returns a reference to the OrtApi interface in use, in case someone wants to use the C API functions
inline const OrtApi& GetApi() { return Global<void>::api_; }
@ -71,6 +70,7 @@ ORT_DEFINE_RELEASE(SessionOptions);
ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
ORT_DEFINE_RELEASE(TypeInfo);
ORT_DEFINE_RELEASE(Value);
ORT_DEFINE_RELEASE(ModelMetadata);
// This is used internally by the C++ API. This is the common base class used by the wrapper objects.
template <typename T>
@ -82,7 +82,7 @@ struct Base {
~Base() { OrtRelease(p_); }
operator T*() { return p_; }
operator const T*() const { return p_; }
operator const T *() const { return p_; }
T* release() {
T* p = p_;
@ -118,6 +118,7 @@ struct MemoryInfo;
struct Env;
struct TypeInfo;
struct Value;
struct ModelMetadata;
struct Env : Base<OrtEnv> {
Env(std::nullptr_t) {}
@ -186,6 +187,18 @@ struct SessionOptions : Base<OrtSessionOptions> {
SessionOptions& Add(OrtCustomOpDomain* custom_op_domain);
};
struct ModelMetadata : Base<OrtModelMetadata> {
explicit ModelMetadata(std::nullptr_t) {}
explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {}
char* GetProducerName(OrtAllocator* allocator) const;
char* GetGraphName(OrtAllocator* allocator) const;
char* GetDomain(OrtAllocator* allocator) const;
char* GetDescription(OrtAllocator* allocator) const;
char* LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const;
int64_t GetVersion() const;
};
struct Session : Base<OrtSession> {
explicit Session(std::nullptr_t) {}
Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options);
@ -205,6 +218,8 @@ struct Session : Base<OrtSession> {
char* GetInputName(size_t index, OrtAllocator* allocator) const;
char* GetOutputName(size_t index, OrtAllocator* allocator) const;
char* GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const;
char* EndProfiling(OrtAllocator* allocator) const;
ModelMetadata GetModelMetadata() const;
TypeInfo GetInputTypeInfo(size_t index) const;
TypeInfo GetOutputTypeInfo(size_t index) const;
@ -274,7 +289,7 @@ struct AllocatorWithDefaultOptions {
AllocatorWithDefaultOptions();
operator OrtAllocator*() { return p_; }
operator const OrtAllocator*() const { return p_; }
operator const OrtAllocator *() const { return p_; }
void* Alloc(size_t size);
void Free(void* p);

View file

@ -278,6 +278,54 @@ inline char* Session::GetOverridableInitializerName(size_t index, OrtAllocator*
return out;
}
inline char* Session::EndProfiling(OrtAllocator* allocator) const {
char* out;
ThrowOnError(Global<void>::api_.SessionEndProfiling(p_, allocator, &out));
return out;
}
inline ModelMetadata Session::GetModelMetadata() const {
OrtModelMetadata* out;
ThrowOnError(Global<void>::api_.SessionGetModelMetadata(p_, &out));
return ModelMetadata{out};
}
inline char* ModelMetadata::GetProducerName(OrtAllocator* allocator) const {
char* out;
ThrowOnError(Global<void>::api_.ModelMetadataGetProducerName(p_, allocator, &out));
return out;
}
inline char* ModelMetadata::GetGraphName(OrtAllocator* allocator) const {
char* out;
ThrowOnError(Global<void>::api_.ModelMetadataGetGraphName(p_, allocator, &out));
return out;
}
inline char* ModelMetadata::GetDomain(OrtAllocator* allocator) const {
char* out;
ThrowOnError(Global<void>::api_.ModelMetadataGetDomain(p_, allocator, &out));
return out;
}
inline char* ModelMetadata::GetDescription(OrtAllocator* allocator) const {
char* out;
ThrowOnError(Global<void>::api_.ModelMetadataGetDescription(p_, allocator, &out));
return out;
}
inline char* ModelMetadata::LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const {
char* out;
ThrowOnError(Global<void>::api_.ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out));
return out;
}
inline int64_t ModelMetadata::GetVersion() const {
int64_t out;
ThrowOnError(Global<void>::api_.ModelMetadataGetVersion(p_, &out));
return out;
}
inline TypeInfo Session::GetInputTypeInfo(size_t index) const {
OrtTypeInfo* out;
ThrowOnError(Global<void>::api_.SessionGetInputTypeInfo(p_, index, &out));

View file

@ -56,6 +56,13 @@ class LoggingManager;
* Pre-defined and custom metadata about the model.
*/
struct ModelMetadata {
ModelMetadata() = default;
ModelMetadata(const ModelMetadata& other)
: producer_name(other.producer_name), graph_name(other.graph_name), domain(other.domain), description(other.description), version(other.version), custom_metadata_map(other.custom_metadata_map) {
}
~ModelMetadata() = default;
ModelMetadata& operator=(const ModelMetadata&) = delete;
std::string producer_name;
std::string graph_name;
std::string domain;

View file

@ -49,9 +49,6 @@ using namespace onnxruntime;
if (_status) return _status; \
} while (0)
#define TENSOR_READ_API_BEGIN \
API_IMPL_BEGIN \
auto v = reinterpret_cast<const ::OrtValue*>(value); \
@ -664,6 +661,99 @@ static OrtStatus* GetNodeDefNameImpl(_In_ const OrtSession* sess, size_t index,
return nullptr;
}
ORT_API_STATUS_IMPL(OrtApis::SessionEndProfiling, _In_ OrtSession* sess, _Inout_ OrtAllocator* allocator,
_Out_ char** out) {
API_IMPL_BEGIN
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);
auto profile_file_name = session->EndProfiling();
*out = StrDup(profile_file_name, allocator);
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::SessionGetModelMetadata, _In_ const OrtSession* sess,
_Outptr_ OrtModelMetadata** out) {
API_IMPL_BEGIN
auto session = reinterpret_cast<const ::onnxruntime::InferenceSession*>(sess);
auto p = session->GetModelMetadata();
if (!p.first.IsOK())
return ToOrtStatus(p.first);
*out = reinterpret_cast<OrtModelMetadata*>(new ModelMetadata(*p.second));
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetProducerName,
_In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _Outptr_ char** value) {
API_IMPL_BEGIN
auto producer_name = reinterpret_cast<const ::onnxruntime::ModelMetadata*>(model_metadata)->producer_name;
*value = StrDup(producer_name, allocator);
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetGraphName,
_In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _Outptr_ char** value) {
API_IMPL_BEGIN
auto graph_name = reinterpret_cast<const ::onnxruntime::ModelMetadata*>(model_metadata)->graph_name;
*value = StrDup(graph_name, allocator);
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetDomain,
_In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _Outptr_ char** value) {
API_IMPL_BEGIN
auto domain = reinterpret_cast<const ::onnxruntime::ModelMetadata*>(model_metadata)->domain;
*value = StrDup(domain, allocator);
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetDescription,
_In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _Outptr_ char** value) {
API_IMPL_BEGIN
auto description = reinterpret_cast<const ::onnxruntime::ModelMetadata*>(model_metadata)->description;
*value = StrDup(description, allocator);
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::ModelMetadataLookupCustomMetadataMap,
_In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator,
_In_ const char* key, _Outptr_ char** value) {
API_IMPL_BEGIN
auto custom_metadata_map =
reinterpret_cast<const ::onnxruntime::ModelMetadata*>(model_metadata)->custom_metadata_map;
std::string temp(key);
auto iter = custom_metadata_map.find(temp);
if (iter == custom_metadata_map.end()) {
*value = nullptr;
} else {
*value = StrDup(iter->second, allocator);
}
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetVersion,
_In_ const OrtModelMetadata* model_metadata,
_Out_ int64_t* value) {
API_IMPL_BEGIN
*value = reinterpret_cast<const ::onnxruntime::ModelMetadata*>(model_metadata)->version;
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::SessionGetInputName, _In_ const OrtSession* sess, size_t index,
_Inout_ OrtAllocator* allocator, _Outptr_ char** output) {
API_IMPL_BEGIN
@ -1400,7 +1490,16 @@ static constexpr OrtApi ort_api_1_to_2 = {
&OrtApis::GetMapValueType,
&OrtApis::GetSequenceElementType,
&OrtApis::ReleaseMapTypeInfo,
&OrtApis::ReleaseSequenceTypeInfo
&OrtApis::ReleaseSequenceTypeInfo,
&OrtApis::SessionEndProfiling,
&OrtApis::SessionGetModelMetadata,
&OrtApis::ModelMetadataGetProducerName,
&OrtApis::ModelMetadataGetGraphName,
&OrtApis::ModelMetadataGetDomain,
&OrtApis::ModelMetadataGetDescription,
&OrtApis::ModelMetadataLookupCustomMetadataMap,
&OrtApis::ModelMetadataGetVersion,
&OrtApis::ReleaseModelMetadata,
};
// Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other)
@ -1428,4 +1527,5 @@ ORT_API(void, OrtApis::ReleaseEnv, _Frees_ptr_opt_ OrtEnv* value) {
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Value, OrtValue)
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(RunOptions, OrtRunOptions)
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Session, ::onnxruntime::InferenceSession)
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Session, ::onnxruntime::InferenceSession)
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(ModelMetadata, ::onnxruntime::ModelMetadata)

View file

@ -18,6 +18,7 @@ ORT_API(void, ReleaseSessionOptions, OrtSessionOptions*);
ORT_API(void, ReleaseCustomOpDomain, OrtCustomOpDomain*);
ORT_API(void, ReleaseMapTypeInfo, OrtMapTypeInfo*);
ORT_API(void, ReleaseSequenceTypeInfo, OrtSequenceTypeInfo*);
ORT_API(void, ReleaseModelMetadata, OrtModelMetadata*);
ORT_API_STATUS_IMPL(CreateStatus, OrtErrorCode code, _In_ const char* msg);
OrtErrorCode ORT_API_CALL GetErrorCode(_In_ const OrtStatus* status) NO_EXCEPTION ORT_ALL_ARGS_NONNULL;
@ -73,6 +74,25 @@ ORT_API_STATUS_IMPL(SessionGetInputName, _In_ const OrtSession* sess, size_t ind
ORT_API_STATUS_IMPL(SessionGetOutputName, _In_ const OrtSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** value);
ORT_API_STATUS_IMPL(SessionGetOverridableInitializerName, _In_ const OrtSession* sess, size_t index,
_Inout_ OrtAllocator* allocator, _Outptr_ char** value);
ORT_API_STATUS_IMPL(SessionEndProfiling, _In_ OrtSession* sess, _Inout_ OrtAllocator* allocator,
_Outptr_ char** out);
ORT_API_STATUS_IMPL(SessionGetModelMetadata, _In_ const OrtSession* sess,
_Outptr_ OrtModelMetadata** out);
ORT_API_STATUS_IMPL(ModelMetadataGetProducerName, _In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _Outptr_ char** value);
ORT_API_STATUS_IMPL(ModelMetadataGetGraphName, _In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _Outptr_ char** value);
ORT_API_STATUS_IMPL(ModelMetadataGetDomain, _In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _Outptr_ char** value);
ORT_API_STATUS_IMPL(ModelMetadataGetDescription, _In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _Outptr_ char** value);
ORT_API_STATUS_IMPL(ModelMetadataLookupCustomMetadataMap, _In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator,
_In_ const char* key, _Outptr_ char** value);
ORT_API_STATUS_IMPL(ModelMetadataGetVersion, _In_ const OrtModelMetadata* model_metadata,
_Out_ int64_t* value);
ORT_API_STATUS_IMPL(CreateRunOptions, _Outptr_ OrtRunOptions** out);

View file

@ -57,14 +57,13 @@ void RunSession(OrtAllocator* allocator, Ort::Session& session_object,
}
}
template <typename T, typename OutT>
void TestInference(Ort::Env& env, T model_uri,
const std::vector<Input>& inputs,
const char* output_name,
const std::vector<int64_t>& expected_dims_y,
const std::vector<OutT>& expected_values_y,
int provider_type,
int provider_type,
OrtCustomOpDomain* custom_op_domain_ptr,
const char* custom_op_library_filename) {
Ort::SessionOptions session_options;
@ -97,8 +96,8 @@ void TestInference(Ort::Env& env, T model_uri,
session_options.Add(custom_op_domain_ptr);
}
if (custom_op_library_filename){
void* library_handle = nullptr; // leak this, no harm.
if (custom_op_library_filename) {
void* library_handle = nullptr; // leak this, no harm.
Ort::GetApi().RegisterCustomOpsLibrary((OrtSessionOptions*)session_options, custom_op_library_filename, &library_handle);
}
@ -107,24 +106,24 @@ void TestInference(Ort::Env& env, T model_uri,
// Now run
//without preallocated output tensor
RunSession<OutT>(default_allocator.get(),
session,
inputs,
output_name,
expected_dims_y,
expected_values_y,
nullptr);
session,
inputs,
output_name,
expected_dims_y,
expected_values_y,
nullptr);
//with preallocated output tensor
Ort::Value value_y = Ort::Value::CreateTensor<float>(default_allocator.get(), expected_dims_y.data(), expected_dims_y.size());
//test it twice
for (int i = 0; i != 2; ++i)
RunSession<OutT>(default_allocator.get(),
session,
inputs,
output_name,
expected_dims_y,
expected_values_y,
&value_y);
session,
inputs,
output_name,
expected_dims_y,
expected_values_y,
&value_y);
}
static constexpr PATH_TYPE MODEL_URI = TSTR("testdata/mul_1.onnx");
@ -132,6 +131,7 @@ static constexpr PATH_TYPE CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_1.onnx");
static constexpr PATH_TYPE CUSTOM_OP_LIBRARY_TEST_MODEL_URI = TSTR("testdata/custom_op_library/custom_op_test.onnx");
static constexpr PATH_TYPE OVERRIDABLE_INITIALIZER_MODEL_URI = TSTR("testdata/overridable_initializer.onnx");
static constexpr PATH_TYPE NAMED_AND_ANON_DIM_PARAM_URI = TSTR("testdata/capi_symbolic_dims.onnx");
static constexpr PATH_TYPE MODEL_WITH_CUSTOM_MODEL_METADATA = TSTR("testdata/model_with_valid_ort_config_json.onnx");
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
static constexpr PATH_TYPE PYOP_FLOAT_MODEL_URI = TSTR("testdata/pyop_1.onnx");
@ -267,36 +267,34 @@ TEST(CApiTest, DISABLED_test_custom_op_library) {
std::vector<Input> inputs(2);
inputs[0].name = "input_1";
inputs[0].dims = {3, 5};
inputs[0].values = {1.1f, 2.2f, 3.3f, 4.4f, 5.5f,
6.6f, 7.7f, 8.8f, 9.9f, 10.0f,
11.1f, 12.2f, 13.3f, 14.4f, 15.5f};
inputs[0].values = {1.1f, 2.2f, 3.3f, 4.4f, 5.5f,
6.6f, 7.7f, 8.8f, 9.9f, 10.0f,
11.1f, 12.2f, 13.3f, 14.4f, 15.5f};
inputs[1].name = "input_2";
inputs[1].dims = {3, 5};
inputs[1].values = {15.5f, 14.4f, 13.3f, 12.2f, 11.1f,
10.0f, 9.9f, 8.8f, 7.7f, 6.6f,
5.5f, 4.4f, 3.3f, 2.2f, 1.1f};
inputs[1].values = {15.5f, 14.4f, 13.3f, 12.2f, 11.1f,
10.0f, 9.9f, 8.8f, 7.7f, 6.6f,
5.5f, 4.4f, 3.3f, 2.2f, 1.1f};
// prepare expected inputs and outputs
std::vector<int64_t> expected_dims_y = {3, 5};
std::vector<int32_t> expected_values_y =
{17, 17, 17, 17, 17,
17, 18, 18, 18, 17,
17, 17, 17, 17, 17};
std::vector<int32_t> expected_values_y =
{17, 17, 17, 17, 17,
17, 18, 18, 18, 17,
17, 17, 17, 17, 17};
std::string lib_name;
#if defined(_WIN32)
lib_name = "custom_op_library.dll";
#elif defined(__APPLE__)
lib_name = "libcustom_op_library.dylib";
#else
lib_name = "libcustom_op_library.so";
#endif
#if defined(_WIN32)
lib_name = "custom_op_library.dll";
#elif defined(__APPLE__)
lib_name = "libcustom_op_library.dylib";
#else
lib_name = "libcustom_op_library.so";
#endif
TestInference<PATH_TYPE, int32_t>(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y, expected_values_y, 0, nullptr, lib_name.c_str());
TestInference<PATH_TYPE, int32_t>(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y, expected_values_y, 0, nullptr, lib_name.c_str());
}
#if defined(ENABLE_LANGUAGE_INTEROP_OPS) && !defined(_WIN32) // on windows, PYTHONHOME must be set explicitly
TEST(CApiTest, DISABLED_test_pyop) {
std::cout << "Test model with pyop" << std::endl;
@ -422,4 +420,69 @@ TEST(CApiTest, override_initializer) {
ASSERT_EQ(type_info.GetElementCount(), 1U);
float* output_data = ort_outputs[2].GetTensorMutableData<float>();
ASSERT_EQ(*output_data, f11_input_data[0]);
}
}
TEST(CApiTest, end_profiling) {
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
auto allocator = onnxruntime::make_unique<MockedOrtAllocator>();
// Create session with profiling enabled (profiling is automatically turned on)
Ort::SessionOptions session_options_1;
#ifdef _WIN32
session_options_1.EnableProfiling(L"profile_prefix");
#else
session_options_1.EnableProfiling("profile_prefix");
#endif
Ort::Session session_1(*ort_env, MODEL_WITH_CUSTOM_MODEL_METADATA, session_options_1);
char* profile_file = session_1.EndProfiling(allocator.get());
ASSERT_TRUE(std::string(profile_file).find("profile_prefix") != std::string::npos);
// Create session with profiling disabled
Ort::SessionOptions session_options_2;
#ifdef _WIN32
session_options_2.DisableProfiling();
#else
session_options_2.DisableProfiling();
#endif
Ort::Session session_2(*ort_env, MODEL_WITH_CUSTOM_MODEL_METADATA, session_options_2);
profile_file = session_2.EndProfiling(allocator.get());
ASSERT_TRUE(std::string(profile_file) == std::string());
}
TEST(CApiTest, model_metadata) {
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
auto allocator = onnxruntime::make_unique<MockedOrtAllocator>();
// Create session
Ort::SessionOptions session_options;
Ort::Session session(*ort_env, MODEL_WITH_CUSTOM_MODEL_METADATA, session_options);
// Fetch model metadata
// The following all tap into the c++ APIs which internally wrap over C APIs
auto model_metadata = session.GetModelMetadata();
char* producer_name = model_metadata.GetProducerName(allocator.get());
ASSERT_TRUE(strcmp("Hari", producer_name) == 0);
char* graph_name = model_metadata.GetGraphName(allocator.get());
ASSERT_TRUE(strcmp("matmul test", graph_name) == 0);
char* domain = model_metadata.GetDomain(allocator.get());
ASSERT_TRUE(strcmp("", domain) == 0);
char* description = model_metadata.GetDescription(allocator.get());
ASSERT_TRUE(strcmp("This is a test model with a valid ORT config Json", description) == 0);
int64_t version = model_metadata.GetVersion();
ASSERT_TRUE(version == 1);
char* lookup_value = model_metadata.LookupCustomMetadataMap("ort_config", allocator.get());
ASSERT_TRUE(strcmp(lookup_value,
"{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, \"graph_optimization_level\": 99, \"enable_profiling\": 1}}") == 0);
// key doesn't exist in custom metadata map
lookup_value = model_metadata.LookupCustomMetadataMap("key_doesnt_exist", allocator.get());
ASSERT_TRUE(lookup_value == nullptr);
}