Allow RunAsync with global TP (#17157)

Allow RunAsync called with a global thread pool.

---------

Co-authored-by: Randy Shuai <rashuai@microsoft.com>
This commit is contained in:
RandySheriffH 2023-08-15 14:29:10 -07:00 committed by GitHub
parent c647e3e8ab
commit 39dfcd5d84
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 71 additions and 11 deletions

View file

@ -2386,34 +2386,31 @@ common::Status InferenceSession::RunAsync(const RunOptions* run_options,
RunAsyncCallbackFn callback,
void* user_data) {
size_t num_fetches = fetch_names.size();
if (!thread_pool_.get() || concurrency::ThreadPool::DegreeOfParallelism(thread_pool_.get()) < 2) {
auto* tp = GetIntraOpThreadPoolToUse();
if (!tp || concurrency::ThreadPool::DegreeOfParallelism(tp) < 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "intra op thread pool must have at least one thread for RunAsync");
}
std::function<void()> run_fn = [=]() {
Status status = Status::OK();
ORT_TRY {
Status status;
if (run_options) {
status = Run(*run_options, feed_names, feeds, fetch_names, fetches);
} else {
RunOptions default_run_options;
status = Run(default_run_options, feed_names, feeds, fetch_names, fetches);
}
if (status.IsOK()) {
callback(user_data, fetches.data(), num_fetches, ToOrtStatus(status));
} else {
callback(user_data, {}, 0, ToOrtStatus(status));
}
}
ORT_CATCH(const std::exception& ex) {
ORT_HANDLE_EXCEPTION([=]() {
callback(user_data, {}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, ex.what())));
ORT_HANDLE_EXCEPTION([&]() {
status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, ex.what());
});
}
ORT_CATCH(...) {
callback(user_data, {}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "unknown exception")));
status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "unknown exception");
}
callback(user_data, fetches.data(), status.IsOK() ? num_fetches : 0, ToOrtStatus(status));
}; // run_fn
concurrency::ThreadPool::Schedule(thread_pool_.get(), run_fn);
concurrency::ThreadPool::Schedule(tp, run_fn);
return Status::OK();
}

View file

@ -10,6 +10,7 @@
#include <fstream>
#include <sstream>
#include <atomic>
#include <thread>
#include <gtest/gtest.h>
#include "test_allocator.h"
#include "../shared_lib/test_fixture.h"
@ -153,6 +154,68 @@ TEST_P(CApiTestGlobalThreadPoolsWithProvider, simple) {
}
}
static std::thread::id caller_tid = std::this_thread::get_id();
static std::atomic_bool atomic_wait{false};
void AsyncCallback(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status_ptr) {
const float* expected_result = reinterpret_cast<const float*>(user_data);
auto callee_tid = std::this_thread::get_id();
EXPECT_NE(caller_tid, callee_tid);
Ort::Status status(status_ptr);
EXPECT_TRUE(status.IsOK());
EXPECT_EQ(num_outputs, 1UL);
Ort::Value output_value(outputs[0]);
EXPECT_NEAR(output_value.GetTensorData<float>()[1], expected_result[1], 0.001);
output_value.release();
atomic_wait.store(true);
}
TEST_P(CApiTestGlobalThreadPoolsWithProvider, simpleAsync) {
Ort::Session session = GetSessionObj<PATH_TYPE, float>(*ort_env, MODEL_URI, GetParam());
if (!session) {
return;
}
std::vector<Input> inputs;
std::vector<int64_t> expected_dims_y;
std::vector<float> expected_values_y;
std::string output_name;
GetInputsAndExpectedOutputs(inputs, expected_dims_y, expected_values_y, output_name);
auto allocator = std::make_unique<MockedOrtAllocator>();
std::vector<Ort::Value> ort_inputs;
std::vector<const char*> input_names;
for (size_t i = 0; i < inputs.size(); i++) {
input_names.emplace_back(inputs[i].name);
ort_inputs.emplace_back(Ort::Value::CreateTensor<float>(allocator->Info(),
inputs[i].values.data(),
inputs[i].values.size(),
inputs[i].dims.data(),
inputs[i].dims.size()));
}
std::vector<const char*> output_names = {output_name.c_str()};
std::vector<Ort::Value> ort_outputs;
ort_outputs.emplace_back(Ort::Value{nullptr});
atomic_wait.store(false);
session.RunAsync(Ort::RunOptions{nullptr},
input_names.data(),
ort_inputs.data(),
ort_inputs.size(),
output_names.data(),
ort_outputs.data(),
1,
AsyncCallback,
expected_values_y.data());
std::chrono::duration<double, std::milli> dur{100};
// timeout in about 10 secs
for (int i = 0; i < 100 && !atomic_wait.load(); ++i) {
std::this_thread::sleep_for(dur);
}
EXPECT_EQ(atomic_wait.load(), true);
}
// Test 2
// run inference on the same model using 2 sessions
// destruct the 2 sessions only at the end