Stop creating static thread pool to fix random hang in onnx_test_runner (#14023)

This commit is contained in:
RandySheriffH 2022-12-19 19:48:14 -08:00 committed by GitHub
parent 533fe37cbd
commit cd305a90d6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 12 additions and 15 deletions

View file

@ -591,7 +591,8 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
owned_tests.push_back(std::move(l));
});
TestEnv test_env(env, sf, TestEnv::GetDefaultThreadPool(Env::Default()), std::move(tests), stat);
auto tp = TestEnv::CreateThreadPool(Env::Default());
TestEnv test_env(env, sf, tp.get(), std::move(tests), stat);
Status st = test_env.Run(p_models, concurrent_session_runs, repeat_count);
if (!st.IsOK()) {
fprintf(stderr, "%s\n", st.ErrorMessage().c_str());

View file

@ -12,18 +12,9 @@
using onnxruntime::Status;
static std::unique_ptr<onnxruntime::concurrency::ThreadPool> default_pool;
static std::once_flag default_pool_init;
PThreadPool TestEnv::GetDefaultThreadPool(onnxruntime::Env& env) {
std::call_once(default_pool_init, [&env] {
using namespace onnxruntime::concurrency;
std::unique_ptr<OrtThreadPool> TestEnv::CreateThreadPool(onnxruntime::Env& env) {
int core_num = env.GetNumPhysicalCpuCores();
onnxruntime::ThreadOptions t_opts;
default_pool = std::make_unique<ThreadPool>(&env, t_opts, ORT_TSTR("onnx_runner_tp"), core_num, false);
});
return default_pool.get();
return std::make_unique<OrtThreadPool>(&env, onnxruntime::ThreadOptions{}, ORT_TSTR("onnx_runner_tp"), core_num, false);
}
TestEnv::TestEnv(Ort::Env& env, Ort::SessionOptions& so, PThreadPool tp,

View file

@ -25,7 +25,8 @@ class ThreadPool;
}
} // namespace onnxruntime
using PThreadPool = onnxruntime::concurrency::ThreadPool*;
using OrtThreadPool = onnxruntime::concurrency::ThreadPool;
using PThreadPool = OrtThreadPool*;
/// <summary>
/// Facilitates running tests
@ -37,7 +38,7 @@ class TestEnv {
~TestEnv();
static PThreadPool GetDefaultThreadPool(onnxruntime::Env& env);
static std::unique_ptr<OrtThreadPool> CreateThreadPool(onnxruntime::Env& env);
/// <summary>
/// Runs all tests cases either concurrently or sequentially

View file

@ -652,6 +652,10 @@ TEST_P(ModelTest, Run) {
std::unique_ptr<ITestCase> l = CreateOnnxTestCase(ToUTF8String(test_case_name), std::move(model_info),
per_sample_tolerance, relative_per_sample_tolerance);
#ifndef USE_DNNL
auto tp = TestEnv::CreateThreadPool(Env::Default());
#endif
for (bool is_single_thread : use_single_thread) {
for (ExecutionMode execution_mode : execution_modes) {
OrtSessionOptions* ortso;
@ -742,7 +746,7 @@ TEST_P(ModelTest, Run) {
if (data_count > 1 && tests_run_parallel.find(l->GetTestCaseName()) != tests_run_parallel.end()) {
LOGS_DEFAULT(ERROR) << "Parallel test for " << l->GetTestCaseName(); // TODO(leca): change level to INFO or even delete the log once verified parallel test working
Ort::SessionOptions ort_session_options(ortso);
std::shared_ptr<TestCaseResult> results = TestCaseRequestContext::Run(TestEnv::GetDefaultThreadPool(Env::Default()), *l, *ort_env, ort_session_options, data_count, 1 /*repeat_count*/);
std::shared_ptr<TestCaseResult> results = TestCaseRequestContext::Run(tp.get(), *l, *ort_env, ort_session_options, data_count, 1 /*repeat_count*/);
for (EXECUTE_RESULT res : results->GetExcutionResult()) {
EXPECT_EQ(res, EXECUTE_RESULT::SUCCESS) << "is_single_thread:" << is_single_thread << ", execution_mode:" << execution_mode << ", provider_name:"
<< provider_name << ", test name:" << results->GetName() << ", result: " << res;