mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Ryanunderhill/c api 8 (#297)
* Make OrtAllocator not be reference counted * Make the allocator interface more type safe * Fix build break * Build break fix * Build break fix * Mistake in previous build fix. * Fix review comments + build break * Missed the export symbols * C specific error, need 'struct' keyword in one case. * Function calling OrtReleaseObject instead of OrtReleaseEnv
This commit is contained in:
parent
751eb60819
commit
98a92547bf
14 changed files with 154 additions and 203 deletions
|
|
@ -132,6 +132,7 @@ typedef enum OrtErrorCode {
|
|||
ORT_API(void, OrtRelease##X, _Frees_ptr_opt_ Ort##X* input);
|
||||
|
||||
// The actual types defined have an Ort prefix
|
||||
ORT_RUNTIME_CLASS(Env);
|
||||
ORT_RUNTIME_CLASS(Status); // nullptr for Status* indicates success
|
||||
ORT_RUNTIME_CLASS(Provider);
|
||||
ORT_RUNTIME_CLASS(AllocatorInfo);
|
||||
|
|
@ -147,8 +148,6 @@ struct OrtRunOptions;
|
|||
typedef struct OrtRunOptions OrtRunOptions;
|
||||
struct OrtSessionOptions;
|
||||
typedef struct OrtSessionOptions OrtSessionOptions;
|
||||
struct OrtEnv;
|
||||
typedef struct OrtEnv OrtEnv;
|
||||
|
||||
/**
|
||||
* Every type inherented from OrtObject should be deleted by OrtReleaseObject(...).
|
||||
|
|
@ -161,17 +160,15 @@ typedef struct OrtObject {
|
|||
|
||||
} OrtObject;
|
||||
|
||||
//inherented from OrtObject
|
||||
typedef struct OrtAllocatorInterface {
|
||||
struct OrtObject parent;
|
||||
void*(ORT_API_CALL* Alloc)(void* this_, size_t size);
|
||||
void(ORT_API_CALL* Free)(void* this_, void* p);
|
||||
const struct OrtAllocatorInfo*(ORT_API_CALL* Info)(const void* this_);
|
||||
} OrtAllocatorInterface;
|
||||
// When passing in an allocator to any ORT function, be sure that the allocator object
|
||||
// is not destroyed until the last allocated object using it is freed.
|
||||
typedef struct OrtAllocator {
|
||||
void*(ORT_API_CALL* Alloc)(struct OrtAllocator* this_, size_t size);
|
||||
void(ORT_API_CALL* Free)(struct OrtAllocator* this_, void* p);
|
||||
const struct OrtAllocatorInfo*(ORT_API_CALL* Info)(const struct OrtAllocator* this_);
|
||||
} OrtAllocator;
|
||||
|
||||
typedef OrtAllocatorInterface* OrtAllocator;
|
||||
|
||||
//Inherented from OrtObject
|
||||
// Inherented from OrtObject
|
||||
typedef struct OrtProviderFactoryInterface {
|
||||
OrtObject parent;
|
||||
OrtStatus*(ORT_API_CALL* CreateProvider)(void* this_, OrtProvider** out);
|
||||
|
|
@ -183,7 +180,7 @@ typedef void(ORT_API_CALL* OrtLoggingFunction)(
|
|||
|
||||
/**
|
||||
* OrtEnv is process-wide. For each process, only one OrtEnv can be created.
|
||||
* \param out Should be freed by `OrtReleaseObject` after use
|
||||
* \param out Should be freed by `OrtReleaseEnv` after use
|
||||
*/
|
||||
ORT_API_STATUS(OrtInitialize, OrtLoggingLevel default_warning_level, _In_ const char* logid, _Out_ OrtEnv** out)
|
||||
ORT_ALL_ARGS_NONNULL;
|
||||
|
|
@ -216,7 +213,7 @@ ORT_API_STATUS(OrtRun, _Inout_ OrtSession* sess,
|
|||
*/
|
||||
ORT_API(OrtSessionOptions*, OrtCreateSessionOptions);
|
||||
|
||||
/// create a copy of an existing OrtSessionOptions
|
||||
// create a copy of an existing OrtSessionOptions
|
||||
ORT_API(OrtSessionOptions*, OrtCloneSessionOptions, OrtSessionOptions*);
|
||||
ORT_API(void, OrtEnableSequentialExecution, _In_ OrtSessionOptions* options);
|
||||
ORT_API(void, OrtDisableSequentialExecution, _In_ OrtSessionOptions* options);
|
||||
|
|
@ -238,13 +235,13 @@ ORT_API(void, OrtDisableMemPattern, _In_ OrtSessionOptions* options);
|
|||
ORT_API(void, OrtEnableCpuMemArena, _In_ OrtSessionOptions* options);
|
||||
ORT_API(void, OrtDisableCpuMemArena, _In_ OrtSessionOptions* options);
|
||||
|
||||
///< logger id to use for session output
|
||||
// < logger id to use for session output
|
||||
ORT_API(void, OrtSetSessionLogId, _In_ OrtSessionOptions* options, const char* logid);
|
||||
|
||||
///< applies to session load, initialization, etc
|
||||
// < applies to session load, initialization, etc
|
||||
ORT_API(void, OrtSetSessionLogVerbosityLevel, _In_ OrtSessionOptions* options, uint32_t session_log_verbosity_level);
|
||||
|
||||
///How many threads in the session thread pool.
|
||||
// How many threads in the session thread pool.
|
||||
ORT_API(int, OrtSetSessionThreadPoolSize, _In_ OrtSessionOptions* options, int session_thread_pool_size);
|
||||
|
||||
/**
|
||||
|
|
@ -269,6 +266,9 @@ ORT_API_STATUS(OrtSessionGetInputTypeInfo, _In_ const OrtSession* sess, size_t i
|
|||
*/
|
||||
ORT_API_STATUS(OrtSessionGetOutputTypeInfo, _In_ const OrtSession* sess, size_t index, _Out_ OrtTypeInfo** out);
|
||||
|
||||
/**
|
||||
* \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible in freeing it.
|
||||
*/
|
||||
ORT_API_STATUS(OrtSessionGetInputName, _In_ const OrtSession* sess, size_t index,
|
||||
_Inout_ OrtAllocator* allocator, _Out_ char** value);
|
||||
ORT_API_STATUS(OrtSessionGetOutputName, _In_ const OrtSession* sess, size_t index,
|
||||
|
|
@ -291,8 +291,7 @@ ORT_API(void, OrtRunOptionsSetTerminate, _In_ OrtRunOptions*, _In_ int flag);
|
|||
|
||||
/**
|
||||
* Create a tensor from an allocator. OrtReleaseValue will also release the buffer inside the output value
|
||||
* \param out will keep a reference to the allocator, without reference counting(will be fixed). Should be freed by
|
||||
* calling OrtReleaseValue
|
||||
* \param out Should be freed by calling OrtReleaseValue
|
||||
* \param type must be one of TENSOR_ELEMENT_DATA_TYPE_xxxx
|
||||
*/
|
||||
ORT_API_STATUS(OrtCreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator,
|
||||
|
|
@ -441,8 +440,8 @@ ORT_API(void*, OrtAllocatorAlloc, _Inout_ OrtAllocator* ptr, size_t size);
|
|||
ORT_API(void, OrtAllocatorFree, _Inout_ OrtAllocator* ptr, void* p);
|
||||
ORT_API(const OrtAllocatorInfo*, OrtAllocatorGetInfo, _In_ const OrtAllocator* ptr);
|
||||
|
||||
// Call OrtReleaseObject to release the returned value
|
||||
ORT_API_STATUS(OrtCreateDefaultAllocator, _Out_ OrtAllocator** out);
|
||||
ORT_API(void, OrtReleaseAllocator, _In_ OrtAllocator* allocator);
|
||||
|
||||
/**
|
||||
* \param msg A null-terminated string. Its content will be copied into the newly created OrtStatus
|
||||
|
|
@ -454,16 +453,11 @@ ORT_API(OrtErrorCode, OrtGetErrorCode, _In_ const OrtStatus* status)
|
|||
ORT_ALL_ARGS_NONNULL;
|
||||
/**
|
||||
* \param status must not be NULL
|
||||
* \return The error message inside the `status`. Don't free the returned value.
|
||||
* \return The error message inside the `status`. Do not free the returned value.
|
||||
*/
|
||||
ORT_API(const char*, OrtGetErrorMessage, _In_ const OrtStatus* status)
|
||||
ORT_ALL_ARGS_NONNULL;
|
||||
|
||||
/**
|
||||
* Deprecated. Please use OrtReleaseObject
|
||||
*/
|
||||
ORT_API(void, OrtReleaseEnv, OrtEnv* env);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -24,6 +24,22 @@
|
|||
return Ort##NAME(value.get()); \
|
||||
}
|
||||
|
||||
namespace std {
|
||||
template <>
|
||||
struct default_delete<OrtAllocator> {
|
||||
void operator()(OrtAllocator* ptr) {
|
||||
OrtReleaseAllocator(ptr);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct default_delete<OrtEnv> {
|
||||
void operator()(OrtEnv* ptr) {
|
||||
OrtReleaseEnv(ptr);
|
||||
}
|
||||
};
|
||||
} // namespace std
|
||||
|
||||
#define DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(TYPE_NAME) \
|
||||
namespace std { \
|
||||
template <> \
|
||||
|
|
@ -34,9 +50,7 @@
|
|||
}; \
|
||||
}
|
||||
|
||||
DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(Env);
|
||||
DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(TypeInfo);
|
||||
DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(Allocator);
|
||||
DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(TensorTypeAndShapeInfo);
|
||||
DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(RunOptions);
|
||||
DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(SessionOptions);
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ OrtGetValueType
|
|||
OrtInitialize
|
||||
OrtInitializeWithCustomLogger
|
||||
OrtIsTensor
|
||||
OrtReleaseAllocator
|
||||
OrtReleaseAllocatorInfo
|
||||
OrtReleaseEnv
|
||||
OrtReleaseObject
|
||||
|
|
|
|||
|
|
@ -8,20 +8,15 @@
|
|||
namespace onnxruntime {
|
||||
class AllocatorWrapper : public IAllocator {
|
||||
public:
|
||||
AllocatorWrapper(OrtAllocator* impl) : impl_(impl) {
|
||||
(*impl)->parent.AddRef(impl);
|
||||
}
|
||||
~AllocatorWrapper() {
|
||||
(*impl_)->parent.Release(impl_);
|
||||
}
|
||||
AllocatorWrapper(OrtAllocator* impl) : impl_(impl) {}
|
||||
void* Alloc(size_t size) override {
|
||||
return (*impl_)->Alloc(impl_, size);
|
||||
return impl_->Alloc(impl_, size);
|
||||
}
|
||||
void Free(void* p) override {
|
||||
return (*impl_)->Free(impl_, p);
|
||||
return impl_->Free(impl_, p);
|
||||
}
|
||||
const OrtAllocatorInfo& Info() const override {
|
||||
return *(OrtAllocatorInfo*)(*impl_)->Info(impl_);
|
||||
return *impl_->Info(impl_);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -5,81 +5,55 @@
|
|||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
#include <assert.h>
|
||||
|
||||
#define ORT_ALLOCATOR_IMPL_BEGIN(CLASS_NAME) \
|
||||
class CLASS_NAME { \
|
||||
private: \
|
||||
const OrtAllocatorInterface* vtable_ = &table_; \
|
||||
std::atomic_int ref_count_; \
|
||||
static void* ORT_API_CALL Alloc_(void* this_ptr, size_t size) { \
|
||||
return ((CLASS_NAME*)this_ptr)->Alloc(size); \
|
||||
} \
|
||||
static void ORT_API_CALL Free_(void* this_ptr, void* p) { \
|
||||
return ((CLASS_NAME*)this_ptr)->Free(p); \
|
||||
} \
|
||||
static const OrtAllocatorInfo* ORT_API_CALL Info_(const void* this_ptr) { \
|
||||
return ((const CLASS_NAME*)this_ptr)->Info(); \
|
||||
} \
|
||||
static uint32_t ORT_API_CALL AddRef_(void* this_) { \
|
||||
CLASS_NAME* this_ptr = (CLASS_NAME*)this_; \
|
||||
return ++this_ptr->ref_count_; \
|
||||
} \
|
||||
static uint32_t ORT_API_CALL Release_(void* this_) { \
|
||||
CLASS_NAME* this_ptr = (CLASS_NAME*)this_; \
|
||||
uint32_t ret = --this_ptr->ref_count_; \
|
||||
if (ret == 0) \
|
||||
delete this_ptr; \
|
||||
return 0; \
|
||||
} \
|
||||
static OrtAllocatorInterface table_;
|
||||
// In the future we'll have more than one allocator type. Since all allocators are of type 'OrtAllocator' and there is a single
|
||||
// OrtReleaseAllocator function, we need to have a common base type that lets us delete them.
|
||||
struct OrtAllocatorImpl : OrtAllocator {
|
||||
virtual ~OrtAllocatorImpl() {}
|
||||
};
|
||||
|
||||
#define ORT_ALLOCATOR_IMPL_END \
|
||||
} \
|
||||
;
|
||||
struct OrtDefaultAllocator : OrtAllocatorImpl {
|
||||
OrtDefaultAllocator() {
|
||||
OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast<OrtDefaultAllocator*>(this_)->Alloc(size); };
|
||||
OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast<OrtDefaultAllocator*>(this_)->Free(p); };
|
||||
OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast<const OrtDefaultAllocator*>(this_)->Info(); };
|
||||
ORT_THROW_ON_ERROR(OrtCreateAllocatorInfo("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault, &cpuAllocatorInfo));
|
||||
}
|
||||
|
||||
ORT_ALLOCATOR_IMPL_BEGIN(OrtDefaultAllocator)
|
||||
private:
|
||||
OrtAllocatorInfo* cpuAllocatorInfo;
|
||||
OrtDefaultAllocator() : ref_count_(1) {
|
||||
ORT_THROW_ON_ERROR(OrtCreateAllocatorInfo("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault, &cpuAllocatorInfo));
|
||||
}
|
||||
~OrtDefaultAllocator() {
|
||||
assert(ref_count_ == 0);
|
||||
OrtReleaseAllocatorInfo(cpuAllocatorInfo);
|
||||
}
|
||||
~OrtDefaultAllocator() {
|
||||
OrtReleaseAllocatorInfo(cpuAllocatorInfo);
|
||||
}
|
||||
|
||||
public:
|
||||
OrtDefaultAllocator(const OrtDefaultAllocator&) = delete;
|
||||
OrtDefaultAllocator& operator=(const OrtDefaultAllocator&) = delete;
|
||||
OrtAllocatorInterface** Upcast() {
|
||||
return const_cast<OrtAllocatorInterface**>(&vtable_);
|
||||
}
|
||||
static OrtAllocatorInterface** Create() {
|
||||
return (OrtAllocatorInterface**)new OrtDefaultAllocator();
|
||||
}
|
||||
void* Alloc(size_t size) {
|
||||
return ::malloc(size);
|
||||
}
|
||||
void Free(void* p) {
|
||||
return ::free(p);
|
||||
}
|
||||
const OrtAllocatorInfo* Info() const {
|
||||
return cpuAllocatorInfo;
|
||||
}
|
||||
ORT_ALLOCATOR_IMPL_END
|
||||
void* Alloc(size_t size) {
|
||||
return ::malloc(size);
|
||||
}
|
||||
void Free(void* p) {
|
||||
return ::free(p);
|
||||
}
|
||||
const OrtAllocatorInfo* Info() const {
|
||||
return cpuAllocatorInfo;
|
||||
}
|
||||
|
||||
private:
|
||||
OrtDefaultAllocator(const OrtDefaultAllocator&) = delete;
|
||||
OrtDefaultAllocator& operator=(const OrtDefaultAllocator&) = delete;
|
||||
|
||||
OrtAllocatorInfo* cpuAllocatorInfo;
|
||||
};
|
||||
|
||||
#define API_IMPL_BEGIN try {
|
||||
#define API_IMPL_END \
|
||||
} \
|
||||
catch (std::exception & ex) { \
|
||||
#define API_IMPL_END \
|
||||
} \
|
||||
catch (std::exception & ex) { \
|
||||
return OrtCreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); \
|
||||
}
|
||||
|
||||
OrtAllocatorInterface OrtDefaultAllocator::table_ = {
|
||||
{OrtDefaultAllocator::AddRef_, OrtDefaultAllocator::Release_}, OrtDefaultAllocator::Alloc_, OrtDefaultAllocator::Free_, OrtDefaultAllocator::Info_};
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtCreateDefaultAllocator, _Out_ OrtAllocator** out) {
|
||||
API_IMPL_BEGIN
|
||||
*out = OrtDefaultAllocator::Create();
|
||||
*out = new OrtDefaultAllocator();
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API(void, OrtReleaseAllocator, _In_ OrtAllocator* allocator) {
|
||||
delete static_cast<OrtAllocatorImpl*>(allocator);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -45,20 +45,18 @@ using onnxruntime::common::Status;
|
|||
if (_status) return _status; \
|
||||
} while (0)
|
||||
|
||||
struct OrtEnv : public onnxruntime::ObjectBase<OrtEnv> {
|
||||
struct OrtEnv {
|
||||
public:
|
||||
Environment* value;
|
||||
LoggingManager* loggingManager;
|
||||
friend class onnxruntime::ObjectBase<OrtEnv>;
|
||||
|
||||
OrtEnv(Environment* value1, LoggingManager* loggingManager1) : value(value1), loggingManager(loggingManager1) {
|
||||
ORT_CHECK_C_OBJECT_LAYOUT;
|
||||
}
|
||||
/**
|
||||
* This function will call ::google::protobuf::ShutdownProtobufLibrary
|
||||
*/
|
||||
~OrtEnv() {
|
||||
assert(ref_count == 0);
|
||||
delete loggingManager;
|
||||
delete value;
|
||||
}
|
||||
|
|
@ -166,7 +164,7 @@ ORT_API_STATUS_IMPL(OrtFillStringTensor, _In_ OrtValue* value, _In_ const char*
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
OrtStatus* CreateTensorImpl(const size_t* shape, size_t shape_len, OrtAllocatorInterface** allocator,
|
||||
OrtStatus* CreateTensorImpl(const size_t* shape, size_t shape_len, OrtAllocator* allocator,
|
||||
std::unique_ptr<Tensor>* out) {
|
||||
size_t elem_count = 1;
|
||||
std::vector<int64_t> shapes(shape_len);
|
||||
|
|
@ -179,13 +177,13 @@ OrtStatus* CreateTensorImpl(const size_t* shape, size_t shape_len, OrtAllocatorI
|
|||
if (!IAllocator::CalcMemSizeForArray(sizeof(T), elem_count, &size_to_allocate)) {
|
||||
return OrtCreateStatus(ORT_FAIL, "not enough memory");
|
||||
}
|
||||
void* p_data = (*allocator)->Alloc(allocator, size_to_allocate);
|
||||
void* p_data = allocator->Alloc(allocator, size_to_allocate);
|
||||
if (p_data == nullptr)
|
||||
return OrtCreateStatus(ORT_FAIL, "size overflow");
|
||||
*out = std::make_unique<Tensor>(DataTypeImpl::GetType<T>(),
|
||||
onnxruntime::TensorShape(shapes),
|
||||
static_cast<void*>(p_data),
|
||||
*(*allocator)->Info(allocator),
|
||||
*allocator->Info(allocator),
|
||||
std::make_shared<onnxruntime::AllocatorWrapper>(allocator));
|
||||
return nullptr;
|
||||
}
|
||||
|
|
@ -574,7 +572,7 @@ ORT_API_STATUS_IMPL(OrtSessionGetOutputTypeInfo, _In_ const OrtSession* sess, si
|
|||
}
|
||||
|
||||
static char* StrDup(const std::string& str, OrtAllocator* allocator) {
|
||||
char* output_string = reinterpret_cast<char*>((*allocator)->Alloc(allocator, str.size() + 1));
|
||||
char* output_string = reinterpret_cast<char*>(allocator->Alloc(allocator, str.size() + 1));
|
||||
memcpy(output_string, str.c_str(), str.size());
|
||||
output_string[str.size()] = '\0';
|
||||
return output_string;
|
||||
|
|
@ -603,7 +601,7 @@ ORT_API(int, OrtIsTensor, _In_ const OrtValue* value) {
|
|||
|
||||
ORT_API(void*, OrtAllocatorAlloc, _Inout_ OrtAllocator* ptr, size_t size) {
|
||||
try {
|
||||
return (*ptr)->Alloc(ptr, size);
|
||||
return ptr->Alloc(ptr, size);
|
||||
} catch (std::exception&) {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
@ -611,14 +609,14 @@ ORT_API(void*, OrtAllocatorAlloc, _Inout_ OrtAllocator* ptr, size_t size) {
|
|||
|
||||
ORT_API(void, OrtAllocatorFree, _Inout_ OrtAllocator* ptr, void* p) {
|
||||
try {
|
||||
(*ptr)->Free(ptr, p);
|
||||
ptr->Free(ptr, p);
|
||||
} catch (std::exception&) {
|
||||
}
|
||||
}
|
||||
|
||||
ORT_API(const struct OrtAllocatorInfo*, OrtAllocatorGetInfo, _In_ const OrtAllocator* ptr) {
|
||||
try {
|
||||
return (*ptr)->Info(ptr);
|
||||
return ptr->Info(ptr);
|
||||
} catch (std::exception&) {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
@ -638,10 +636,7 @@ ORT_API_STATUS_IMPL(OrtSessionGetOutputName, _In_ const OrtSession* sess, size_t
|
|||
API_IMPL_END
|
||||
}
|
||||
|
||||
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Env, OrtEnv)
|
||||
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Value, MLValue)
|
||||
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Session, ::onnxruntime::InferenceSession)
|
||||
DEFINE_RELEASE_ORT_OBJECT_FUNCTION_FOR_ARRAY(Status, char)
|
||||
|
||||
ORT_API(void, OrtReleaseEnv, OrtEnv* env) {
|
||||
OrtReleaseObject(env);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -545,7 +545,7 @@ Status OnnxTestCase::ConvertTestData(OrtSession* session, const std::vector<onnx
|
|||
ORT_THROW_ON_ERROR(OrtSessionGetOutputName(session, i, allocator, &temp_name));
|
||||
}
|
||||
var_names[i] = temp_name;
|
||||
(*allocator)->Free(allocator, temp_name);
|
||||
allocator->Free(allocator, temp_name);
|
||||
}
|
||||
}
|
||||
for (size_t input_index = 0; input_index != test_data_pbs.size(); ++input_index) {
|
||||
|
|
|
|||
|
|
@ -287,13 +287,17 @@ SeqTestRunner::SeqTestRunner(OrtSession* session1,
|
|||
TestCaseCallBack on_finished1) : DataRunner(session1, c->GetTestCaseName(), c, on_finished1), repeat_count_(repeat_count) {
|
||||
}
|
||||
|
||||
DataRunner::DataRunner(OrtSession* session1, const std::string& test_case_name1, ITestCase* c, TestCaseCallBack on_finished1) : test_case_name_(test_case_name1), c_(c), session(session1), on_finished(on_finished1), default_allocator(MockedOrtAllocator::Create()) {
|
||||
DataRunner::DataRunner(OrtSession* session1, const std::string& test_case_name1, ITestCase* c, TestCaseCallBack on_finished1) : test_case_name_(test_case_name1), c_(c), session(session1), on_finished(on_finished1), default_allocator(std::make_unique<MockedOrtAllocator>()) {
|
||||
std::string s;
|
||||
c->GetNodeName(&s);
|
||||
result = std::make_shared<TestCaseResult>(c->GetDataCount(), EXECUTE_RESULT::UNKNOWN_ERROR, s);
|
||||
SetTimeSpecToZero(&spent_time_);
|
||||
}
|
||||
|
||||
DataRunner::~DataRunner() {
|
||||
OrtReleaseSession(session);
|
||||
}
|
||||
|
||||
void DataRunner::RunTask(size_t task_id, ORT_CALLBACK_INSTANCE pci, bool store_result) {
|
||||
EXECUTE_RESULT res = EXECUTE_RESULT::UNKNOWN_ERROR;
|
||||
try {
|
||||
|
|
@ -326,10 +330,10 @@ EXECUTE_RESULT DataRunner::RunTaskImpl(size_t task_id) {
|
|||
std::vector<std::string> output_names(output_count);
|
||||
for (size_t i = 0; i != output_count; ++i) {
|
||||
char* output_name = nullptr;
|
||||
ORT_THROW_ON_ERROR(OrtSessionGetOutputName(session, i, default_allocator, &output_name));
|
||||
ORT_THROW_ON_ERROR(OrtSessionGetOutputName(session, i, default_allocator.get(), &output_name));
|
||||
assert(output_name != nullptr);
|
||||
output_names[i] = output_name;
|
||||
(*default_allocator)->Free(default_allocator, output_name);
|
||||
default_allocator->Free(output_name);
|
||||
}
|
||||
|
||||
TIME_SPEC start_time, end_time;
|
||||
|
|
|
|||
|
|
@ -31,6 +31,8 @@ void ORT_CALLBACK RunTestCase(ORT_CALLBACK_INSTANCE instance, void* context, ORT
|
|||
void ORT_CALLBACK RunSingleDataItem(ORT_CALLBACK_INSTANCE instance, void* context, ORT_WORK work);
|
||||
::onnxruntime::common::Status OnTestCaseFinished(ORT_CALLBACK_INSTANCE pci, TestCaseTask* task, std::shared_ptr<TestCaseResult> result);
|
||||
|
||||
struct MockedOrtAllocator;
|
||||
|
||||
class DataRunner {
|
||||
protected:
|
||||
typedef TestCaseCallBack CALL_BACK;
|
||||
|
|
@ -43,7 +45,7 @@ class DataRunner {
|
|||
private:
|
||||
OrtSession* session;
|
||||
CALL_BACK on_finished;
|
||||
OrtAllocatorInterface** const default_allocator;
|
||||
std::unique_ptr<MockedOrtAllocator> default_allocator;
|
||||
EXECUTE_RESULT RunTaskImpl(size_t task_id);
|
||||
ORT_DISALLOW_COPY_AND_ASSIGNMENT(DataRunner);
|
||||
|
||||
|
|
@ -51,10 +53,7 @@ class DataRunner {
|
|||
DataRunner(OrtSession* session1, const std::string& test_case_name1, ITestCase* c, TestCaseCallBack on_finished1);
|
||||
virtual void OnTaskFinished(size_t task_id, EXECUTE_RESULT res, ORT_CALLBACK_INSTANCE pci) noexcept = 0;
|
||||
void RunTask(size_t task_id, ORT_CALLBACK_INSTANCE pci, bool store_result);
|
||||
virtual ~DataRunner() {
|
||||
OrtReleaseSession(session);
|
||||
OrtReleaseObject(default_allocator);
|
||||
}
|
||||
virtual ~DataRunner();
|
||||
|
||||
virtual void Start(ORT_CALLBACK_INSTANCE pci, size_t concurrent_runs) = 0;
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,6 @@ TEST_F(CApiTest, DefaultAllocator) {
|
|||
memset(p, 0, 100);
|
||||
OrtAllocatorFree(default_allocator.get(), p);
|
||||
const OrtAllocatorInfo* info1 = OrtAllocatorGetInfo(default_allocator.get());
|
||||
const OrtAllocatorInfo* info2 = (*default_allocator)->Info(default_allocator.get());
|
||||
const OrtAllocatorInfo* info2 = default_allocator->Info(default_allocator.get());
|
||||
ASSERT_EQ(0, OrtCompareAllocatorInfo(info1, info2));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class CApiTestImpl : public ::testing::Test {
|
|||
}
|
||||
|
||||
void TearDown() override {
|
||||
if (env) OrtReleaseObject(env);
|
||||
if (env) OrtReleaseEnv(env);
|
||||
}
|
||||
|
||||
// Objects declared here can be used by all tests in the test case for Foo.
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ void TestInference(OrtEnv* env, T model_uri,
|
|||
sf.AppendCustomOpLibPath("libonnxruntime_custom_op_shared_lib_test.so");
|
||||
}
|
||||
std::unique_ptr<OrtSession, decltype(&OrtReleaseSession)> inference_session(sf.OrtCreateSession(model_uri), OrtReleaseSession);
|
||||
std::unique_ptr<OrtAllocator> default_allocator(MockedOrtAllocator::Create());
|
||||
std::unique_ptr<MockedOrtAllocator> default_allocator(std::make_unique<MockedOrtAllocator>());
|
||||
// Now run
|
||||
RunSession(default_allocator.get(), inference_session.get(), dims_x, values_x, expected_dims_y, expected_values_y);
|
||||
}
|
||||
|
|
@ -156,7 +156,7 @@ TEST_F(CApiTest, create_session_without_session_option) {
|
|||
TEST_F(CApiTest, create_tensor) {
|
||||
const char* s[] = {"abc", "kmp"};
|
||||
size_t expected_len = 2;
|
||||
std::unique_ptr<OrtAllocator> default_allocator(MockedOrtAllocator::Create());
|
||||
std::unique_ptr<MockedOrtAllocator> default_allocator(std::make_unique<MockedOrtAllocator>());
|
||||
{
|
||||
std::unique_ptr<OrtValue, decltype(&OrtReleaseValue)> tensor(
|
||||
OrtCreateTensorAsOrtValue(default_allocator.get(), {expected_len}, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING), OrtReleaseValue);
|
||||
|
|
|
|||
|
|
@ -7,79 +7,20 @@
|
|||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
#include <assert.h>
|
||||
|
||||
#define ORT_ALLOCATOR_IMPL_BEGIN(CLASS_NAME) \
|
||||
class CLASS_NAME { \
|
||||
private: \
|
||||
const OrtAllocatorInterface* vtable_ = &table_; \
|
||||
std::atomic_int ref_count_; \
|
||||
static void* ORT_API_CALL Alloc_(void* this_ptr, size_t size) { \
|
||||
return ((CLASS_NAME*)this_ptr)->Alloc(size); \
|
||||
} \
|
||||
static void ORT_API_CALL Free_(void* this_ptr, void* p) { \
|
||||
return ((CLASS_NAME*)this_ptr)->Free(p); \
|
||||
} \
|
||||
static const OrtAllocatorInfo* ORT_API_CALL Info_(const void* this_ptr) { \
|
||||
return ((const CLASS_NAME*)this_ptr)->Info(); \
|
||||
} \
|
||||
static uint32_t ORT_API_CALL AddRef_(void* this_) { \
|
||||
CLASS_NAME* this_ptr = (CLASS_NAME*)this_; \
|
||||
return ++this_ptr->ref_count_; \
|
||||
} \
|
||||
static uint32_t ORT_API_CALL Release_(void* this_) { \
|
||||
CLASS_NAME* this_ptr = (CLASS_NAME*)this_; \
|
||||
uint32_t ret = --this_ptr->ref_count_; \
|
||||
if (ret == 0) \
|
||||
delete this_ptr; \
|
||||
return 0; \
|
||||
} \
|
||||
static OrtAllocatorInterface table_;
|
||||
struct MockedOrtAllocator : OrtAllocator {
|
||||
MockedOrtAllocator();
|
||||
~MockedOrtAllocator();
|
||||
|
||||
#define ORT_ALLOCATOR_IMPL_END \
|
||||
} \
|
||||
;
|
||||
void* Alloc(size_t size);
|
||||
void Free(void* p);
|
||||
const OrtAllocatorInfo* Info() const;
|
||||
|
||||
ORT_ALLOCATOR_IMPL_BEGIN(MockedOrtAllocator)
|
||||
private:
|
||||
std::atomic<size_t> memory_inuse;
|
||||
OrtAllocatorInfo* cpuAllocatorInfo;
|
||||
MockedOrtAllocator() : ref_count_(1), memory_inuse(0) {
|
||||
ORT_THROW_ON_ERROR(OrtCreateAllocatorInfo("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault, &cpuAllocatorInfo));
|
||||
}
|
||||
~MockedOrtAllocator() {
|
||||
assert(ref_count_ == 0);
|
||||
OrtReleaseAllocatorInfo(cpuAllocatorInfo);
|
||||
}
|
||||
void LeakCheck();
|
||||
|
||||
public:
|
||||
MockedOrtAllocator(const MockedOrtAllocator&) = delete;
|
||||
MockedOrtAllocator& operator=(const MockedOrtAllocator&) = delete;
|
||||
OrtAllocatorInterface** Upcast() {
|
||||
return const_cast<OrtAllocatorInterface**>(&vtable_);
|
||||
}
|
||||
static OrtAllocatorInterface** Create() {
|
||||
return (OrtAllocatorInterface**)new MockedOrtAllocator();
|
||||
}
|
||||
void* Alloc(size_t size) {
|
||||
constexpr size_t extra_len = sizeof(size_t);
|
||||
memory_inuse.fetch_add(size += extra_len);
|
||||
void* p = ::malloc(size);
|
||||
*(size_t*)p = size;
|
||||
return (char*)p + extra_len;
|
||||
}
|
||||
void Free(void* p) {
|
||||
constexpr size_t extra_len = sizeof(size_t);
|
||||
if (!p) return;
|
||||
p = (char*)p - extra_len;
|
||||
size_t len = *(size_t*)p;
|
||||
memory_inuse.fetch_sub(len);
|
||||
return ::free(p);
|
||||
}
|
||||
const OrtAllocatorInfo* Info() const {
|
||||
return cpuAllocatorInfo;
|
||||
}
|
||||
private:
|
||||
MockedOrtAllocator(const MockedOrtAllocator&) = delete;
|
||||
MockedOrtAllocator& operator=(const MockedOrtAllocator&) = delete;
|
||||
|
||||
void LeakCheck() {
|
||||
if (memory_inuse.load())
|
||||
throw std::runtime_error("memory leak!!!");
|
||||
}
|
||||
ORT_ALLOCATOR_IMPL_END
|
||||
std::atomic<size_t> memory_inuse{0};
|
||||
OrtAllocatorInfo* cpuAllocatorInfo;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -3,5 +3,39 @@
|
|||
|
||||
#include "test_allocator.h"
|
||||
|
||||
OrtAllocatorInterface MockedOrtAllocator::table_ = {
|
||||
{MockedOrtAllocator::AddRef_, MockedOrtAllocator::Release_}, MockedOrtAllocator::Alloc_, MockedOrtAllocator::Free_, MockedOrtAllocator::Info_};
|
||||
MockedOrtAllocator::MockedOrtAllocator() {
|
||||
OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast<MockedOrtAllocator*>(this_)->Alloc(size); };
|
||||
OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast<MockedOrtAllocator*>(this_)->Free(p); };
|
||||
OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast<const MockedOrtAllocator*>(this_)->Info(); };
|
||||
ORT_THROW_ON_ERROR(OrtCreateAllocatorInfo("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault, &cpuAllocatorInfo));
|
||||
}
|
||||
|
||||
MockedOrtAllocator::~MockedOrtAllocator() {
|
||||
OrtReleaseAllocatorInfo(cpuAllocatorInfo);
|
||||
}
|
||||
|
||||
void* MockedOrtAllocator::Alloc(size_t size) {
|
||||
constexpr size_t extra_len = sizeof(size_t);
|
||||
memory_inuse.fetch_add(size += extra_len);
|
||||
void* p = ::malloc(size);
|
||||
*(size_t*)p = size;
|
||||
return (char*)p + extra_len;
|
||||
}
|
||||
|
||||
void MockedOrtAllocator::Free(void* p) {
|
||||
constexpr size_t extra_len = sizeof(size_t);
|
||||
if (!p) return;
|
||||
p = (char*)p - extra_len;
|
||||
size_t len = *(size_t*)p;
|
||||
memory_inuse.fetch_sub(len);
|
||||
return ::free(p);
|
||||
}
|
||||
|
||||
const OrtAllocatorInfo* MockedOrtAllocator::Info() const {
|
||||
return cpuAllocatorInfo;
|
||||
}
|
||||
|
||||
void MockedOrtAllocator::LeakCheck() {
|
||||
if (memory_inuse.load())
|
||||
throw std::runtime_error("memory leak!!!");
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue