From 21eb747a0fcad6559c4e601697e2f4fa06e2fb27 Mon Sep 17 00:00:00 2001 From: RandySheriffH <48490400+RandySheriffH@users.noreply.github.com> Date: Fri, 12 Nov 2021 19:10:31 -0800 Subject: [PATCH] Custom thread creation and join hooks (#9426) --- .../core/session/onnxruntime_c_api.h | 84 ++++++++++++++++++- .../core/session/onnxruntime_cxx_api.h | 4 + .../core/session/onnxruntime_cxx_inline.h | 15 ++++ onnxruntime/core/framework/session_options.h | 9 ++ onnxruntime/core/platform/env.h | 11 +++ onnxruntime/core/platform/posix/env.cc | 81 +++++++++++------- onnxruntime/core/platform/windows/env.cc | 41 +++++++-- onnxruntime/core/session/inference_session.cc | 16 ++++ onnxruntime/core/session/onnxruntime_c_api.cc | 27 ++++++ onnxruntime/core/session/ort_apis.h | 6 ++ onnxruntime/core/util/thread_utils.cc | 35 ++++++++ onnxruntime/core/util/thread_utils.h | 5 ++ .../test/global_thread_pools/test_main.cc | 50 ++++++++++- onnxruntime/test/shared_lib/test_inference.cc | 47 +++++++++++ 14 files changed, 387 insertions(+), 44 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index aff4f7e6ba..b8bfdfc2e5 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -527,6 +527,29 @@ typedef struct OrtApiBase OrtApiBase; */ ORT_EXPORT const OrtApiBase* ORT_API_CALL OrtGetApiBase(void) NO_EXCEPTION; +/** \brief Thread work loop function +* +* Onnxruntime will provide the working loop on custom thread creation +* Argument is an onnxruntime built-in type which will be provided when thread pool calls OrtCustomCreateThreadFn +*/ +typedef void (*OrtThreadWorkerFn)(void* ort_worker_fn_param); + +typedef const struct OrtCustomHandleType{ char __place_holder; }* OrtCustomThreadHandle; + +/** \brief Ort custom thread creation function +* +* The function should return a thread handle to be used in onnxruntime thread pools +* Onnxruntime will throw exception on return value of nullptr or 0, indicating that the function failed to create a thread +*/ +typedef OrtCustomThreadHandle (*OrtCustomCreateThreadFn)(void* ort_custom_thread_creation_options, OrtThreadWorkerFn ort_thread_worker_fn, void* ort_worker_fn_param); + +/** \brief Custom thread join function +* +* Onnxruntime thread pool destructor will call the function to join a custom thread. +* Argument ort_custom_thread_handle is the value returned by OrtCustomCreateThreadFn +*/ +typedef void (*OrtCustomJoinThreadFn)(OrtCustomThreadHandle ort_custom_thread_handle); + /** \brief The C API * * All C API functions are defined inside this structure as pointers to functions. @@ -3037,7 +3060,6 @@ struct OrtApi { * \snippet{doc} snippets.dox OrtStatus Return Value */ ORT_API2_STATUS(HasValue, _In_ const OrtValue* value, _Out_ int* out); - /// @} /// \name OrtKernelContext @@ -3053,6 +3075,66 @@ struct OrtApi { */ ORT_API2_STATUS(KernelContext_GetGPUComputeStream, _In_ const OrtKernelContext* context, _Outptr_ void** out); /// @} + + /// \name SessionOptions + /// @{ + /** \brief Set custom thread creation function + * + * \param[in] session options + * \param[in] custom thread creation function + * + * * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsSetCustomCreateThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn); + + /** \brief Set creation options for custom thread + * + * \param[in] session options + * \param[in] custom thread creation options (can be nullptr) + * + * * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsSetCustomThreadCreationOptions, _Inout_ OrtSessionOptions* options, _In_ void* ort_custom_thread_creation_options); + + /** \brief Set custom thread join function + * + * \param[in] session options + * \param[in] custom join thread function, must not be nullptr when ort_custom_create_thread_fn is set + * + * * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn); + /// @} + + /// \name OrtThreadingOptions + /// @{ + /** \brief Set custom thread creation function for global thread pools + * + * \param[inout] tp_options + * \param[in] custom thread creation function + * + * * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetGlobalCustomCreateThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn); + + /** \brief Set custom thread creation options for global thread pools + * + * \param[inout] tp_options + * \param[in] custom thread creation options (can be nullptr) + * + * * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetGlobalCustomThreadCreationOptions, _Inout_ OrtThreadingOptions* tp_options, _In_ void* ort_custom_thread_creation_options); + + /** \brief Set custom thread join function for global thread pools + * + * \param[inout] tp_options + * \param[in] custom thread join function, must not be nullptr when global ort_custom_create_thread_fn is set + * + * * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetGlobalCustomJoinThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn); + /// @} }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index fc8b1eeeeb..28f6c79745 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -351,6 +351,10 @@ struct SessionOptions : Base { SessionOptions& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO SessionOptions& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT + + SessionOptions& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn + SessionOptions& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions + SessionOptions& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn }; /** \brief Wrapper around ::OrtModelMetadata diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 928d207831..dd1a2521a5 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -510,6 +510,21 @@ inline SessionOptions& SessionOptions::AppendExecutionProvider_TensorRT(const Or return *this; } +inline SessionOptions& SessionOptions::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) { + ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(p_, ort_custom_create_thread_fn)); + return *this; +} + +inline SessionOptions& SessionOptions::SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options) { + ThrowOnError(GetApi().SessionOptionsSetCustomThreadCreationOptions(p_, ort_custom_thread_creation_options)); + return *this; +} + +inline SessionOptions& SessionOptions::SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) { + ThrowOnError(GetApi().SessionOptionsSetCustomJoinThreadFn(p_, ort_custom_join_thread_fn)); + return *this; +} + inline SessionOptions& SessionOptions::AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options) { ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(p_, &provider_options)); return *this; diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index ff06d0f7b9..2f17779238 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -119,6 +119,15 @@ struct SessionOptions { // See onnxruntime_c_api.h for detailed documentation. Status AddInitializer(_In_z_ const char* name, _In_ const OrtValue* val) noexcept; + + // custom function callback to create a thread + OrtCustomCreateThreadFn custom_create_thread_fn = nullptr; + + // custom options to pass to custom_create_thread_fn + void* custom_thread_creation_options = nullptr; + + // custom function callback to join a thread + OrtCustomJoinThreadFn custom_join_thread_fn = nullptr; }; } // namespace onnxruntime diff --git a/onnxruntime/core/platform/env.h b/onnxruntime/core/platform/env.h index 3beec3a807..88ff67ace7 100644 --- a/onnxruntime/core/platform/env.h +++ b/onnxruntime/core/platform/env.h @@ -28,6 +28,7 @@ limitations under the License. #include "core/framework/callback.h" #include "core/platform/env_time.h" #include "core/platform/telemetry.h" +#include "core/session/onnxruntime_c_api.h" #ifndef _WIN32 #include @@ -49,6 +50,12 @@ using FileOffsetType = off_t; class EnvThread { public: virtual ~EnvThread() = default; + + protected: + OrtCustomCreateThreadFn custom_create_thread_fn = nullptr; + void* custom_thread_creation_options = nullptr; + OrtCustomJoinThreadFn custom_join_thread_fn = nullptr; + OrtCustomThreadHandle custom_thread_handle = nullptr; }; // Parameters that are required to create a set of threads for a thread pool @@ -67,6 +74,10 @@ struct ThreadOptions { // Set or unset denormal as zero. bool set_denormal_as_zero = false; + + OrtCustomCreateThreadFn custom_create_thread_fn = nullptr; + void* custom_thread_creation_options = nullptr; + OrtCustomJoinThreadFn custom_join_thread_fn = nullptr; }; /// \brief An interface used by the onnxruntime implementation to /// access operating system functionality like the filesystem etc. diff --git a/onnxruntime/core/platform/posix/env.cc b/onnxruntime/core/platform/posix/env.cc index c63997fc04..cf91759e4a 100644 --- a/onnxruntime/core/platform/posix/env.cc +++ b/onnxruntime/core/platform/posix/env.cc @@ -143,47 +143,63 @@ class PosixThread : public EnvThread { PosixThread(const ORTCHAR_T* name_prefix, int index, unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param), Eigen::ThreadPoolInterface* param, const ThreadOptions& thread_options) { - pthread_attr_t attr; - int s = pthread_attr_init(&attr); - if (s != 0) { - auto[err_no, err_msg] = GetSystemError(); - ORT_THROW("pthread_attr_init failed, error code: ", err_no, " error msg: ", err_msg); - } - if (thread_options.stack_size > 0) { - s = pthread_attr_setstacksize(&attr, thread_options.stack_size); - if (s != 0) { - auto[err_no, err_msg] = GetSystemError(); - ORT_THROW("pthread_attr_setstacksize failed, error code: ", err_no, " error msg: ", err_msg); + custom_create_thread_fn = thread_options.custom_create_thread_fn; + custom_thread_creation_options = thread_options.custom_thread_creation_options; + custom_join_thread_fn = thread_options.custom_join_thread_fn; + + if (custom_create_thread_fn) { + custom_thread_handle = custom_create_thread_fn(custom_thread_creation_options, CustomThreadMain, new Param{name_prefix, index, start_address, param, thread_options}); + if (!custom_thread_handle) { + ORT_THROW("custom_create_thread_fn returned invalid handle."); + } + } else { + pthread_attr_t attr; + int s = pthread_attr_init(&attr); + if (s != 0) { + auto [err_no, err_msg] = GetSystemError(); + ORT_THROW("pthread_attr_init failed, error code: ", err_no, " error msg: ", err_msg); + } + if (thread_options.stack_size > 0) { + s = pthread_attr_setstacksize(&attr, thread_options.stack_size); + if (s != 0) { + auto [err_no, err_msg] = GetSystemError(); + ORT_THROW("pthread_attr_setstacksize failed, error code: ", err_no, " error msg: ", err_msg); + } + } + s = pthread_create(&hThread, &attr, ThreadMain, + new Param{name_prefix, index, start_address, param, thread_options}); + if (s != 0) { + auto [err_no, err_msg] = GetSystemError(); + ORT_THROW("pthread_create failed, error code: ", err_no, " error msg: ", err_msg); } - } - s = pthread_create(&hThread, &attr, ThreadMain, - new Param{name_prefix, index, start_address, param, thread_options}); - if (s != 0) { - auto[err_no, err_msg] = GetSystemError(); - ORT_THROW("pthread_create failed, error code: ", err_no, " error msg: ", err_msg); - } #if !defined(__APPLE__) && !defined(__ANDROID__) && !defined(__wasm__) - if (!thread_options.affinity.empty()) { - cpu_set_t cpuset; - CPU_ZERO(&cpuset); - CPU_SET(thread_options.affinity[index], &cpuset); - s = pthread_setaffinity_np(hThread, sizeof(cpu_set_t), &cpuset); - if (s != 0) { - auto[err_no, err_msg] = GetSystemError(); - ORT_THROW("pthread_setaffinity_np failed, error code: ", err_no, " error msg: ", err_msg); + if (!thread_options.affinity.empty()) { + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + CPU_SET(thread_options.affinity[index], &cpuset); + s = pthread_setaffinity_np(hThread, sizeof(cpu_set_t), &cpuset); + if (s != 0) { + auto [err_no, err_msg] = GetSystemError(); + ORT_THROW("pthread_setaffinity_np failed, error code: ", err_no, " error msg: ", err_msg); + } } - } #endif + } } ~PosixThread() override { - void* res; + if (custom_thread_handle) { + custom_join_thread_fn(custom_thread_handle); + custom_thread_handle = nullptr; + } else { + void* res; #ifdef NDEBUG - pthread_join(hThread, &res); + pthread_join(hThread, &res); #else - int ret = pthread_join(hThread, &res); - assert(ret == 0); + int ret = pthread_join(hThread, &res); + assert(ret == 0); #endif + } } private: @@ -198,6 +214,9 @@ class PosixThread : public EnvThread { } return nullptr; } + static void CustomThreadMain(void* param) { + ThreadMain(param); + } pthread_t hThread; }; diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index 7879711955..1f272aa4c6 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -59,18 +59,33 @@ class WindowsThread : public EnvThread { public: WindowsThread(const ORTCHAR_T* name_prefix, int index, unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param), Eigen::ThreadPoolInterface* param, - const ThreadOptions& thread_options) - : hThread((HANDLE)_beginthreadex(nullptr, thread_options.stack_size, ThreadMain, - new Param{name_prefix, index, start_address, param, thread_options}, 0, - &threadID)) { + const ThreadOptions& thread_options) { + custom_create_thread_fn = thread_options.custom_create_thread_fn; + custom_thread_creation_options = thread_options.custom_thread_creation_options; + custom_join_thread_fn = thread_options.custom_join_thread_fn; + + if (custom_create_thread_fn) { + custom_thread_handle = custom_create_thread_fn(custom_thread_creation_options, (OrtThreadWorkerFn)CustomThreadMain, new Param{name_prefix, index, start_address, param, thread_options}); + if (!custom_thread_handle) { + ORT_THROW("custom_create_thread_fn returned invalid handle."); + } + } else { + hThread.reset(reinterpret_cast(_beginthreadex(nullptr, thread_options.stack_size, ThreadMain, + new Param{name_prefix, index, start_address, param, thread_options}, 0, + &threadID))); + } } ~WindowsThread() { - DWORD waitStatus = WaitForSingleObject(hThread.get(), INFINITE); - FAIL_FAST_LAST_ERROR_IF(waitStatus == WAIT_FAILED); + if (custom_thread_handle) { + custom_join_thread_fn(custom_thread_handle); + custom_thread_handle = nullptr; + } else { + DWORD waitStatus = WaitForSingleObject(hThread.get(), INFINITE); + FAIL_FAST_LAST_ERROR_IF(waitStatus == WAIT_FAILED); + } } - private: typedef HRESULT(WINAPI* SetThreadDescriptionFunc)(HANDLE hThread, PCWSTR lpThreadDescription); @@ -100,7 +115,6 @@ class WindowsThread : public EnvThread { // Ignore the error (void)pSetThrDesc(GetCurrentThread(), oss.str().c_str()); } - unsigned ret = 0; ORT_TRY { ret = p->start_address(p->index, p->param); @@ -113,6 +127,15 @@ class WindowsThread : public EnvThread { } #pragma warning(pop) + static void __stdcall CustomThreadMain(void* param) { + std::unique_ptr p((Param*)param); + ORT_TRY { + p->start_address(p->index, p->param); + } + ORT_CATCH(const std::exception&) { + p->param->Cancel(); + } + } unsigned threadID = 0; wil::unique_handle hThread; }; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 0bd5349615..31c09e925e 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -296,6 +296,14 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL && to.affinity_vec_len == 0; to.allow_spinning = allow_intra_op_spinning; + + // Set custom threading functions + to.custom_create_thread_fn = session_options_.custom_create_thread_fn; + to.custom_thread_creation_options = session_options.custom_thread_creation_options; + to.custom_join_thread_fn = session_options_.custom_join_thread_fn; + if (to.custom_create_thread_fn) { + ORT_ENFORCE(to.custom_join_thread_fn, "custom join thread function not set for intra op thread pool"); + } thread_pool_ = concurrency::CreateThreadPool(&Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP); } @@ -316,6 +324,14 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, to.name = inter_thread_pool_name_.c_str(); to.set_denormal_as_zero = set_denormal_as_zero; to.allow_spinning = allow_inter_op_spinning; + + // Set custom threading functions + to.custom_create_thread_fn = session_options_.custom_create_thread_fn; + to.custom_thread_creation_options = session_options.custom_thread_creation_options; + to.custom_join_thread_fn = session_options_.custom_join_thread_fn; + if (to.custom_create_thread_fn) { + ORT_ENFORCE(to.custom_join_thread_fn, "custom join thread function not set for inter op thread pool"); + } inter_op_thread_pool_ = concurrency::CreateThreadPool(&Env::Default(), to, concurrency::ThreadPoolType::INTER_OP); if (inter_op_thread_pool_ == nullptr) { diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 56e49efe30..be11a37c4e 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2082,6 +2082,27 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArrayWithPrepackedWeightsContainer API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetCustomCreateThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn) { + API_IMPL_BEGIN + options->value.custom_create_thread_fn = ort_custom_create_thread_fn; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetCustomThreadCreationOptions, _Inout_ OrtSessionOptions* options, _In_ void* ort_custom_thread_creation_options) { + API_IMPL_BEGIN + options->value.custom_thread_creation_options = ort_custom_thread_creation_options; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn) { + API_IMPL_BEGIN + options->value.custom_join_thread_fn = ort_custom_join_thread_fn; + return nullptr; + API_IMPL_END +} + static constexpr OrtApiBase ort_api_base = { &OrtApis::GetApi, &OrtApis::GetVersionString, @@ -2358,6 +2379,12 @@ static constexpr OrtApi ort_api_1_to_10 = { // Version 10 - In development, feel free to add/remove/rearrange here &OrtApis::HasValue, &OrtApis::KernelContext_GetGPUComputeStream, + &OrtApis::SessionOptionsSetCustomCreateThreadFn, + &OrtApis::SessionOptionsSetCustomThreadCreationOptions, + &OrtApis::SessionOptionsSetCustomJoinThreadFn, + &OrtApis::SetGlobalCustomCreateThreadFn, + &OrtApis::SetGlobalCustomThreadCreationOptions, + &OrtApis::SetGlobalCustomJoinThreadFn, }; // Asserts to do a some checks to ensure older Versions of the OrtApi never change (will detect an addition or deletion but not if they cancel out each other) diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index bf9f023dc3..561870d235 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -317,5 +317,11 @@ ORT_API_STATUS_IMPL(GetSparseTensorValues, _In_ const OrtValue* ort_value, _Outp ORT_API_STATUS_IMPL(GetSparseTensorIndicesTypeShape, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Outptr_ OrtTensorTypeAndShapeInfo** out); ORT_API_STATUS_IMPL(GetSparseTensorIndices, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Out_ size_t* num_indices, _Outptr_ const void** indices); ORT_API_STATUS_IMPL(KernelContext_GetGPUComputeStream, _In_ const OrtKernelContext* context, _Outptr_ void** out); +ORT_API_STATUS_IMPL(SessionOptionsSetCustomCreateThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn); +ORT_API_STATUS_IMPL(SessionOptionsSetCustomThreadCreationOptions, _Inout_ OrtSessionOptions* options, _In_ void* ort_custom_thread_creation_options); +ORT_API_STATUS_IMPL(SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn); +ORT_API_STATUS_IMPL(SetGlobalCustomCreateThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn); +ORT_API_STATUS_IMPL(SetGlobalCustomThreadCreationOptions, _Inout_ OrtThreadingOptions* tp_options, _In_ void* ort_custom_thread_creation_options); +ORT_API_STATUS_IMPL(SetGlobalCustomJoinThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn); } // namespace OrtApis diff --git a/onnxruntime/core/util/thread_utils.cc b/onnxruntime/core/util/thread_utils.cc index 372b933a25..ab87422cd6 100644 --- a/onnxruntime/core/util/thread_utils.cc +++ b/onnxruntime/core/util/thread_utils.cc @@ -28,6 +28,14 @@ CreateThreadPoolHelper(Env* env, OrtThreadPoolParams options) { } to.set_denormal_as_zero = options.set_denormal_as_zero; + // set custom thread management members + to.custom_create_thread_fn = options.custom_create_thread_fn; + to.custom_thread_creation_options = options.custom_thread_creation_options; + to.custom_join_thread_fn = options.custom_join_thread_fn; + if (to.custom_create_thread_fn) { + ORT_ENFORCE(to.custom_join_thread_fn, "custom join thread function not set"); + } + return std::make_unique(env, to, options.name, options.thread_pool_size, options.allow_spinning); } @@ -99,4 +107,31 @@ ORT_API_STATUS_IMPL(SetGlobalDenormalAsZero, _Inout_ OrtThreadingOptions* tp_opt return nullptr; } +ORT_API_STATUS_IMPL(SetGlobalCustomCreateThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn) { + if (!tp_options) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Received null OrtThreadingOptions"); + } + tp_options->inter_op_thread_pool_params.custom_create_thread_fn = ort_custom_create_thread_fn; + tp_options->intra_op_thread_pool_params.custom_create_thread_fn = ort_custom_create_thread_fn; + return nullptr; +} + +ORT_API_STATUS_IMPL(SetGlobalCustomThreadCreationOptions, _Inout_ OrtThreadingOptions* tp_options, _In_ void* ort_custom_thread_creation_options) { + if (!tp_options) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Received null OrtThreadingOptions"); + } + tp_options->inter_op_thread_pool_params.custom_thread_creation_options = ort_custom_thread_creation_options; + tp_options->intra_op_thread_pool_params.custom_thread_creation_options = ort_custom_thread_creation_options; + return nullptr; +} + +ORT_API_STATUS_IMPL(SetGlobalCustomJoinThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn) { + if (!tp_options) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Received null OrtThreadingOptions"); + } + tp_options->inter_op_thread_pool_params.custom_join_thread_fn = ort_custom_join_thread_fn; + tp_options->intra_op_thread_pool_params.custom_join_thread_fn = ort_custom_join_thread_fn; + return nullptr; +} + } // namespace OrtApis diff --git a/onnxruntime/core/util/thread_utils.h b/onnxruntime/core/util/thread_utils.h index d8edd896d8..2780caae82 100644 --- a/onnxruntime/core/util/thread_utils.h +++ b/onnxruntime/core/util/thread_utils.h @@ -28,6 +28,11 @@ struct OrtThreadPoolParams { // Set or unset denormal as zero bool set_denormal_as_zero = false; + + // members to manage custom threads + OrtCustomCreateThreadFn custom_create_thread_fn = nullptr; + void* custom_thread_creation_options = nullptr; + OrtCustomJoinThreadFn custom_join_thread_fn = nullptr; }; struct OrtThreadingOptions { diff --git a/onnxruntime/test/global_thread_pools/test_main.cc b/onnxruntime/test/global_thread_pools/test_main.cc index a0c6f162e9..6b39abfa55 100644 --- a/onnxruntime/test/global_thread_pools/test_main.cc +++ b/onnxruntime/test/global_thread_pools/test_main.cc @@ -26,13 +26,42 @@ std::unique_ptr ort_env; return -1; \ } +namespace TestGlobalCustomThreadHooks { + +std::vector threads; +int32_t custom_thread_creation_options = 5; +int32_t custom_creation_hook_called = 0; +int32_t custom_join_hook_called = 0; + +OrtCustomThreadHandle CreateThreadCustomized(void* options, OrtThreadWorkerFn work_loop, void* param) { + if (*((int32_t*)options) == 5) { + custom_creation_hook_called += 1; + } + threads.push_back(std::thread(work_loop, param)); + return reinterpret_cast(threads.back().native_handle()); +} + +void JoinThreadCustomized(OrtCustomThreadHandle handle) { + for (auto& t : threads) { + if (reinterpret_cast(t.native_handle()) == handle) { + custom_join_hook_called += 1; + t.join(); + } + } +} + +} // namespace TestGlobalCustomThreadHooks + +using namespace TestGlobalCustomThreadHooks; + int main(int argc, char** argv) { int status = 0; + const int thread_pool_size = std::thread::hardware_concurrency(); ORT_TRY { ::testing::InitGoogleTest(&argc, argv); const OrtApi* g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); - OrtThreadingOptions* tp_options; std::unique_ptr st_ptr(nullptr, g_ort->ReleaseStatus); + OrtThreadingOptions* tp_options; st_ptr.reset(g_ort->CreateThreadingOptions(&tp_options)); ORT_RETURN_IF_NON_NULL_STATUS(st_ptr); @@ -40,10 +69,19 @@ int main(int argc, char** argv) { st_ptr.reset(g_ort->SetGlobalSpinControl(tp_options, 0)); ORT_RETURN_IF_NON_NULL_STATUS(st_ptr); - st_ptr.reset(g_ort->SetGlobalIntraOpNumThreads(tp_options, std::thread::hardware_concurrency())); + st_ptr.reset(g_ort->SetGlobalIntraOpNumThreads(tp_options, thread_pool_size)); ORT_RETURN_IF_NON_NULL_STATUS(st_ptr); - st_ptr.reset(g_ort->SetGlobalInterOpNumThreads(tp_options, std::thread::hardware_concurrency())); + st_ptr.reset(g_ort->SetGlobalCustomCreateThreadFn(tp_options, CreateThreadCustomized)); + ORT_RETURN_IF_NON_NULL_STATUS(st_ptr); + + st_ptr.reset(g_ort->SetGlobalCustomThreadCreationOptions(tp_options, &custom_thread_creation_options)); + ORT_RETURN_IF_NON_NULL_STATUS(st_ptr); + + st_ptr.reset(g_ort->SetGlobalCustomJoinThreadFn(tp_options, JoinThreadCustomized)); + ORT_RETURN_IF_NON_NULL_STATUS(st_ptr); + + st_ptr.reset(g_ort->SetGlobalInterOpNumThreads(tp_options, thread_pool_size)); ORT_RETURN_IF_NON_NULL_STATUS(st_ptr); st_ptr.reset(g_ort->SetGlobalDenormalAsZero(tp_options)); @@ -63,6 +101,12 @@ int main(int argc, char** argv) { //TODO: Fix the C API issue ort_env.reset(); //If we don't do this, it will crash +#ifndef _OPENMP + const int expexted_custom_calls = (thread_pool_size - 1) << 1; + ORT_ENFORCE(custom_creation_hook_called == expexted_custom_calls, "custom thread creation function was not called as expected"); + ORT_ENFORCE(custom_join_hook_called == expexted_custom_calls, "custom thread join function was not called as expected"); +#endif + #ifndef USE_ONNXRUNTIME_DLL //make memory leak checker happy ::google::protobuf::ShutdownProtobufLibrary(); diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 337ab29aa4..3aca885ea2 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -9,6 +9,7 @@ #include #include #include +#include #include @@ -18,6 +19,7 @@ #include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_run_options_config_keys.h" +#include "core/util/thread_utils.h" #include "providers.h" #include "test_allocator.h" #include "test_fixture.h" @@ -1836,3 +1838,48 @@ TEST(CApiTest, TestConfigureTensorRTProviderOptions) { ASSERT_TRUE(stat(engine_cache_path, &buffer) == 0); } #endif + +#ifndef _OPENMP +namespace TestPerSessionCustomThreadHooks { + +std::vector threads; +int32_t custom_thread_creation_options = 5; +int32_t custom_creation_hook_called = 0; +int32_t custom_join_hook_called = 0; + +OrtCustomThreadHandle CreateThreadCustomized(void* options, OrtThreadWorkerFn work_loop, void* param) { + if (*((int32_t*)options) == 5) { + custom_creation_hook_called += 1; + } + threads.push_back(std::thread(work_loop, param)); + return reinterpret_cast(threads.back().native_handle()); +} + +void JoinThreadCustomized(OrtCustomThreadHandle handle) { + for (auto& t : threads) { + if (reinterpret_cast(t.native_handle()) == handle) { + custom_join_hook_called += 1; + t.join(); + } + } +} + +TEST(CApiTest, TestPerSessionCustomThreadPoolHooks) { + const int32_t thread_count = 3; + Ort::SessionOptions session_options; + // test both intra and inter op thread pool + session_options.SetExecutionMode(ExecutionMode::ORT_PARALLEL); + session_options.SetIntraOpNumThreads(thread_count); + session_options.SetInterOpNumThreads(thread_count); + session_options.SetCustomCreateThreadFn(CreateThreadCustomized); + session_options.SetCustomThreadCreationOptions(&custom_thread_creation_options); + session_options.SetCustomJoinThreadFn(JoinThreadCustomized); + { + Ort::Session session(*ort_env, MODEL_URI, session_options); + } + ASSERT_TRUE(custom_creation_hook_called == (thread_count - 1) << 1); + ASSERT_TRUE(custom_join_hook_called == (thread_count - 1) << 1); +} + +} // namespace TestPerSessionCustomThreadHooks +#endif