mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Allow RunAsync with global TP (#17157)
Allow RunAsync called with a global thread pool. --------- Co-authored-by: Randy Shuai <rashuai@microsoft.com>
This commit is contained in:
parent
c647e3e8ab
commit
39dfcd5d84
2 changed files with 71 additions and 11 deletions
|
|
@ -2386,34 +2386,31 @@ common::Status InferenceSession::RunAsync(const RunOptions* run_options,
|
|||
RunAsyncCallbackFn callback,
|
||||
void* user_data) {
|
||||
size_t num_fetches = fetch_names.size();
|
||||
if (!thread_pool_.get() || concurrency::ThreadPool::DegreeOfParallelism(thread_pool_.get()) < 2) {
|
||||
auto* tp = GetIntraOpThreadPoolToUse();
|
||||
if (!tp || concurrency::ThreadPool::DegreeOfParallelism(tp) < 2) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "intra op thread pool must have at least one thread for RunAsync");
|
||||
}
|
||||
std::function<void()> run_fn = [=]() {
|
||||
Status status = Status::OK();
|
||||
ORT_TRY {
|
||||
Status status;
|
||||
if (run_options) {
|
||||
status = Run(*run_options, feed_names, feeds, fetch_names, fetches);
|
||||
} else {
|
||||
RunOptions default_run_options;
|
||||
status = Run(default_run_options, feed_names, feeds, fetch_names, fetches);
|
||||
}
|
||||
if (status.IsOK()) {
|
||||
callback(user_data, fetches.data(), num_fetches, ToOrtStatus(status));
|
||||
} else {
|
||||
callback(user_data, {}, 0, ToOrtStatus(status));
|
||||
}
|
||||
}
|
||||
ORT_CATCH(const std::exception& ex) {
|
||||
ORT_HANDLE_EXCEPTION([=]() {
|
||||
callback(user_data, {}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, ex.what())));
|
||||
ORT_HANDLE_EXCEPTION([&]() {
|
||||
status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, ex.what());
|
||||
});
|
||||
}
|
||||
ORT_CATCH(...) {
|
||||
callback(user_data, {}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "unknown exception")));
|
||||
status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "unknown exception");
|
||||
}
|
||||
callback(user_data, fetches.data(), status.IsOK() ? num_fetches : 0, ToOrtStatus(status));
|
||||
}; // run_fn
|
||||
concurrency::ThreadPool::Schedule(thread_pool_.get(), run_fn);
|
||||
concurrency::ThreadPool::Schedule(tp, run_fn);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <atomic>
|
||||
#include <thread>
|
||||
#include <gtest/gtest.h>
|
||||
#include "test_allocator.h"
|
||||
#include "../shared_lib/test_fixture.h"
|
||||
|
|
@ -153,6 +154,68 @@ TEST_P(CApiTestGlobalThreadPoolsWithProvider, simple) {
|
|||
}
|
||||
}
|
||||
|
||||
static std::thread::id caller_tid = std::this_thread::get_id();
|
||||
static std::atomic_bool atomic_wait{false};
|
||||
|
||||
void AsyncCallback(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status_ptr) {
|
||||
const float* expected_result = reinterpret_cast<const float*>(user_data);
|
||||
auto callee_tid = std::this_thread::get_id();
|
||||
EXPECT_NE(caller_tid, callee_tid);
|
||||
Ort::Status status(status_ptr);
|
||||
EXPECT_TRUE(status.IsOK());
|
||||
EXPECT_EQ(num_outputs, 1UL);
|
||||
Ort::Value output_value(outputs[0]);
|
||||
EXPECT_NEAR(output_value.GetTensorData<float>()[1], expected_result[1], 0.001);
|
||||
output_value.release();
|
||||
atomic_wait.store(true);
|
||||
}
|
||||
|
||||
TEST_P(CApiTestGlobalThreadPoolsWithProvider, simpleAsync) {
|
||||
Ort::Session session = GetSessionObj<PATH_TYPE, float>(*ort_env, MODEL_URI, GetParam());
|
||||
if (!session) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<Input> inputs;
|
||||
std::vector<int64_t> expected_dims_y;
|
||||
std::vector<float> expected_values_y;
|
||||
std::string output_name;
|
||||
GetInputsAndExpectedOutputs(inputs, expected_dims_y, expected_values_y, output_name);
|
||||
|
||||
auto allocator = std::make_unique<MockedOrtAllocator>();
|
||||
std::vector<Ort::Value> ort_inputs;
|
||||
std::vector<const char*> input_names;
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
input_names.emplace_back(inputs[i].name);
|
||||
ort_inputs.emplace_back(Ort::Value::CreateTensor<float>(allocator->Info(),
|
||||
inputs[i].values.data(),
|
||||
inputs[i].values.size(),
|
||||
inputs[i].dims.data(),
|
||||
inputs[i].dims.size()));
|
||||
}
|
||||
std::vector<const char*> output_names = {output_name.c_str()};
|
||||
std::vector<Ort::Value> ort_outputs;
|
||||
ort_outputs.emplace_back(Ort::Value{nullptr});
|
||||
|
||||
atomic_wait.store(false);
|
||||
session.RunAsync(Ort::RunOptions{nullptr},
|
||||
input_names.data(),
|
||||
ort_inputs.data(),
|
||||
ort_inputs.size(),
|
||||
output_names.data(),
|
||||
ort_outputs.data(),
|
||||
1,
|
||||
AsyncCallback,
|
||||
expected_values_y.data());
|
||||
|
||||
std::chrono::duration<double, std::milli> dur{100};
|
||||
// timeout in about 10 secs
|
||||
for (int i = 0; i < 100 && !atomic_wait.load(); ++i) {
|
||||
std::this_thread::sleep_for(dur);
|
||||
}
|
||||
EXPECT_EQ(atomic_wait.load(), true);
|
||||
}
|
||||
|
||||
// Test 2
|
||||
// run inference on the same model using 2 sessions
|
||||
// destruct the 2 sessions only at the end
|
||||
|
|
|
|||
Loading…
Reference in a new issue