Custom thread creation and join hooks (#9426)

This commit is contained in:
RandySheriffH 2021-11-12 19:10:31 -08:00 committed by GitHub
parent 9f69d8bbae
commit 21eb747a0f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 387 additions and 44 deletions

View file

@ -527,6 +527,29 @@ typedef struct OrtApiBase OrtApiBase;
*/
ORT_EXPORT const OrtApiBase* ORT_API_CALL OrtGetApiBase(void) NO_EXCEPTION;
/** \brief Thread work loop function
*
* Onnxruntime will provide the working loop on custom thread creation
* Argument is an onnxruntime built-in type which will be provided when thread pool calls OrtCustomCreateThreadFn
*/
typedef void (*OrtThreadWorkerFn)(void* ort_worker_fn_param);
typedef const struct OrtCustomHandleType{ char __place_holder; }* OrtCustomThreadHandle;
/** \brief Ort custom thread creation function
*
* The function should return a thread handle to be used in onnxruntime thread pools
* Onnxruntime will throw exception on return value of nullptr or 0, indicating that the function failed to create a thread
*/
typedef OrtCustomThreadHandle (*OrtCustomCreateThreadFn)(void* ort_custom_thread_creation_options, OrtThreadWorkerFn ort_thread_worker_fn, void* ort_worker_fn_param);
/** \brief Custom thread join function
*
* Onnxruntime thread pool destructor will call the function to join a custom thread.
* Argument ort_custom_thread_handle is the value returned by OrtCustomCreateThreadFn
*/
typedef void (*OrtCustomJoinThreadFn)(OrtCustomThreadHandle ort_custom_thread_handle);
/** \brief The C API
*
* All C API functions are defined inside this structure as pointers to functions.
@ -3037,7 +3060,6 @@ struct OrtApi {
* \snippet{doc} snippets.dox OrtStatus Return Value
*/
ORT_API2_STATUS(HasValue, _In_ const OrtValue* value, _Out_ int* out);
/// @}
/// \name OrtKernelContext
@ -3053,6 +3075,66 @@ struct OrtApi {
*/
ORT_API2_STATUS(KernelContext_GetGPUComputeStream, _In_ const OrtKernelContext* context, _Outptr_ void** out);
/// @}
/// \name SessionOptions
/// @{
/** \brief Set custom thread creation function
*
* \param[in] session options
* \param[in] custom thread creation function
*
* * \snippet{doc} snippets.dox OrtStatus Return Value
*/
ORT_API2_STATUS(SessionOptionsSetCustomCreateThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn);
/** \brief Set creation options for custom thread
*
* \param[in] session options
* \param[in] custom thread creation options (can be nullptr)
*
* * \snippet{doc} snippets.dox OrtStatus Return Value
*/
ORT_API2_STATUS(SessionOptionsSetCustomThreadCreationOptions, _Inout_ OrtSessionOptions* options, _In_ void* ort_custom_thread_creation_options);
/** \brief Set custom thread join function
*
* \param[in] session options
* \param[in] custom join thread function, must not be nullptr when ort_custom_create_thread_fn is set
*
* * \snippet{doc} snippets.dox OrtStatus Return Value
*/
ORT_API2_STATUS(SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn);
/// @}
/// \name OrtThreadingOptions
/// @{
/** \brief Set custom thread creation function for global thread pools
*
* \param[inout] tp_options
* \param[in] custom thread creation function
*
* * \snippet{doc} snippets.dox OrtStatus Return Value
*/
ORT_API2_STATUS(SetGlobalCustomCreateThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn);
/** \brief Set custom thread creation options for global thread pools
*
* \param[inout] tp_options
* \param[in] custom thread creation options (can be nullptr)
*
* * \snippet{doc} snippets.dox OrtStatus Return Value
*/
ORT_API2_STATUS(SetGlobalCustomThreadCreationOptions, _Inout_ OrtThreadingOptions* tp_options, _In_ void* ort_custom_thread_creation_options);
/** \brief Set custom thread join function for global thread pools
*
* \param[inout] tp_options
* \param[in] custom thread join function, must not be nullptr when global ort_custom_create_thread_fn is set
*
* * \snippet{doc} snippets.dox OrtStatus Return Value
*/
ORT_API2_STATUS(SetGlobalCustomJoinThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn);
/// @}
};
/*

View file

@ -351,6 +351,10 @@ struct SessionOptions : Base<OrtSessionOptions> {
SessionOptions& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM
SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO
SessionOptions& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
SessionOptions& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn
SessionOptions& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions
SessionOptions& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn
};
/** \brief Wrapper around ::OrtModelMetadata

View file

@ -510,6 +510,21 @@ inline SessionOptions& SessionOptions::AppendExecutionProvider_TensorRT(const Or
return *this;
}
inline SessionOptions& SessionOptions::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(p_, ort_custom_create_thread_fn));
return *this;
}
inline SessionOptions& SessionOptions::SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
ThrowOnError(GetApi().SessionOptionsSetCustomThreadCreationOptions(p_, ort_custom_thread_creation_options));
return *this;
}
inline SessionOptions& SessionOptions::SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
ThrowOnError(GetApi().SessionOptionsSetCustomJoinThreadFn(p_, ort_custom_join_thread_fn));
return *this;
}
inline SessionOptions& SessionOptions::AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options) {
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(p_, &provider_options));
return *this;

View file

@ -119,6 +119,15 @@ struct SessionOptions {
// See onnxruntime_c_api.h for detailed documentation.
Status AddInitializer(_In_z_ const char* name, _In_ const OrtValue* val) noexcept;
// custom function callback to create a thread
OrtCustomCreateThreadFn custom_create_thread_fn = nullptr;
// custom options to pass to custom_create_thread_fn
void* custom_thread_creation_options = nullptr;
// custom function callback to join a thread
OrtCustomJoinThreadFn custom_join_thread_fn = nullptr;
};
} // namespace onnxruntime

View file

@ -28,6 +28,7 @@ limitations under the License.
#include "core/framework/callback.h"
#include "core/platform/env_time.h"
#include "core/platform/telemetry.h"
#include "core/session/onnxruntime_c_api.h"
#ifndef _WIN32
#include <sys/types.h>
@ -49,6 +50,12 @@ using FileOffsetType = off_t;
class EnvThread {
public:
virtual ~EnvThread() = default;
protected:
OrtCustomCreateThreadFn custom_create_thread_fn = nullptr;
void* custom_thread_creation_options = nullptr;
OrtCustomJoinThreadFn custom_join_thread_fn = nullptr;
OrtCustomThreadHandle custom_thread_handle = nullptr;
};
// Parameters that are required to create a set of threads for a thread pool
@ -67,6 +74,10 @@ struct ThreadOptions {
// Set or unset denormal as zero.
bool set_denormal_as_zero = false;
OrtCustomCreateThreadFn custom_create_thread_fn = nullptr;
void* custom_thread_creation_options = nullptr;
OrtCustomJoinThreadFn custom_join_thread_fn = nullptr;
};
/// \brief An interface used by the onnxruntime implementation to
/// access operating system functionality like the filesystem etc.

View file

@ -143,47 +143,63 @@ class PosixThread : public EnvThread {
PosixThread(const ORTCHAR_T* name_prefix, int index,
unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param), Eigen::ThreadPoolInterface* param,
const ThreadOptions& thread_options) {
pthread_attr_t attr;
int s = pthread_attr_init(&attr);
if (s != 0) {
auto[err_no, err_msg] = GetSystemError();
ORT_THROW("pthread_attr_init failed, error code: ", err_no, " error msg: ", err_msg);
}
if (thread_options.stack_size > 0) {
s = pthread_attr_setstacksize(&attr, thread_options.stack_size);
if (s != 0) {
auto[err_no, err_msg] = GetSystemError();
ORT_THROW("pthread_attr_setstacksize failed, error code: ", err_no, " error msg: ", err_msg);
custom_create_thread_fn = thread_options.custom_create_thread_fn;
custom_thread_creation_options = thread_options.custom_thread_creation_options;
custom_join_thread_fn = thread_options.custom_join_thread_fn;
if (custom_create_thread_fn) {
custom_thread_handle = custom_create_thread_fn(custom_thread_creation_options, CustomThreadMain, new Param{name_prefix, index, start_address, param, thread_options});
if (!custom_thread_handle) {
ORT_THROW("custom_create_thread_fn returned invalid handle.");
}
} else {
pthread_attr_t attr;
int s = pthread_attr_init(&attr);
if (s != 0) {
auto [err_no, err_msg] = GetSystemError();
ORT_THROW("pthread_attr_init failed, error code: ", err_no, " error msg: ", err_msg);
}
if (thread_options.stack_size > 0) {
s = pthread_attr_setstacksize(&attr, thread_options.stack_size);
if (s != 0) {
auto [err_no, err_msg] = GetSystemError();
ORT_THROW("pthread_attr_setstacksize failed, error code: ", err_no, " error msg: ", err_msg);
}
}
s = pthread_create(&hThread, &attr, ThreadMain,
new Param{name_prefix, index, start_address, param, thread_options});
if (s != 0) {
auto [err_no, err_msg] = GetSystemError();
ORT_THROW("pthread_create failed, error code: ", err_no, " error msg: ", err_msg);
}
}
s = pthread_create(&hThread, &attr, ThreadMain,
new Param{name_prefix, index, start_address, param, thread_options});
if (s != 0) {
auto[err_no, err_msg] = GetSystemError();
ORT_THROW("pthread_create failed, error code: ", err_no, " error msg: ", err_msg);
}
#if !defined(__APPLE__) && !defined(__ANDROID__) && !defined(__wasm__)
if (!thread_options.affinity.empty()) {
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
CPU_SET(thread_options.affinity[index], &cpuset);
s = pthread_setaffinity_np(hThread, sizeof(cpu_set_t), &cpuset);
if (s != 0) {
auto[err_no, err_msg] = GetSystemError();
ORT_THROW("pthread_setaffinity_np failed, error code: ", err_no, " error msg: ", err_msg);
if (!thread_options.affinity.empty()) {
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
CPU_SET(thread_options.affinity[index], &cpuset);
s = pthread_setaffinity_np(hThread, sizeof(cpu_set_t), &cpuset);
if (s != 0) {
auto [err_no, err_msg] = GetSystemError();
ORT_THROW("pthread_setaffinity_np failed, error code: ", err_no, " error msg: ", err_msg);
}
}
}
#endif
}
}
~PosixThread() override {
void* res;
if (custom_thread_handle) {
custom_join_thread_fn(custom_thread_handle);
custom_thread_handle = nullptr;
} else {
void* res;
#ifdef NDEBUG
pthread_join(hThread, &res);
pthread_join(hThread, &res);
#else
int ret = pthread_join(hThread, &res);
assert(ret == 0);
int ret = pthread_join(hThread, &res);
assert(ret == 0);
#endif
}
}
private:
@ -198,6 +214,9 @@ class PosixThread : public EnvThread {
}
return nullptr;
}
static void CustomThreadMain(void* param) {
ThreadMain(param);
}
pthread_t hThread;
};

View file

@ -1,4 +1,4 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -59,18 +59,33 @@ class WindowsThread : public EnvThread {
public:
WindowsThread(const ORTCHAR_T* name_prefix, int index,
unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param), Eigen::ThreadPoolInterface* param,
const ThreadOptions& thread_options)
: hThread((HANDLE)_beginthreadex(nullptr, thread_options.stack_size, ThreadMain,
new Param{name_prefix, index, start_address, param, thread_options}, 0,
&threadID)) {
const ThreadOptions& thread_options) {
custom_create_thread_fn = thread_options.custom_create_thread_fn;
custom_thread_creation_options = thread_options.custom_thread_creation_options;
custom_join_thread_fn = thread_options.custom_join_thread_fn;
if (custom_create_thread_fn) {
custom_thread_handle = custom_create_thread_fn(custom_thread_creation_options, (OrtThreadWorkerFn)CustomThreadMain, new Param{name_prefix, index, start_address, param, thread_options});
if (!custom_thread_handle) {
ORT_THROW("custom_create_thread_fn returned invalid handle.");
}
} else {
hThread.reset(reinterpret_cast<HANDLE>(_beginthreadex(nullptr, thread_options.stack_size, ThreadMain,
new Param{name_prefix, index, start_address, param, thread_options}, 0,
&threadID)));
}
}
~WindowsThread() {
DWORD waitStatus = WaitForSingleObject(hThread.get(), INFINITE);
FAIL_FAST_LAST_ERROR_IF(waitStatus == WAIT_FAILED);
if (custom_thread_handle) {
custom_join_thread_fn(custom_thread_handle);
custom_thread_handle = nullptr;
} else {
DWORD waitStatus = WaitForSingleObject(hThread.get(), INFINITE);
FAIL_FAST_LAST_ERROR_IF(waitStatus == WAIT_FAILED);
}
}
private:
typedef HRESULT(WINAPI* SetThreadDescriptionFunc)(HANDLE hThread, PCWSTR lpThreadDescription);
@ -100,7 +115,6 @@ class WindowsThread : public EnvThread {
// Ignore the error
(void)pSetThrDesc(GetCurrentThread(), oss.str().c_str());
}
unsigned ret = 0;
ORT_TRY {
ret = p->start_address(p->index, p->param);
@ -113,6 +127,15 @@ class WindowsThread : public EnvThread {
}
#pragma warning(pop)
static void __stdcall CustomThreadMain(void* param) {
std::unique_ptr<Param> p((Param*)param);
ORT_TRY {
p->start_address(p->index, p->param);
}
ORT_CATCH(const std::exception&) {
p->param->Cancel();
}
}
unsigned threadID = 0;
wil::unique_handle hThread;
};

View file

@ -296,6 +296,14 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options,
session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL &&
to.affinity_vec_len == 0;
to.allow_spinning = allow_intra_op_spinning;
// Set custom threading functions
to.custom_create_thread_fn = session_options_.custom_create_thread_fn;
to.custom_thread_creation_options = session_options.custom_thread_creation_options;
to.custom_join_thread_fn = session_options_.custom_join_thread_fn;
if (to.custom_create_thread_fn) {
ORT_ENFORCE(to.custom_join_thread_fn, "custom join thread function not set for intra op thread pool");
}
thread_pool_ =
concurrency::CreateThreadPool(&Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP);
}
@ -316,6 +324,14 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options,
to.name = inter_thread_pool_name_.c_str();
to.set_denormal_as_zero = set_denormal_as_zero;
to.allow_spinning = allow_inter_op_spinning;
// Set custom threading functions
to.custom_create_thread_fn = session_options_.custom_create_thread_fn;
to.custom_thread_creation_options = session_options.custom_thread_creation_options;
to.custom_join_thread_fn = session_options_.custom_join_thread_fn;
if (to.custom_create_thread_fn) {
ORT_ENFORCE(to.custom_join_thread_fn, "custom join thread function not set for inter op thread pool");
}
inter_op_thread_pool_ =
concurrency::CreateThreadPool(&Env::Default(), to, concurrency::ThreadPoolType::INTER_OP);
if (inter_op_thread_pool_ == nullptr) {

View file

@ -2082,6 +2082,27 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArrayWithPrepackedWeightsContainer
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetCustomCreateThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
API_IMPL_BEGIN
options->value.custom_create_thread_fn = ort_custom_create_thread_fn;
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetCustomThreadCreationOptions, _Inout_ OrtSessionOptions* options, _In_ void* ort_custom_thread_creation_options) {
API_IMPL_BEGIN
options->value.custom_thread_creation_options = ort_custom_thread_creation_options;
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
API_IMPL_BEGIN
options->value.custom_join_thread_fn = ort_custom_join_thread_fn;
return nullptr;
API_IMPL_END
}
static constexpr OrtApiBase ort_api_base = {
&OrtApis::GetApi,
&OrtApis::GetVersionString,
@ -2358,6 +2379,12 @@ static constexpr OrtApi ort_api_1_to_10 = {
// Version 10 - In development, feel free to add/remove/rearrange here
&OrtApis::HasValue,
&OrtApis::KernelContext_GetGPUComputeStream,
&OrtApis::SessionOptionsSetCustomCreateThreadFn,
&OrtApis::SessionOptionsSetCustomThreadCreationOptions,
&OrtApis::SessionOptionsSetCustomJoinThreadFn,
&OrtApis::SetGlobalCustomCreateThreadFn,
&OrtApis::SetGlobalCustomThreadCreationOptions,
&OrtApis::SetGlobalCustomJoinThreadFn,
};
// Asserts to do a some checks to ensure older Versions of the OrtApi never change (will detect an addition or deletion but not if they cancel out each other)

View file

@ -317,5 +317,11 @@ ORT_API_STATUS_IMPL(GetSparseTensorValues, _In_ const OrtValue* ort_value, _Outp
ORT_API_STATUS_IMPL(GetSparseTensorIndicesTypeShape, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Outptr_ OrtTensorTypeAndShapeInfo** out);
ORT_API_STATUS_IMPL(GetSparseTensorIndices, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Out_ size_t* num_indices, _Outptr_ const void** indices);
ORT_API_STATUS_IMPL(KernelContext_GetGPUComputeStream, _In_ const OrtKernelContext* context, _Outptr_ void** out);
ORT_API_STATUS_IMPL(SessionOptionsSetCustomCreateThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn);
ORT_API_STATUS_IMPL(SessionOptionsSetCustomThreadCreationOptions, _Inout_ OrtSessionOptions* options, _In_ void* ort_custom_thread_creation_options);
ORT_API_STATUS_IMPL(SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn);
ORT_API_STATUS_IMPL(SetGlobalCustomCreateThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn);
ORT_API_STATUS_IMPL(SetGlobalCustomThreadCreationOptions, _Inout_ OrtThreadingOptions* tp_options, _In_ void* ort_custom_thread_creation_options);
ORT_API_STATUS_IMPL(SetGlobalCustomJoinThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn);
} // namespace OrtApis

View file

@ -28,6 +28,14 @@ CreateThreadPoolHelper(Env* env, OrtThreadPoolParams options) {
}
to.set_denormal_as_zero = options.set_denormal_as_zero;
// set custom thread management members
to.custom_create_thread_fn = options.custom_create_thread_fn;
to.custom_thread_creation_options = options.custom_thread_creation_options;
to.custom_join_thread_fn = options.custom_join_thread_fn;
if (to.custom_create_thread_fn) {
ORT_ENFORCE(to.custom_join_thread_fn, "custom join thread function not set");
}
return std::make_unique<ThreadPool>(env, to, options.name, options.thread_pool_size,
options.allow_spinning);
}
@ -99,4 +107,31 @@ ORT_API_STATUS_IMPL(SetGlobalDenormalAsZero, _Inout_ OrtThreadingOptions* tp_opt
return nullptr;
}
ORT_API_STATUS_IMPL(SetGlobalCustomCreateThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
if (!tp_options) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Received null OrtThreadingOptions");
}
tp_options->inter_op_thread_pool_params.custom_create_thread_fn = ort_custom_create_thread_fn;
tp_options->intra_op_thread_pool_params.custom_create_thread_fn = ort_custom_create_thread_fn;
return nullptr;
}
ORT_API_STATUS_IMPL(SetGlobalCustomThreadCreationOptions, _Inout_ OrtThreadingOptions* tp_options, _In_ void* ort_custom_thread_creation_options) {
if (!tp_options) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Received null OrtThreadingOptions");
}
tp_options->inter_op_thread_pool_params.custom_thread_creation_options = ort_custom_thread_creation_options;
tp_options->intra_op_thread_pool_params.custom_thread_creation_options = ort_custom_thread_creation_options;
return nullptr;
}
ORT_API_STATUS_IMPL(SetGlobalCustomJoinThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
if (!tp_options) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Received null OrtThreadingOptions");
}
tp_options->inter_op_thread_pool_params.custom_join_thread_fn = ort_custom_join_thread_fn;
tp_options->intra_op_thread_pool_params.custom_join_thread_fn = ort_custom_join_thread_fn;
return nullptr;
}
} // namespace OrtApis

View file

@ -28,6 +28,11 @@ struct OrtThreadPoolParams {
// Set or unset denormal as zero
bool set_denormal_as_zero = false;
// members to manage custom threads
OrtCustomCreateThreadFn custom_create_thread_fn = nullptr;
void* custom_thread_creation_options = nullptr;
OrtCustomJoinThreadFn custom_join_thread_fn = nullptr;
};
struct OrtThreadingOptions {

View file

@ -26,13 +26,42 @@ std::unique_ptr<Ort::Env> ort_env;
return -1; \
}
namespace TestGlobalCustomThreadHooks {
std::vector<std::thread> threads;
int32_t custom_thread_creation_options = 5;
int32_t custom_creation_hook_called = 0;
int32_t custom_join_hook_called = 0;
OrtCustomThreadHandle CreateThreadCustomized(void* options, OrtThreadWorkerFn work_loop, void* param) {
if (*((int32_t*)options) == 5) {
custom_creation_hook_called += 1;
}
threads.push_back(std::thread(work_loop, param));
return reinterpret_cast<OrtCustomThreadHandle>(threads.back().native_handle());
}
void JoinThreadCustomized(OrtCustomThreadHandle handle) {
for (auto& t : threads) {
if (reinterpret_cast<OrtCustomThreadHandle>(t.native_handle()) == handle) {
custom_join_hook_called += 1;
t.join();
}
}
}
} // namespace TestGlobalCustomThreadHooks
using namespace TestGlobalCustomThreadHooks;
int main(int argc, char** argv) {
int status = 0;
const int thread_pool_size = std::thread::hardware_concurrency();
ORT_TRY {
::testing::InitGoogleTest(&argc, argv);
const OrtApi* g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION);
OrtThreadingOptions* tp_options;
std::unique_ptr<OrtStatus, decltype(OrtApi::ReleaseStatus)> st_ptr(nullptr, g_ort->ReleaseStatus);
OrtThreadingOptions* tp_options;
st_ptr.reset(g_ort->CreateThreadingOptions(&tp_options));
ORT_RETURN_IF_NON_NULL_STATUS(st_ptr);
@ -40,10 +69,19 @@ int main(int argc, char** argv) {
st_ptr.reset(g_ort->SetGlobalSpinControl(tp_options, 0));
ORT_RETURN_IF_NON_NULL_STATUS(st_ptr);
st_ptr.reset(g_ort->SetGlobalIntraOpNumThreads(tp_options, std::thread::hardware_concurrency()));
st_ptr.reset(g_ort->SetGlobalIntraOpNumThreads(tp_options, thread_pool_size));
ORT_RETURN_IF_NON_NULL_STATUS(st_ptr);
st_ptr.reset(g_ort->SetGlobalInterOpNumThreads(tp_options, std::thread::hardware_concurrency()));
st_ptr.reset(g_ort->SetGlobalCustomCreateThreadFn(tp_options, CreateThreadCustomized));
ORT_RETURN_IF_NON_NULL_STATUS(st_ptr);
st_ptr.reset(g_ort->SetGlobalCustomThreadCreationOptions(tp_options, &custom_thread_creation_options));
ORT_RETURN_IF_NON_NULL_STATUS(st_ptr);
st_ptr.reset(g_ort->SetGlobalCustomJoinThreadFn(tp_options, JoinThreadCustomized));
ORT_RETURN_IF_NON_NULL_STATUS(st_ptr);
st_ptr.reset(g_ort->SetGlobalInterOpNumThreads(tp_options, thread_pool_size));
ORT_RETURN_IF_NON_NULL_STATUS(st_ptr);
st_ptr.reset(g_ort->SetGlobalDenormalAsZero(tp_options));
@ -63,6 +101,12 @@ int main(int argc, char** argv) {
//TODO: Fix the C API issue
ort_env.reset(); //If we don't do this, it will crash
#ifndef _OPENMP
const int expexted_custom_calls = (thread_pool_size - 1) << 1;
ORT_ENFORCE(custom_creation_hook_called == expexted_custom_calls, "custom thread creation function was not called as expected");
ORT_ENFORCE(custom_join_hook_called == expexted_custom_calls, "custom thread join function was not called as expected");
#endif
#ifndef USE_ONNXRUNTIME_DLL
//make memory leak checker happy
::google::protobuf::ShutdownProtobufLibrary();

View file

@ -9,6 +9,7 @@
#include <atomic>
#include <mutex>
#include <algorithm>
#include <thread>
#include <gtest/gtest.h>
@ -18,6 +19,7 @@
#include "core/session/onnxruntime_cxx_api.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/session/onnxruntime_run_options_config_keys.h"
#include "core/util/thread_utils.h"
#include "providers.h"
#include "test_allocator.h"
#include "test_fixture.h"
@ -1836,3 +1838,48 @@ TEST(CApiTest, TestConfigureTensorRTProviderOptions) {
ASSERT_TRUE(stat(engine_cache_path, &buffer) == 0);
}
#endif
#ifndef _OPENMP
namespace TestPerSessionCustomThreadHooks {
std::vector<std::thread> threads;
int32_t custom_thread_creation_options = 5;
int32_t custom_creation_hook_called = 0;
int32_t custom_join_hook_called = 0;
OrtCustomThreadHandle CreateThreadCustomized(void* options, OrtThreadWorkerFn work_loop, void* param) {
if (*((int32_t*)options) == 5) {
custom_creation_hook_called += 1;
}
threads.push_back(std::thread(work_loop, param));
return reinterpret_cast<OrtCustomThreadHandle>(threads.back().native_handle());
}
void JoinThreadCustomized(OrtCustomThreadHandle handle) {
for (auto& t : threads) {
if (reinterpret_cast<OrtCustomThreadHandle>(t.native_handle()) == handle) {
custom_join_hook_called += 1;
t.join();
}
}
}
TEST(CApiTest, TestPerSessionCustomThreadPoolHooks) {
const int32_t thread_count = 3;
Ort::SessionOptions session_options;
// test both intra and inter op thread pool
session_options.SetExecutionMode(ExecutionMode::ORT_PARALLEL);
session_options.SetIntraOpNumThreads(thread_count);
session_options.SetInterOpNumThreads(thread_count);
session_options.SetCustomCreateThreadFn(CreateThreadCustomized);
session_options.SetCustomThreadCreationOptions(&custom_thread_creation_options);
session_options.SetCustomJoinThreadFn(JoinThreadCustomized);
{
Ort::Session session(*ort_env, MODEL_URI, session_options);
}
ASSERT_TRUE(custom_creation_hook_called == (thread_count - 1) << 1);
ASSERT_TRUE(custom_join_hook_called == (thread_count - 1) << 1);
}
} // namespace TestPerSessionCustomThreadHooks
#endif