Ryanunderhill/c api 8 (#297)

* Make OrtAllocator not be reference counted

* Make the allocator interface more type safe

* Fix build break

* Build break fix

* Build break fix

* Mistake in previous build fix.

* Fix review comments + build break

* Missed the export symbols

* C specific error, need 'struct' keyword in one case.

* Function calling OrtReleaseObject instead of OrtReleaseEnv
This commit is contained in:
Ryan Hill 2019-01-10 02:06:29 -08:00 committed by GitHub
parent 751eb60819
commit 98a92547bf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 154 additions and 203 deletions

View file

@ -132,6 +132,7 @@ typedef enum OrtErrorCode {
ORT_API(void, OrtRelease##X, _Frees_ptr_opt_ Ort##X* input);
// The actual types defined have an Ort prefix
ORT_RUNTIME_CLASS(Env);
ORT_RUNTIME_CLASS(Status); // nullptr for Status* indicates success
ORT_RUNTIME_CLASS(Provider);
ORT_RUNTIME_CLASS(AllocatorInfo);
@ -147,8 +148,6 @@ struct OrtRunOptions;
typedef struct OrtRunOptions OrtRunOptions;
struct OrtSessionOptions;
typedef struct OrtSessionOptions OrtSessionOptions;
struct OrtEnv;
typedef struct OrtEnv OrtEnv;
/**
* Every type inherented from OrtObject should be deleted by OrtReleaseObject(...).
@ -161,17 +160,15 @@ typedef struct OrtObject {
} OrtObject;
//inherented from OrtObject
typedef struct OrtAllocatorInterface {
struct OrtObject parent;
void*(ORT_API_CALL* Alloc)(void* this_, size_t size);
void(ORT_API_CALL* Free)(void* this_, void* p);
const struct OrtAllocatorInfo*(ORT_API_CALL* Info)(const void* this_);
} OrtAllocatorInterface;
// When passing in an allocator to any ORT function, be sure that the allocator object
// is not destroyed until the last allocated object using it is freed.
typedef struct OrtAllocator {
void*(ORT_API_CALL* Alloc)(struct OrtAllocator* this_, size_t size);
void(ORT_API_CALL* Free)(struct OrtAllocator* this_, void* p);
const struct OrtAllocatorInfo*(ORT_API_CALL* Info)(const struct OrtAllocator* this_);
} OrtAllocator;
typedef OrtAllocatorInterface* OrtAllocator;
//Inherented from OrtObject
// Inherented from OrtObject
typedef struct OrtProviderFactoryInterface {
OrtObject parent;
OrtStatus*(ORT_API_CALL* CreateProvider)(void* this_, OrtProvider** out);
@ -183,7 +180,7 @@ typedef void(ORT_API_CALL* OrtLoggingFunction)(
/**
* OrtEnv is process-wide. For each process, only one OrtEnv can be created.
* \param out Should be freed by `OrtReleaseObject` after use
* \param out Should be freed by `OrtReleaseEnv` after use
*/
ORT_API_STATUS(OrtInitialize, OrtLoggingLevel default_warning_level, _In_ const char* logid, _Out_ OrtEnv** out)
ORT_ALL_ARGS_NONNULL;
@ -216,7 +213,7 @@ ORT_API_STATUS(OrtRun, _Inout_ OrtSession* sess,
*/
ORT_API(OrtSessionOptions*, OrtCreateSessionOptions);
/// create a copy of an existing OrtSessionOptions
// create a copy of an existing OrtSessionOptions
ORT_API(OrtSessionOptions*, OrtCloneSessionOptions, OrtSessionOptions*);
ORT_API(void, OrtEnableSequentialExecution, _In_ OrtSessionOptions* options);
ORT_API(void, OrtDisableSequentialExecution, _In_ OrtSessionOptions* options);
@ -238,13 +235,13 @@ ORT_API(void, OrtDisableMemPattern, _In_ OrtSessionOptions* options);
ORT_API(void, OrtEnableCpuMemArena, _In_ OrtSessionOptions* options);
ORT_API(void, OrtDisableCpuMemArena, _In_ OrtSessionOptions* options);
///< logger id to use for session output
// < logger id to use for session output
ORT_API(void, OrtSetSessionLogId, _In_ OrtSessionOptions* options, const char* logid);
///< applies to session load, initialization, etc
// < applies to session load, initialization, etc
ORT_API(void, OrtSetSessionLogVerbosityLevel, _In_ OrtSessionOptions* options, uint32_t session_log_verbosity_level);
///How many threads in the session thread pool.
// How many threads in the session thread pool.
ORT_API(int, OrtSetSessionThreadPoolSize, _In_ OrtSessionOptions* options, int session_thread_pool_size);
/**
@ -269,6 +266,9 @@ ORT_API_STATUS(OrtSessionGetInputTypeInfo, _In_ const OrtSession* sess, size_t i
*/
ORT_API_STATUS(OrtSessionGetOutputTypeInfo, _In_ const OrtSession* sess, size_t index, _Out_ OrtTypeInfo** out);
/**
* \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible in freeing it.
*/
ORT_API_STATUS(OrtSessionGetInputName, _In_ const OrtSession* sess, size_t index,
_Inout_ OrtAllocator* allocator, _Out_ char** value);
ORT_API_STATUS(OrtSessionGetOutputName, _In_ const OrtSession* sess, size_t index,
@ -291,8 +291,7 @@ ORT_API(void, OrtRunOptionsSetTerminate, _In_ OrtRunOptions*, _In_ int flag);
/**
* Create a tensor from an allocator. OrtReleaseValue will also release the buffer inside the output value
* \param out will keep a reference to the allocator, without reference counting(will be fixed). Should be freed by
* calling OrtReleaseValue
* \param out Should be freed by calling OrtReleaseValue
* \param type must be one of TENSOR_ELEMENT_DATA_TYPE_xxxx
*/
ORT_API_STATUS(OrtCreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator,
@ -441,8 +440,8 @@ ORT_API(void*, OrtAllocatorAlloc, _Inout_ OrtAllocator* ptr, size_t size);
ORT_API(void, OrtAllocatorFree, _Inout_ OrtAllocator* ptr, void* p);
ORT_API(const OrtAllocatorInfo*, OrtAllocatorGetInfo, _In_ const OrtAllocator* ptr);
// Call OrtReleaseObject to release the returned value
ORT_API_STATUS(OrtCreateDefaultAllocator, _Out_ OrtAllocator** out);
ORT_API(void, OrtReleaseAllocator, _In_ OrtAllocator* allocator);
/**
* \param msg A null-terminated string. Its content will be copied into the newly created OrtStatus
@ -454,16 +453,11 @@ ORT_API(OrtErrorCode, OrtGetErrorCode, _In_ const OrtStatus* status)
ORT_ALL_ARGS_NONNULL;
/**
* \param status must not be NULL
* \return The error message inside the `status`. Don't free the returned value.
* \return The error message inside the `status`. Do not free the returned value.
*/
ORT_API(const char*, OrtGetErrorMessage, _In_ const OrtStatus* status)
ORT_ALL_ARGS_NONNULL;
/**
* Deprecated. Please use OrtReleaseObject
*/
ORT_API(void, OrtReleaseEnv, OrtEnv* env);
#ifdef __cplusplus
}
#endif

View file

@ -24,6 +24,22 @@
return Ort##NAME(value.get()); \
}
namespace std {
template <>
struct default_delete<OrtAllocator> {
void operator()(OrtAllocator* ptr) {
OrtReleaseAllocator(ptr);
}
};
template <>
struct default_delete<OrtEnv> {
void operator()(OrtEnv* ptr) {
OrtReleaseEnv(ptr);
}
};
} // namespace std
#define DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(TYPE_NAME) \
namespace std { \
template <> \
@ -34,9 +50,7 @@
}; \
}
DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(Env);
DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(TypeInfo);
DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(Allocator);
DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(TensorTypeAndShapeInfo);
DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(RunOptions);
DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(SessionOptions);

View file

@ -44,6 +44,7 @@ OrtGetValueType
OrtInitialize
OrtInitializeWithCustomLogger
OrtIsTensor
OrtReleaseAllocator
OrtReleaseAllocatorInfo
OrtReleaseEnv
OrtReleaseObject

View file

@ -8,20 +8,15 @@
namespace onnxruntime {
class AllocatorWrapper : public IAllocator {
public:
AllocatorWrapper(OrtAllocator* impl) : impl_(impl) {
(*impl)->parent.AddRef(impl);
}
~AllocatorWrapper() {
(*impl_)->parent.Release(impl_);
}
AllocatorWrapper(OrtAllocator* impl) : impl_(impl) {}
void* Alloc(size_t size) override {
return (*impl_)->Alloc(impl_, size);
return impl_->Alloc(impl_, size);
}
void Free(void* p) override {
return (*impl_)->Free(impl_, p);
return impl_->Free(impl_, p);
}
const OrtAllocatorInfo& Info() const override {
return *(OrtAllocatorInfo*)(*impl_)->Info(impl_);
return *impl_->Info(impl_);
}
private:

View file

@ -5,81 +5,55 @@
#include "core/session/onnxruntime_cxx_api.h"
#include <assert.h>
#define ORT_ALLOCATOR_IMPL_BEGIN(CLASS_NAME) \
class CLASS_NAME { \
private: \
const OrtAllocatorInterface* vtable_ = &table_; \
std::atomic_int ref_count_; \
static void* ORT_API_CALL Alloc_(void* this_ptr, size_t size) { \
return ((CLASS_NAME*)this_ptr)->Alloc(size); \
} \
static void ORT_API_CALL Free_(void* this_ptr, void* p) { \
return ((CLASS_NAME*)this_ptr)->Free(p); \
} \
static const OrtAllocatorInfo* ORT_API_CALL Info_(const void* this_ptr) { \
return ((const CLASS_NAME*)this_ptr)->Info(); \
} \
static uint32_t ORT_API_CALL AddRef_(void* this_) { \
CLASS_NAME* this_ptr = (CLASS_NAME*)this_; \
return ++this_ptr->ref_count_; \
} \
static uint32_t ORT_API_CALL Release_(void* this_) { \
CLASS_NAME* this_ptr = (CLASS_NAME*)this_; \
uint32_t ret = --this_ptr->ref_count_; \
if (ret == 0) \
delete this_ptr; \
return 0; \
} \
static OrtAllocatorInterface table_;
// In the future we'll have more than one allocator type. Since all allocators are of type 'OrtAllocator' and there is a single
// OrtReleaseAllocator function, we need to have a common base type that lets us delete them.
struct OrtAllocatorImpl : OrtAllocator {
virtual ~OrtAllocatorImpl() {}
};
#define ORT_ALLOCATOR_IMPL_END \
} \
;
struct OrtDefaultAllocator : OrtAllocatorImpl {
OrtDefaultAllocator() {
OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast<OrtDefaultAllocator*>(this_)->Alloc(size); };
OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast<OrtDefaultAllocator*>(this_)->Free(p); };
OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast<const OrtDefaultAllocator*>(this_)->Info(); };
ORT_THROW_ON_ERROR(OrtCreateAllocatorInfo("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault, &cpuAllocatorInfo));
}
ORT_ALLOCATOR_IMPL_BEGIN(OrtDefaultAllocator)
private:
OrtAllocatorInfo* cpuAllocatorInfo;
OrtDefaultAllocator() : ref_count_(1) {
ORT_THROW_ON_ERROR(OrtCreateAllocatorInfo("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault, &cpuAllocatorInfo));
}
~OrtDefaultAllocator() {
assert(ref_count_ == 0);
OrtReleaseAllocatorInfo(cpuAllocatorInfo);
}
~OrtDefaultAllocator() {
OrtReleaseAllocatorInfo(cpuAllocatorInfo);
}
public:
OrtDefaultAllocator(const OrtDefaultAllocator&) = delete;
OrtDefaultAllocator& operator=(const OrtDefaultAllocator&) = delete;
OrtAllocatorInterface** Upcast() {
return const_cast<OrtAllocatorInterface**>(&vtable_);
}
static OrtAllocatorInterface** Create() {
return (OrtAllocatorInterface**)new OrtDefaultAllocator();
}
void* Alloc(size_t size) {
return ::malloc(size);
}
void Free(void* p) {
return ::free(p);
}
const OrtAllocatorInfo* Info() const {
return cpuAllocatorInfo;
}
ORT_ALLOCATOR_IMPL_END
void* Alloc(size_t size) {
return ::malloc(size);
}
void Free(void* p) {
return ::free(p);
}
const OrtAllocatorInfo* Info() const {
return cpuAllocatorInfo;
}
private:
OrtDefaultAllocator(const OrtDefaultAllocator&) = delete;
OrtDefaultAllocator& operator=(const OrtDefaultAllocator&) = delete;
OrtAllocatorInfo* cpuAllocatorInfo;
};
#define API_IMPL_BEGIN try {
#define API_IMPL_END \
} \
catch (std::exception & ex) { \
#define API_IMPL_END \
} \
catch (std::exception & ex) { \
return OrtCreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); \
}
OrtAllocatorInterface OrtDefaultAllocator::table_ = {
{OrtDefaultAllocator::AddRef_, OrtDefaultAllocator::Release_}, OrtDefaultAllocator::Alloc_, OrtDefaultAllocator::Free_, OrtDefaultAllocator::Info_};
ORT_API_STATUS_IMPL(OrtCreateDefaultAllocator, _Out_ OrtAllocator** out) {
API_IMPL_BEGIN
*out = OrtDefaultAllocator::Create();
*out = new OrtDefaultAllocator();
return nullptr;
API_IMPL_END
}
ORT_API(void, OrtReleaseAllocator, _In_ OrtAllocator* allocator) {
delete static_cast<OrtAllocatorImpl*>(allocator);
}

View file

@ -45,20 +45,18 @@ using onnxruntime::common::Status;
if (_status) return _status; \
} while (0)
struct OrtEnv : public onnxruntime::ObjectBase<OrtEnv> {
struct OrtEnv {
public:
Environment* value;
LoggingManager* loggingManager;
friend class onnxruntime::ObjectBase<OrtEnv>;
OrtEnv(Environment* value1, LoggingManager* loggingManager1) : value(value1), loggingManager(loggingManager1) {
ORT_CHECK_C_OBJECT_LAYOUT;
}
/**
* This function will call ::google::protobuf::ShutdownProtobufLibrary
*/
~OrtEnv() {
assert(ref_count == 0);
delete loggingManager;
delete value;
}
@ -166,7 +164,7 @@ ORT_API_STATUS_IMPL(OrtFillStringTensor, _In_ OrtValue* value, _In_ const char*
}
template <typename T>
OrtStatus* CreateTensorImpl(const size_t* shape, size_t shape_len, OrtAllocatorInterface** allocator,
OrtStatus* CreateTensorImpl(const size_t* shape, size_t shape_len, OrtAllocator* allocator,
std::unique_ptr<Tensor>* out) {
size_t elem_count = 1;
std::vector<int64_t> shapes(shape_len);
@ -179,13 +177,13 @@ OrtStatus* CreateTensorImpl(const size_t* shape, size_t shape_len, OrtAllocatorI
if (!IAllocator::CalcMemSizeForArray(sizeof(T), elem_count, &size_to_allocate)) {
return OrtCreateStatus(ORT_FAIL, "not enough memory");
}
void* p_data = (*allocator)->Alloc(allocator, size_to_allocate);
void* p_data = allocator->Alloc(allocator, size_to_allocate);
if (p_data == nullptr)
return OrtCreateStatus(ORT_FAIL, "size overflow");
*out = std::make_unique<Tensor>(DataTypeImpl::GetType<T>(),
onnxruntime::TensorShape(shapes),
static_cast<void*>(p_data),
*(*allocator)->Info(allocator),
*allocator->Info(allocator),
std::make_shared<onnxruntime::AllocatorWrapper>(allocator));
return nullptr;
}
@ -574,7 +572,7 @@ ORT_API_STATUS_IMPL(OrtSessionGetOutputTypeInfo, _In_ const OrtSession* sess, si
}
static char* StrDup(const std::string& str, OrtAllocator* allocator) {
char* output_string = reinterpret_cast<char*>((*allocator)->Alloc(allocator, str.size() + 1));
char* output_string = reinterpret_cast<char*>(allocator->Alloc(allocator, str.size() + 1));
memcpy(output_string, str.c_str(), str.size());
output_string[str.size()] = '\0';
return output_string;
@ -603,7 +601,7 @@ ORT_API(int, OrtIsTensor, _In_ const OrtValue* value) {
ORT_API(void*, OrtAllocatorAlloc, _Inout_ OrtAllocator* ptr, size_t size) {
try {
return (*ptr)->Alloc(ptr, size);
return ptr->Alloc(ptr, size);
} catch (std::exception&) {
return nullptr;
}
@ -611,14 +609,14 @@ ORT_API(void*, OrtAllocatorAlloc, _Inout_ OrtAllocator* ptr, size_t size) {
ORT_API(void, OrtAllocatorFree, _Inout_ OrtAllocator* ptr, void* p) {
try {
(*ptr)->Free(ptr, p);
ptr->Free(ptr, p);
} catch (std::exception&) {
}
}
ORT_API(const struct OrtAllocatorInfo*, OrtAllocatorGetInfo, _In_ const OrtAllocator* ptr) {
try {
return (*ptr)->Info(ptr);
return ptr->Info(ptr);
} catch (std::exception&) {
return nullptr;
}
@ -638,10 +636,7 @@ ORT_API_STATUS_IMPL(OrtSessionGetOutputName, _In_ const OrtSession* sess, size_t
API_IMPL_END
}
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Env, OrtEnv)
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Value, MLValue)
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Session, ::onnxruntime::InferenceSession)
DEFINE_RELEASE_ORT_OBJECT_FUNCTION_FOR_ARRAY(Status, char)
ORT_API(void, OrtReleaseEnv, OrtEnv* env) {
OrtReleaseObject(env);
}

View file

@ -545,7 +545,7 @@ Status OnnxTestCase::ConvertTestData(OrtSession* session, const std::vector<onnx
ORT_THROW_ON_ERROR(OrtSessionGetOutputName(session, i, allocator, &temp_name));
}
var_names[i] = temp_name;
(*allocator)->Free(allocator, temp_name);
allocator->Free(allocator, temp_name);
}
}
for (size_t input_index = 0; input_index != test_data_pbs.size(); ++input_index) {

View file

@ -287,13 +287,17 @@ SeqTestRunner::SeqTestRunner(OrtSession* session1,
TestCaseCallBack on_finished1) : DataRunner(session1, c->GetTestCaseName(), c, on_finished1), repeat_count_(repeat_count) {
}
DataRunner::DataRunner(OrtSession* session1, const std::string& test_case_name1, ITestCase* c, TestCaseCallBack on_finished1) : test_case_name_(test_case_name1), c_(c), session(session1), on_finished(on_finished1), default_allocator(MockedOrtAllocator::Create()) {
DataRunner::DataRunner(OrtSession* session1, const std::string& test_case_name1, ITestCase* c, TestCaseCallBack on_finished1) : test_case_name_(test_case_name1), c_(c), session(session1), on_finished(on_finished1), default_allocator(std::make_unique<MockedOrtAllocator>()) {
std::string s;
c->GetNodeName(&s);
result = std::make_shared<TestCaseResult>(c->GetDataCount(), EXECUTE_RESULT::UNKNOWN_ERROR, s);
SetTimeSpecToZero(&spent_time_);
}
DataRunner::~DataRunner() {
OrtReleaseSession(session);
}
void DataRunner::RunTask(size_t task_id, ORT_CALLBACK_INSTANCE pci, bool store_result) {
EXECUTE_RESULT res = EXECUTE_RESULT::UNKNOWN_ERROR;
try {
@ -326,10 +330,10 @@ EXECUTE_RESULT DataRunner::RunTaskImpl(size_t task_id) {
std::vector<std::string> output_names(output_count);
for (size_t i = 0; i != output_count; ++i) {
char* output_name = nullptr;
ORT_THROW_ON_ERROR(OrtSessionGetOutputName(session, i, default_allocator, &output_name));
ORT_THROW_ON_ERROR(OrtSessionGetOutputName(session, i, default_allocator.get(), &output_name));
assert(output_name != nullptr);
output_names[i] = output_name;
(*default_allocator)->Free(default_allocator, output_name);
default_allocator->Free(output_name);
}
TIME_SPEC start_time, end_time;

View file

@ -31,6 +31,8 @@ void ORT_CALLBACK RunTestCase(ORT_CALLBACK_INSTANCE instance, void* context, ORT
void ORT_CALLBACK RunSingleDataItem(ORT_CALLBACK_INSTANCE instance, void* context, ORT_WORK work);
::onnxruntime::common::Status OnTestCaseFinished(ORT_CALLBACK_INSTANCE pci, TestCaseTask* task, std::shared_ptr<TestCaseResult> result);
struct MockedOrtAllocator;
class DataRunner {
protected:
typedef TestCaseCallBack CALL_BACK;
@ -43,7 +45,7 @@ class DataRunner {
private:
OrtSession* session;
CALL_BACK on_finished;
OrtAllocatorInterface** const default_allocator;
std::unique_ptr<MockedOrtAllocator> default_allocator;
EXECUTE_RESULT RunTaskImpl(size_t task_id);
ORT_DISALLOW_COPY_AND_ASSIGNMENT(DataRunner);
@ -51,10 +53,7 @@ class DataRunner {
DataRunner(OrtSession* session1, const std::string& test_case_name1, ITestCase* c, TestCaseCallBack on_finished1);
virtual void OnTaskFinished(size_t task_id, EXECUTE_RESULT res, ORT_CALLBACK_INSTANCE pci) noexcept = 0;
void RunTask(size_t task_id, ORT_CALLBACK_INSTANCE pci, bool store_result);
virtual ~DataRunner() {
OrtReleaseSession(session);
OrtReleaseObject(default_allocator);
}
virtual ~DataRunner();
virtual void Start(ORT_CALLBACK_INSTANCE pci, size_t concurrent_runs) = 0;

View file

@ -29,6 +29,6 @@ TEST_F(CApiTest, DefaultAllocator) {
memset(p, 0, 100);
OrtAllocatorFree(default_allocator.get(), p);
const OrtAllocatorInfo* info1 = OrtAllocatorGetInfo(default_allocator.get());
const OrtAllocatorInfo* info2 = (*default_allocator)->Info(default_allocator.get());
const OrtAllocatorInfo* info2 = default_allocator->Info(default_allocator.get());
ASSERT_EQ(0, OrtCompareAllocatorInfo(info1, info2));
}

View file

@ -29,7 +29,7 @@ class CApiTestImpl : public ::testing::Test {
}
void TearDown() override {
if (env) OrtReleaseObject(env);
if (env) OrtReleaseEnv(env);
}
// Objects declared here can be used by all tests in the test case for Foo.

View file

@ -99,7 +99,7 @@ void TestInference(OrtEnv* env, T model_uri,
sf.AppendCustomOpLibPath("libonnxruntime_custom_op_shared_lib_test.so");
}
std::unique_ptr<OrtSession, decltype(&OrtReleaseSession)> inference_session(sf.OrtCreateSession(model_uri), OrtReleaseSession);
std::unique_ptr<OrtAllocator> default_allocator(MockedOrtAllocator::Create());
std::unique_ptr<MockedOrtAllocator> default_allocator(std::make_unique<MockedOrtAllocator>());
// Now run
RunSession(default_allocator.get(), inference_session.get(), dims_x, values_x, expected_dims_y, expected_values_y);
}
@ -156,7 +156,7 @@ TEST_F(CApiTest, create_session_without_session_option) {
TEST_F(CApiTest, create_tensor) {
const char* s[] = {"abc", "kmp"};
size_t expected_len = 2;
std::unique_ptr<OrtAllocator> default_allocator(MockedOrtAllocator::Create());
std::unique_ptr<MockedOrtAllocator> default_allocator(std::make_unique<MockedOrtAllocator>());
{
std::unique_ptr<OrtValue, decltype(&OrtReleaseValue)> tensor(
OrtCreateTensorAsOrtValue(default_allocator.get(), {expected_len}, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING), OrtReleaseValue);

View file

@ -7,79 +7,20 @@
#include "core/session/onnxruntime_cxx_api.h"
#include <assert.h>
#define ORT_ALLOCATOR_IMPL_BEGIN(CLASS_NAME) \
class CLASS_NAME { \
private: \
const OrtAllocatorInterface* vtable_ = &table_; \
std::atomic_int ref_count_; \
static void* ORT_API_CALL Alloc_(void* this_ptr, size_t size) { \
return ((CLASS_NAME*)this_ptr)->Alloc(size); \
} \
static void ORT_API_CALL Free_(void* this_ptr, void* p) { \
return ((CLASS_NAME*)this_ptr)->Free(p); \
} \
static const OrtAllocatorInfo* ORT_API_CALL Info_(const void* this_ptr) { \
return ((const CLASS_NAME*)this_ptr)->Info(); \
} \
static uint32_t ORT_API_CALL AddRef_(void* this_) { \
CLASS_NAME* this_ptr = (CLASS_NAME*)this_; \
return ++this_ptr->ref_count_; \
} \
static uint32_t ORT_API_CALL Release_(void* this_) { \
CLASS_NAME* this_ptr = (CLASS_NAME*)this_; \
uint32_t ret = --this_ptr->ref_count_; \
if (ret == 0) \
delete this_ptr; \
return 0; \
} \
static OrtAllocatorInterface table_;
struct MockedOrtAllocator : OrtAllocator {
MockedOrtAllocator();
~MockedOrtAllocator();
#define ORT_ALLOCATOR_IMPL_END \
} \
;
void* Alloc(size_t size);
void Free(void* p);
const OrtAllocatorInfo* Info() const;
ORT_ALLOCATOR_IMPL_BEGIN(MockedOrtAllocator)
private:
std::atomic<size_t> memory_inuse;
OrtAllocatorInfo* cpuAllocatorInfo;
MockedOrtAllocator() : ref_count_(1), memory_inuse(0) {
ORT_THROW_ON_ERROR(OrtCreateAllocatorInfo("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault, &cpuAllocatorInfo));
}
~MockedOrtAllocator() {
assert(ref_count_ == 0);
OrtReleaseAllocatorInfo(cpuAllocatorInfo);
}
void LeakCheck();
public:
MockedOrtAllocator(const MockedOrtAllocator&) = delete;
MockedOrtAllocator& operator=(const MockedOrtAllocator&) = delete;
OrtAllocatorInterface** Upcast() {
return const_cast<OrtAllocatorInterface**>(&vtable_);
}
static OrtAllocatorInterface** Create() {
return (OrtAllocatorInterface**)new MockedOrtAllocator();
}
void* Alloc(size_t size) {
constexpr size_t extra_len = sizeof(size_t);
memory_inuse.fetch_add(size += extra_len);
void* p = ::malloc(size);
*(size_t*)p = size;
return (char*)p + extra_len;
}
void Free(void* p) {
constexpr size_t extra_len = sizeof(size_t);
if (!p) return;
p = (char*)p - extra_len;
size_t len = *(size_t*)p;
memory_inuse.fetch_sub(len);
return ::free(p);
}
const OrtAllocatorInfo* Info() const {
return cpuAllocatorInfo;
}
private:
MockedOrtAllocator(const MockedOrtAllocator&) = delete;
MockedOrtAllocator& operator=(const MockedOrtAllocator&) = delete;
void LeakCheck() {
if (memory_inuse.load())
throw std::runtime_error("memory leak!!!");
}
ORT_ALLOCATOR_IMPL_END
std::atomic<size_t> memory_inuse{0};
OrtAllocatorInfo* cpuAllocatorInfo;
};

View file

@ -3,5 +3,39 @@
#include "test_allocator.h"
OrtAllocatorInterface MockedOrtAllocator::table_ = {
{MockedOrtAllocator::AddRef_, MockedOrtAllocator::Release_}, MockedOrtAllocator::Alloc_, MockedOrtAllocator::Free_, MockedOrtAllocator::Info_};
MockedOrtAllocator::MockedOrtAllocator() {
OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast<MockedOrtAllocator*>(this_)->Alloc(size); };
OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast<MockedOrtAllocator*>(this_)->Free(p); };
OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast<const MockedOrtAllocator*>(this_)->Info(); };
ORT_THROW_ON_ERROR(OrtCreateAllocatorInfo("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault, &cpuAllocatorInfo));
}
MockedOrtAllocator::~MockedOrtAllocator() {
OrtReleaseAllocatorInfo(cpuAllocatorInfo);
}
void* MockedOrtAllocator::Alloc(size_t size) {
constexpr size_t extra_len = sizeof(size_t);
memory_inuse.fetch_add(size += extra_len);
void* p = ::malloc(size);
*(size_t*)p = size;
return (char*)p + extra_len;
}
void MockedOrtAllocator::Free(void* p) {
constexpr size_t extra_len = sizeof(size_t);
if (!p) return;
p = (char*)p - extra_len;
size_t len = *(size_t*)p;
memory_inuse.fetch_sub(len);
return ::free(p);
}
const OrtAllocatorInfo* MockedOrtAllocator::Info() const {
return cpuAllocatorInfo;
}
void MockedOrtAllocator::LeakCheck() {
if (memory_inuse.load())
throw std::runtime_error("memory leak!!!");
}