RunAsync C/CXX API (#16613)

Implement RunAsync API - the session will run in a thread of intra-op
thread pool.

---------

Co-authored-by: Randy Shuai <rashuai@microsoft.com>
This commit is contained in:
RandySheriffH 2023-07-16 16:51:40 -07:00 committed by GitHub
parent 2cf31a20cf
commit e1ca8ee6d4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 308 additions and 66 deletions

View file

@ -696,6 +696,15 @@ typedef void (*OrtCustomJoinThreadFn)(OrtCustomThreadHandle ort_custom_thread_ha
typedef OrtStatus*(ORT_API_CALL* RegisterCustomOpsFn)(OrtSessionOptions* options, const OrtApiBase* api);
/** \brief Callback function for RunAsync
*
* \param[in] user_data User specific data that passed back to the callback
* \param[out] outputs On succeed, outputs host inference results, on error, the value will be nullptr
* \param[out] num_outputs Number of outputs, on error, the value will be zero
* \param[out] status On error, status will provide details
*/
typedef void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status);
/** \brief The C API
*
* All C API functions are defined inside this structure as pointers to functions.
@ -4316,6 +4325,27 @@ struct OrtApi {
*/
ORT_API2_STATUS(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg,
_In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys);
/** \brief Run the model asynchronously in a thread owned by intra op thread pool
*
* \param[in] session
* \param[in] run_options If nullptr, will use a default ::OrtRunOptions
* \param[in] input_names Array of null terminated UTF8 encoded strings of the input names
* \param[in] input Array of ::OrtValue%s of the input values
* \param[in] input_len Number of elements in the input_names and inputs arrays
* \param[in] output_names Array of null terminated UTF8 encoded strings of the output names
* \param[in] output_names_len Number of elements in the output_names and outputs array
* \param[out] output Array of OrtValue* owned by customers, size to output_names_len. It could simply be an array of nullptr
* The array will be passed back to run_async_callback
* \param[in] run_async_callback Callback function on model run completion
* \param[in] user_data User data that pass back to run_async_callback
*/
ORT_API2_STATUS(RunAsync, _Inout_ OrtSession* session, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_Inout_updates_all_(output_names_len) OrtValue** output,
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data);
};
/*

View file

@ -1067,6 +1067,24 @@ struct SessionImpl : ConstSessionImpl<T> {
void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding
/** \brief Run the model asynchronously in a thread owned by intra op thread pool
*
* Wraps OrtApi::RunAsync
*
* \param[in] run_options
* \param[in] input_names Array of null terminated UTF8 encoded strings of the input names
* \param[in] input_values Array of ::OrtValue%s of the input values
* \param[in] input_count Number of elements in the input_names and inputs arrays
* \param[in] output_names Array of null terminated UTF8 encoded strings of the output names
* \param[out] output_values Array of ::OrtValue%s owned by customers, size to output_count. It could simply be an array of nullptr
* The array will be passed back to the callback
* \param[in] output_count Number of elements in the output_names and outputs array
* \param[in] callback Callback function on model run completion
* \param[in] user_data User data that pass back to the callback
*/
void RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data);
/** \brief End profiling and return a copy of the profiling file name.
*
* \param allocator to allocate memory for the copy of the string returned

View file

@ -972,6 +972,16 @@ inline void SessionImpl<T>::Run(const RunOptions& run_options, const IoBinding&
ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding));
}
template <typename T>
inline void SessionImpl<T>::RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data) {
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
ThrowOnError(GetApi().RunAsync(this->p_, run_options, input_names,
ort_input_values, input_count, output_names, output_count,
ort_output_values, callback, user_data));
}
template <typename T>
inline AllocatedStringPtr SessionImpl<T>::EndProfilingAllocated(OrtAllocator* allocator) {
char* out = nullptr;

View file

@ -2300,6 +2300,116 @@ Status InferenceSession::Run(const RunOptions& run_options,
return retval;
}
Status InferenceSession::Run(const RunOptions& run_options,
gsl::span<const char* const> feed_names,
gsl::span<const OrtValue* const> feeds,
gsl::span<const char* const> fetch_names,
gsl::span<OrtValue*> fetches) {
size_t num_feeds = feed_names.size();
size_t num_fetches = fetch_names.size();
InlinedVector<std::string> feed_name_vec;
feed_name_vec.reserve(num_feeds);
InlinedVector<OrtValue> feed_vec;
feed_vec.reserve(num_feeds);
for (size_t i = 0; i != num_feeds; ++i) {
if (feed_names[i] == nullptr || feed_names[i][0] == '\0') {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input name cannot be empty");
}
if (!feeds[i]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, MakeString("NULL input supplied for input ", feed_names[i]).c_str());
}
feed_name_vec.emplace_back(feed_names[i]);
feed_vec.emplace_back(*feeds[i]);
}
// Create output feed
InlinedVector<std::string> fetch_name_vec;
fetch_name_vec.reserve(num_fetches);
for (size_t i = 0; i != num_fetches; ++i) {
if (fetch_names[i] == nullptr || fetch_names[i][0] == '\0') {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output name cannot be empty");
}
fetch_name_vec.emplace_back(fetch_names[i]);
}
std::vector<OrtValue> fetch_vec;
fetch_vec.reserve(num_fetches);
for (size_t i = 0; i != num_fetches; ++i) {
if (fetches[i] != nullptr) {
fetch_vec.emplace_back(*fetches[i]);
} else {
fetch_vec.emplace_back();
}
}
Status status;
status = Run(run_options, feed_name_vec, feed_vec, fetch_name_vec, &fetch_vec, nullptr);
if (!status.IsOK())
return status;
// We do it in two loops to make sure copy __ctors does not throw
InlinedVector<std::unique_ptr<OrtValue>> fetch_unique_ptrs;
fetch_unique_ptrs.reserve(num_fetches);
for (size_t i = 0; i != num_fetches; ++i) {
if (fetches[i] == nullptr) {
fetch_unique_ptrs.emplace_back(std::make_unique<OrtValue>(fetch_vec[i]));
} else {
fetch_unique_ptrs.emplace_back();
}
}
for (size_t i = 0; i != num_fetches; ++i) {
if (fetches[i] == nullptr) {
ORT_ENFORCE(fetch_unique_ptrs[i] != nullptr);
fetches[i] = fetch_unique_ptrs[i].release();
}
}
return Status::OK();
}
common::Status InferenceSession::RunAsync(const RunOptions* run_options,
gsl::span<const char* const> feed_names,
gsl::span<const OrtValue* const> feeds,
gsl::span<const char* const> fetch_names,
gsl::span<OrtValue*> fetches,
RunAsyncCallbackFn callback,
void* user_data) {
size_t num_fetches = fetch_names.size();
if (!thread_pool_.get() || concurrency::ThreadPool::DegreeOfParallelism(thread_pool_.get()) < 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "intra op thread pool must have at least one thread for RunAsync");
}
std::function<void()> run_fn = [=]() {
ORT_TRY {
Status status;
if (run_options) {
status = Run(*run_options, feed_names, feeds, fetch_names, fetches);
} else {
RunOptions default_run_options;
status = Run(default_run_options, feed_names, feeds, fetch_names, fetches);
}
if (status.IsOK()) {
callback(user_data, fetches.data(), num_fetches, ToOrtStatus(status));
} else {
callback(user_data, {}, 0, ToOrtStatus(status));
}
}
ORT_CATCH(const std::exception& ex) {
ORT_HANDLE_EXCEPTION([=]() {
callback(user_data, {}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, ex.what())));
});
}
ORT_CATCH(...) {
callback(user_data, {}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "unknown exception")));
}
}; // run_fn
concurrency::ThreadPool::Schedule(thread_pool_.get(), run_fn);
return Status::OK();
}
common::Status InferenceSession::Run(const NameMLValMap& feeds, gsl::span<const std::string> output_names,
std::vector<OrtValue>* p_fetches) {
return Run(RunOptions(), feeds, output_names, p_fetches);

View file

@ -305,6 +305,20 @@ class InferenceSession {
std::vector<OrtValue>* p_fetches,
const std::vector<OrtDevice>* p_fetches_device_info = nullptr);
[[nodiscard]] common::Status Run(const RunOptions& run_options,
gsl::span<const char* const> feed_names,
gsl::span<const OrtValue* const> feeds,
gsl::span<const char* const> fetch_names,
gsl::span<OrtValue*> fetches);
[[nodiscard]] common::Status RunAsync(const RunOptions* run_options,
gsl::span<const char* const> feed_names,
gsl::span<const OrtValue* const> feeds,
gsl::span<const char* const> fetch_names,
gsl::span<OrtValue*> fetches,
RunAsyncCallbackFn callback,
void* user_data = nullptr);
/**
* Run a pre-loaded and pre-intialized model.
* Multiple threads are allowed to run this function; hence its thread-safe.

View file

@ -817,81 +817,56 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In
ORT_API_STATUS_IMPL(OrtApis::Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names1, size_t output_names_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_Inout_updates_all_(output_names_len) OrtValue** output) {
API_IMPL_BEGIN
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);
InlinedVector<std::string> feed_names;
feed_names.reserve(input_len);
InlinedVector<OrtValue> feeds;
feeds.reserve(input_len);
for (size_t i = 0; i != input_len; ++i) {
if (input_names[i] == nullptr || input_names[i][0] == '\0') {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "input name cannot be empty");
}
if (!input[i]) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
MakeString("NULL input supplied for input ", input_names[i]).c_str());
}
feed_names.emplace_back(input_names[i]);
feeds.emplace_back(*input[i]);
}
// Create output feed
InlinedVector<std::string> output_names;
output_names.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output_names1[i] == nullptr || output_names1[i][0] == '\0') {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "output name cannot be empty");
}
output_names.emplace_back(output_names1[i]);
}
std::vector<OrtValue> fetches;
fetches.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] != nullptr) {
fetches.emplace_back(*output[i]);
} else {
fetches.emplace_back();
}
}
gsl::span<const char* const> input_names_span(input_names, input_len);
gsl::span<const OrtValue* const> input_span(input, input_len);
gsl::span<const char* const> output_name_span(output_names, output_names_len);
gsl::span<OrtValue*> output_span(output, output_names_len);
Status status;
if (run_options == nullptr) {
OrtRunOptions op;
status = session->Run(op, feed_names, feeds, output_names, &fetches, nullptr);
if (run_options) {
status = session->Run(*run_options,
input_names_span,
input_span,
output_name_span,
output_span);
} else {
status = session->Run(*run_options, feed_names, feeds, output_names, &fetches, nullptr);
const RunOptions default_run_options;
status = session->Run(default_run_options,
input_names_span,
input_span,
output_name_span,
output_span);
}
return ToOrtStatus(status);
API_IMPL_END
}
if (!status.IsOK())
return ToOrtStatus(status);
ORT_API_STATUS_IMPL(OrtApis::RunAsync, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_Inout_updates_all_(output_names_len) OrtValue** output,
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data) {
API_IMPL_BEGIN
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);
// We do it in two loops to make sure copy __ctors does not throw
InlinedVector<std::unique_ptr<OrtValue>> output_unique_ptrs;
output_unique_ptrs.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] == nullptr) {
output_unique_ptrs.emplace_back(std::make_unique<OrtValue>(fetches[i]));
} else {
output_unique_ptrs.emplace_back();
}
}
gsl::span<const char* const> input_names_span(input_names, input_len);
gsl::span<const OrtValue* const> input_span(input, input_len);
gsl::span<const char* const> output_name_span(output_names, output_names_len);
gsl::span<OrtValue*> output_span(output, output_names_len);
assert(output_unique_ptrs.size() == output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] == nullptr) {
assert(output_unique_ptrs[i] != nullptr);
output[i] = output_unique_ptrs[i].release();
}
}
return nullptr;
return ToOrtStatus(session->RunAsync(run_options,
input_names_span,
input_span,
output_name_span,
output_span,
run_async_callback,
user_data));
API_IMPL_END
}
@ -2735,6 +2710,7 @@ static constexpr OrtApi ort_api_1_to_16 = {
&OrtApis::GetROCMProviderOptionsAsString,
&OrtApis::ReleaseROCMProviderOptions,
&OrtApis::CreateAndRegisterAllocatorV2,
&OrtApis::RunAsync,
};
// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.

View file

@ -478,4 +478,11 @@ ORT_API(void, ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProviderOptions
ORT_API_STATUS_IMPL(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg,
_In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys);
ORT_API_STATUS_IMPL(RunAsync, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_Inout_updates_all_(output_names_len) OrtValue** outputs,
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data);
} // namespace OrtApis

View file

@ -3267,8 +3267,8 @@ TEST(MultiKernelSingleSchemaTest, valid) {
Ort::Value::CreateTensor<float>(memory_info, x_value, 10, x_dim, 1),
};
Ort::RunOptions run_optoins;
auto output_tensors = session.Run(run_optoins, input_names, input_tensors, 1, output_names, 2);
Ort::RunOptions run_options;
auto output_tensors = session.Run(run_options, input_names, input_tensors, 1, output_names, 2);
ASSERT_TRUE(*output_tensors[1].GetTensorData<int32_t>() == 72);
}
@ -3346,3 +3346,80 @@ TEST(MultiKernelSingleSchemaTest, DuplicateKernel) {
}
#endif
static std::thread::id caller_tid = std::this_thread::get_id();
static std::atomic_bool atomic_wait{false};
void CallbackSucceed(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status_ptr) {
auto callee_tid = std::this_thread::get_id();
EXPECT_NE(*(reinterpret_cast<std::thread::id*>(user_data)), callee_tid);
Ort::Status status(status_ptr);
EXPECT_TRUE(status.IsOK());
EXPECT_EQ(num_outputs, 1UL);
Ort::Value output_value(outputs[0]);
EXPECT_EQ(output_value.At<float>({1, 0}), 9.f);
output_value.release();
atomic_wait.store(true);
}
TEST(CApiTest, RunAsync) {
Ort::SessionOptions session_options;
session_options.SetIntraOpNumThreads(2);
Ort::Session session(*ort_env, MODEL_URI, session_options);
const char* input_names[] = {"X"};
float x_value[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
int64_t x_dim[] = {3, 2};
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
Ort::Value input_tensors[1] = {
Ort::Value::CreateTensor<float>(memory_info, x_value, 6, x_dim, 2),
};
const char* output_names[] = {"Y"};
Ort::RunOptions run_options;
Ort::Value output_values[1] = {Ort::Value{nullptr}};
EXPECT_NO_THROW(session.RunAsync(run_options,
input_names,
input_tensors,
1,
output_names,
output_values,
1,
CallbackSucceed,
&caller_tid));
std::chrono::duration<double, std::milli> dur{100};
// timeout in about 10 secs
for (int i = 0; i < 100 && !atomic_wait.load(); ++i) {
std::this_thread::sleep_for(dur);
}
EXPECT_EQ(atomic_wait.load(), true);
}
void CallbackFail(void*, OrtValue**, size_t, OrtStatusPtr) {
EXPECT_TRUE(false); // the callback is not supposed to be invoked
}
TEST(CApiTest, RunAsyncFail) {
Ort::SessionOptions session_options;
session_options.SetIntraOpNumThreads(1); // This will cause RunAsync fail
Ort::Session session(*ort_env, MODEL_URI, session_options);
const char* input_names[] = {"X"};
float x_value[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
int64_t x_dim[] = {3, 2};
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
Ort::Value input_tensors[1] = {
Ort::Value::CreateTensor<float>(memory_info, x_value, 6, x_dim, 2),
};
Ort::Value output_values[1] = {Ort::Value{nullptr}};
const char* output_names[] = {"Y"};
Ort::RunOptions run_options;
EXPECT_THROW(session.RunAsync(run_options, input_names, input_tensors, 1, output_names, output_values, 1, CallbackFail, nullptr), std::exception);
}