add ThreadingOptions, wraps OrtThreadingOptions (#13711)

…threadpools' options of The Env.

### Description
<!-- Describe your changes. -->
add a c++ class ThreadingOptions, wraps OrtThreadingOptions
as I described in issue #13710 


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

close #13710

Co-authored-by: zengxiangneng <zengxiangneng@360.cn>
This commit is contained in:
we1559 2023-01-07 03:21:10 +08:00 committed by GitHub
parent babc1323e3
commit c65a03699a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 70 additions and 0 deletions

View file

@ -189,6 +189,7 @@ namespace detail {
ORT_DEFINE_RELEASE(Allocator);
ORT_DEFINE_RELEASE(MemoryInfo);
ORT_DEFINE_RELEASE(CustomOpDomain);
ORT_DEFINE_RELEASE(ThreadingOptions);
ORT_DEFINE_RELEASE(Env);
ORT_DEFINE_RELEASE(RunOptions);
ORT_DEFINE_RELEASE(Session);
@ -340,6 +341,36 @@ struct Status : detail::Base<OrtStatus> {
OrtErrorCode GetErrorCode() const;
};
/** \brief The ThreadingOptions
*
* The ThreadingOptions used for set global threadpools' options of The Env.
*/
struct ThreadingOptions : detail::Base<OrtThreadingOptions> {
/// \brief Wraps OrtApi::CreateThreadingOptions
ThreadingOptions();
/// \brief Wraps OrtApi::SetGlobalIntraOpNumThreads
ThreadingOptions& SetGlobalIntraOpNumThreads(int intra_op_num_threads);
/// \brief Wraps OrtApi::SetGlobalInterOpNumThreads
ThreadingOptions& SetGlobalInterOpNumThreads(int inter_op_num_threads);
/// \brief Wraps OrtApi::SetGlobalSpinControl
ThreadingOptions& SetGlobalSpinControl(int allow_spinning);
/// \brief Wraps OrtApi::SetGlobalDenormalAsZero
ThreadingOptions& SetGlobalDenormalAsZero();
/// \brief Wraps OrtApi::SetGlobalCustomCreateThreadFn
ThreadingOptions& SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn);
/// \brief Wraps OrtApi::SetGlobalCustomThreadCreationOptions
ThreadingOptions& SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options);
/// \brief Wraps OrtApi::SetGlobalCustomJoinThreadFn
ThreadingOptions& SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn);
};
/** \brief The Env (Environment)
*
* The Env holds the logging state used by all other objects.

View file

@ -368,6 +368,45 @@ inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial
ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_));
}
inline ThreadingOptions::ThreadingOptions() {
ThrowOnError(GetApi().CreateThreadingOptions(&p_));
}
inline ThreadingOptions& ThreadingOptions::SetGlobalIntraOpNumThreads(int intra_op_num_threads) {
ThrowOnError(GetApi().SetGlobalIntraOpNumThreads(p_, intra_op_num_threads));
return *this;
}
inline ThreadingOptions& ThreadingOptions::SetGlobalInterOpNumThreads(int inter_op_num_threads) {
ThrowOnError(GetApi().SetGlobalInterOpNumThreads(p_, inter_op_num_threads));
return *this;
}
inline ThreadingOptions& ThreadingOptions::SetGlobalSpinControl(int allow_spinning) {
ThrowOnError(GetApi().SetGlobalSpinControl(p_, allow_spinning));
return *this;
}
inline ThreadingOptions& ThreadingOptions::SetGlobalDenormalAsZero() {
ThrowOnError(GetApi().SetGlobalDenormalAsZero(p_));
return *this;
}
inline ThreadingOptions& ThreadingOptions::SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
ThrowOnError(GetApi().SetGlobalCustomCreateThreadFn(p_, ort_custom_create_thread_fn));
return *this;
}
inline ThreadingOptions& ThreadingOptions::SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
ThrowOnError(GetApi().SetGlobalCustomThreadCreationOptions(p_, ort_custom_thread_creation_options));
return *this;
}
inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
ThrowOnError(GetApi().SetGlobalCustomJoinThreadFn(p_, ort_custom_join_thread_fn));
return *this;
}
inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_));
if (strcmp(logid, "onnxruntime-node") == 0) {