Allow saving on CPU usage for infrequent inference requests by reducing thread spinning (#11841)

Introduce Start/Stop threadpool spinning switch
Add a session config option to force spinning stop at the end of the Run()
This commit is contained in:
Dmitri Smirnov 2022-06-23 10:04:37 -07:00 committed by GitHub
parent c398ad513f
commit 607b7df060
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 644 additions and 559 deletions

File diff suppressed because it is too large Load diff

View file

@ -221,6 +221,14 @@ class ThreadPool {
"Per-thread state should be trivially destructible");
};
// The below API allows to disable spinning
// This is used to support real-time scenarios where
// spinning between relatively infrequent requests
// contributes to high CPU usage while not processing anything.
void EnableSpinning();
void DisableSpinning();
// Schedules fn() for execution in the pool of threads. The function may run
// synchronously if it cannot be enqueued. This will occur if the thread pool's
// degree-of-parallelism is 1, but it may also occur for implementation-dependent

View file

@ -114,6 +114,13 @@ static const char* const kOrtSessionOptionsConfigNnapiEpPartitioningStopOps = "e
// Available since version 1.11.
static const char* const kOrtSessionOptionsConfigDynamicBlockBase = "session.dynamic_block_base";
// This option allows to decrease CPU usage between infrequent
// requests and forces any TP threads spinning stop immediately when the last of
// concurrent Run() call returns.
// Spinning is restarted on the next Run() call.
// Applies only to internal thread-pools
static const char* const kOrtSessionOptionsConfigForceSpinningStop = "session.force_spinning_stop";
// "1": all inconsistencies encountered during shape and type inference
// will result in failures.
// "0": in some cases warnings will be logged but processing will continue. The default.

View file

@ -657,6 +657,18 @@ std::string ThreadPool::StopProfiling(concurrency::ThreadPool* tp) {
}
}
void ThreadPool::EnableSpinning() {
if (extended_eigen_threadpool_) {
extended_eigen_threadpool_->EnableSpinning();
}
}
void ThreadPool::DisableSpinning() {
if (extended_eigen_threadpool_) {
extended_eigen_threadpool_->DisableSpinning();
}
}
// Return the number of threads created by the pool.
int ThreadPool::NumThreads() const {
if (underlying_threadpool_) {

View file

@ -264,6 +264,7 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options,
}
use_per_session_threads_ = session_options.use_per_session_threads;
force_spinning_stop_between_runs_ = session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigForceSpinningStop, "0") == "1";
if (use_per_session_threads_) {
LOGS(*session_logger_, INFO) << "Creating and using per session threadpools since use_per_session_threads_ is true";
@ -1835,6 +1836,31 @@ Status InferenceSession::PartialRun(onnxruntime::RunOptions& run_options,
}
#endif
namespace {
// Concurrent runs counting and thread-pool spin control
struct ThreadPoolSpinningSwitch {
concurrency::ThreadPool* intra_tp_{nullptr};
concurrency::ThreadPool* inter_tp_{nullptr};
std::atomic<int>& concurrent_num_runs_;
// __Ctor Refcounting and spinning control
ThreadPoolSpinningSwitch(concurrency::ThreadPool* intra_tp,
concurrency::ThreadPool* inter_tp,
std::atomic<int>& ref) noexcept
: intra_tp_(intra_tp), inter_tp_(inter_tp), concurrent_num_runs_(ref) {
if (concurrent_num_runs_.fetch_add(1, std::memory_order_relaxed) == 0) {
if (intra_tp_) intra_tp_->EnableSpinning();
if (inter_tp_) inter_tp_->EnableSpinning();
}
}
~ThreadPoolSpinningSwitch() {
if (1 == concurrent_num_runs_.fetch_sub(1, std::memory_order_acq_rel)) {
if (intra_tp_) intra_tp_->DisableSpinning();
if (inter_tp_) inter_tp_->DisableSpinning();
}
}
};
} // namespace
Status InferenceSession::Run(const RunOptions& run_options,
const std::vector<std::string>& feed_names, const std::vector<OrtValue>& feeds,
const std::vector<std::string>& output_names, std::vector<OrtValue>* p_fetches,
@ -1852,12 +1878,20 @@ Status InferenceSession::Run(const RunOptions& run_options,
Status retval = Status::OK();
const Env& env = Env::Default();
// Increment/decrement concurrent_num_runs_ and control
// session threads spinning as configured. Do nothing for graph replay except the counter.
const bool control_spinning = use_per_session_threads_ &&
force_spinning_stop_between_runs_ &&
!cached_execution_provider_for_graph_replay_.IsGraphCaptured();
auto* intra_tp = (control_spinning) ? thread_pool_.get() : nullptr;
auto* inter_tp = (control_spinning) ? inter_op_thread_pool_.get() : nullptr;
ThreadPoolSpinningSwitch runs_refcounter_and_tp_spin_control(intra_tp, inter_tp, current_num_runs_);
// Check if this Run() is simply going to be a CUDA Graph replay.
if (cached_execution_provider_for_graph_replay_.IsGraphCaptured()) {
LOGS(*session_logger_, INFO) << "Replaying the captured "
<< cached_execution_provider_for_graph_replay_.Type()
<< " CUDA Graph for this model with tag: " << run_options.run_tag;
++current_num_runs_;
ORT_RETURN_IF_ERROR_SESSIONID_(cached_execution_provider_for_graph_replay_.ReplayGraph());
} else {
std::vector<IExecutionProvider*> exec_providers_to_stop;
@ -1902,8 +1936,6 @@ Status InferenceSession::Run(const RunOptions& run_options,
LOGS(*session_logger_, INFO) << "Running with tag: " << run_options.run_tag;
}
++current_num_runs_;
// scope of owned_run_logger is just the call to Execute.
// If Execute ever becomes async we need a different approach
std::unique_ptr<logging::Logger> owned_run_logger;
@ -1939,6 +1971,7 @@ Status InferenceSession::Run(const RunOptions& run_options,
#ifdef DEBUG_NODE_INPUTS_OUTPUTS
session_state_->IncrementGraphExecutionCounter();
#endif
ORT_CHECK_AND_SET_RETVAL(utils::ExecuteGraph(*session_state_, feeds_fetches_manager, feeds, *p_fetches,
session_options_.execution_mode, run_options.terminate, run_logger,
run_options.only_execute_path_to_fetches));
@ -1962,7 +1995,6 @@ Status InferenceSession::Run(const RunOptions& run_options,
ShrinkMemoryArenas(arenas_to_shrink);
}
}
--current_num_runs_;
// keep track of telemetry
++telemetry_.total_runs_since_last_;

View file

@ -664,6 +664,12 @@ class InferenceSession {
std::basic_string<ORTCHAR_T> thread_pool_name_;
std::basic_string<ORTCHAR_T> inter_thread_pool_name_;
// This option allows to decrease CPU usage between infrequent
// requests and forces any TP threads spinning stop immediately when the last of
// concurrent ExecuteGraph() call returns.
// Spinning is restarted on the next Run()
bool force_spinning_stop_between_runs_ = false;
std::unique_ptr<onnxruntime::concurrency::ThreadPool> thread_pool_;
std::unique_ptr<onnxruntime::concurrency::ThreadPool> inter_op_thread_pool_;