mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Stop creating static thread pool to fix random hang in onnx_test_runner (#14023)
This commit is contained in:
parent
533fe37cbd
commit
cd305a90d6
4 changed files with 12 additions and 15 deletions
|
|
@ -591,7 +591,8 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
|
|||
owned_tests.push_back(std::move(l));
|
||||
});
|
||||
|
||||
TestEnv test_env(env, sf, TestEnv::GetDefaultThreadPool(Env::Default()), std::move(tests), stat);
|
||||
auto tp = TestEnv::CreateThreadPool(Env::Default());
|
||||
TestEnv test_env(env, sf, tp.get(), std::move(tests), stat);
|
||||
Status st = test_env.Run(p_models, concurrent_session_runs, repeat_count);
|
||||
if (!st.IsOK()) {
|
||||
fprintf(stderr, "%s\n", st.ErrorMessage().c_str());
|
||||
|
|
|
|||
|
|
@ -12,18 +12,9 @@
|
|||
|
||||
using onnxruntime::Status;
|
||||
|
||||
static std::unique_ptr<onnxruntime::concurrency::ThreadPool> default_pool;
|
||||
static std::once_flag default_pool_init;
|
||||
|
||||
PThreadPool TestEnv::GetDefaultThreadPool(onnxruntime::Env& env) {
|
||||
std::call_once(default_pool_init, [&env] {
|
||||
using namespace onnxruntime::concurrency;
|
||||
std::unique_ptr<OrtThreadPool> TestEnv::CreateThreadPool(onnxruntime::Env& env) {
|
||||
int core_num = env.GetNumPhysicalCpuCores();
|
||||
|
||||
onnxruntime::ThreadOptions t_opts;
|
||||
default_pool = std::make_unique<ThreadPool>(&env, t_opts, ORT_TSTR("onnx_runner_tp"), core_num, false);
|
||||
});
|
||||
return default_pool.get();
|
||||
return std::make_unique<OrtThreadPool>(&env, onnxruntime::ThreadOptions{}, ORT_TSTR("onnx_runner_tp"), core_num, false);
|
||||
}
|
||||
|
||||
TestEnv::TestEnv(Ort::Env& env, Ort::SessionOptions& so, PThreadPool tp,
|
||||
|
|
|
|||
|
|
@ -25,7 +25,8 @@ class ThreadPool;
|
|||
}
|
||||
} // namespace onnxruntime
|
||||
|
||||
using PThreadPool = onnxruntime::concurrency::ThreadPool*;
|
||||
using OrtThreadPool = onnxruntime::concurrency::ThreadPool;
|
||||
using PThreadPool = OrtThreadPool*;
|
||||
|
||||
/// <summary>
|
||||
/// Facilitates running tests
|
||||
|
|
@ -37,7 +38,7 @@ class TestEnv {
|
|||
|
||||
~TestEnv();
|
||||
|
||||
static PThreadPool GetDefaultThreadPool(onnxruntime::Env& env);
|
||||
static std::unique_ptr<OrtThreadPool> CreateThreadPool(onnxruntime::Env& env);
|
||||
|
||||
/// <summary>
|
||||
/// Runs all tests cases either concurrently or sequentially
|
||||
|
|
|
|||
|
|
@ -652,6 +652,10 @@ TEST_P(ModelTest, Run) {
|
|||
std::unique_ptr<ITestCase> l = CreateOnnxTestCase(ToUTF8String(test_case_name), std::move(model_info),
|
||||
per_sample_tolerance, relative_per_sample_tolerance);
|
||||
|
||||
#ifndef USE_DNNL
|
||||
auto tp = TestEnv::CreateThreadPool(Env::Default());
|
||||
#endif
|
||||
|
||||
for (bool is_single_thread : use_single_thread) {
|
||||
for (ExecutionMode execution_mode : execution_modes) {
|
||||
OrtSessionOptions* ortso;
|
||||
|
|
@ -742,7 +746,7 @@ TEST_P(ModelTest, Run) {
|
|||
if (data_count > 1 && tests_run_parallel.find(l->GetTestCaseName()) != tests_run_parallel.end()) {
|
||||
LOGS_DEFAULT(ERROR) << "Parallel test for " << l->GetTestCaseName(); // TODO(leca): change level to INFO or even delete the log once verified parallel test working
|
||||
Ort::SessionOptions ort_session_options(ortso);
|
||||
std::shared_ptr<TestCaseResult> results = TestCaseRequestContext::Run(TestEnv::GetDefaultThreadPool(Env::Default()), *l, *ort_env, ort_session_options, data_count, 1 /*repeat_count*/);
|
||||
std::shared_ptr<TestCaseResult> results = TestCaseRequestContext::Run(tp.get(), *l, *ort_env, ort_session_options, data_count, 1 /*repeat_count*/);
|
||||
for (EXECUTE_RESULT res : results->GetExcutionResult()) {
|
||||
EXPECT_EQ(res, EXECUTE_RESULT::SUCCESS) << "is_single_thread:" << is_single_thread << ", execution_mode:" << execution_mode << ", provider_name:"
|
||||
<< provider_name << ", test name:" << results->GetName() << ", result: " << res;
|
||||
|
|
|
|||
Loading…
Reference in a new issue