From 39dfcd5d84d428f7e3c7fd99a9a74540642b1f5d Mon Sep 17 00:00:00 2001 From: RandySheriffH <48490400+RandySheriffH@users.noreply.github.com> Date: Tue, 15 Aug 2023 14:29:10 -0700 Subject: [PATCH] Allow RunAsync with global TP (#17157) Allow RunAsync called with a global thread pool. --------- Co-authored-by: Randy Shuai --- onnxruntime/core/session/inference_session.cc | 19 +++--- .../global_thread_pools/test_inference.cc | 63 +++++++++++++++++++ 2 files changed, 71 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 7ff0e51192..6a70176ebc 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2386,34 +2386,31 @@ common::Status InferenceSession::RunAsync(const RunOptions* run_options, RunAsyncCallbackFn callback, void* user_data) { size_t num_fetches = fetch_names.size(); - if (!thread_pool_.get() || concurrency::ThreadPool::DegreeOfParallelism(thread_pool_.get()) < 2) { + auto* tp = GetIntraOpThreadPoolToUse(); + if (!tp || concurrency::ThreadPool::DegreeOfParallelism(tp) < 2) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "intra op thread pool must have at least one thread for RunAsync"); } std::function run_fn = [=]() { + Status status = Status::OK(); ORT_TRY { - Status status; if (run_options) { status = Run(*run_options, feed_names, feeds, fetch_names, fetches); } else { RunOptions default_run_options; status = Run(default_run_options, feed_names, feeds, fetch_names, fetches); } - if (status.IsOK()) { - callback(user_data, fetches.data(), num_fetches, ToOrtStatus(status)); - } else { - callback(user_data, {}, 0, ToOrtStatus(status)); - } } ORT_CATCH(const std::exception& ex) { - ORT_HANDLE_EXCEPTION([=]() { - callback(user_data, {}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, ex.what()))); + ORT_HANDLE_EXCEPTION([&]() { + status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, ex.what()); }); } ORT_CATCH(...) { - callback(user_data, {}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "unknown exception"))); + status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "unknown exception"); } + callback(user_data, fetches.data(), status.IsOK() ? num_fetches : 0, ToOrtStatus(status)); }; // run_fn - concurrency::ThreadPool::Schedule(thread_pool_.get(), run_fn); + concurrency::ThreadPool::Schedule(tp, run_fn); return Status::OK(); } diff --git a/onnxruntime/test/global_thread_pools/test_inference.cc b/onnxruntime/test/global_thread_pools/test_inference.cc index 0e3a6aee79..4772e7de2b 100644 --- a/onnxruntime/test/global_thread_pools/test_inference.cc +++ b/onnxruntime/test/global_thread_pools/test_inference.cc @@ -10,6 +10,7 @@ #include #include #include +#include #include #include "test_allocator.h" #include "../shared_lib/test_fixture.h" @@ -153,6 +154,68 @@ TEST_P(CApiTestGlobalThreadPoolsWithProvider, simple) { } } +static std::thread::id caller_tid = std::this_thread::get_id(); +static std::atomic_bool atomic_wait{false}; + +void AsyncCallback(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status_ptr) { + const float* expected_result = reinterpret_cast(user_data); + auto callee_tid = std::this_thread::get_id(); + EXPECT_NE(caller_tid, callee_tid); + Ort::Status status(status_ptr); + EXPECT_TRUE(status.IsOK()); + EXPECT_EQ(num_outputs, 1UL); + Ort::Value output_value(outputs[0]); + EXPECT_NEAR(output_value.GetTensorData()[1], expected_result[1], 0.001); + output_value.release(); + atomic_wait.store(true); +} + +TEST_P(CApiTestGlobalThreadPoolsWithProvider, simpleAsync) { + Ort::Session session = GetSessionObj(*ort_env, MODEL_URI, GetParam()); + if (!session) { + return; + } + + std::vector inputs; + std::vector expected_dims_y; + std::vector expected_values_y; + std::string output_name; + GetInputsAndExpectedOutputs(inputs, expected_dims_y, expected_values_y, output_name); + + auto allocator = std::make_unique(); + std::vector ort_inputs; + std::vector input_names; + for (size_t i = 0; i < inputs.size(); i++) { + input_names.emplace_back(inputs[i].name); + ort_inputs.emplace_back(Ort::Value::CreateTensor(allocator->Info(), + inputs[i].values.data(), + inputs[i].values.size(), + inputs[i].dims.data(), + inputs[i].dims.size())); + } + std::vector output_names = {output_name.c_str()}; + std::vector ort_outputs; + ort_outputs.emplace_back(Ort::Value{nullptr}); + + atomic_wait.store(false); + session.RunAsync(Ort::RunOptions{nullptr}, + input_names.data(), + ort_inputs.data(), + ort_inputs.size(), + output_names.data(), + ort_outputs.data(), + 1, + AsyncCallback, + expected_values_y.data()); + + std::chrono::duration dur{100}; + // timeout in about 10 secs + for (int i = 0; i < 100 && !atomic_wait.load(); ++i) { + std::this_thread::sleep_for(dur); + } + EXPECT_EQ(atomic_wait.load(), true); +} + // Test 2 // run inference on the same model using 2 sessions // destruct the 2 sessions only at the end