Threadpool related changes. (#3564)

Threadpool related changes.

Don't create ORT threadpool if openmp is enabled (except for inter op threadpool).
Created a new static function ThreadPool::NumThreads to account for openmp settings and null threadpool ptr.
Log a warning when using SetIntraOpNumThreads when openmp is enabled.
Added a document for ORT devs.
Fix LSTM to use the new threadpool abstractions.
Rename GetNumCpuCores to GetThreadAffinityMasks and move it to the Env class.

Co-authored-by: Tracy Sharpe <tracysh@microsoft.com>
This commit is contained in:
Pranav Sharma 2020-04-21 09:57:39 -07:00 committed by GitHub
parent 3dd3f84116
commit 9636da3951
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 174 additions and 164 deletions

18
docs/NotesOnThreading.md Normal file
View file

@ -0,0 +1,18 @@
# Notes on Threading in ORT
This document is intended for ORT developers.
ORT allows the usage of either OpenMP or non-OpenMP (ORT) threads for execution. Threadpool management
is abstracted behind: (1) ThreadPool class in threadpool.h and (2) functions in thread_utils.h.
When developing an op, please use these abstractions to parallelize your code. These abstractions centralize 2 things.
When OpenMP is enabled, they resort to using OpenMP. When OpenMP is disabled they resort to sequential execution if the threadpool ptr is NULL or schedule the tasks on the threadpool otherwise.
Examples of these abstractions are: (threadpool.h has more documentation for these)
* TryBatchParallelFor
* TryParallelFor
* static version of NumThreads
**Please do not write #ifdef pragma omp in operator code**.
For intra op parallelism ORT users can use either OpenMP or ORT threadpool. The choice of using OpenMP is indicated by building ORT with ```--use_openmp``` switch. For inter op parallelism, however, we always use the ORT threadpool.

View file

@ -158,14 +158,12 @@ class ThreadPool {
const std::function<void(std::ptrdiff_t first, std::ptrdiff_t last)>& fn);
static void TryParallelFor(concurrency::ThreadPool* tp, std::ptrdiff_t total, double cost_per_unit,
const std::function<void(std::ptrdiff_t first, std::ptrdiff_t last)>& fn) {
if (tp == nullptr) {
fn(0, total);
return;
}
tp->ParallelFor(total, cost_per_unit, fn);
TryParallelFor(tp, total, TensorOpCost{0, 0, static_cast<double>(cost_per_unit)}, fn);
}
void ParallelFor(std::ptrdiff_t total, const TensorOpCost& cost_per_unit,
const std::function<void(std::ptrdiff_t first, std::ptrdiff_t)>& fn);
static void TryParallelFor(concurrency::ThreadPool* tp, std::ptrdiff_t total, const TensorOpCost& cost_per_unit,
const std::function<void(std::ptrdiff_t first, std::ptrdiff_t last)>& fn) {
if (tp == nullptr) {
@ -174,10 +172,12 @@ class ThreadPool {
}
tp->ParallelFor(total, cost_per_unit, fn);
}
// Similar to ParallelFor above, but takes the specified scheduling strategy
// into account.
void ParallelFor(std::ptrdiff_t total, const SchedulingParams& scheduling_params,
const std::function<void(std::ptrdiff_t, std::ptrdiff_t)>& fn);
void
ParallelFor(std::ptrdiff_t total, const SchedulingParams& scheduling_params,
const std::function<void(std::ptrdiff_t, std::ptrdiff_t)>& fn);
static void TryParallelFor(concurrency::ThreadPool* tp, std::ptrdiff_t total, const SchedulingParams& scheduling_params,
const std::function<void(std::ptrdiff_t, std::ptrdiff_t)>& fn) {
@ -187,7 +187,12 @@ class ThreadPool {
}
tp->ParallelFor(total, scheduling_params, fn);
}
// Returns the number of threads in the pool.
// Prefer using this API to get the number of threads unless you know what you're doing.
// This API takes into account if openmp is enabled/disabled and if the thread pool ptr is nullptr.
static int NumThreads(const concurrency::ThreadPool* tp);
// Returns the number of threads in the pool. Preferably use the static version of this API instead.
int NumThreads() const;
// Returns current thread id between 0 and NumThreads() - 1, if called from a

View file

@ -298,6 +298,15 @@ void ThreadPool::ParallelFor(std::ptrdiff_t total, double cost_per_unit,
ParallelFor(total, TensorOpCost{0, 0, static_cast<double>(cost_per_unit)}, fn);
}
int ThreadPool::NumThreads(const concurrency::ThreadPool* tp) {
#ifdef _OPENMP
ORT_UNUSED_PARAMETER(tp);
return (omp_get_num_threads() == 1) ? omp_get_max_threads() : 1;
#else
return tp ? tp->NumThreads() : 1;
#endif
}
int ThreadPool::NumThreads() const {
return underlying_threadpool_->NumThreads();
}

View file

@ -97,6 +97,10 @@ Abstract:
//
// Select the threading model.
//
// N.B. MLAS_NO_ONNXRUNTIME_THREADPOOL is used to build MLAS test code outside
// of the ONNX Runtime source tree. OpenMP may or may not be enabled in this
// configuration.
//
#if !defined(MLAS_NO_ONNXRUNTIME_THREADPOOL)
#include "core/platform/threadpool.h"
@ -666,19 +670,17 @@ MlasGetMaximumThreadCount(
MLAS_THREADPOOL* ThreadPool
)
{
#ifdef MLAS_NO_ONNXRUNTIME_THREADPOOL
#if defined(MLAS_NO_ONNXRUNTIME_THREADPOOL)
MLAS_UNREFERENCED_PARAMETER(ThreadPool);
#else
if (ThreadPool != nullptr) {
return ThreadPool->NumThreads();
}
#endif
#if defined(_OPENMP)
return (omp_get_num_threads() == 1) ? omp_get_max_threads() : 1;
#else
return 1;
#endif
#else
return onnxruntime::concurrency::ThreadPool::NumThreads(ThreadPool);
#endif
}
inline

View file

@ -113,6 +113,9 @@ class Env {
virtual int GetNumCpuCores() const = 0;
// This function doesn't support systems with more than 64 logical processors
virtual std::vector<size_t> GetThreadAffinityMasks() const = 0;
/// \brief Returns the number of micro-seconds since the Unix epoch.
virtual uint64_t NowMicros() const {
return env_time_->NowMicros();

View file

@ -120,7 +120,7 @@ class PosixThread : public EnvThread {
if (s != 0)
ORT_THROW("pthread_setaffinity_np failed");
}
#endif
#endif
}
~PosixThread() override {
@ -178,6 +178,12 @@ class PosixEnv : public Env {
return std::thread::hardware_concurrency();
}
std::vector<size_t> GetThreadAffinityMasks() const override {
std::vector<size_t> ret(std::thread::hardware_concurrency() / 2);
std::iota(ret.begin(), ret.end(), 0);
return ret;
}
void SleepForMicroseconds(int64_t micros) const override {
while (micros > 0) {
timespec sleep_time;

View file

@ -141,6 +141,30 @@ class WindowsEnv : public Env {
return processorCoreCount;
}
std::vector<size_t> GetThreadAffinityMasks() const override {
auto generate_vector_of_n = [](int n) {
std::vector<size_t> ret(n);
std::iota(ret.begin(), ret.end(), 0);
return ret;
};
// Indeed 64 should be enough. However, it's harmless to have a little more.
SYSTEM_LOGICAL_PROCESSOR_INFORMATION buffer[256];
DWORD returnLength = sizeof(buffer);
if (GetLogicalProcessorInformation(buffer, &returnLength) == FALSE) {
return generate_vector_of_n(std::thread::hardware_concurrency());
}
std::vector<size_t> ret;
int count = (int)(returnLength / sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION));
for (int i = 0; i != count; ++i) {
if (buffer[i].Relationship == RelationProcessorCore) {
ret.push_back(buffer[i].ProcessorMask);
}
}
if (ret.empty())
return generate_vector_of_n(std::thread::hardware_concurrency());
return ret;
}
static WindowsEnv& Instance() {
static WindowsEnv default_env;
return default_env;

View file

@ -163,7 +163,7 @@ Equations (Default: f=Sigmoid, g=Tanh, h=Tanh):
namespace onnxruntime {
template <typename TLambda>
static inline void ExecuteLambdaInParallel(TLambda lambda, int max, int step,
static inline void ExecuteLambdaInParallel(TLambda lambda, int max, int step, double cost,
onnxruntime::concurrency::ThreadPool* ttp) {
// #define NOTHREADS to execute the lambdas directly and in order if you need to do that to debug
@ -174,67 +174,12 @@ static inline void ExecuteLambdaInParallel(TLambda lambda, int max, int step,
std::bind(lambda, i)();
}
#else
// ORT_ENFORCE may and does throw at times from within the tasks that run
// on a thread-pool. Without propagating exceptions the process exits silently
// which will make diagnosing bugs more difficult.
// \! UGLY
// We have a problem here with the current thread-pool is that it takes std::function
// by value and copies it more than once (even though it is movable).
//
// To report status and exceptions properly it's better to use
// futures and promises but they are not copyable, so we can't come up with a functor
// with a promise member and we are downgrading to C++11 where we can't have captures that moved in.
//
// At the same time promises MUST live in the child thread so if we throw from the main thread
// we don't destroy any promises that are on the main thread stack which children threads may still be using.
//
// The only solution with the current Eigen that comes to mind is to have shared_ptr to with std::promise.
//
const int total_tasks = max / (step > 0 ? step : 1) + (max % step > 0 ? 1 : 0);
std::vector<std::future<void> > futures;
futures.reserve(total_tasks);
if (ttp != nullptr) {
for (int i = 0, t = 0; i < max; i += step, ++t) {
auto p_ptr = std::make_shared<std::promise<void> >();
futures.push_back(p_ptr->get_future());
ttp->Schedule([p_ptr, lambda, i]() {
try {
lambda(i);
p_ptr->set_value();
} catch (...) {
p_ptr->set_exception(std::current_exception());
}
});
concurrency::ThreadPool::TryParallelFor(ttp, total_tasks, cost, [lambda, step](ptrdiff_t first, ptrdiff_t last) {
for (int i = static_cast<int>(first), end = static_cast<int>(last); i < end; ++i) {
lambda(i * step);
}
// We'd like to wait until all of the tasks have finished
// even though one or more have already thrown. We will store
// the first exception and then will re-throw at the end.
std::exception_ptr pending_exception;
for (auto& fut : futures) {
try {
// get() will re-throw any exceptions
// the running task may throw
fut.get();
} catch (...) {
if (!pending_exception) {
pending_exception = std::current_exception();
}
}
}
if (pending_exception) {
std::rethrow_exception(pending_exception);
}
} else {
for (int i = 0; i < max; i += step) {
std::bind(lambda, i)();
}
}
});
#endif
}
@ -488,8 +433,7 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
const size_t last_cell_size_per_direction = batch_size * hidden_size_;
IAllocatorUniquePtr<T> local_last_cell;
gsl::span<T> last_cell = Y_c ? Y_c->MutableDataAsSpan<T>() :
Allocate(alloc, last_cell_size_per_direction * num_directions_, local_last_cell);
gsl::span<T> last_cell = Y_c ? Y_c->MutableDataAsSpan<T>() : Allocate(alloc, last_cell_size_per_direction * num_directions_, local_last_cell);
gsl::span<T> last_cell_1 = last_cell.subspan(0, last_cell_size_per_direction);
@ -501,17 +445,12 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
recurrent_weights.subspan(hidden_weights_size_per_direction, hidden_weights_size_per_direction);
gsl::span<const T> bias_2 = bias.empty() ? bias : bias.subspan(bias_size_per_direction, bias_size_per_direction);
gsl::span<const T> peephole_weights_2 =
peephole_weights.empty() ?
peephole_weights :
peephole_weights.subspan(peephole_weights_size_per_direction, peephole_weights_size_per_direction);
peephole_weights.empty() ? peephole_weights : peephole_weights.subspan(peephole_weights_size_per_direction, peephole_weights_size_per_direction);
gsl::span<const T> initial_hidden_2 =
initial_hidden.empty() ?
initial_hidden :
initial_hidden.subspan(initial_hidden_size_per_direction, initial_hidden_size_per_direction);
initial_hidden.empty() ? initial_hidden : initial_hidden.subspan(initial_hidden_size_per_direction, initial_hidden_size_per_direction);
gsl::span<const T> initial_cell_2 =
initial_cell.empty() ? initial_cell :
initial_cell.subspan(initial_cell_size_per_direction, initial_cell_size_per_direction);
initial_cell.empty() ? initial_cell : initial_cell.subspan(initial_cell_size_per_direction, initial_cell_size_per_direction);
gsl::span<T> output_2 =
output.empty() ? output : output.subspan(per_direction_offset, output_size - per_direction_offset);
@ -861,7 +800,7 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
ComputeGemm(local_fused_hidden_rows, hidden_size_x4, hidden_size_, alpha, previous_state,
previous_state_end, // Ht-1
hidden_size_, recurrent_weights.cbegin(), recurrent_weights.cend(), // R[iofc]
hidden_size_, beta, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
hidden_size_, beta, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
hidden_size_x4, nullptr);
DumpMatrix("Xt*(W[iofc]^T) + Ht-t*R[iofc]" + row_str, &*step_out_IOFC, local_fused_hidden_rows, hidden_size_x4);
@ -910,7 +849,8 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
}
};
ExecuteLambdaInParallel(hidden_gemm_and_activations, batch_size_, fused_hidden_rows, mlas_tp_);
double cost = max_sequence_length * fused_hidden_rows; // TODO: approximate cost, needs more tuning.
ExecuteLambdaInParallel(hidden_gemm_and_activations, batch_size_, fused_hidden_rows, cost, mlas_tp_);
} else {
span_T_const_iter previous_state_end = batched_hidden_state_one_step.cend();
@ -935,7 +875,7 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
// calculate Xt*(W[iofc]^T) + Ht-t*R[iofc]
ComputeGemm(batch_size_, hidden_size_x4, hidden_size_, alpha, previous_state, previous_state_end, // Ht-1
hidden_size_, recurrent_weights.cbegin(), recurrent_weights.cend(), // R[iofc]
hidden_size_, beta, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
hidden_size_, beta, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
hidden_size_x4, mlas_tp_);
span_T_iter batched_output;
@ -1011,7 +951,7 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
if (output_sequence && direction_ == Direction::kReverse)
ReverseSequence<T>(outputs, original_outputs, sequence_lengths, seq_length_, batch_size_, hidden_size_,
num_directions,mlas_tp_);
num_directions, mlas_tp_);
}
// #define PREVIOUS_BROKEN_VERSION
@ -1141,7 +1081,7 @@ void UniDirectionalLstm<T>::GateComputations(
template <typename T>
void UniDirectionalLstm<T>::SetNumThreads() {
int threads = mlas_tp_ == nullptr ? 1 : mlas_tp_->NumThreads();
int threads = concurrency::ThreadPool::NumThreads(mlas_tp_);
if (threads < 1)
threads = 1;

View file

@ -141,7 +141,15 @@ ORT_API_STATUS_IMPL(OrtApis::SetSessionGraphOptimizationLevel, _In_ OrtSessionOp
}
ORT_API_STATUS_IMPL(OrtApis::SetIntraOpNumThreads, _Inout_ OrtSessionOptions* options, int intra_op_num_threads) {
#ifdef _OPENMP
ORT_UNUSED_PARAMETER(options);
ORT_UNUSED_PARAMETER(intra_op_num_threads);
LOGS_DEFAULT(WARNING) << "Since openmp is enabled in this build, this API cannot be used to configure"
" intra op num threads. Please use the openmp environment variables to control"
" the number of threads.";
#else
options->value.intra_op_param.thread_pool_size = intra_op_num_threads;
#endif
return nullptr;
}

View file

@ -54,12 +54,12 @@ Status Environment::Initialize(std::unique_ptr<logging::LoggingManager> logging_
if (to.name == nullptr) {
to.name = ORT_TSTR("intra-op");
}
intra_op_thread_pool_ = concurrency::CreateThreadPool(&Env::Default(), to, nullptr);
intra_op_thread_pool_ = concurrency::CreateThreadPool(&Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP, nullptr);
to = tp_options->inter_op_thread_pool_params;
if (to.name == nullptr) {
to.name = ORT_TSTR("inter-op");
}
inter_op_thread_pool_ = concurrency::CreateThreadPool(&Env::Default(), to, nullptr);
inter_op_thread_pool_ = concurrency::CreateThreadPool(&Env::Default(), to, concurrency::ThreadPoolType::INTER_OP, nullptr);
}
try {

View file

@ -183,7 +183,7 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options,
session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL &&
to.affinity_vec_len == 0;
thread_pool_ =
concurrency::CreateThreadPool(&Env::Default(), to, nullptr);
concurrency::CreateThreadPool(&Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP, nullptr);
}
if (session_options_.execution_mode == ExecutionMode::ORT_PARALLEL) {
OrtThreadPoolParams to = session_options_.inter_op_param;
@ -194,7 +194,7 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options,
if (to.name == nullptr)
to.name = ORT_TSTR("intra-op");
inter_op_thread_pool_ =
concurrency::CreateThreadPool(&Env::Default(), to, nullptr);
concurrency::CreateThreadPool(&Env::Default(), to, concurrency::ThreadPoolType::INTER_OP, nullptr);
if (inter_op_thread_pool_ == nullptr) {
LOGS(*session_logger_, INFO) << "Failed to create the inter-op thread pool for the parallel executor, setting ExecutionMode to SEQUENTIAL";
session_options_.execution_mode = ExecutionMode::ORT_SEQUENTIAL;

View file

@ -9,59 +9,51 @@
namespace onnxruntime {
namespace concurrency {
static inline std::vector<size_t> GenerateVectorOfN(size_t n) {
std::vector<size_t> ret(n);
std::iota(ret.begin(), ret.end(), 0);
return ret;
}
#ifdef _WIN32
// This function doesn't support systems with more than 64 logical processors
static std::vector<size_t> GetNumCpuCores() {
// Indeed 64 should be enough. However, it's harmless to have a little more.
SYSTEM_LOGICAL_PROCESSOR_INFORMATION buffer[256];
DWORD returnLength = sizeof(buffer);
if (GetLogicalProcessorInformation(buffer, &returnLength) == FALSE) {
return GenerateVectorOfN(std::thread::hardware_concurrency());
}
std::vector<size_t> ret;
int count = (int)(returnLength / sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION));
for (int i = 0; i != count; ++i) {
if (buffer[i].Relationship == RelationProcessorCore) {
ret.push_back(buffer[i].ProcessorMask);
}
}
if (ret.empty())
return GenerateVectorOfN(std::thread::hardware_concurrency());
return ret;
}
#else
static std::vector<size_t> GetNumCpuCores() {
return GenerateVectorOfN(std::thread::hardware_concurrency() / 2);
}
#endif
std::unique_ptr<ThreadPool> CreateThreadPool(Env* env, OrtThreadPoolParams options, Eigen::Allocator* allocator) {
if (options.thread_pool_size == 1)
return nullptr;
std::vector<size_t> cpu_list;
ThreadOptions to;
if (options.affinity_vec_len != 0) {
to.affinity.assign(options.affinity_vec, options.affinity_vec + options.affinity_vec_len);
}
if (options.thread_pool_size <= 0) { // default
cpu_list = GetNumCpuCores();
if (cpu_list.empty() || cpu_list.size() == 1)
return nullptr;
options.thread_pool_size = static_cast<int>(cpu_list.size());
if (options.auto_set_affinity)
to.affinity = cpu_list;
}
static std::unique_ptr<ThreadPool>
CreateThreadPoolHelper(Env* env, OrtThreadPoolParams options, Eigen::Allocator* allocator) {
if (options.thread_pool_size == 1)
return nullptr;
std::vector<size_t> cpu_list;
ThreadOptions to;
if (options.affinity_vec_len != 0) {
to.affinity.assign(options.affinity_vec, options.affinity_vec + options.affinity_vec_len);
}
if (options.thread_pool_size <= 0) { // default
cpu_list = Env::Default().GetThreadAffinityMasks();
if (cpu_list.empty() || cpu_list.size() == 1)
return nullptr;
options.thread_pool_size = static_cast<int>(cpu_list.size());
if (options.auto_set_affinity)
to.affinity = cpu_list;
}
return onnxruntime::make_unique<ThreadPool>(env, to, options.name, options.thread_pool_size,
options.allow_spinning, allocator);
}
} // namespace concurrency
return onnxruntime::make_unique<ThreadPool>(env, to, options.name, options.thread_pool_size,
options.allow_spinning, allocator);
}
std::unique_ptr<ThreadPool>
CreateThreadPool(Env* env, OrtThreadPoolParams options, ThreadPoolType tpool_type, Eigen::Allocator* allocator) {
// If openmp is enabled we don't want to create any additional threadpools for sequential execution.
// However, parallel execution relies on the existence of a separate threadpool. Hence we allow eigen threadpools
// to be created for parallel execution.
#ifdef _OPENMP
ORT_UNUSED_PARAMETER(env);
ORT_UNUSED_PARAMETER(options);
ORT_UNUSED_PARAMETER(allocator);
if (tpool_type != ThreadPoolType::INTER_OP) {
return nullptr;
} else {
return CreateThreadPoolHelper(env, options, allocator);
}
#else
ORT_UNUSED_PARAMETER(tpool_type);
return CreateThreadPoolHelper(env, options, allocator);
#endif
}
} // namespace concurrency
} // namespace onnxruntime
namespace OrtApis{
namespace OrtApis {
ORT_API_STATUS_IMPL(CreateThreadingOptions, _Outptr_ OrtThreadingOptions** out) {
*out = new OrtThreadingOptions();
return nullptr;
@ -70,4 +62,4 @@ ORT_API_STATUS_IMPL(CreateThreadingOptions, _Outptr_ OrtThreadingOptions** out)
ORT_API(void, ReleaseThreadingOptions, _Frees_ptr_opt_ OrtThreadingOptions* p) {
delete p;
}
}
} // namespace OrtApis

View file

@ -7,7 +7,7 @@
#include <memory>
#include <string>
struct OrtThreadPoolParams{
struct OrtThreadPoolParams {
//0: Use default setting. (All the physical cores or half of the logical cores)
//1: Don't create thread pool
//n: Create a thread pool with n threads.
@ -25,7 +25,7 @@ struct OrtThreadPoolParams{
size_t* affinity_vec = nullptr;
size_t affinity_vec_len = 0;
const ORTCHAR_T* name = nullptr;
} ;
};
struct OrtThreadingOptions {
// Params for creating the threads that parallelizes execution of an op
@ -33,13 +33,17 @@ struct OrtThreadingOptions {
// Params for creating the threads that parallelizes execution across ops
OrtThreadPoolParams inter_op_thread_pool_params;
} ;
};
namespace onnxruntime {
namespace concurrency {
enum class ThreadPoolType : uint8_t {
INTRA_OP,
INTER_OP
};
std::unique_ptr<ThreadPool> CreateThreadPool(Env* env, OrtThreadPoolParams options,
ThreadPoolType tpool_type,
Eigen::Allocator* allocator = nullptr);
} // namespace concurrency
} // namespace onnxruntime

View file

@ -166,7 +166,7 @@ class PlannerTest : public ::testing::Test {
PlannerTest()
: model_("test", false, DefaultLoggingManager().DefaultLogger()),
graph_(model_.MainGraph()),
tp_(concurrency::CreateThreadPool(&onnxruntime::Env::Default(), OrtThreadPoolParams())),
tp_(concurrency::CreateThreadPool(&onnxruntime::Env::Default(), OrtThreadPoolParams(), concurrency::ThreadPoolType::INTRA_OP)),
state_(execution_providers_, false, tp_.get(), nullptr) {
std_kernel_ = KernelDefBuilder().SetName("Transpose").Provider(kCpuExecutionProvider).SinceVersion(1, 10).Build();
in_place_kernel_ =
@ -201,8 +201,8 @@ class PlannerTest : public ::testing::Test {
void BindKernel(onnxruntime::Node* p_node, ::onnxruntime::KernelDef& kernel_def, KernelRegistry* reg) {
auto info = onnxruntime::make_unique<OpKernelInfo>(*p_node, kernel_def, *execution_providers_.Get(*p_node),
state_.GetInitializedTensors(), state_.GetOrtValueNameIdxMap(),
state_.GetFuncMgr(), state_.GetDataTransferMgr());
state_.GetInitializedTensors(), state_.GetOrtValueNameIdxMap(),
state_.GetFuncMgr(), state_.GetDataTransferMgr());
op_kernel_infos_.push_back(std::move(info));
if (reg->TryFindKernel(*p_node, onnxruntime::kCpuExecutionProvider) == nullptr) {
auto st = reg->Register(

View file

@ -28,12 +28,12 @@ namespace onnxruntime {
//parameter is thread pool size
class MathGemmTest : public testing::TestWithParam<int> {
protected:
static OrtThreadPoolParams CreateThreadPoolOptions(int size){
OrtThreadPoolParams option;
option.thread_pool_size = size;
return option;
}
std::unique_ptr<concurrency::ThreadPool> tp{concurrency::CreateThreadPool(&Env::Default(),CreateThreadPoolOptions(GetParam()))};
static OrtThreadPoolParams CreateThreadPoolOptions(int size) {
OrtThreadPoolParams option;
option.thread_pool_size = size;
return option;
}
std::unique_ptr<concurrency::ThreadPool> tp{concurrency::CreateThreadPool(&Env::Default(), CreateThreadPoolOptions(GetParam()), concurrency::ThreadPoolType::INTRA_OP)};
};
TEST_P(MathGemmTest, GemmNoTransNoTrans) {
@ -124,7 +124,7 @@ TEST_P(MathGemmTest, GemmNoTransTrans) {
}
INSTANTIATE_TEST_SUITE_P(MathGemmTests, MathGemmTest,
testing::Values(1, 0));
testing::Values(1, 0));
TEST(MathTest, GemvNoTrans) {
auto& provider = CPUMathUtil::Instance();

View file

@ -42,7 +42,7 @@ class SessionStateAddGetKernelTest : public testing::TestWithParam<int> {};
TEST_P(SessionStateAddGetKernelTest, AddGetKernelTest) {
OrtThreadPoolParams to;
to.thread_pool_size = GetParam();
auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to);
auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP);
ONNX_OPERATOR_SCHEMA(Variable)
.SetDoc("Input variable.")
.Output(0, "output_1", "docstr for output_1.", "tensor(int32)");
@ -96,8 +96,7 @@ class TestParam {
bool enable_mem_pattern;
int thread_count;
};
TestParam param_list[] = {{3, true, 0}, {4, true, 0}, {3, false, 0}, {4, false, 0},
{3, true, 1}, {4, true, 1}, {3, false, 1}, {4, false, 1}};
TestParam param_list[] = {{3, true, 0}, {4, true, 0}, {3, false, 0}, {4, false, 0}, {3, true, 1}, {4, true, 1}, {3, false, 1}, {4, false, 1}};
} // namespace
class SessionStateTestP : public testing::TestWithParam<TestParam> {};
// Test that we separate out constant and non-constant initializers correctly
@ -105,7 +104,7 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) {
const TestParam& param = GetParam();
OrtThreadPoolParams to;
to.thread_pool_size = to.thread_pool_size;
auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to);
auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP);
std::basic_ostringstream<ORTCHAR_T> oss;
oss << ORT_TSTR("testdata/optional_inputs_ir") << param.ir_version << ORT_TSTR(".onnx");