mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Support a few new C/C++ APIs (#2794)
* Initial commit * More changes * More changes * Changes * More changes * More changes * More changes * More changes * Updates * Fix break * PR feedback * Nit * Resolve conflicts * More changes
This commit is contained in:
parent
7437928f47
commit
3afb83ac3c
8 changed files with 336 additions and 46 deletions
|
|
@ -158,6 +158,7 @@ ORT_RUNTIME_CLASS(SessionOptions);
|
|||
ORT_RUNTIME_CLASS(CustomOpDomain);
|
||||
ORT_RUNTIME_CLASS(MapTypeInfo);
|
||||
ORT_RUNTIME_CLASS(SequenceTypeInfo);
|
||||
ORT_RUNTIME_CLASS(ModelMetadata);
|
||||
|
||||
// When passing in an allocator to any ORT function, be sure that the allocator object
|
||||
// is not destroyed until the last allocated object using it is freed.
|
||||
|
|
@ -381,7 +382,7 @@ struct OrtApi {
|
|||
OrtStatus*(ORT_API_CALL* SessionGetOverridableInitializerTypeInfo)(_In_ const OrtSession* sess, size_t index, _Outptr_ OrtTypeInfo** type_info)NO_EXCEPTION;
|
||||
|
||||
/**
|
||||
* \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible in freeing it.
|
||||
* \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it.
|
||||
*/
|
||||
OrtStatus*(ORT_API_CALL* SessionGetInputName)(_In_ const OrtSession* sess, size_t index,
|
||||
_Inout_ OrtAllocator* allocator, _Outptr_ char** value)NO_EXCEPTION;
|
||||
|
|
@ -710,6 +711,42 @@ struct OrtApi {
|
|||
|
||||
ORT_CLASS_RELEASE(MapTypeInfo);
|
||||
ORT_CLASS_RELEASE(SequenceTypeInfo);
|
||||
|
||||
/**
|
||||
* \param out is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it.
|
||||
* Profiling is turned ON automatically if enabled for the particular session by invoking EnableProfiling()
|
||||
* on the SessionOptions instance used to create the session.
|
||||
*/
|
||||
OrtStatus*(ORT_API_CALL* SessionEndProfiling)(_In_ OrtSession* sess, _Inout_ OrtAllocator* allocator,
|
||||
_Outptr_ char** out)NO_EXCEPTION;
|
||||
|
||||
/**
|
||||
* \param out is a pointer to the newly created object. The pointer should be freed by calling ReleaseModelMetadata after use.
|
||||
*/
|
||||
OrtStatus*(ORT_API_CALL* SessionGetModelMetadata)(_In_ const OrtSession* sess,
|
||||
_Outptr_ OrtModelMetadata** out)NO_EXCEPTION;
|
||||
|
||||
/**
|
||||
* \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it.
|
||||
*/
|
||||
OrtStatus*(ORT_API_CALL* ModelMetadataGetProducerName)(_In_ const OrtModelMetadata* model_metadata,
|
||||
_Inout_ OrtAllocator* allocator, _Outptr_ char** value)NO_EXCEPTION;
|
||||
OrtStatus*(ORT_API_CALL* ModelMetadataGetGraphName)(_In_ const OrtModelMetadata* model_metadata,
|
||||
_Inout_ OrtAllocator* allocator, _Outptr_ char** value)NO_EXCEPTION;
|
||||
OrtStatus*(ORT_API_CALL* ModelMetadataGetDomain)(_In_ const OrtModelMetadata* model_metadata,
|
||||
_Inout_ OrtAllocator* allocator, _Outptr_ char** value)NO_EXCEPTION;
|
||||
OrtStatus*(ORT_API_CALL* ModelMetadataGetDescription)(_In_ const OrtModelMetadata* model_metadata,
|
||||
_Inout_ OrtAllocator* allocator, _Outptr_ char** value)NO_EXCEPTION;
|
||||
/**
|
||||
* \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it.
|
||||
* 'value' will be a nullptr if the given key is not found in the custom metadata map.
|
||||
*/
|
||||
OrtStatus*(ORT_API_CALL* ModelMetadataLookupCustomMetadataMap)(_In_ const OrtModelMetadata* model_metadata, _Inout_ OrtAllocator* allocator,
|
||||
_In_ const char* key, _Outptr_ char** value)NO_EXCEPTION;
|
||||
|
||||
OrtStatus*(ORT_API_CALL* ModelMetadataGetVersion)(_In_ const OrtModelMetadata* model_metadata, _Out_ int64_t* value)NO_EXCEPTION;
|
||||
|
||||
ORT_CLASS_RELEASE(ModelMetadata);
|
||||
};
|
||||
|
||||
/*
|
||||
|
|
|
|||
|
|
@ -53,7 +53,6 @@ template <typename T>
|
|||
const OrtApi& Global<T>::api_ = *OrtGetApiBase()->GetApi(ORT_API_VERSION);
|
||||
#endif
|
||||
|
||||
|
||||
// This returns a reference to the OrtApi interface in use, in case someone wants to use the C API functions
|
||||
inline const OrtApi& GetApi() { return Global<void>::api_; }
|
||||
|
||||
|
|
@ -71,6 +70,7 @@ ORT_DEFINE_RELEASE(SessionOptions);
|
|||
ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
|
||||
ORT_DEFINE_RELEASE(TypeInfo);
|
||||
ORT_DEFINE_RELEASE(Value);
|
||||
ORT_DEFINE_RELEASE(ModelMetadata);
|
||||
|
||||
// This is used internally by the C++ API. This is the common base class used by the wrapper objects.
|
||||
template <typename T>
|
||||
|
|
@ -82,7 +82,7 @@ struct Base {
|
|||
~Base() { OrtRelease(p_); }
|
||||
|
||||
operator T*() { return p_; }
|
||||
operator const T*() const { return p_; }
|
||||
operator const T *() const { return p_; }
|
||||
|
||||
T* release() {
|
||||
T* p = p_;
|
||||
|
|
@ -118,6 +118,7 @@ struct MemoryInfo;
|
|||
struct Env;
|
||||
struct TypeInfo;
|
||||
struct Value;
|
||||
struct ModelMetadata;
|
||||
|
||||
struct Env : Base<OrtEnv> {
|
||||
Env(std::nullptr_t) {}
|
||||
|
|
@ -186,6 +187,18 @@ struct SessionOptions : Base<OrtSessionOptions> {
|
|||
SessionOptions& Add(OrtCustomOpDomain* custom_op_domain);
|
||||
};
|
||||
|
||||
struct ModelMetadata : Base<OrtModelMetadata> {
|
||||
explicit ModelMetadata(std::nullptr_t) {}
|
||||
explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {}
|
||||
|
||||
char* GetProducerName(OrtAllocator* allocator) const;
|
||||
char* GetGraphName(OrtAllocator* allocator) const;
|
||||
char* GetDomain(OrtAllocator* allocator) const;
|
||||
char* GetDescription(OrtAllocator* allocator) const;
|
||||
char* LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const;
|
||||
int64_t GetVersion() const;
|
||||
};
|
||||
|
||||
struct Session : Base<OrtSession> {
|
||||
explicit Session(std::nullptr_t) {}
|
||||
Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options);
|
||||
|
|
@ -205,6 +218,8 @@ struct Session : Base<OrtSession> {
|
|||
char* GetInputName(size_t index, OrtAllocator* allocator) const;
|
||||
char* GetOutputName(size_t index, OrtAllocator* allocator) const;
|
||||
char* GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const;
|
||||
char* EndProfiling(OrtAllocator* allocator) const;
|
||||
ModelMetadata GetModelMetadata() const;
|
||||
|
||||
TypeInfo GetInputTypeInfo(size_t index) const;
|
||||
TypeInfo GetOutputTypeInfo(size_t index) const;
|
||||
|
|
@ -274,7 +289,7 @@ struct AllocatorWithDefaultOptions {
|
|||
AllocatorWithDefaultOptions();
|
||||
|
||||
operator OrtAllocator*() { return p_; }
|
||||
operator const OrtAllocator*() const { return p_; }
|
||||
operator const OrtAllocator *() const { return p_; }
|
||||
|
||||
void* Alloc(size_t size);
|
||||
void Free(void* p);
|
||||
|
|
|
|||
|
|
@ -278,6 +278,54 @@ inline char* Session::GetOverridableInitializerName(size_t index, OrtAllocator*
|
|||
return out;
|
||||
}
|
||||
|
||||
inline char* Session::EndProfiling(OrtAllocator* allocator) const {
|
||||
char* out;
|
||||
ThrowOnError(Global<void>::api_.SessionEndProfiling(p_, allocator, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline ModelMetadata Session::GetModelMetadata() const {
|
||||
OrtModelMetadata* out;
|
||||
ThrowOnError(Global<void>::api_.SessionGetModelMetadata(p_, &out));
|
||||
return ModelMetadata{out};
|
||||
}
|
||||
|
||||
inline char* ModelMetadata::GetProducerName(OrtAllocator* allocator) const {
|
||||
char* out;
|
||||
ThrowOnError(Global<void>::api_.ModelMetadataGetProducerName(p_, allocator, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline char* ModelMetadata::GetGraphName(OrtAllocator* allocator) const {
|
||||
char* out;
|
||||
ThrowOnError(Global<void>::api_.ModelMetadataGetGraphName(p_, allocator, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline char* ModelMetadata::GetDomain(OrtAllocator* allocator) const {
|
||||
char* out;
|
||||
ThrowOnError(Global<void>::api_.ModelMetadataGetDomain(p_, allocator, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline char* ModelMetadata::GetDescription(OrtAllocator* allocator) const {
|
||||
char* out;
|
||||
ThrowOnError(Global<void>::api_.ModelMetadataGetDescription(p_, allocator, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline char* ModelMetadata::LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const {
|
||||
char* out;
|
||||
ThrowOnError(Global<void>::api_.ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline int64_t ModelMetadata::GetVersion() const {
|
||||
int64_t out;
|
||||
ThrowOnError(Global<void>::api_.ModelMetadataGetVersion(p_, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
inline TypeInfo Session::GetInputTypeInfo(size_t index) const {
|
||||
OrtTypeInfo* out;
|
||||
ThrowOnError(Global<void>::api_.SessionGetInputTypeInfo(p_, index, &out));
|
||||
|
|
|
|||
|
|
@ -56,6 +56,13 @@ class LoggingManager;
|
|||
* Pre-defined and custom metadata about the model.
|
||||
*/
|
||||
struct ModelMetadata {
|
||||
ModelMetadata() = default;
|
||||
ModelMetadata(const ModelMetadata& other)
|
||||
: producer_name(other.producer_name), graph_name(other.graph_name), domain(other.domain), description(other.description), version(other.version), custom_metadata_map(other.custom_metadata_map) {
|
||||
}
|
||||
~ModelMetadata() = default;
|
||||
ModelMetadata& operator=(const ModelMetadata&) = delete;
|
||||
|
||||
std::string producer_name;
|
||||
std::string graph_name;
|
||||
std::string domain;
|
||||
|
|
|
|||
|
|
@ -49,9 +49,6 @@ using namespace onnxruntime;
|
|||
if (_status) return _status; \
|
||||
} while (0)
|
||||
|
||||
|
||||
|
||||
|
||||
#define TENSOR_READ_API_BEGIN \
|
||||
API_IMPL_BEGIN \
|
||||
auto v = reinterpret_cast<const ::OrtValue*>(value); \
|
||||
|
|
@ -664,6 +661,99 @@ static OrtStatus* GetNodeDefNameImpl(_In_ const OrtSession* sess, size_t index,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::SessionEndProfiling, _In_ OrtSession* sess, _Inout_ OrtAllocator* allocator,
|
||||
_Out_ char** out) {
|
||||
API_IMPL_BEGIN
|
||||
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);
|
||||
auto profile_file_name = session->EndProfiling();
|
||||
*out = StrDup(profile_file_name, allocator);
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::SessionGetModelMetadata, _In_ const OrtSession* sess,
|
||||
_Outptr_ OrtModelMetadata** out) {
|
||||
API_IMPL_BEGIN
|
||||
auto session = reinterpret_cast<const ::onnxruntime::InferenceSession*>(sess);
|
||||
auto p = session->GetModelMetadata();
|
||||
if (!p.first.IsOK())
|
||||
return ToOrtStatus(p.first);
|
||||
*out = reinterpret_cast<OrtModelMetadata*>(new ModelMetadata(*p.second));
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetProducerName,
|
||||
_In_ const OrtModelMetadata* model_metadata,
|
||||
_Inout_ OrtAllocator* allocator, _Outptr_ char** value) {
|
||||
API_IMPL_BEGIN
|
||||
auto producer_name = reinterpret_cast<const ::onnxruntime::ModelMetadata*>(model_metadata)->producer_name;
|
||||
*value = StrDup(producer_name, allocator);
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetGraphName,
|
||||
_In_ const OrtModelMetadata* model_metadata,
|
||||
_Inout_ OrtAllocator* allocator, _Outptr_ char** value) {
|
||||
API_IMPL_BEGIN
|
||||
auto graph_name = reinterpret_cast<const ::onnxruntime::ModelMetadata*>(model_metadata)->graph_name;
|
||||
*value = StrDup(graph_name, allocator);
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetDomain,
|
||||
_In_ const OrtModelMetadata* model_metadata,
|
||||
_Inout_ OrtAllocator* allocator, _Outptr_ char** value) {
|
||||
API_IMPL_BEGIN
|
||||
auto domain = reinterpret_cast<const ::onnxruntime::ModelMetadata*>(model_metadata)->domain;
|
||||
*value = StrDup(domain, allocator);
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetDescription,
|
||||
_In_ const OrtModelMetadata* model_metadata,
|
||||
_Inout_ OrtAllocator* allocator, _Outptr_ char** value) {
|
||||
API_IMPL_BEGIN
|
||||
auto description = reinterpret_cast<const ::onnxruntime::ModelMetadata*>(model_metadata)->description;
|
||||
*value = StrDup(description, allocator);
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::ModelMetadataLookupCustomMetadataMap,
|
||||
_In_ const OrtModelMetadata* model_metadata,
|
||||
_Inout_ OrtAllocator* allocator,
|
||||
_In_ const char* key, _Outptr_ char** value) {
|
||||
API_IMPL_BEGIN
|
||||
auto custom_metadata_map =
|
||||
reinterpret_cast<const ::onnxruntime::ModelMetadata*>(model_metadata)->custom_metadata_map;
|
||||
|
||||
std::string temp(key);
|
||||
|
||||
auto iter = custom_metadata_map.find(temp);
|
||||
|
||||
if (iter == custom_metadata_map.end()) {
|
||||
*value = nullptr;
|
||||
} else {
|
||||
*value = StrDup(iter->second, allocator);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetVersion,
|
||||
_In_ const OrtModelMetadata* model_metadata,
|
||||
_Out_ int64_t* value) {
|
||||
API_IMPL_BEGIN
|
||||
*value = reinterpret_cast<const ::onnxruntime::ModelMetadata*>(model_metadata)->version;
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::SessionGetInputName, _In_ const OrtSession* sess, size_t index,
|
||||
_Inout_ OrtAllocator* allocator, _Outptr_ char** output) {
|
||||
API_IMPL_BEGIN
|
||||
|
|
@ -1400,7 +1490,16 @@ static constexpr OrtApi ort_api_1_to_2 = {
|
|||
&OrtApis::GetMapValueType,
|
||||
&OrtApis::GetSequenceElementType,
|
||||
&OrtApis::ReleaseMapTypeInfo,
|
||||
&OrtApis::ReleaseSequenceTypeInfo
|
||||
&OrtApis::ReleaseSequenceTypeInfo,
|
||||
&OrtApis::SessionEndProfiling,
|
||||
&OrtApis::SessionGetModelMetadata,
|
||||
&OrtApis::ModelMetadataGetProducerName,
|
||||
&OrtApis::ModelMetadataGetGraphName,
|
||||
&OrtApis::ModelMetadataGetDomain,
|
||||
&OrtApis::ModelMetadataGetDescription,
|
||||
&OrtApis::ModelMetadataLookupCustomMetadataMap,
|
||||
&OrtApis::ModelMetadataGetVersion,
|
||||
&OrtApis::ReleaseModelMetadata,
|
||||
};
|
||||
|
||||
// Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other)
|
||||
|
|
@ -1428,4 +1527,5 @@ ORT_API(void, OrtApis::ReleaseEnv, _Frees_ptr_opt_ OrtEnv* value) {
|
|||
|
||||
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Value, OrtValue)
|
||||
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(RunOptions, OrtRunOptions)
|
||||
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Session, ::onnxruntime::InferenceSession)
|
||||
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Session, ::onnxruntime::InferenceSession)
|
||||
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(ModelMetadata, ::onnxruntime::ModelMetadata)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ ORT_API(void, ReleaseSessionOptions, OrtSessionOptions*);
|
|||
ORT_API(void, ReleaseCustomOpDomain, OrtCustomOpDomain*);
|
||||
ORT_API(void, ReleaseMapTypeInfo, OrtMapTypeInfo*);
|
||||
ORT_API(void, ReleaseSequenceTypeInfo, OrtSequenceTypeInfo*);
|
||||
ORT_API(void, ReleaseModelMetadata, OrtModelMetadata*);
|
||||
|
||||
ORT_API_STATUS_IMPL(CreateStatus, OrtErrorCode code, _In_ const char* msg);
|
||||
OrtErrorCode ORT_API_CALL GetErrorCode(_In_ const OrtStatus* status) NO_EXCEPTION ORT_ALL_ARGS_NONNULL;
|
||||
|
|
@ -73,6 +74,25 @@ ORT_API_STATUS_IMPL(SessionGetInputName, _In_ const OrtSession* sess, size_t ind
|
|||
ORT_API_STATUS_IMPL(SessionGetOutputName, _In_ const OrtSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** value);
|
||||
ORT_API_STATUS_IMPL(SessionGetOverridableInitializerName, _In_ const OrtSession* sess, size_t index,
|
||||
_Inout_ OrtAllocator* allocator, _Outptr_ char** value);
|
||||
ORT_API_STATUS_IMPL(SessionEndProfiling, _In_ OrtSession* sess, _Inout_ OrtAllocator* allocator,
|
||||
_Outptr_ char** out);
|
||||
ORT_API_STATUS_IMPL(SessionGetModelMetadata, _In_ const OrtSession* sess,
|
||||
_Outptr_ OrtModelMetadata** out);
|
||||
|
||||
ORT_API_STATUS_IMPL(ModelMetadataGetProducerName, _In_ const OrtModelMetadata* model_metadata,
|
||||
_Inout_ OrtAllocator* allocator, _Outptr_ char** value);
|
||||
ORT_API_STATUS_IMPL(ModelMetadataGetGraphName, _In_ const OrtModelMetadata* model_metadata,
|
||||
_Inout_ OrtAllocator* allocator, _Outptr_ char** value);
|
||||
ORT_API_STATUS_IMPL(ModelMetadataGetDomain, _In_ const OrtModelMetadata* model_metadata,
|
||||
_Inout_ OrtAllocator* allocator, _Outptr_ char** value);
|
||||
ORT_API_STATUS_IMPL(ModelMetadataGetDescription, _In_ const OrtModelMetadata* model_metadata,
|
||||
_Inout_ OrtAllocator* allocator, _Outptr_ char** value);
|
||||
ORT_API_STATUS_IMPL(ModelMetadataLookupCustomMetadataMap, _In_ const OrtModelMetadata* model_metadata,
|
||||
_Inout_ OrtAllocator* allocator,
|
||||
_In_ const char* key, _Outptr_ char** value);
|
||||
|
||||
ORT_API_STATUS_IMPL(ModelMetadataGetVersion, _In_ const OrtModelMetadata* model_metadata,
|
||||
_Out_ int64_t* value);
|
||||
|
||||
ORT_API_STATUS_IMPL(CreateRunOptions, _Outptr_ OrtRunOptions** out);
|
||||
|
||||
|
|
|
|||
|
|
@ -57,14 +57,13 @@ void RunSession(OrtAllocator* allocator, Ort::Session& session_object,
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename OutT>
|
||||
void TestInference(Ort::Env& env, T model_uri,
|
||||
const std::vector<Input>& inputs,
|
||||
const char* output_name,
|
||||
const std::vector<int64_t>& expected_dims_y,
|
||||
const std::vector<OutT>& expected_values_y,
|
||||
int provider_type,
|
||||
int provider_type,
|
||||
OrtCustomOpDomain* custom_op_domain_ptr,
|
||||
const char* custom_op_library_filename) {
|
||||
Ort::SessionOptions session_options;
|
||||
|
|
@ -97,8 +96,8 @@ void TestInference(Ort::Env& env, T model_uri,
|
|||
session_options.Add(custom_op_domain_ptr);
|
||||
}
|
||||
|
||||
if (custom_op_library_filename){
|
||||
void* library_handle = nullptr; // leak this, no harm.
|
||||
if (custom_op_library_filename) {
|
||||
void* library_handle = nullptr; // leak this, no harm.
|
||||
Ort::GetApi().RegisterCustomOpsLibrary((OrtSessionOptions*)session_options, custom_op_library_filename, &library_handle);
|
||||
}
|
||||
|
||||
|
|
@ -107,24 +106,24 @@ void TestInference(Ort::Env& env, T model_uri,
|
|||
// Now run
|
||||
//without preallocated output tensor
|
||||
RunSession<OutT>(default_allocator.get(),
|
||||
session,
|
||||
inputs,
|
||||
output_name,
|
||||
expected_dims_y,
|
||||
expected_values_y,
|
||||
nullptr);
|
||||
session,
|
||||
inputs,
|
||||
output_name,
|
||||
expected_dims_y,
|
||||
expected_values_y,
|
||||
nullptr);
|
||||
//with preallocated output tensor
|
||||
Ort::Value value_y = Ort::Value::CreateTensor<float>(default_allocator.get(), expected_dims_y.data(), expected_dims_y.size());
|
||||
|
||||
//test it twice
|
||||
for (int i = 0; i != 2; ++i)
|
||||
RunSession<OutT>(default_allocator.get(),
|
||||
session,
|
||||
inputs,
|
||||
output_name,
|
||||
expected_dims_y,
|
||||
expected_values_y,
|
||||
&value_y);
|
||||
session,
|
||||
inputs,
|
||||
output_name,
|
||||
expected_dims_y,
|
||||
expected_values_y,
|
||||
&value_y);
|
||||
}
|
||||
|
||||
static constexpr PATH_TYPE MODEL_URI = TSTR("testdata/mul_1.onnx");
|
||||
|
|
@ -132,6 +131,7 @@ static constexpr PATH_TYPE CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_1.onnx");
|
|||
static constexpr PATH_TYPE CUSTOM_OP_LIBRARY_TEST_MODEL_URI = TSTR("testdata/custom_op_library/custom_op_test.onnx");
|
||||
static constexpr PATH_TYPE OVERRIDABLE_INITIALIZER_MODEL_URI = TSTR("testdata/overridable_initializer.onnx");
|
||||
static constexpr PATH_TYPE NAMED_AND_ANON_DIM_PARAM_URI = TSTR("testdata/capi_symbolic_dims.onnx");
|
||||
static constexpr PATH_TYPE MODEL_WITH_CUSTOM_MODEL_METADATA = TSTR("testdata/model_with_valid_ort_config_json.onnx");
|
||||
|
||||
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
|
||||
static constexpr PATH_TYPE PYOP_FLOAT_MODEL_URI = TSTR("testdata/pyop_1.onnx");
|
||||
|
|
@ -267,36 +267,34 @@ TEST(CApiTest, DISABLED_test_custom_op_library) {
|
|||
std::vector<Input> inputs(2);
|
||||
inputs[0].name = "input_1";
|
||||
inputs[0].dims = {3, 5};
|
||||
inputs[0].values = {1.1f, 2.2f, 3.3f, 4.4f, 5.5f,
|
||||
6.6f, 7.7f, 8.8f, 9.9f, 10.0f,
|
||||
11.1f, 12.2f, 13.3f, 14.4f, 15.5f};
|
||||
inputs[0].values = {1.1f, 2.2f, 3.3f, 4.4f, 5.5f,
|
||||
6.6f, 7.7f, 8.8f, 9.9f, 10.0f,
|
||||
11.1f, 12.2f, 13.3f, 14.4f, 15.5f};
|
||||
inputs[1].name = "input_2";
|
||||
inputs[1].dims = {3, 5};
|
||||
inputs[1].values = {15.5f, 14.4f, 13.3f, 12.2f, 11.1f,
|
||||
10.0f, 9.9f, 8.8f, 7.7f, 6.6f,
|
||||
5.5f, 4.4f, 3.3f, 2.2f, 1.1f};
|
||||
inputs[1].values = {15.5f, 14.4f, 13.3f, 12.2f, 11.1f,
|
||||
10.0f, 9.9f, 8.8f, 7.7f, 6.6f,
|
||||
5.5f, 4.4f, 3.3f, 2.2f, 1.1f};
|
||||
|
||||
// prepare expected inputs and outputs
|
||||
std::vector<int64_t> expected_dims_y = {3, 5};
|
||||
std::vector<int32_t> expected_values_y =
|
||||
{17, 17, 17, 17, 17,
|
||||
17, 18, 18, 18, 17,
|
||||
17, 17, 17, 17, 17};
|
||||
std::vector<int32_t> expected_values_y =
|
||||
{17, 17, 17, 17, 17,
|
||||
17, 18, 18, 18, 17,
|
||||
17, 17, 17, 17, 17};
|
||||
|
||||
std::string lib_name;
|
||||
#if defined(_WIN32)
|
||||
lib_name = "custom_op_library.dll";
|
||||
#elif defined(__APPLE__)
|
||||
lib_name = "libcustom_op_library.dylib";
|
||||
#else
|
||||
lib_name = "libcustom_op_library.so";
|
||||
#endif
|
||||
#if defined(_WIN32)
|
||||
lib_name = "custom_op_library.dll";
|
||||
#elif defined(__APPLE__)
|
||||
lib_name = "libcustom_op_library.dylib";
|
||||
#else
|
||||
lib_name = "libcustom_op_library.so";
|
||||
#endif
|
||||
|
||||
TestInference<PATH_TYPE, int32_t>(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y, expected_values_y, 0, nullptr, lib_name.c_str());
|
||||
TestInference<PATH_TYPE, int32_t>(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y, expected_values_y, 0, nullptr, lib_name.c_str());
|
||||
}
|
||||
|
||||
|
||||
|
||||
#if defined(ENABLE_LANGUAGE_INTEROP_OPS) && !defined(_WIN32) // on windows, PYTHONHOME must be set explicitly
|
||||
TEST(CApiTest, DISABLED_test_pyop) {
|
||||
std::cout << "Test model with pyop" << std::endl;
|
||||
|
|
@ -422,4 +420,69 @@ TEST(CApiTest, override_initializer) {
|
|||
ASSERT_EQ(type_info.GetElementCount(), 1U);
|
||||
float* output_data = ort_outputs[2].GetTensorMutableData<float>();
|
||||
ASSERT_EQ(*output_data, f11_input_data[0]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CApiTest, end_profiling) {
|
||||
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
|
||||
auto allocator = onnxruntime::make_unique<MockedOrtAllocator>();
|
||||
|
||||
// Create session with profiling enabled (profiling is automatically turned on)
|
||||
Ort::SessionOptions session_options_1;
|
||||
#ifdef _WIN32
|
||||
session_options_1.EnableProfiling(L"profile_prefix");
|
||||
#else
|
||||
session_options_1.EnableProfiling("profile_prefix");
|
||||
#endif
|
||||
Ort::Session session_1(*ort_env, MODEL_WITH_CUSTOM_MODEL_METADATA, session_options_1);
|
||||
char* profile_file = session_1.EndProfiling(allocator.get());
|
||||
|
||||
ASSERT_TRUE(std::string(profile_file).find("profile_prefix") != std::string::npos);
|
||||
|
||||
// Create session with profiling disabled
|
||||
Ort::SessionOptions session_options_2;
|
||||
#ifdef _WIN32
|
||||
session_options_2.DisableProfiling();
|
||||
#else
|
||||
session_options_2.DisableProfiling();
|
||||
#endif
|
||||
Ort::Session session_2(*ort_env, MODEL_WITH_CUSTOM_MODEL_METADATA, session_options_2);
|
||||
profile_file = session_2.EndProfiling(allocator.get());
|
||||
|
||||
ASSERT_TRUE(std::string(profile_file) == std::string());
|
||||
}
|
||||
|
||||
TEST(CApiTest, model_metadata) {
|
||||
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
|
||||
auto allocator = onnxruntime::make_unique<MockedOrtAllocator>();
|
||||
|
||||
// Create session
|
||||
Ort::SessionOptions session_options;
|
||||
Ort::Session session(*ort_env, MODEL_WITH_CUSTOM_MODEL_METADATA, session_options);
|
||||
|
||||
// Fetch model metadata
|
||||
// The following all tap into the c++ APIs which internally wrap over C APIs
|
||||
auto model_metadata = session.GetModelMetadata();
|
||||
|
||||
char* producer_name = model_metadata.GetProducerName(allocator.get());
|
||||
ASSERT_TRUE(strcmp("Hari", producer_name) == 0);
|
||||
|
||||
char* graph_name = model_metadata.GetGraphName(allocator.get());
|
||||
ASSERT_TRUE(strcmp("matmul test", graph_name) == 0);
|
||||
|
||||
char* domain = model_metadata.GetDomain(allocator.get());
|
||||
ASSERT_TRUE(strcmp("", domain) == 0);
|
||||
|
||||
char* description = model_metadata.GetDescription(allocator.get());
|
||||
ASSERT_TRUE(strcmp("This is a test model with a valid ORT config Json", description) == 0);
|
||||
|
||||
int64_t version = model_metadata.GetVersion();
|
||||
ASSERT_TRUE(version == 1);
|
||||
|
||||
char* lookup_value = model_metadata.LookupCustomMetadataMap("ort_config", allocator.get());
|
||||
ASSERT_TRUE(strcmp(lookup_value,
|
||||
"{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, \"graph_optimization_level\": 99, \"enable_profiling\": 1}}") == 0);
|
||||
|
||||
// key doesn't exist in custom metadata map
|
||||
lookup_value = model_metadata.LookupCustomMetadataMap("key_doesnt_exist", allocator.get());
|
||||
ASSERT_TRUE(lookup_value == nullptr);
|
||||
}
|
||||
|
|
|
|||
Binary file not shown.
Loading…
Reference in a new issue