diff --git a/docs/NotesOnThreading.md b/docs/NotesOnThreading.md new file mode 100644 index 0000000000..df148ff4eb --- /dev/null +++ b/docs/NotesOnThreading.md @@ -0,0 +1,18 @@ +# Notes on Threading in ORT + +This document is intended for ORT developers. + +ORT allows the usage of either OpenMP or non-OpenMP (ORT) threads for execution. Threadpool management +is abstracted behind: (1) ThreadPool class in threadpool.h and (2) functions in thread_utils.h. + +When developing an op, please use these abstractions to parallelize your code. These abstractions centralize 2 things. +When OpenMP is enabled, they resort to using OpenMP. When OpenMP is disabled they resort to sequential execution if the threadpool ptr is NULL or schedule the tasks on the threadpool otherwise. + +Examples of these abstractions are: (threadpool.h has more documentation for these) +* TryBatchParallelFor +* TryParallelFor +* static version of NumThreads + +**Please do not write #ifdef pragma omp in operator code**. + +For intra op parallelism ORT users can use either OpenMP or ORT threadpool. The choice of using OpenMP is indicated by building ORT with ```--use_openmp``` switch. For inter op parallelism, however, we always use the ORT threadpool. diff --git a/include/onnxruntime/core/platform/threadpool.h b/include/onnxruntime/core/platform/threadpool.h index 2e9ae6cdf4..6cd9d81df3 100644 --- a/include/onnxruntime/core/platform/threadpool.h +++ b/include/onnxruntime/core/platform/threadpool.h @@ -158,14 +158,12 @@ class ThreadPool { const std::function& fn); static void TryParallelFor(concurrency::ThreadPool* tp, std::ptrdiff_t total, double cost_per_unit, const std::function& fn) { - if (tp == nullptr) { - fn(0, total); - return; - } - tp->ParallelFor(total, cost_per_unit, fn); + TryParallelFor(tp, total, TensorOpCost{0, 0, static_cast(cost_per_unit)}, fn); } + void ParallelFor(std::ptrdiff_t total, const TensorOpCost& cost_per_unit, const std::function& fn); + static void TryParallelFor(concurrency::ThreadPool* tp, std::ptrdiff_t total, const TensorOpCost& cost_per_unit, const std::function& fn) { if (tp == nullptr) { @@ -174,10 +172,12 @@ class ThreadPool { } tp->ParallelFor(total, cost_per_unit, fn); } + // Similar to ParallelFor above, but takes the specified scheduling strategy // into account. - void ParallelFor(std::ptrdiff_t total, const SchedulingParams& scheduling_params, - const std::function& fn); + void + ParallelFor(std::ptrdiff_t total, const SchedulingParams& scheduling_params, + const std::function& fn); static void TryParallelFor(concurrency::ThreadPool* tp, std::ptrdiff_t total, const SchedulingParams& scheduling_params, const std::function& fn) { @@ -187,7 +187,12 @@ class ThreadPool { } tp->ParallelFor(total, scheduling_params, fn); } - // Returns the number of threads in the pool. + + // Prefer using this API to get the number of threads unless you know what you're doing. + // This API takes into account if openmp is enabled/disabled and if the thread pool ptr is nullptr. + static int NumThreads(const concurrency::ThreadPool* tp); + + // Returns the number of threads in the pool. Preferably use the static version of this API instead. int NumThreads() const; // Returns current thread id between 0 and NumThreads() - 1, if called from a diff --git a/onnxruntime/core/common/threadpool.cc b/onnxruntime/core/common/threadpool.cc index b9a70c5d74..7afeb65508 100644 --- a/onnxruntime/core/common/threadpool.cc +++ b/onnxruntime/core/common/threadpool.cc @@ -298,6 +298,15 @@ void ThreadPool::ParallelFor(std::ptrdiff_t total, double cost_per_unit, ParallelFor(total, TensorOpCost{0, 0, static_cast(cost_per_unit)}, fn); } +int ThreadPool::NumThreads(const concurrency::ThreadPool* tp) { +#ifdef _OPENMP + ORT_UNUSED_PARAMETER(tp); + return (omp_get_num_threads() == 1) ? omp_get_max_threads() : 1; +#else + return tp ? tp->NumThreads() : 1; +#endif +} + int ThreadPool::NumThreads() const { return underlying_threadpool_->NumThreads(); } diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 5128c7a7f0..e09f4f4531 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -97,6 +97,10 @@ Abstract: // // Select the threading model. // +// N.B. MLAS_NO_ONNXRUNTIME_THREADPOOL is used to build MLAS test code outside +// of the ONNX Runtime source tree. OpenMP may or may not be enabled in this +// configuration. +// #if !defined(MLAS_NO_ONNXRUNTIME_THREADPOOL) #include "core/platform/threadpool.h" @@ -666,19 +670,17 @@ MlasGetMaximumThreadCount( MLAS_THREADPOOL* ThreadPool ) { -#ifdef MLAS_NO_ONNXRUNTIME_THREADPOOL +#if defined(MLAS_NO_ONNXRUNTIME_THREADPOOL) MLAS_UNREFERENCED_PARAMETER(ThreadPool); -#else - if (ThreadPool != nullptr) { - return ThreadPool->NumThreads(); - } -#endif #if defined(_OPENMP) return (omp_get_num_threads() == 1) ? omp_get_max_threads() : 1; #else return 1; #endif +#else + return onnxruntime::concurrency::ThreadPool::NumThreads(ThreadPool); +#endif } inline diff --git a/onnxruntime/core/platform/env.h b/onnxruntime/core/platform/env.h index 90019b9402..55e0a943d8 100644 --- a/onnxruntime/core/platform/env.h +++ b/onnxruntime/core/platform/env.h @@ -113,6 +113,9 @@ class Env { virtual int GetNumCpuCores() const = 0; + // This function doesn't support systems with more than 64 logical processors + virtual std::vector GetThreadAffinityMasks() const = 0; + /// \brief Returns the number of micro-seconds since the Unix epoch. virtual uint64_t NowMicros() const { return env_time_->NowMicros(); diff --git a/onnxruntime/core/platform/posix/env.cc b/onnxruntime/core/platform/posix/env.cc index 6f244fbf5c..a908aacb6f 100644 --- a/onnxruntime/core/platform/posix/env.cc +++ b/onnxruntime/core/platform/posix/env.cc @@ -120,7 +120,7 @@ class PosixThread : public EnvThread { if (s != 0) ORT_THROW("pthread_setaffinity_np failed"); } - #endif +#endif } ~PosixThread() override { @@ -178,6 +178,12 @@ class PosixEnv : public Env { return std::thread::hardware_concurrency(); } + std::vector GetThreadAffinityMasks() const override { + std::vector ret(std::thread::hardware_concurrency() / 2); + std::iota(ret.begin(), ret.end(), 0); + return ret; + } + void SleepForMicroseconds(int64_t micros) const override { while (micros > 0) { timespec sleep_time; diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index 5e762b7774..73e51cd266 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -141,6 +141,30 @@ class WindowsEnv : public Env { return processorCoreCount; } + std::vector GetThreadAffinityMasks() const override { + auto generate_vector_of_n = [](int n) { + std::vector ret(n); + std::iota(ret.begin(), ret.end(), 0); + return ret; + }; + // Indeed 64 should be enough. However, it's harmless to have a little more. + SYSTEM_LOGICAL_PROCESSOR_INFORMATION buffer[256]; + DWORD returnLength = sizeof(buffer); + if (GetLogicalProcessorInformation(buffer, &returnLength) == FALSE) { + return generate_vector_of_n(std::thread::hardware_concurrency()); + } + std::vector ret; + int count = (int)(returnLength / sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION)); + for (int i = 0; i != count; ++i) { + if (buffer[i].Relationship == RelationProcessorCore) { + ret.push_back(buffer[i].ProcessorMask); + } + } + if (ret.empty()) + return generate_vector_of_n(std::thread::hardware_concurrency()); + return ret; + } + static WindowsEnv& Instance() { static WindowsEnv default_env; return default_env; diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc index 1906ec2ea1..4a197a59e0 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc @@ -163,7 +163,7 @@ Equations (Default: f=Sigmoid, g=Tanh, h=Tanh): namespace onnxruntime { template -static inline void ExecuteLambdaInParallel(TLambda lambda, int max, int step, +static inline void ExecuteLambdaInParallel(TLambda lambda, int max, int step, double cost, onnxruntime::concurrency::ThreadPool* ttp) { // #define NOTHREADS to execute the lambdas directly and in order if you need to do that to debug @@ -174,67 +174,12 @@ static inline void ExecuteLambdaInParallel(TLambda lambda, int max, int step, std::bind(lambda, i)(); } #else - - // ORT_ENFORCE may and does throw at times from within the tasks that run - // on a thread-pool. Without propagating exceptions the process exits silently - // which will make diagnosing bugs more difficult. - - // \! UGLY - // We have a problem here with the current thread-pool is that it takes std::function - // by value and copies it more than once (even though it is movable). - // - // To report status and exceptions properly it's better to use - // futures and promises but they are not copyable, so we can't come up with a functor - // with a promise member and we are downgrading to C++11 where we can't have captures that moved in. - // - // At the same time promises MUST live in the child thread so if we throw from the main thread - // we don't destroy any promises that are on the main thread stack which children threads may still be using. - // - // The only solution with the current Eigen that comes to mind is to have shared_ptr to with std::promise. - // const int total_tasks = max / (step > 0 ? step : 1) + (max % step > 0 ? 1 : 0); - std::vector > futures; - futures.reserve(total_tasks); - - if (ttp != nullptr) { - for (int i = 0, t = 0; i < max; i += step, ++t) { - auto p_ptr = std::make_shared >(); - futures.push_back(p_ptr->get_future()); - ttp->Schedule([p_ptr, lambda, i]() { - try { - lambda(i); - p_ptr->set_value(); - } catch (...) { - p_ptr->set_exception(std::current_exception()); - } - }); + concurrency::ThreadPool::TryParallelFor(ttp, total_tasks, cost, [lambda, step](ptrdiff_t first, ptrdiff_t last) { + for (int i = static_cast(first), end = static_cast(last); i < end; ++i) { + lambda(i * step); } - - // We'd like to wait until all of the tasks have finished - // even though one or more have already thrown. We will store - // the first exception and then will re-throw at the end. - std::exception_ptr pending_exception; - for (auto& fut : futures) { - try { - // get() will re-throw any exceptions - // the running task may throw - fut.get(); - } catch (...) { - if (!pending_exception) { - pending_exception = std::current_exception(); - } - } - } - - if (pending_exception) { - std::rethrow_exception(pending_exception); - } - } else { - for (int i = 0; i < max; i += step) { - std::bind(lambda, i)(); - } - } - + }); #endif } @@ -488,8 +433,7 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const { const size_t last_cell_size_per_direction = batch_size * hidden_size_; IAllocatorUniquePtr local_last_cell; - gsl::span last_cell = Y_c ? Y_c->MutableDataAsSpan() : - Allocate(alloc, last_cell_size_per_direction * num_directions_, local_last_cell); + gsl::span last_cell = Y_c ? Y_c->MutableDataAsSpan() : Allocate(alloc, last_cell_size_per_direction * num_directions_, local_last_cell); gsl::span last_cell_1 = last_cell.subspan(0, last_cell_size_per_direction); @@ -501,17 +445,12 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const { recurrent_weights.subspan(hidden_weights_size_per_direction, hidden_weights_size_per_direction); gsl::span bias_2 = bias.empty() ? bias : bias.subspan(bias_size_per_direction, bias_size_per_direction); gsl::span peephole_weights_2 = - peephole_weights.empty() ? - peephole_weights : - peephole_weights.subspan(peephole_weights_size_per_direction, peephole_weights_size_per_direction); + peephole_weights.empty() ? peephole_weights : peephole_weights.subspan(peephole_weights_size_per_direction, peephole_weights_size_per_direction); gsl::span initial_hidden_2 = - initial_hidden.empty() ? - initial_hidden : - initial_hidden.subspan(initial_hidden_size_per_direction, initial_hidden_size_per_direction); + initial_hidden.empty() ? initial_hidden : initial_hidden.subspan(initial_hidden_size_per_direction, initial_hidden_size_per_direction); gsl::span initial_cell_2 = - initial_cell.empty() ? initial_cell : - initial_cell.subspan(initial_cell_size_per_direction, initial_cell_size_per_direction); + initial_cell.empty() ? initial_cell : initial_cell.subspan(initial_cell_size_per_direction, initial_cell_size_per_direction); gsl::span output_2 = output.empty() ? output : output.subspan(per_direction_offset, output_size - per_direction_offset); @@ -861,7 +800,7 @@ void UniDirectionalLstm::Compute(const gsl::span& inputs_arg, ComputeGemm(local_fused_hidden_rows, hidden_size_x4, hidden_size_, alpha, previous_state, previous_state_end, // Ht-1 hidden_size_, recurrent_weights.cbegin(), recurrent_weights.cend(), // R[iofc] - hidden_size_, beta, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T) + hidden_size_, beta, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T) hidden_size_x4, nullptr); DumpMatrix("Xt*(W[iofc]^T) + Ht-t*R[iofc]" + row_str, &*step_out_IOFC, local_fused_hidden_rows, hidden_size_x4); @@ -910,7 +849,8 @@ void UniDirectionalLstm::Compute(const gsl::span& inputs_arg, } }; - ExecuteLambdaInParallel(hidden_gemm_and_activations, batch_size_, fused_hidden_rows, mlas_tp_); + double cost = max_sequence_length * fused_hidden_rows; // TODO: approximate cost, needs more tuning. + ExecuteLambdaInParallel(hidden_gemm_and_activations, batch_size_, fused_hidden_rows, cost, mlas_tp_); } else { span_T_const_iter previous_state_end = batched_hidden_state_one_step.cend(); @@ -935,7 +875,7 @@ void UniDirectionalLstm::Compute(const gsl::span& inputs_arg, // calculate Xt*(W[iofc]^T) + Ht-t*R[iofc] ComputeGemm(batch_size_, hidden_size_x4, hidden_size_, alpha, previous_state, previous_state_end, // Ht-1 hidden_size_, recurrent_weights.cbegin(), recurrent_weights.cend(), // R[iofc] - hidden_size_, beta, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T) + hidden_size_, beta, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T) hidden_size_x4, mlas_tp_); span_T_iter batched_output; @@ -1011,7 +951,7 @@ void UniDirectionalLstm::Compute(const gsl::span& inputs_arg, if (output_sequence && direction_ == Direction::kReverse) ReverseSequence(outputs, original_outputs, sequence_lengths, seq_length_, batch_size_, hidden_size_, - num_directions,mlas_tp_); + num_directions, mlas_tp_); } // #define PREVIOUS_BROKEN_VERSION @@ -1141,7 +1081,7 @@ void UniDirectionalLstm::GateComputations( template void UniDirectionalLstm::SetNumThreads() { - int threads = mlas_tp_ == nullptr ? 1 : mlas_tp_->NumThreads(); + int threads = concurrency::ThreadPool::NumThreads(mlas_tp_); if (threads < 1) threads = 1; diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index 62f02ef309..8df0e16229 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -141,7 +141,15 @@ ORT_API_STATUS_IMPL(OrtApis::SetSessionGraphOptimizationLevel, _In_ OrtSessionOp } ORT_API_STATUS_IMPL(OrtApis::SetIntraOpNumThreads, _Inout_ OrtSessionOptions* options, int intra_op_num_threads) { +#ifdef _OPENMP + ORT_UNUSED_PARAMETER(options); + ORT_UNUSED_PARAMETER(intra_op_num_threads); + LOGS_DEFAULT(WARNING) << "Since openmp is enabled in this build, this API cannot be used to configure" + " intra op num threads. Please use the openmp environment variables to control" + " the number of threads."; +#else options->value.intra_op_param.thread_pool_size = intra_op_num_threads; +#endif return nullptr; } diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index ec2751b2e9..c79919eb7f 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -54,12 +54,12 @@ Status Environment::Initialize(std::unique_ptr logging_ if (to.name == nullptr) { to.name = ORT_TSTR("intra-op"); } - intra_op_thread_pool_ = concurrency::CreateThreadPool(&Env::Default(), to, nullptr); + intra_op_thread_pool_ = concurrency::CreateThreadPool(&Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP, nullptr); to = tp_options->inter_op_thread_pool_params; if (to.name == nullptr) { to.name = ORT_TSTR("inter-op"); } - inter_op_thread_pool_ = concurrency::CreateThreadPool(&Env::Default(), to, nullptr); + inter_op_thread_pool_ = concurrency::CreateThreadPool(&Env::Default(), to, concurrency::ThreadPoolType::INTER_OP, nullptr); } try { diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 2c93875c01..147e232cd8 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -183,7 +183,7 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL && to.affinity_vec_len == 0; thread_pool_ = - concurrency::CreateThreadPool(&Env::Default(), to, nullptr); + concurrency::CreateThreadPool(&Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP, nullptr); } if (session_options_.execution_mode == ExecutionMode::ORT_PARALLEL) { OrtThreadPoolParams to = session_options_.inter_op_param; @@ -194,7 +194,7 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, if (to.name == nullptr) to.name = ORT_TSTR("intra-op"); inter_op_thread_pool_ = - concurrency::CreateThreadPool(&Env::Default(), to, nullptr); + concurrency::CreateThreadPool(&Env::Default(), to, concurrency::ThreadPoolType::INTER_OP, nullptr); if (inter_op_thread_pool_ == nullptr) { LOGS(*session_logger_, INFO) << "Failed to create the inter-op thread pool for the parallel executor, setting ExecutionMode to SEQUENTIAL"; session_options_.execution_mode = ExecutionMode::ORT_SEQUENTIAL; diff --git a/onnxruntime/core/util/thread_utils.cc b/onnxruntime/core/util/thread_utils.cc index dbb82f3496..1ef5b9ebf8 100644 --- a/onnxruntime/core/util/thread_utils.cc +++ b/onnxruntime/core/util/thread_utils.cc @@ -9,59 +9,51 @@ namespace onnxruntime { namespace concurrency { -static inline std::vector GenerateVectorOfN(size_t n) { - std::vector ret(n); - std::iota(ret.begin(), ret.end(), 0); - return ret; - } -#ifdef _WIN32 - // This function doesn't support systems with more than 64 logical processors - static std::vector GetNumCpuCores() { - // Indeed 64 should be enough. However, it's harmless to have a little more. - SYSTEM_LOGICAL_PROCESSOR_INFORMATION buffer[256]; - DWORD returnLength = sizeof(buffer); - if (GetLogicalProcessorInformation(buffer, &returnLength) == FALSE) { - return GenerateVectorOfN(std::thread::hardware_concurrency()); - } - std::vector ret; - int count = (int)(returnLength / sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION)); - for (int i = 0; i != count; ++i) { - if (buffer[i].Relationship == RelationProcessorCore) { - ret.push_back(buffer[i].ProcessorMask); - } - } - if (ret.empty()) - return GenerateVectorOfN(std::thread::hardware_concurrency()); - return ret; - } -#else - static std::vector GetNumCpuCores() { - return GenerateVectorOfN(std::thread::hardware_concurrency() / 2); - } -#endif - std::unique_ptr CreateThreadPool(Env* env, OrtThreadPoolParams options, Eigen::Allocator* allocator) { - if (options.thread_pool_size == 1) - return nullptr; - std::vector cpu_list; - ThreadOptions to; - if (options.affinity_vec_len != 0) { - to.affinity.assign(options.affinity_vec, options.affinity_vec + options.affinity_vec_len); - } - if (options.thread_pool_size <= 0) { // default - cpu_list = GetNumCpuCores(); - if (cpu_list.empty() || cpu_list.size() == 1) - return nullptr; - options.thread_pool_size = static_cast(cpu_list.size()); - if (options.auto_set_affinity) - to.affinity = cpu_list; - } +static std::unique_ptr +CreateThreadPoolHelper(Env* env, OrtThreadPoolParams options, Eigen::Allocator* allocator) { + if (options.thread_pool_size == 1) + return nullptr; + std::vector cpu_list; + ThreadOptions to; + if (options.affinity_vec_len != 0) { + to.affinity.assign(options.affinity_vec, options.affinity_vec + options.affinity_vec_len); + } + if (options.thread_pool_size <= 0) { // default + cpu_list = Env::Default().GetThreadAffinityMasks(); + if (cpu_list.empty() || cpu_list.size() == 1) + return nullptr; + options.thread_pool_size = static_cast(cpu_list.size()); + if (options.auto_set_affinity) + to.affinity = cpu_list; + } - return onnxruntime::make_unique(env, to, options.name, options.thread_pool_size, - options.allow_spinning, allocator); - } - } // namespace concurrency + return onnxruntime::make_unique(env, to, options.name, options.thread_pool_size, + options.allow_spinning, allocator); +} + +std::unique_ptr +CreateThreadPool(Env* env, OrtThreadPoolParams options, ThreadPoolType tpool_type, Eigen::Allocator* allocator) { +// If openmp is enabled we don't want to create any additional threadpools for sequential execution. +// However, parallel execution relies on the existence of a separate threadpool. Hence we allow eigen threadpools +// to be created for parallel execution. +#ifdef _OPENMP + ORT_UNUSED_PARAMETER(env); + ORT_UNUSED_PARAMETER(options); + ORT_UNUSED_PARAMETER(allocator); + if (tpool_type != ThreadPoolType::INTER_OP) { + return nullptr; + } else { + return CreateThreadPoolHelper(env, options, allocator); + } +#else + ORT_UNUSED_PARAMETER(tpool_type); + return CreateThreadPoolHelper(env, options, allocator); +#endif +} + +} // namespace concurrency } // namespace onnxruntime -namespace OrtApis{ +namespace OrtApis { ORT_API_STATUS_IMPL(CreateThreadingOptions, _Outptr_ OrtThreadingOptions** out) { *out = new OrtThreadingOptions(); return nullptr; @@ -70,4 +62,4 @@ ORT_API_STATUS_IMPL(CreateThreadingOptions, _Outptr_ OrtThreadingOptions** out) ORT_API(void, ReleaseThreadingOptions, _Frees_ptr_opt_ OrtThreadingOptions* p) { delete p; } -} \ No newline at end of file +} // namespace OrtApis \ No newline at end of file diff --git a/onnxruntime/core/util/thread_utils.h b/onnxruntime/core/util/thread_utils.h index 7d040ce4e1..66e670312a 100644 --- a/onnxruntime/core/util/thread_utils.h +++ b/onnxruntime/core/util/thread_utils.h @@ -7,7 +7,7 @@ #include #include -struct OrtThreadPoolParams{ +struct OrtThreadPoolParams { //0: Use default setting. (All the physical cores or half of the logical cores) //1: Don't create thread pool //n: Create a thread pool with n threads. @@ -25,7 +25,7 @@ struct OrtThreadPoolParams{ size_t* affinity_vec = nullptr; size_t affinity_vec_len = 0; const ORTCHAR_T* name = nullptr; -} ; +}; struct OrtThreadingOptions { // Params for creating the threads that parallelizes execution of an op @@ -33,13 +33,17 @@ struct OrtThreadingOptions { // Params for creating the threads that parallelizes execution across ops OrtThreadPoolParams inter_op_thread_pool_params; -} ; +}; namespace onnxruntime { namespace concurrency { - +enum class ThreadPoolType : uint8_t { + INTRA_OP, + INTER_OP +}; std::unique_ptr CreateThreadPool(Env* env, OrtThreadPoolParams options, + ThreadPoolType tpool_type, Eigen::Allocator* allocator = nullptr); } // namespace concurrency } // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index 261540fc5a..2369ebac06 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -166,7 +166,7 @@ class PlannerTest : public ::testing::Test { PlannerTest() : model_("test", false, DefaultLoggingManager().DefaultLogger()), graph_(model_.MainGraph()), - tp_(concurrency::CreateThreadPool(&onnxruntime::Env::Default(), OrtThreadPoolParams())), + tp_(concurrency::CreateThreadPool(&onnxruntime::Env::Default(), OrtThreadPoolParams(), concurrency::ThreadPoolType::INTRA_OP)), state_(execution_providers_, false, tp_.get(), nullptr) { std_kernel_ = KernelDefBuilder().SetName("Transpose").Provider(kCpuExecutionProvider).SinceVersion(1, 10).Build(); in_place_kernel_ = @@ -201,8 +201,8 @@ class PlannerTest : public ::testing::Test { void BindKernel(onnxruntime::Node* p_node, ::onnxruntime::KernelDef& kernel_def, KernelRegistry* reg) { auto info = onnxruntime::make_unique(*p_node, kernel_def, *execution_providers_.Get(*p_node), - state_.GetInitializedTensors(), state_.GetOrtValueNameIdxMap(), - state_.GetFuncMgr(), state_.GetDataTransferMgr()); + state_.GetInitializedTensors(), state_.GetOrtValueNameIdxMap(), + state_.GetFuncMgr(), state_.GetDataTransferMgr()); op_kernel_infos_.push_back(std::move(info)); if (reg->TryFindKernel(*p_node, onnxruntime::kCpuExecutionProvider) == nullptr) { auto st = reg->Register( diff --git a/onnxruntime/test/framework/math_test.cc b/onnxruntime/test/framework/math_test.cc index 9798a523cc..7f071339ac 100644 --- a/onnxruntime/test/framework/math_test.cc +++ b/onnxruntime/test/framework/math_test.cc @@ -28,12 +28,12 @@ namespace onnxruntime { //parameter is thread pool size class MathGemmTest : public testing::TestWithParam { protected: - static OrtThreadPoolParams CreateThreadPoolOptions(int size){ - OrtThreadPoolParams option; - option.thread_pool_size = size; - return option; - } - std::unique_ptr tp{concurrency::CreateThreadPool(&Env::Default(),CreateThreadPoolOptions(GetParam()))}; + static OrtThreadPoolParams CreateThreadPoolOptions(int size) { + OrtThreadPoolParams option; + option.thread_pool_size = size; + return option; + } + std::unique_ptr tp{concurrency::CreateThreadPool(&Env::Default(), CreateThreadPoolOptions(GetParam()), concurrency::ThreadPoolType::INTRA_OP)}; }; TEST_P(MathGemmTest, GemmNoTransNoTrans) { @@ -124,7 +124,7 @@ TEST_P(MathGemmTest, GemmNoTransTrans) { } INSTANTIATE_TEST_SUITE_P(MathGemmTests, MathGemmTest, - testing::Values(1, 0)); + testing::Values(1, 0)); TEST(MathTest, GemvNoTrans) { auto& provider = CPUMathUtil::Instance(); diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index b9da657320..8f8427b4fd 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -42,7 +42,7 @@ class SessionStateAddGetKernelTest : public testing::TestWithParam {}; TEST_P(SessionStateAddGetKernelTest, AddGetKernelTest) { OrtThreadPoolParams to; to.thread_pool_size = GetParam(); - auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to); + auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP); ONNX_OPERATOR_SCHEMA(Variable) .SetDoc("Input variable.") .Output(0, "output_1", "docstr for output_1.", "tensor(int32)"); @@ -96,8 +96,7 @@ class TestParam { bool enable_mem_pattern; int thread_count; }; -TestParam param_list[] = {{3, true, 0}, {4, true, 0}, {3, false, 0}, {4, false, 0}, - {3, true, 1}, {4, true, 1}, {3, false, 1}, {4, false, 1}}; +TestParam param_list[] = {{3, true, 0}, {4, true, 0}, {3, false, 0}, {4, false, 0}, {3, true, 1}, {4, true, 1}, {3, false, 1}, {4, false, 1}}; } // namespace class SessionStateTestP : public testing::TestWithParam {}; // Test that we separate out constant and non-constant initializers correctly @@ -105,7 +104,7 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { const TestParam& param = GetParam(); OrtThreadPoolParams to; to.thread_pool_size = to.thread_pool_size; - auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to); + auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP); std::basic_ostringstream oss; oss << ORT_TSTR("testdata/optional_inputs_ir") << param.ir_version << ORT_TSTR(".onnx");