mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Custom thread creation and join hooks (#9426)
This commit is contained in:
parent
9f69d8bbae
commit
21eb747a0f
14 changed files with 387 additions and 44 deletions
|
|
@ -527,6 +527,29 @@ typedef struct OrtApiBase OrtApiBase;
|
|||
*/
|
||||
ORT_EXPORT const OrtApiBase* ORT_API_CALL OrtGetApiBase(void) NO_EXCEPTION;
|
||||
|
||||
/** \brief Thread work loop function
|
||||
*
|
||||
* Onnxruntime will provide the working loop on custom thread creation
|
||||
* Argument is an onnxruntime built-in type which will be provided when thread pool calls OrtCustomCreateThreadFn
|
||||
*/
|
||||
typedef void (*OrtThreadWorkerFn)(void* ort_worker_fn_param);
|
||||
|
||||
typedef const struct OrtCustomHandleType{ char __place_holder; }* OrtCustomThreadHandle;
|
||||
|
||||
/** \brief Ort custom thread creation function
|
||||
*
|
||||
* The function should return a thread handle to be used in onnxruntime thread pools
|
||||
* Onnxruntime will throw exception on return value of nullptr or 0, indicating that the function failed to create a thread
|
||||
*/
|
||||
typedef OrtCustomThreadHandle (*OrtCustomCreateThreadFn)(void* ort_custom_thread_creation_options, OrtThreadWorkerFn ort_thread_worker_fn, void* ort_worker_fn_param);
|
||||
|
||||
/** \brief Custom thread join function
|
||||
*
|
||||
* Onnxruntime thread pool destructor will call the function to join a custom thread.
|
||||
* Argument ort_custom_thread_handle is the value returned by OrtCustomCreateThreadFn
|
||||
*/
|
||||
typedef void (*OrtCustomJoinThreadFn)(OrtCustomThreadHandle ort_custom_thread_handle);
|
||||
|
||||
/** \brief The C API
|
||||
*
|
||||
* All C API functions are defined inside this structure as pointers to functions.
|
||||
|
|
@ -3037,7 +3060,6 @@ struct OrtApi {
|
|||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*/
|
||||
ORT_API2_STATUS(HasValue, _In_ const OrtValue* value, _Out_ int* out);
|
||||
|
||||
/// @}
|
||||
|
||||
/// \name OrtKernelContext
|
||||
|
|
@ -3053,6 +3075,66 @@ struct OrtApi {
|
|||
*/
|
||||
ORT_API2_STATUS(KernelContext_GetGPUComputeStream, _In_ const OrtKernelContext* context, _Outptr_ void** out);
|
||||
/// @}
|
||||
|
||||
/// \name SessionOptions
|
||||
/// @{
|
||||
/** \brief Set custom thread creation function
|
||||
*
|
||||
* \param[in] session options
|
||||
* \param[in] custom thread creation function
|
||||
*
|
||||
* * \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*/
|
||||
ORT_API2_STATUS(SessionOptionsSetCustomCreateThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn);
|
||||
|
||||
/** \brief Set creation options for custom thread
|
||||
*
|
||||
* \param[in] session options
|
||||
* \param[in] custom thread creation options (can be nullptr)
|
||||
*
|
||||
* * \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*/
|
||||
ORT_API2_STATUS(SessionOptionsSetCustomThreadCreationOptions, _Inout_ OrtSessionOptions* options, _In_ void* ort_custom_thread_creation_options);
|
||||
|
||||
/** \brief Set custom thread join function
|
||||
*
|
||||
* \param[in] session options
|
||||
* \param[in] custom join thread function, must not be nullptr when ort_custom_create_thread_fn is set
|
||||
*
|
||||
* * \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*/
|
||||
ORT_API2_STATUS(SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn);
|
||||
/// @}
|
||||
|
||||
/// \name OrtThreadingOptions
|
||||
/// @{
|
||||
/** \brief Set custom thread creation function for global thread pools
|
||||
*
|
||||
* \param[inout] tp_options
|
||||
* \param[in] custom thread creation function
|
||||
*
|
||||
* * \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*/
|
||||
ORT_API2_STATUS(SetGlobalCustomCreateThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn);
|
||||
|
||||
/** \brief Set custom thread creation options for global thread pools
|
||||
*
|
||||
* \param[inout] tp_options
|
||||
* \param[in] custom thread creation options (can be nullptr)
|
||||
*
|
||||
* * \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*/
|
||||
ORT_API2_STATUS(SetGlobalCustomThreadCreationOptions, _Inout_ OrtThreadingOptions* tp_options, _In_ void* ort_custom_thread_creation_options);
|
||||
|
||||
/** \brief Set custom thread join function for global thread pools
|
||||
*
|
||||
* \param[inout] tp_options
|
||||
* \param[in] custom thread join function, must not be nullptr when global ort_custom_create_thread_fn is set
|
||||
*
|
||||
* * \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*/
|
||||
ORT_API2_STATUS(SetGlobalCustomJoinThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn);
|
||||
/// @}
|
||||
};
|
||||
|
||||
/*
|
||||
|
|
|
|||
|
|
@ -351,6 +351,10 @@ struct SessionOptions : Base<OrtSessionOptions> {
|
|||
SessionOptions& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM
|
||||
SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO
|
||||
SessionOptions& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
|
||||
|
||||
SessionOptions& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn
|
||||
SessionOptions& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions
|
||||
SessionOptions& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn
|
||||
};
|
||||
|
||||
/** \brief Wrapper around ::OrtModelMetadata
|
||||
|
|
|
|||
|
|
@ -510,6 +510,21 @@ inline SessionOptions& SessionOptions::AppendExecutionProvider_TensorRT(const Or
|
|||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
|
||||
ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(p_, ort_custom_create_thread_fn));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
|
||||
ThrowOnError(GetApi().SessionOptionsSetCustomThreadCreationOptions(p_, ort_custom_thread_creation_options));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
|
||||
ThrowOnError(GetApi().SessionOptionsSetCustomJoinThreadFn(p_, ort_custom_join_thread_fn));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options) {
|
||||
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(p_, &provider_options));
|
||||
return *this;
|
||||
|
|
|
|||
|
|
@ -119,6 +119,15 @@ struct SessionOptions {
|
|||
|
||||
// See onnxruntime_c_api.h for detailed documentation.
|
||||
Status AddInitializer(_In_z_ const char* name, _In_ const OrtValue* val) noexcept;
|
||||
|
||||
// custom function callback to create a thread
|
||||
OrtCustomCreateThreadFn custom_create_thread_fn = nullptr;
|
||||
|
||||
// custom options to pass to custom_create_thread_fn
|
||||
void* custom_thread_creation_options = nullptr;
|
||||
|
||||
// custom function callback to join a thread
|
||||
OrtCustomJoinThreadFn custom_join_thread_fn = nullptr;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||
#include "core/framework/callback.h"
|
||||
#include "core/platform/env_time.h"
|
||||
#include "core/platform/telemetry.h"
|
||||
#include "core/session/onnxruntime_c_api.h"
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <sys/types.h>
|
||||
|
|
@ -49,6 +50,12 @@ using FileOffsetType = off_t;
|
|||
class EnvThread {
|
||||
public:
|
||||
virtual ~EnvThread() = default;
|
||||
|
||||
protected:
|
||||
OrtCustomCreateThreadFn custom_create_thread_fn = nullptr;
|
||||
void* custom_thread_creation_options = nullptr;
|
||||
OrtCustomJoinThreadFn custom_join_thread_fn = nullptr;
|
||||
OrtCustomThreadHandle custom_thread_handle = nullptr;
|
||||
};
|
||||
|
||||
// Parameters that are required to create a set of threads for a thread pool
|
||||
|
|
@ -67,6 +74,10 @@ struct ThreadOptions {
|
|||
|
||||
// Set or unset denormal as zero.
|
||||
bool set_denormal_as_zero = false;
|
||||
|
||||
OrtCustomCreateThreadFn custom_create_thread_fn = nullptr;
|
||||
void* custom_thread_creation_options = nullptr;
|
||||
OrtCustomJoinThreadFn custom_join_thread_fn = nullptr;
|
||||
};
|
||||
/// \brief An interface used by the onnxruntime implementation to
|
||||
/// access operating system functionality like the filesystem etc.
|
||||
|
|
|
|||
|
|
@ -143,47 +143,63 @@ class PosixThread : public EnvThread {
|
|||
PosixThread(const ORTCHAR_T* name_prefix, int index,
|
||||
unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param), Eigen::ThreadPoolInterface* param,
|
||||
const ThreadOptions& thread_options) {
|
||||
pthread_attr_t attr;
|
||||
int s = pthread_attr_init(&attr);
|
||||
if (s != 0) {
|
||||
auto[err_no, err_msg] = GetSystemError();
|
||||
ORT_THROW("pthread_attr_init failed, error code: ", err_no, " error msg: ", err_msg);
|
||||
}
|
||||
if (thread_options.stack_size > 0) {
|
||||
s = pthread_attr_setstacksize(&attr, thread_options.stack_size);
|
||||
if (s != 0) {
|
||||
auto[err_no, err_msg] = GetSystemError();
|
||||
ORT_THROW("pthread_attr_setstacksize failed, error code: ", err_no, " error msg: ", err_msg);
|
||||
custom_create_thread_fn = thread_options.custom_create_thread_fn;
|
||||
custom_thread_creation_options = thread_options.custom_thread_creation_options;
|
||||
custom_join_thread_fn = thread_options.custom_join_thread_fn;
|
||||
|
||||
if (custom_create_thread_fn) {
|
||||
custom_thread_handle = custom_create_thread_fn(custom_thread_creation_options, CustomThreadMain, new Param{name_prefix, index, start_address, param, thread_options});
|
||||
if (!custom_thread_handle) {
|
||||
ORT_THROW("custom_create_thread_fn returned invalid handle.");
|
||||
}
|
||||
} else {
|
||||
pthread_attr_t attr;
|
||||
int s = pthread_attr_init(&attr);
|
||||
if (s != 0) {
|
||||
auto [err_no, err_msg] = GetSystemError();
|
||||
ORT_THROW("pthread_attr_init failed, error code: ", err_no, " error msg: ", err_msg);
|
||||
}
|
||||
if (thread_options.stack_size > 0) {
|
||||
s = pthread_attr_setstacksize(&attr, thread_options.stack_size);
|
||||
if (s != 0) {
|
||||
auto [err_no, err_msg] = GetSystemError();
|
||||
ORT_THROW("pthread_attr_setstacksize failed, error code: ", err_no, " error msg: ", err_msg);
|
||||
}
|
||||
}
|
||||
s = pthread_create(&hThread, &attr, ThreadMain,
|
||||
new Param{name_prefix, index, start_address, param, thread_options});
|
||||
if (s != 0) {
|
||||
auto [err_no, err_msg] = GetSystemError();
|
||||
ORT_THROW("pthread_create failed, error code: ", err_no, " error msg: ", err_msg);
|
||||
}
|
||||
}
|
||||
s = pthread_create(&hThread, &attr, ThreadMain,
|
||||
new Param{name_prefix, index, start_address, param, thread_options});
|
||||
if (s != 0) {
|
||||
auto[err_no, err_msg] = GetSystemError();
|
||||
ORT_THROW("pthread_create failed, error code: ", err_no, " error msg: ", err_msg);
|
||||
}
|
||||
#if !defined(__APPLE__) && !defined(__ANDROID__) && !defined(__wasm__)
|
||||
if (!thread_options.affinity.empty()) {
|
||||
cpu_set_t cpuset;
|
||||
CPU_ZERO(&cpuset);
|
||||
CPU_SET(thread_options.affinity[index], &cpuset);
|
||||
s = pthread_setaffinity_np(hThread, sizeof(cpu_set_t), &cpuset);
|
||||
if (s != 0) {
|
||||
auto[err_no, err_msg] = GetSystemError();
|
||||
ORT_THROW("pthread_setaffinity_np failed, error code: ", err_no, " error msg: ", err_msg);
|
||||
if (!thread_options.affinity.empty()) {
|
||||
cpu_set_t cpuset;
|
||||
CPU_ZERO(&cpuset);
|
||||
CPU_SET(thread_options.affinity[index], &cpuset);
|
||||
s = pthread_setaffinity_np(hThread, sizeof(cpu_set_t), &cpuset);
|
||||
if (s != 0) {
|
||||
auto [err_no, err_msg] = GetSystemError();
|
||||
ORT_THROW("pthread_setaffinity_np failed, error code: ", err_no, " error msg: ", err_msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
~PosixThread() override {
|
||||
void* res;
|
||||
if (custom_thread_handle) {
|
||||
custom_join_thread_fn(custom_thread_handle);
|
||||
custom_thread_handle = nullptr;
|
||||
} else {
|
||||
void* res;
|
||||
#ifdef NDEBUG
|
||||
pthread_join(hThread, &res);
|
||||
pthread_join(hThread, &res);
|
||||
#else
|
||||
int ret = pthread_join(hThread, &res);
|
||||
assert(ret == 0);
|
||||
int ret = pthread_join(hThread, &res);
|
||||
assert(ret == 0);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -198,6 +214,9 @@ class PosixThread : public EnvThread {
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
static void CustomThreadMain(void* param) {
|
||||
ThreadMain(param);
|
||||
}
|
||||
pthread_t hThread;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
|
@ -59,18 +59,33 @@ class WindowsThread : public EnvThread {
|
|||
public:
|
||||
WindowsThread(const ORTCHAR_T* name_prefix, int index,
|
||||
unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param), Eigen::ThreadPoolInterface* param,
|
||||
const ThreadOptions& thread_options)
|
||||
: hThread((HANDLE)_beginthreadex(nullptr, thread_options.stack_size, ThreadMain,
|
||||
new Param{name_prefix, index, start_address, param, thread_options}, 0,
|
||||
&threadID)) {
|
||||
const ThreadOptions& thread_options) {
|
||||
custom_create_thread_fn = thread_options.custom_create_thread_fn;
|
||||
custom_thread_creation_options = thread_options.custom_thread_creation_options;
|
||||
custom_join_thread_fn = thread_options.custom_join_thread_fn;
|
||||
|
||||
if (custom_create_thread_fn) {
|
||||
custom_thread_handle = custom_create_thread_fn(custom_thread_creation_options, (OrtThreadWorkerFn)CustomThreadMain, new Param{name_prefix, index, start_address, param, thread_options});
|
||||
if (!custom_thread_handle) {
|
||||
ORT_THROW("custom_create_thread_fn returned invalid handle.");
|
||||
}
|
||||
} else {
|
||||
hThread.reset(reinterpret_cast<HANDLE>(_beginthreadex(nullptr, thread_options.stack_size, ThreadMain,
|
||||
new Param{name_prefix, index, start_address, param, thread_options}, 0,
|
||||
&threadID)));
|
||||
}
|
||||
}
|
||||
|
||||
~WindowsThread() {
|
||||
DWORD waitStatus = WaitForSingleObject(hThread.get(), INFINITE);
|
||||
FAIL_FAST_LAST_ERROR_IF(waitStatus == WAIT_FAILED);
|
||||
if (custom_thread_handle) {
|
||||
custom_join_thread_fn(custom_thread_handle);
|
||||
custom_thread_handle = nullptr;
|
||||
} else {
|
||||
DWORD waitStatus = WaitForSingleObject(hThread.get(), INFINITE);
|
||||
FAIL_FAST_LAST_ERROR_IF(waitStatus == WAIT_FAILED);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
typedef HRESULT(WINAPI* SetThreadDescriptionFunc)(HANDLE hThread, PCWSTR lpThreadDescription);
|
||||
|
||||
|
|
@ -100,7 +115,6 @@ class WindowsThread : public EnvThread {
|
|||
// Ignore the error
|
||||
(void)pSetThrDesc(GetCurrentThread(), oss.str().c_str());
|
||||
}
|
||||
|
||||
unsigned ret = 0;
|
||||
ORT_TRY {
|
||||
ret = p->start_address(p->index, p->param);
|
||||
|
|
@ -113,6 +127,15 @@ class WindowsThread : public EnvThread {
|
|||
}
|
||||
#pragma warning(pop)
|
||||
|
||||
static void __stdcall CustomThreadMain(void* param) {
|
||||
std::unique_ptr<Param> p((Param*)param);
|
||||
ORT_TRY {
|
||||
p->start_address(p->index, p->param);
|
||||
}
|
||||
ORT_CATCH(const std::exception&) {
|
||||
p->param->Cancel();
|
||||
}
|
||||
}
|
||||
unsigned threadID = 0;
|
||||
wil::unique_handle hThread;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -296,6 +296,14 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options,
|
|||
session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL &&
|
||||
to.affinity_vec_len == 0;
|
||||
to.allow_spinning = allow_intra_op_spinning;
|
||||
|
||||
// Set custom threading functions
|
||||
to.custom_create_thread_fn = session_options_.custom_create_thread_fn;
|
||||
to.custom_thread_creation_options = session_options.custom_thread_creation_options;
|
||||
to.custom_join_thread_fn = session_options_.custom_join_thread_fn;
|
||||
if (to.custom_create_thread_fn) {
|
||||
ORT_ENFORCE(to.custom_join_thread_fn, "custom join thread function not set for intra op thread pool");
|
||||
}
|
||||
thread_pool_ =
|
||||
concurrency::CreateThreadPool(&Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP);
|
||||
}
|
||||
|
|
@ -316,6 +324,14 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options,
|
|||
to.name = inter_thread_pool_name_.c_str();
|
||||
to.set_denormal_as_zero = set_denormal_as_zero;
|
||||
to.allow_spinning = allow_inter_op_spinning;
|
||||
|
||||
// Set custom threading functions
|
||||
to.custom_create_thread_fn = session_options_.custom_create_thread_fn;
|
||||
to.custom_thread_creation_options = session_options.custom_thread_creation_options;
|
||||
to.custom_join_thread_fn = session_options_.custom_join_thread_fn;
|
||||
if (to.custom_create_thread_fn) {
|
||||
ORT_ENFORCE(to.custom_join_thread_fn, "custom join thread function not set for inter op thread pool");
|
||||
}
|
||||
inter_op_thread_pool_ =
|
||||
concurrency::CreateThreadPool(&Env::Default(), to, concurrency::ThreadPoolType::INTER_OP);
|
||||
if (inter_op_thread_pool_ == nullptr) {
|
||||
|
|
|
|||
|
|
@ -2082,6 +2082,27 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArrayWithPrepackedWeightsContainer
|
|||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetCustomCreateThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
|
||||
API_IMPL_BEGIN
|
||||
options->value.custom_create_thread_fn = ort_custom_create_thread_fn;
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetCustomThreadCreationOptions, _Inout_ OrtSessionOptions* options, _In_ void* ort_custom_thread_creation_options) {
|
||||
API_IMPL_BEGIN
|
||||
options->value.custom_thread_creation_options = ort_custom_thread_creation_options;
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
|
||||
API_IMPL_BEGIN
|
||||
options->value.custom_join_thread_fn = ort_custom_join_thread_fn;
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
static constexpr OrtApiBase ort_api_base = {
|
||||
&OrtApis::GetApi,
|
||||
&OrtApis::GetVersionString,
|
||||
|
|
@ -2358,6 +2379,12 @@ static constexpr OrtApi ort_api_1_to_10 = {
|
|||
// Version 10 - In development, feel free to add/remove/rearrange here
|
||||
&OrtApis::HasValue,
|
||||
&OrtApis::KernelContext_GetGPUComputeStream,
|
||||
&OrtApis::SessionOptionsSetCustomCreateThreadFn,
|
||||
&OrtApis::SessionOptionsSetCustomThreadCreationOptions,
|
||||
&OrtApis::SessionOptionsSetCustomJoinThreadFn,
|
||||
&OrtApis::SetGlobalCustomCreateThreadFn,
|
||||
&OrtApis::SetGlobalCustomThreadCreationOptions,
|
||||
&OrtApis::SetGlobalCustomJoinThreadFn,
|
||||
};
|
||||
|
||||
// Asserts to do a some checks to ensure older Versions of the OrtApi never change (will detect an addition or deletion but not if they cancel out each other)
|
||||
|
|
|
|||
|
|
@ -317,5 +317,11 @@ ORT_API_STATUS_IMPL(GetSparseTensorValues, _In_ const OrtValue* ort_value, _Outp
|
|||
ORT_API_STATUS_IMPL(GetSparseTensorIndicesTypeShape, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Outptr_ OrtTensorTypeAndShapeInfo** out);
|
||||
ORT_API_STATUS_IMPL(GetSparseTensorIndices, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Out_ size_t* num_indices, _Outptr_ const void** indices);
|
||||
ORT_API_STATUS_IMPL(KernelContext_GetGPUComputeStream, _In_ const OrtKernelContext* context, _Outptr_ void** out);
|
||||
ORT_API_STATUS_IMPL(SessionOptionsSetCustomCreateThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn);
|
||||
ORT_API_STATUS_IMPL(SessionOptionsSetCustomThreadCreationOptions, _Inout_ OrtSessionOptions* options, _In_ void* ort_custom_thread_creation_options);
|
||||
ORT_API_STATUS_IMPL(SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn);
|
||||
ORT_API_STATUS_IMPL(SetGlobalCustomCreateThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn);
|
||||
ORT_API_STATUS_IMPL(SetGlobalCustomThreadCreationOptions, _Inout_ OrtThreadingOptions* tp_options, _In_ void* ort_custom_thread_creation_options);
|
||||
ORT_API_STATUS_IMPL(SetGlobalCustomJoinThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn);
|
||||
|
||||
} // namespace OrtApis
|
||||
|
|
|
|||
|
|
@ -28,6 +28,14 @@ CreateThreadPoolHelper(Env* env, OrtThreadPoolParams options) {
|
|||
}
|
||||
to.set_denormal_as_zero = options.set_denormal_as_zero;
|
||||
|
||||
// set custom thread management members
|
||||
to.custom_create_thread_fn = options.custom_create_thread_fn;
|
||||
to.custom_thread_creation_options = options.custom_thread_creation_options;
|
||||
to.custom_join_thread_fn = options.custom_join_thread_fn;
|
||||
if (to.custom_create_thread_fn) {
|
||||
ORT_ENFORCE(to.custom_join_thread_fn, "custom join thread function not set");
|
||||
}
|
||||
|
||||
return std::make_unique<ThreadPool>(env, to, options.name, options.thread_pool_size,
|
||||
options.allow_spinning);
|
||||
}
|
||||
|
|
@ -99,4 +107,31 @@ ORT_API_STATUS_IMPL(SetGlobalDenormalAsZero, _Inout_ OrtThreadingOptions* tp_opt
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(SetGlobalCustomCreateThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
|
||||
if (!tp_options) {
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Received null OrtThreadingOptions");
|
||||
}
|
||||
tp_options->inter_op_thread_pool_params.custom_create_thread_fn = ort_custom_create_thread_fn;
|
||||
tp_options->intra_op_thread_pool_params.custom_create_thread_fn = ort_custom_create_thread_fn;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(SetGlobalCustomThreadCreationOptions, _Inout_ OrtThreadingOptions* tp_options, _In_ void* ort_custom_thread_creation_options) {
|
||||
if (!tp_options) {
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Received null OrtThreadingOptions");
|
||||
}
|
||||
tp_options->inter_op_thread_pool_params.custom_thread_creation_options = ort_custom_thread_creation_options;
|
||||
tp_options->intra_op_thread_pool_params.custom_thread_creation_options = ort_custom_thread_creation_options;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(SetGlobalCustomJoinThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
|
||||
if (!tp_options) {
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Received null OrtThreadingOptions");
|
||||
}
|
||||
tp_options->inter_op_thread_pool_params.custom_join_thread_fn = ort_custom_join_thread_fn;
|
||||
tp_options->intra_op_thread_pool_params.custom_join_thread_fn = ort_custom_join_thread_fn;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace OrtApis
|
||||
|
|
|
|||
|
|
@ -28,6 +28,11 @@ struct OrtThreadPoolParams {
|
|||
|
||||
// Set or unset denormal as zero
|
||||
bool set_denormal_as_zero = false;
|
||||
|
||||
// members to manage custom threads
|
||||
OrtCustomCreateThreadFn custom_create_thread_fn = nullptr;
|
||||
void* custom_thread_creation_options = nullptr;
|
||||
OrtCustomJoinThreadFn custom_join_thread_fn = nullptr;
|
||||
};
|
||||
|
||||
struct OrtThreadingOptions {
|
||||
|
|
|
|||
|
|
@ -26,13 +26,42 @@ std::unique_ptr<Ort::Env> ort_env;
|
|||
return -1; \
|
||||
}
|
||||
|
||||
namespace TestGlobalCustomThreadHooks {
|
||||
|
||||
std::vector<std::thread> threads;
|
||||
int32_t custom_thread_creation_options = 5;
|
||||
int32_t custom_creation_hook_called = 0;
|
||||
int32_t custom_join_hook_called = 0;
|
||||
|
||||
OrtCustomThreadHandle CreateThreadCustomized(void* options, OrtThreadWorkerFn work_loop, void* param) {
|
||||
if (*((int32_t*)options) == 5) {
|
||||
custom_creation_hook_called += 1;
|
||||
}
|
||||
threads.push_back(std::thread(work_loop, param));
|
||||
return reinterpret_cast<OrtCustomThreadHandle>(threads.back().native_handle());
|
||||
}
|
||||
|
||||
void JoinThreadCustomized(OrtCustomThreadHandle handle) {
|
||||
for (auto& t : threads) {
|
||||
if (reinterpret_cast<OrtCustomThreadHandle>(t.native_handle()) == handle) {
|
||||
custom_join_hook_called += 1;
|
||||
t.join();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace TestGlobalCustomThreadHooks
|
||||
|
||||
using namespace TestGlobalCustomThreadHooks;
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
int status = 0;
|
||||
const int thread_pool_size = std::thread::hardware_concurrency();
|
||||
ORT_TRY {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
const OrtApi* g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION);
|
||||
OrtThreadingOptions* tp_options;
|
||||
std::unique_ptr<OrtStatus, decltype(OrtApi::ReleaseStatus)> st_ptr(nullptr, g_ort->ReleaseStatus);
|
||||
OrtThreadingOptions* tp_options;
|
||||
|
||||
st_ptr.reset(g_ort->CreateThreadingOptions(&tp_options));
|
||||
ORT_RETURN_IF_NON_NULL_STATUS(st_ptr);
|
||||
|
|
@ -40,10 +69,19 @@ int main(int argc, char** argv) {
|
|||
st_ptr.reset(g_ort->SetGlobalSpinControl(tp_options, 0));
|
||||
ORT_RETURN_IF_NON_NULL_STATUS(st_ptr);
|
||||
|
||||
st_ptr.reset(g_ort->SetGlobalIntraOpNumThreads(tp_options, std::thread::hardware_concurrency()));
|
||||
st_ptr.reset(g_ort->SetGlobalIntraOpNumThreads(tp_options, thread_pool_size));
|
||||
ORT_RETURN_IF_NON_NULL_STATUS(st_ptr);
|
||||
|
||||
st_ptr.reset(g_ort->SetGlobalInterOpNumThreads(tp_options, std::thread::hardware_concurrency()));
|
||||
st_ptr.reset(g_ort->SetGlobalCustomCreateThreadFn(tp_options, CreateThreadCustomized));
|
||||
ORT_RETURN_IF_NON_NULL_STATUS(st_ptr);
|
||||
|
||||
st_ptr.reset(g_ort->SetGlobalCustomThreadCreationOptions(tp_options, &custom_thread_creation_options));
|
||||
ORT_RETURN_IF_NON_NULL_STATUS(st_ptr);
|
||||
|
||||
st_ptr.reset(g_ort->SetGlobalCustomJoinThreadFn(tp_options, JoinThreadCustomized));
|
||||
ORT_RETURN_IF_NON_NULL_STATUS(st_ptr);
|
||||
|
||||
st_ptr.reset(g_ort->SetGlobalInterOpNumThreads(tp_options, thread_pool_size));
|
||||
ORT_RETURN_IF_NON_NULL_STATUS(st_ptr);
|
||||
|
||||
st_ptr.reset(g_ort->SetGlobalDenormalAsZero(tp_options));
|
||||
|
|
@ -63,6 +101,12 @@ int main(int argc, char** argv) {
|
|||
//TODO: Fix the C API issue
|
||||
ort_env.reset(); //If we don't do this, it will crash
|
||||
|
||||
#ifndef _OPENMP
|
||||
const int expexted_custom_calls = (thread_pool_size - 1) << 1;
|
||||
ORT_ENFORCE(custom_creation_hook_called == expexted_custom_calls, "custom thread creation function was not called as expected");
|
||||
ORT_ENFORCE(custom_join_hook_called == expexted_custom_calls, "custom thread join function was not called as expected");
|
||||
#endif
|
||||
|
||||
#ifndef USE_ONNXRUNTIME_DLL
|
||||
//make memory leak checker happy
|
||||
::google::protobuf::ShutdownProtobufLibrary();
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
#include <atomic>
|
||||
#include <mutex>
|
||||
#include <algorithm>
|
||||
#include <thread>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
|
|
@ -18,6 +19,7 @@
|
|||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
#include "core/session/onnxruntime_session_options_config_keys.h"
|
||||
#include "core/session/onnxruntime_run_options_config_keys.h"
|
||||
#include "core/util/thread_utils.h"
|
||||
#include "providers.h"
|
||||
#include "test_allocator.h"
|
||||
#include "test_fixture.h"
|
||||
|
|
@ -1836,3 +1838,48 @@ TEST(CApiTest, TestConfigureTensorRTProviderOptions) {
|
|||
ASSERT_TRUE(stat(engine_cache_path, &buffer) == 0);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef _OPENMP
|
||||
namespace TestPerSessionCustomThreadHooks {
|
||||
|
||||
std::vector<std::thread> threads;
|
||||
int32_t custom_thread_creation_options = 5;
|
||||
int32_t custom_creation_hook_called = 0;
|
||||
int32_t custom_join_hook_called = 0;
|
||||
|
||||
OrtCustomThreadHandle CreateThreadCustomized(void* options, OrtThreadWorkerFn work_loop, void* param) {
|
||||
if (*((int32_t*)options) == 5) {
|
||||
custom_creation_hook_called += 1;
|
||||
}
|
||||
threads.push_back(std::thread(work_loop, param));
|
||||
return reinterpret_cast<OrtCustomThreadHandle>(threads.back().native_handle());
|
||||
}
|
||||
|
||||
void JoinThreadCustomized(OrtCustomThreadHandle handle) {
|
||||
for (auto& t : threads) {
|
||||
if (reinterpret_cast<OrtCustomThreadHandle>(t.native_handle()) == handle) {
|
||||
custom_join_hook_called += 1;
|
||||
t.join();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CApiTest, TestPerSessionCustomThreadPoolHooks) {
|
||||
const int32_t thread_count = 3;
|
||||
Ort::SessionOptions session_options;
|
||||
// test both intra and inter op thread pool
|
||||
session_options.SetExecutionMode(ExecutionMode::ORT_PARALLEL);
|
||||
session_options.SetIntraOpNumThreads(thread_count);
|
||||
session_options.SetInterOpNumThreads(thread_count);
|
||||
session_options.SetCustomCreateThreadFn(CreateThreadCustomized);
|
||||
session_options.SetCustomThreadCreationOptions(&custom_thread_creation_options);
|
||||
session_options.SetCustomJoinThreadFn(JoinThreadCustomized);
|
||||
{
|
||||
Ort::Session session(*ort_env, MODEL_URI, session_options);
|
||||
}
|
||||
ASSERT_TRUE(custom_creation_hook_called == (thread_count - 1) << 1);
|
||||
ASSERT_TRUE(custom_join_hook_called == (thread_count - 1) << 1);
|
||||
}
|
||||
|
||||
} // namespace TestPerSessionCustomThreadHooks
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Reference in a new issue