diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index d244a85910..4a45383bb2 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -591,7 +591,8 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); owned_tests.push_back(std::move(l)); }); - TestEnv test_env(env, sf, TestEnv::GetDefaultThreadPool(Env::Default()), std::move(tests), stat); + auto tp = TestEnv::CreateThreadPool(Env::Default()); + TestEnv test_env(env, sf, tp.get(), std::move(tests), stat); Status st = test_env.Run(p_models, concurrent_session_runs, repeat_count); if (!st.IsOK()) { fprintf(stderr, "%s\n", st.ErrorMessage().c_str()); diff --git a/onnxruntime/test/onnx/testenv.cc b/onnxruntime/test/onnx/testenv.cc index 849cdd4427..3e7a48f411 100644 --- a/onnxruntime/test/onnx/testenv.cc +++ b/onnxruntime/test/onnx/testenv.cc @@ -12,18 +12,9 @@ using onnxruntime::Status; -static std::unique_ptr default_pool; -static std::once_flag default_pool_init; - -PThreadPool TestEnv::GetDefaultThreadPool(onnxruntime::Env& env) { - std::call_once(default_pool_init, [&env] { - using namespace onnxruntime::concurrency; +std::unique_ptr TestEnv::CreateThreadPool(onnxruntime::Env& env) { int core_num = env.GetNumPhysicalCpuCores(); - - onnxruntime::ThreadOptions t_opts; - default_pool = std::make_unique(&env, t_opts, ORT_TSTR("onnx_runner_tp"), core_num, false); - }); - return default_pool.get(); + return std::make_unique(&env, onnxruntime::ThreadOptions{}, ORT_TSTR("onnx_runner_tp"), core_num, false); } TestEnv::TestEnv(Ort::Env& env, Ort::SessionOptions& so, PThreadPool tp, diff --git a/onnxruntime/test/onnx/testenv.h b/onnxruntime/test/onnx/testenv.h index 445b99d65a..6a29f8818b 100644 --- a/onnxruntime/test/onnx/testenv.h +++ b/onnxruntime/test/onnx/testenv.h @@ -25,7 +25,8 @@ class ThreadPool; } } // namespace onnxruntime -using PThreadPool = onnxruntime::concurrency::ThreadPool*; +using OrtThreadPool = onnxruntime::concurrency::ThreadPool; +using PThreadPool = OrtThreadPool*; /// /// Facilitates running tests @@ -37,7 +38,7 @@ class TestEnv { ~TestEnv(); - static PThreadPool GetDefaultThreadPool(onnxruntime::Env& env); + static std::unique_ptr CreateThreadPool(onnxruntime::Env& env); /// /// Runs all tests cases either concurrently or sequentially diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index 0454c7ded3..361d2bc5f8 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -652,6 +652,10 @@ TEST_P(ModelTest, Run) { std::unique_ptr l = CreateOnnxTestCase(ToUTF8String(test_case_name), std::move(model_info), per_sample_tolerance, relative_per_sample_tolerance); +#ifndef USE_DNNL + auto tp = TestEnv::CreateThreadPool(Env::Default()); +#endif + for (bool is_single_thread : use_single_thread) { for (ExecutionMode execution_mode : execution_modes) { OrtSessionOptions* ortso; @@ -742,7 +746,7 @@ TEST_P(ModelTest, Run) { if (data_count > 1 && tests_run_parallel.find(l->GetTestCaseName()) != tests_run_parallel.end()) { LOGS_DEFAULT(ERROR) << "Parallel test for " << l->GetTestCaseName(); // TODO(leca): change level to INFO or even delete the log once verified parallel test working Ort::SessionOptions ort_session_options(ortso); - std::shared_ptr results = TestCaseRequestContext::Run(TestEnv::GetDefaultThreadPool(Env::Default()), *l, *ort_env, ort_session_options, data_count, 1 /*repeat_count*/); + std::shared_ptr results = TestCaseRequestContext::Run(tp.get(), *l, *ort_env, ort_session_options, data_count, 1 /*repeat_count*/); for (EXECUTE_RESULT res : results->GetExcutionResult()) { EXPECT_EQ(res, EXECUTE_RESULT::SUCCESS) << "is_single_thread:" << is_single_thread << ", execution_mode:" << execution_mode << ", provider_name:" << provider_name << ", test name:" << results->GetName() << ", result: " << res;