diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index ccd7580e4f..07fdca7cd0 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -132,6 +132,7 @@ typedef enum OrtErrorCode { ORT_API(void, OrtRelease##X, _Frees_ptr_opt_ Ort##X* input); // The actual types defined have an Ort prefix +ORT_RUNTIME_CLASS(Env); ORT_RUNTIME_CLASS(Status); // nullptr for Status* indicates success ORT_RUNTIME_CLASS(Provider); ORT_RUNTIME_CLASS(AllocatorInfo); @@ -147,8 +148,6 @@ struct OrtRunOptions; typedef struct OrtRunOptions OrtRunOptions; struct OrtSessionOptions; typedef struct OrtSessionOptions OrtSessionOptions; -struct OrtEnv; -typedef struct OrtEnv OrtEnv; /** * Every type inherented from OrtObject should be deleted by OrtReleaseObject(...). @@ -161,17 +160,15 @@ typedef struct OrtObject { } OrtObject; -//inherented from OrtObject -typedef struct OrtAllocatorInterface { - struct OrtObject parent; - void*(ORT_API_CALL* Alloc)(void* this_, size_t size); - void(ORT_API_CALL* Free)(void* this_, void* p); - const struct OrtAllocatorInfo*(ORT_API_CALL* Info)(const void* this_); -} OrtAllocatorInterface; +// When passing in an allocator to any ORT function, be sure that the allocator object +// is not destroyed until the last allocated object using it is freed. +typedef struct OrtAllocator { + void*(ORT_API_CALL* Alloc)(struct OrtAllocator* this_, size_t size); + void(ORT_API_CALL* Free)(struct OrtAllocator* this_, void* p); + const struct OrtAllocatorInfo*(ORT_API_CALL* Info)(const struct OrtAllocator* this_); +} OrtAllocator; -typedef OrtAllocatorInterface* OrtAllocator; - -//Inherented from OrtObject +// Inherented from OrtObject typedef struct OrtProviderFactoryInterface { OrtObject parent; OrtStatus*(ORT_API_CALL* CreateProvider)(void* this_, OrtProvider** out); @@ -183,7 +180,7 @@ typedef void(ORT_API_CALL* OrtLoggingFunction)( /** * OrtEnv is process-wide. For each process, only one OrtEnv can be created. - * \param out Should be freed by `OrtReleaseObject` after use + * \param out Should be freed by `OrtReleaseEnv` after use */ ORT_API_STATUS(OrtInitialize, OrtLoggingLevel default_warning_level, _In_ const char* logid, _Out_ OrtEnv** out) ORT_ALL_ARGS_NONNULL; @@ -216,7 +213,7 @@ ORT_API_STATUS(OrtRun, _Inout_ OrtSession* sess, */ ORT_API(OrtSessionOptions*, OrtCreateSessionOptions); -/// create a copy of an existing OrtSessionOptions +// create a copy of an existing OrtSessionOptions ORT_API(OrtSessionOptions*, OrtCloneSessionOptions, OrtSessionOptions*); ORT_API(void, OrtEnableSequentialExecution, _In_ OrtSessionOptions* options); ORT_API(void, OrtDisableSequentialExecution, _In_ OrtSessionOptions* options); @@ -238,13 +235,13 @@ ORT_API(void, OrtDisableMemPattern, _In_ OrtSessionOptions* options); ORT_API(void, OrtEnableCpuMemArena, _In_ OrtSessionOptions* options); ORT_API(void, OrtDisableCpuMemArena, _In_ OrtSessionOptions* options); -///< logger id to use for session output +// < logger id to use for session output ORT_API(void, OrtSetSessionLogId, _In_ OrtSessionOptions* options, const char* logid); -///< applies to session load, initialization, etc +// < applies to session load, initialization, etc ORT_API(void, OrtSetSessionLogVerbosityLevel, _In_ OrtSessionOptions* options, uint32_t session_log_verbosity_level); -///How many threads in the session thread pool. +// How many threads in the session thread pool. ORT_API(int, OrtSetSessionThreadPoolSize, _In_ OrtSessionOptions* options, int session_thread_pool_size); /** @@ -269,6 +266,9 @@ ORT_API_STATUS(OrtSessionGetInputTypeInfo, _In_ const OrtSession* sess, size_t i */ ORT_API_STATUS(OrtSessionGetOutputTypeInfo, _In_ const OrtSession* sess, size_t index, _Out_ OrtTypeInfo** out); +/** + * \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible in freeing it. + */ ORT_API_STATUS(OrtSessionGetInputName, _In_ const OrtSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Out_ char** value); ORT_API_STATUS(OrtSessionGetOutputName, _In_ const OrtSession* sess, size_t index, @@ -291,8 +291,7 @@ ORT_API(void, OrtRunOptionsSetTerminate, _In_ OrtRunOptions*, _In_ int flag); /** * Create a tensor from an allocator. OrtReleaseValue will also release the buffer inside the output value - * \param out will keep a reference to the allocator, without reference counting(will be fixed). Should be freed by - * calling OrtReleaseValue + * \param out Should be freed by calling OrtReleaseValue * \param type must be one of TENSOR_ELEMENT_DATA_TYPE_xxxx */ ORT_API_STATUS(OrtCreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator, @@ -441,8 +440,8 @@ ORT_API(void*, OrtAllocatorAlloc, _Inout_ OrtAllocator* ptr, size_t size); ORT_API(void, OrtAllocatorFree, _Inout_ OrtAllocator* ptr, void* p); ORT_API(const OrtAllocatorInfo*, OrtAllocatorGetInfo, _In_ const OrtAllocator* ptr); -// Call OrtReleaseObject to release the returned value ORT_API_STATUS(OrtCreateDefaultAllocator, _Out_ OrtAllocator** out); +ORT_API(void, OrtReleaseAllocator, _In_ OrtAllocator* allocator); /** * \param msg A null-terminated string. Its content will be copied into the newly created OrtStatus @@ -454,16 +453,11 @@ ORT_API(OrtErrorCode, OrtGetErrorCode, _In_ const OrtStatus* status) ORT_ALL_ARGS_NONNULL; /** * \param status must not be NULL - * \return The error message inside the `status`. Don't free the returned value. + * \return The error message inside the `status`. Do not free the returned value. */ ORT_API(const char*, OrtGetErrorMessage, _In_ const OrtStatus* status) ORT_ALL_ARGS_NONNULL; -/** - * Deprecated. Please use OrtReleaseObject - */ -ORT_API(void, OrtReleaseEnv, OrtEnv* env); - #ifdef __cplusplus } #endif diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index afdc9fa9bd..5f7bb2cb5c 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -24,6 +24,22 @@ return Ort##NAME(value.get()); \ } +namespace std { +template <> +struct default_delete { + void operator()(OrtAllocator* ptr) { + OrtReleaseAllocator(ptr); + } +}; + +template <> +struct default_delete { + void operator()(OrtEnv* ptr) { + OrtReleaseEnv(ptr); + } +}; +} // namespace std + #define DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(TYPE_NAME) \ namespace std { \ template <> \ @@ -34,9 +50,7 @@ }; \ } -DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(Env); DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(TypeInfo); -DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(Allocator); DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(TensorTypeAndShapeInfo); DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(RunOptions); DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(SessionOptions); diff --git a/onnxruntime/core/providers/cpu/symbols.txt b/onnxruntime/core/providers/cpu/symbols.txt index 2c063e02f1..1940e8d441 100644 --- a/onnxruntime/core/providers/cpu/symbols.txt +++ b/onnxruntime/core/providers/cpu/symbols.txt @@ -44,6 +44,7 @@ OrtGetValueType OrtInitialize OrtInitializeWithCustomLogger OrtIsTensor +OrtReleaseAllocator OrtReleaseAllocatorInfo OrtReleaseEnv OrtReleaseObject diff --git a/onnxruntime/core/session/allocator_impl.h b/onnxruntime/core/session/allocator_impl.h index cd28f5b416..f09670fac7 100644 --- a/onnxruntime/core/session/allocator_impl.h +++ b/onnxruntime/core/session/allocator_impl.h @@ -8,20 +8,15 @@ namespace onnxruntime { class AllocatorWrapper : public IAllocator { public: - AllocatorWrapper(OrtAllocator* impl) : impl_(impl) { - (*impl)->parent.AddRef(impl); - } - ~AllocatorWrapper() { - (*impl_)->parent.Release(impl_); - } + AllocatorWrapper(OrtAllocator* impl) : impl_(impl) {} void* Alloc(size_t size) override { - return (*impl_)->Alloc(impl_, size); + return impl_->Alloc(impl_, size); } void Free(void* p) override { - return (*impl_)->Free(impl_, p); + return impl_->Free(impl_, p); } const OrtAllocatorInfo& Info() const override { - return *(OrtAllocatorInfo*)(*impl_)->Info(impl_); + return *impl_->Info(impl_); } private: diff --git a/onnxruntime/core/session/default_cpu_allocator_c_api.cc b/onnxruntime/core/session/default_cpu_allocator_c_api.cc index 8957fbdf4f..704b1e1d0c 100644 --- a/onnxruntime/core/session/default_cpu_allocator_c_api.cc +++ b/onnxruntime/core/session/default_cpu_allocator_c_api.cc @@ -5,81 +5,55 @@ #include "core/session/onnxruntime_cxx_api.h" #include -#define ORT_ALLOCATOR_IMPL_BEGIN(CLASS_NAME) \ - class CLASS_NAME { \ - private: \ - const OrtAllocatorInterface* vtable_ = &table_; \ - std::atomic_int ref_count_; \ - static void* ORT_API_CALL Alloc_(void* this_ptr, size_t size) { \ - return ((CLASS_NAME*)this_ptr)->Alloc(size); \ - } \ - static void ORT_API_CALL Free_(void* this_ptr, void* p) { \ - return ((CLASS_NAME*)this_ptr)->Free(p); \ - } \ - static const OrtAllocatorInfo* ORT_API_CALL Info_(const void* this_ptr) { \ - return ((const CLASS_NAME*)this_ptr)->Info(); \ - } \ - static uint32_t ORT_API_CALL AddRef_(void* this_) { \ - CLASS_NAME* this_ptr = (CLASS_NAME*)this_; \ - return ++this_ptr->ref_count_; \ - } \ - static uint32_t ORT_API_CALL Release_(void* this_) { \ - CLASS_NAME* this_ptr = (CLASS_NAME*)this_; \ - uint32_t ret = --this_ptr->ref_count_; \ - if (ret == 0) \ - delete this_ptr; \ - return 0; \ - } \ - static OrtAllocatorInterface table_; +// In the future we'll have more than one allocator type. Since all allocators are of type 'OrtAllocator' and there is a single +// OrtReleaseAllocator function, we need to have a common base type that lets us delete them. +struct OrtAllocatorImpl : OrtAllocator { + virtual ~OrtAllocatorImpl() {} +}; -#define ORT_ALLOCATOR_IMPL_END \ - } \ - ; +struct OrtDefaultAllocator : OrtAllocatorImpl { + OrtDefaultAllocator() { + OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast(this_)->Alloc(size); }; + OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast(this_)->Free(p); }; + OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast(this_)->Info(); }; + ORT_THROW_ON_ERROR(OrtCreateAllocatorInfo("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault, &cpuAllocatorInfo)); + } -ORT_ALLOCATOR_IMPL_BEGIN(OrtDefaultAllocator) -private: -OrtAllocatorInfo* cpuAllocatorInfo; -OrtDefaultAllocator() : ref_count_(1) { - ORT_THROW_ON_ERROR(OrtCreateAllocatorInfo("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault, &cpuAllocatorInfo)); -} -~OrtDefaultAllocator() { - assert(ref_count_ == 0); - OrtReleaseAllocatorInfo(cpuAllocatorInfo); -} + ~OrtDefaultAllocator() { + OrtReleaseAllocatorInfo(cpuAllocatorInfo); + } -public: -OrtDefaultAllocator(const OrtDefaultAllocator&) = delete; -OrtDefaultAllocator& operator=(const OrtDefaultAllocator&) = delete; -OrtAllocatorInterface** Upcast() { - return const_cast(&vtable_); -} -static OrtAllocatorInterface** Create() { - return (OrtAllocatorInterface**)new OrtDefaultAllocator(); -} -void* Alloc(size_t size) { - return ::malloc(size); -} -void Free(void* p) { - return ::free(p); -} -const OrtAllocatorInfo* Info() const { - return cpuAllocatorInfo; -} -ORT_ALLOCATOR_IMPL_END + void* Alloc(size_t size) { + return ::malloc(size); + } + void Free(void* p) { + return ::free(p); + } + const OrtAllocatorInfo* Info() const { + return cpuAllocatorInfo; + } + + private: + OrtDefaultAllocator(const OrtDefaultAllocator&) = delete; + OrtDefaultAllocator& operator=(const OrtDefaultAllocator&) = delete; + + OrtAllocatorInfo* cpuAllocatorInfo; +}; #define API_IMPL_BEGIN try { -#define API_IMPL_END \ - } \ - catch (std::exception & ex) { \ +#define API_IMPL_END \ + } \ + catch (std::exception & ex) { \ return OrtCreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); \ } -OrtAllocatorInterface OrtDefaultAllocator::table_ = { - {OrtDefaultAllocator::AddRef_, OrtDefaultAllocator::Release_}, OrtDefaultAllocator::Alloc_, OrtDefaultAllocator::Free_, OrtDefaultAllocator::Info_}; - ORT_API_STATUS_IMPL(OrtCreateDefaultAllocator, _Out_ OrtAllocator** out) { API_IMPL_BEGIN - *out = OrtDefaultAllocator::Create(); + *out = new OrtDefaultAllocator(); return nullptr; API_IMPL_END } + +ORT_API(void, OrtReleaseAllocator, _In_ OrtAllocator* allocator) { + delete static_cast(allocator); +} diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index f226e35ff3..05911b00d1 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -45,20 +45,18 @@ using onnxruntime::common::Status; if (_status) return _status; \ } while (0) -struct OrtEnv : public onnxruntime::ObjectBase { +struct OrtEnv { public: Environment* value; LoggingManager* loggingManager; friend class onnxruntime::ObjectBase; OrtEnv(Environment* value1, LoggingManager* loggingManager1) : value(value1), loggingManager(loggingManager1) { - ORT_CHECK_C_OBJECT_LAYOUT; } /** * This function will call ::google::protobuf::ShutdownProtobufLibrary */ ~OrtEnv() { - assert(ref_count == 0); delete loggingManager; delete value; } @@ -166,7 +164,7 @@ ORT_API_STATUS_IMPL(OrtFillStringTensor, _In_ OrtValue* value, _In_ const char* } template -OrtStatus* CreateTensorImpl(const size_t* shape, size_t shape_len, OrtAllocatorInterface** allocator, +OrtStatus* CreateTensorImpl(const size_t* shape, size_t shape_len, OrtAllocator* allocator, std::unique_ptr* out) { size_t elem_count = 1; std::vector shapes(shape_len); @@ -179,13 +177,13 @@ OrtStatus* CreateTensorImpl(const size_t* shape, size_t shape_len, OrtAllocatorI if (!IAllocator::CalcMemSizeForArray(sizeof(T), elem_count, &size_to_allocate)) { return OrtCreateStatus(ORT_FAIL, "not enough memory"); } - void* p_data = (*allocator)->Alloc(allocator, size_to_allocate); + void* p_data = allocator->Alloc(allocator, size_to_allocate); if (p_data == nullptr) return OrtCreateStatus(ORT_FAIL, "size overflow"); *out = std::make_unique(DataTypeImpl::GetType(), onnxruntime::TensorShape(shapes), static_cast(p_data), - *(*allocator)->Info(allocator), + *allocator->Info(allocator), std::make_shared(allocator)); return nullptr; } @@ -574,7 +572,7 @@ ORT_API_STATUS_IMPL(OrtSessionGetOutputTypeInfo, _In_ const OrtSession* sess, si } static char* StrDup(const std::string& str, OrtAllocator* allocator) { - char* output_string = reinterpret_cast((*allocator)->Alloc(allocator, str.size() + 1)); + char* output_string = reinterpret_cast(allocator->Alloc(allocator, str.size() + 1)); memcpy(output_string, str.c_str(), str.size()); output_string[str.size()] = '\0'; return output_string; @@ -603,7 +601,7 @@ ORT_API(int, OrtIsTensor, _In_ const OrtValue* value) { ORT_API(void*, OrtAllocatorAlloc, _Inout_ OrtAllocator* ptr, size_t size) { try { - return (*ptr)->Alloc(ptr, size); + return ptr->Alloc(ptr, size); } catch (std::exception&) { return nullptr; } @@ -611,14 +609,14 @@ ORT_API(void*, OrtAllocatorAlloc, _Inout_ OrtAllocator* ptr, size_t size) { ORT_API(void, OrtAllocatorFree, _Inout_ OrtAllocator* ptr, void* p) { try { - (*ptr)->Free(ptr, p); + ptr->Free(ptr, p); } catch (std::exception&) { } } ORT_API(const struct OrtAllocatorInfo*, OrtAllocatorGetInfo, _In_ const OrtAllocator* ptr) { try { - return (*ptr)->Info(ptr); + return ptr->Info(ptr); } catch (std::exception&) { return nullptr; } @@ -638,10 +636,7 @@ ORT_API_STATUS_IMPL(OrtSessionGetOutputName, _In_ const OrtSession* sess, size_t API_IMPL_END } +DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Env, OrtEnv) DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Value, MLValue) DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Session, ::onnxruntime::InferenceSession) DEFINE_RELEASE_ORT_OBJECT_FUNCTION_FOR_ARRAY(Status, char) - -ORT_API(void, OrtReleaseEnv, OrtEnv* env) { - OrtReleaseObject(env); -} diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index d77c809d99..47137c30d0 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -545,7 +545,7 @@ Status OnnxTestCase::ConvertTestData(OrtSession* session, const std::vectorFree(allocator, temp_name); + allocator->Free(allocator, temp_name); } } for (size_t input_index = 0; input_index != test_data_pbs.size(); ++input_index) { diff --git a/onnxruntime/test/onnx/runner.cc b/onnxruntime/test/onnx/runner.cc index d21fe79b27..f4f6d8765f 100644 --- a/onnxruntime/test/onnx/runner.cc +++ b/onnxruntime/test/onnx/runner.cc @@ -287,13 +287,17 @@ SeqTestRunner::SeqTestRunner(OrtSession* session1, TestCaseCallBack on_finished1) : DataRunner(session1, c->GetTestCaseName(), c, on_finished1), repeat_count_(repeat_count) { } -DataRunner::DataRunner(OrtSession* session1, const std::string& test_case_name1, ITestCase* c, TestCaseCallBack on_finished1) : test_case_name_(test_case_name1), c_(c), session(session1), on_finished(on_finished1), default_allocator(MockedOrtAllocator::Create()) { +DataRunner::DataRunner(OrtSession* session1, const std::string& test_case_name1, ITestCase* c, TestCaseCallBack on_finished1) : test_case_name_(test_case_name1), c_(c), session(session1), on_finished(on_finished1), default_allocator(std::make_unique()) { std::string s; c->GetNodeName(&s); result = std::make_shared(c->GetDataCount(), EXECUTE_RESULT::UNKNOWN_ERROR, s); SetTimeSpecToZero(&spent_time_); } +DataRunner::~DataRunner() { + OrtReleaseSession(session); +} + void DataRunner::RunTask(size_t task_id, ORT_CALLBACK_INSTANCE pci, bool store_result) { EXECUTE_RESULT res = EXECUTE_RESULT::UNKNOWN_ERROR; try { @@ -326,10 +330,10 @@ EXECUTE_RESULT DataRunner::RunTaskImpl(size_t task_id) { std::vector output_names(output_count); for (size_t i = 0; i != output_count; ++i) { char* output_name = nullptr; - ORT_THROW_ON_ERROR(OrtSessionGetOutputName(session, i, default_allocator, &output_name)); + ORT_THROW_ON_ERROR(OrtSessionGetOutputName(session, i, default_allocator.get(), &output_name)); assert(output_name != nullptr); output_names[i] = output_name; - (*default_allocator)->Free(default_allocator, output_name); + default_allocator->Free(output_name); } TIME_SPEC start_time, end_time; diff --git a/onnxruntime/test/onnx/runner.h b/onnxruntime/test/onnx/runner.h index f3dfd852cd..8b35c7a357 100644 --- a/onnxruntime/test/onnx/runner.h +++ b/onnxruntime/test/onnx/runner.h @@ -31,6 +31,8 @@ void ORT_CALLBACK RunTestCase(ORT_CALLBACK_INSTANCE instance, void* context, ORT void ORT_CALLBACK RunSingleDataItem(ORT_CALLBACK_INSTANCE instance, void* context, ORT_WORK work); ::onnxruntime::common::Status OnTestCaseFinished(ORT_CALLBACK_INSTANCE pci, TestCaseTask* task, std::shared_ptr result); +struct MockedOrtAllocator; + class DataRunner { protected: typedef TestCaseCallBack CALL_BACK; @@ -43,7 +45,7 @@ class DataRunner { private: OrtSession* session; CALL_BACK on_finished; - OrtAllocatorInterface** const default_allocator; + std::unique_ptr default_allocator; EXECUTE_RESULT RunTaskImpl(size_t task_id); ORT_DISALLOW_COPY_AND_ASSIGNMENT(DataRunner); @@ -51,10 +53,7 @@ class DataRunner { DataRunner(OrtSession* session1, const std::string& test_case_name1, ITestCase* c, TestCaseCallBack on_finished1); virtual void OnTaskFinished(size_t task_id, EXECUTE_RESULT res, ORT_CALLBACK_INSTANCE pci) noexcept = 0; void RunTask(size_t task_id, ORT_CALLBACK_INSTANCE pci, bool store_result); - virtual ~DataRunner() { - OrtReleaseSession(session); - OrtReleaseObject(default_allocator); - } + virtual ~DataRunner(); virtual void Start(ORT_CALLBACK_INSTANCE pci, size_t concurrent_runs) = 0; diff --git a/onnxruntime/test/shared_lib/test_allocator.cc b/onnxruntime/test/shared_lib/test_allocator.cc index ab97260e95..92d4533bac 100644 --- a/onnxruntime/test/shared_lib/test_allocator.cc +++ b/onnxruntime/test/shared_lib/test_allocator.cc @@ -29,6 +29,6 @@ TEST_F(CApiTest, DefaultAllocator) { memset(p, 0, 100); OrtAllocatorFree(default_allocator.get(), p); const OrtAllocatorInfo* info1 = OrtAllocatorGetInfo(default_allocator.get()); - const OrtAllocatorInfo* info2 = (*default_allocator)->Info(default_allocator.get()); + const OrtAllocatorInfo* info2 = default_allocator->Info(default_allocator.get()); ASSERT_EQ(0, OrtCompareAllocatorInfo(info1, info2)); } diff --git a/onnxruntime/test/shared_lib/test_fixture.h b/onnxruntime/test/shared_lib/test_fixture.h index 20b5cb877e..7c766c2b53 100644 --- a/onnxruntime/test/shared_lib/test_fixture.h +++ b/onnxruntime/test/shared_lib/test_fixture.h @@ -29,7 +29,7 @@ class CApiTestImpl : public ::testing::Test { } void TearDown() override { - if (env) OrtReleaseObject(env); + if (env) OrtReleaseEnv(env); } // Objects declared here can be used by all tests in the test case for Foo. diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 3407263970..e623d11de9 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -99,7 +99,7 @@ void TestInference(OrtEnv* env, T model_uri, sf.AppendCustomOpLibPath("libonnxruntime_custom_op_shared_lib_test.so"); } std::unique_ptr inference_session(sf.OrtCreateSession(model_uri), OrtReleaseSession); - std::unique_ptr default_allocator(MockedOrtAllocator::Create()); + std::unique_ptr default_allocator(std::make_unique()); // Now run RunSession(default_allocator.get(), inference_session.get(), dims_x, values_x, expected_dims_y, expected_values_y); } @@ -156,7 +156,7 @@ TEST_F(CApiTest, create_session_without_session_option) { TEST_F(CApiTest, create_tensor) { const char* s[] = {"abc", "kmp"}; size_t expected_len = 2; - std::unique_ptr default_allocator(MockedOrtAllocator::Create()); + std::unique_ptr default_allocator(std::make_unique()); { std::unique_ptr tensor( OrtCreateTensorAsOrtValue(default_allocator.get(), {expected_len}, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING), OrtReleaseValue); diff --git a/onnxruntime/test/util/include/test_allocator.h b/onnxruntime/test/util/include/test_allocator.h index 660320c901..ed1f44d578 100644 --- a/onnxruntime/test/util/include/test_allocator.h +++ b/onnxruntime/test/util/include/test_allocator.h @@ -7,79 +7,20 @@ #include "core/session/onnxruntime_cxx_api.h" #include -#define ORT_ALLOCATOR_IMPL_BEGIN(CLASS_NAME) \ - class CLASS_NAME { \ - private: \ - const OrtAllocatorInterface* vtable_ = &table_; \ - std::atomic_int ref_count_; \ - static void* ORT_API_CALL Alloc_(void* this_ptr, size_t size) { \ - return ((CLASS_NAME*)this_ptr)->Alloc(size); \ - } \ - static void ORT_API_CALL Free_(void* this_ptr, void* p) { \ - return ((CLASS_NAME*)this_ptr)->Free(p); \ - } \ - static const OrtAllocatorInfo* ORT_API_CALL Info_(const void* this_ptr) { \ - return ((const CLASS_NAME*)this_ptr)->Info(); \ - } \ - static uint32_t ORT_API_CALL AddRef_(void* this_) { \ - CLASS_NAME* this_ptr = (CLASS_NAME*)this_; \ - return ++this_ptr->ref_count_; \ - } \ - static uint32_t ORT_API_CALL Release_(void* this_) { \ - CLASS_NAME* this_ptr = (CLASS_NAME*)this_; \ - uint32_t ret = --this_ptr->ref_count_; \ - if (ret == 0) \ - delete this_ptr; \ - return 0; \ - } \ - static OrtAllocatorInterface table_; +struct MockedOrtAllocator : OrtAllocator { + MockedOrtAllocator(); + ~MockedOrtAllocator(); -#define ORT_ALLOCATOR_IMPL_END \ - } \ - ; + void* Alloc(size_t size); + void Free(void* p); + const OrtAllocatorInfo* Info() const; -ORT_ALLOCATOR_IMPL_BEGIN(MockedOrtAllocator) -private: -std::atomic memory_inuse; -OrtAllocatorInfo* cpuAllocatorInfo; -MockedOrtAllocator() : ref_count_(1), memory_inuse(0) { - ORT_THROW_ON_ERROR(OrtCreateAllocatorInfo("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault, &cpuAllocatorInfo)); -} -~MockedOrtAllocator() { - assert(ref_count_ == 0); - OrtReleaseAllocatorInfo(cpuAllocatorInfo); -} + void LeakCheck(); -public: -MockedOrtAllocator(const MockedOrtAllocator&) = delete; -MockedOrtAllocator& operator=(const MockedOrtAllocator&) = delete; -OrtAllocatorInterface** Upcast() { - return const_cast(&vtable_); -} -static OrtAllocatorInterface** Create() { - return (OrtAllocatorInterface**)new MockedOrtAllocator(); -} -void* Alloc(size_t size) { - constexpr size_t extra_len = sizeof(size_t); - memory_inuse.fetch_add(size += extra_len); - void* p = ::malloc(size); - *(size_t*)p = size; - return (char*)p + extra_len; -} -void Free(void* p) { - constexpr size_t extra_len = sizeof(size_t); - if (!p) return; - p = (char*)p - extra_len; - size_t len = *(size_t*)p; - memory_inuse.fetch_sub(len); - return ::free(p); -} -const OrtAllocatorInfo* Info() const { - return cpuAllocatorInfo; -} + private: + MockedOrtAllocator(const MockedOrtAllocator&) = delete; + MockedOrtAllocator& operator=(const MockedOrtAllocator&) = delete; -void LeakCheck() { - if (memory_inuse.load()) - throw std::runtime_error("memory leak!!!"); -} -ORT_ALLOCATOR_IMPL_END + std::atomic memory_inuse{0}; + OrtAllocatorInfo* cpuAllocatorInfo; +}; diff --git a/onnxruntime/test/util/test_allocator.cc b/onnxruntime/test/util/test_allocator.cc index 3119042e6d..52ae1bcca2 100644 --- a/onnxruntime/test/util/test_allocator.cc +++ b/onnxruntime/test/util/test_allocator.cc @@ -3,5 +3,39 @@ #include "test_allocator.h" -OrtAllocatorInterface MockedOrtAllocator::table_ = { - {MockedOrtAllocator::AddRef_, MockedOrtAllocator::Release_}, MockedOrtAllocator::Alloc_, MockedOrtAllocator::Free_, MockedOrtAllocator::Info_}; +MockedOrtAllocator::MockedOrtAllocator() { + OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast(this_)->Alloc(size); }; + OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast(this_)->Free(p); }; + OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast(this_)->Info(); }; + ORT_THROW_ON_ERROR(OrtCreateAllocatorInfo("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault, &cpuAllocatorInfo)); +} + +MockedOrtAllocator::~MockedOrtAllocator() { + OrtReleaseAllocatorInfo(cpuAllocatorInfo); +} + +void* MockedOrtAllocator::Alloc(size_t size) { + constexpr size_t extra_len = sizeof(size_t); + memory_inuse.fetch_add(size += extra_len); + void* p = ::malloc(size); + *(size_t*)p = size; + return (char*)p + extra_len; +} + +void MockedOrtAllocator::Free(void* p) { + constexpr size_t extra_len = sizeof(size_t); + if (!p) return; + p = (char*)p - extra_len; + size_t len = *(size_t*)p; + memory_inuse.fetch_sub(len); + return ::free(p); +} + +const OrtAllocatorInfo* MockedOrtAllocator::Info() const { + return cpuAllocatorInfo; +} + +void MockedOrtAllocator::LeakCheck() { + if (memory_inuse.load()) + throw std::runtime_error("memory leak!!!"); +}