mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
RunAsync C/CXX API (#16613)
Implement RunAsync API - the session will run in a thread of intra-op thread pool. --------- Co-authored-by: Randy Shuai <rashuai@microsoft.com>
This commit is contained in:
parent
2cf31a20cf
commit
e1ca8ee6d4
8 changed files with 308 additions and 66 deletions
|
|
@ -696,6 +696,15 @@ typedef void (*OrtCustomJoinThreadFn)(OrtCustomThreadHandle ort_custom_thread_ha
|
|||
|
||||
typedef OrtStatus*(ORT_API_CALL* RegisterCustomOpsFn)(OrtSessionOptions* options, const OrtApiBase* api);
|
||||
|
||||
/** \brief Callback function for RunAsync
|
||||
*
|
||||
* \param[in] user_data User specific data that passed back to the callback
|
||||
* \param[out] outputs On succeed, outputs host inference results, on error, the value will be nullptr
|
||||
* \param[out] num_outputs Number of outputs, on error, the value will be zero
|
||||
* \param[out] status On error, status will provide details
|
||||
*/
|
||||
typedef void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status);
|
||||
|
||||
/** \brief The C API
|
||||
*
|
||||
* All C API functions are defined inside this structure as pointers to functions.
|
||||
|
|
@ -4316,6 +4325,27 @@ struct OrtApi {
|
|||
*/
|
||||
ORT_API2_STATUS(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg,
|
||||
_In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys);
|
||||
|
||||
/** \brief Run the model asynchronously in a thread owned by intra op thread pool
|
||||
*
|
||||
* \param[in] session
|
||||
* \param[in] run_options If nullptr, will use a default ::OrtRunOptions
|
||||
* \param[in] input_names Array of null terminated UTF8 encoded strings of the input names
|
||||
* \param[in] input Array of ::OrtValue%s of the input values
|
||||
* \param[in] input_len Number of elements in the input_names and inputs arrays
|
||||
* \param[in] output_names Array of null terminated UTF8 encoded strings of the output names
|
||||
* \param[in] output_names_len Number of elements in the output_names and outputs array
|
||||
* \param[out] output Array of OrtValue* owned by customers, size to output_names_len. It could simply be an array of nullptr
|
||||
* The array will be passed back to run_async_callback
|
||||
* \param[in] run_async_callback Callback function on model run completion
|
||||
* \param[in] user_data User data that pass back to run_async_callback
|
||||
*/
|
||||
ORT_API2_STATUS(RunAsync, _Inout_ OrtSession* session, _In_opt_ const OrtRunOptions* run_options,
|
||||
_In_reads_(input_len) const char* const* input_names,
|
||||
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
|
||||
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
|
||||
_Inout_updates_all_(output_names_len) OrtValue** output,
|
||||
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data);
|
||||
};
|
||||
|
||||
/*
|
||||
|
|
|
|||
|
|
@ -1067,6 +1067,24 @@ struct SessionImpl : ConstSessionImpl<T> {
|
|||
|
||||
void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding
|
||||
|
||||
/** \brief Run the model asynchronously in a thread owned by intra op thread pool
|
||||
*
|
||||
* Wraps OrtApi::RunAsync
|
||||
*
|
||||
* \param[in] run_options
|
||||
* \param[in] input_names Array of null terminated UTF8 encoded strings of the input names
|
||||
* \param[in] input_values Array of ::OrtValue%s of the input values
|
||||
* \param[in] input_count Number of elements in the input_names and inputs arrays
|
||||
* \param[in] output_names Array of null terminated UTF8 encoded strings of the output names
|
||||
* \param[out] output_values Array of ::OrtValue%s owned by customers, size to output_count. It could simply be an array of nullptr
|
||||
* The array will be passed back to the callback
|
||||
* \param[in] output_count Number of elements in the output_names and outputs array
|
||||
* \param[in] callback Callback function on model run completion
|
||||
* \param[in] user_data User data that pass back to the callback
|
||||
*/
|
||||
void RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
|
||||
const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data);
|
||||
|
||||
/** \brief End profiling and return a copy of the profiling file name.
|
||||
*
|
||||
* \param allocator to allocate memory for the copy of the string returned
|
||||
|
|
|
|||
|
|
@ -972,6 +972,16 @@ inline void SessionImpl<T>::Run(const RunOptions& run_options, const IoBinding&
|
|||
ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void SessionImpl<T>::RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
|
||||
const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data) {
|
||||
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
|
||||
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
|
||||
ThrowOnError(GetApi().RunAsync(this->p_, run_options, input_names,
|
||||
ort_input_values, input_count, output_names, output_count,
|
||||
ort_output_values, callback, user_data));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline AllocatedStringPtr SessionImpl<T>::EndProfilingAllocated(OrtAllocator* allocator) {
|
||||
char* out = nullptr;
|
||||
|
|
|
|||
|
|
@ -2300,6 +2300,116 @@ Status InferenceSession::Run(const RunOptions& run_options,
|
|||
return retval;
|
||||
}
|
||||
|
||||
Status InferenceSession::Run(const RunOptions& run_options,
|
||||
gsl::span<const char* const> feed_names,
|
||||
gsl::span<const OrtValue* const> feeds,
|
||||
gsl::span<const char* const> fetch_names,
|
||||
gsl::span<OrtValue*> fetches) {
|
||||
size_t num_feeds = feed_names.size();
|
||||
size_t num_fetches = fetch_names.size();
|
||||
InlinedVector<std::string> feed_name_vec;
|
||||
feed_name_vec.reserve(num_feeds);
|
||||
InlinedVector<OrtValue> feed_vec;
|
||||
feed_vec.reserve(num_feeds);
|
||||
|
||||
for (size_t i = 0; i != num_feeds; ++i) {
|
||||
if (feed_names[i] == nullptr || feed_names[i][0] == '\0') {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input name cannot be empty");
|
||||
}
|
||||
|
||||
if (!feeds[i]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, MakeString("NULL input supplied for input ", feed_names[i]).c_str());
|
||||
}
|
||||
|
||||
feed_name_vec.emplace_back(feed_names[i]);
|
||||
feed_vec.emplace_back(*feeds[i]);
|
||||
}
|
||||
|
||||
// Create output feed
|
||||
InlinedVector<std::string> fetch_name_vec;
|
||||
fetch_name_vec.reserve(num_fetches);
|
||||
for (size_t i = 0; i != num_fetches; ++i) {
|
||||
if (fetch_names[i] == nullptr || fetch_names[i][0] == '\0') {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output name cannot be empty");
|
||||
}
|
||||
fetch_name_vec.emplace_back(fetch_names[i]);
|
||||
}
|
||||
|
||||
std::vector<OrtValue> fetch_vec;
|
||||
fetch_vec.reserve(num_fetches);
|
||||
for (size_t i = 0; i != num_fetches; ++i) {
|
||||
if (fetches[i] != nullptr) {
|
||||
fetch_vec.emplace_back(*fetches[i]);
|
||||
} else {
|
||||
fetch_vec.emplace_back();
|
||||
}
|
||||
}
|
||||
|
||||
Status status;
|
||||
status = Run(run_options, feed_name_vec, feed_vec, fetch_name_vec, &fetch_vec, nullptr);
|
||||
|
||||
if (!status.IsOK())
|
||||
return status;
|
||||
|
||||
// We do it in two loops to make sure copy __ctors does not throw
|
||||
InlinedVector<std::unique_ptr<OrtValue>> fetch_unique_ptrs;
|
||||
fetch_unique_ptrs.reserve(num_fetches);
|
||||
for (size_t i = 0; i != num_fetches; ++i) {
|
||||
if (fetches[i] == nullptr) {
|
||||
fetch_unique_ptrs.emplace_back(std::make_unique<OrtValue>(fetch_vec[i]));
|
||||
} else {
|
||||
fetch_unique_ptrs.emplace_back();
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i != num_fetches; ++i) {
|
||||
if (fetches[i] == nullptr) {
|
||||
ORT_ENFORCE(fetch_unique_ptrs[i] != nullptr);
|
||||
fetches[i] = fetch_unique_ptrs[i].release();
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
common::Status InferenceSession::RunAsync(const RunOptions* run_options,
|
||||
gsl::span<const char* const> feed_names,
|
||||
gsl::span<const OrtValue* const> feeds,
|
||||
gsl::span<const char* const> fetch_names,
|
||||
gsl::span<OrtValue*> fetches,
|
||||
RunAsyncCallbackFn callback,
|
||||
void* user_data) {
|
||||
size_t num_fetches = fetch_names.size();
|
||||
if (!thread_pool_.get() || concurrency::ThreadPool::DegreeOfParallelism(thread_pool_.get()) < 2) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "intra op thread pool must have at least one thread for RunAsync");
|
||||
}
|
||||
std::function<void()> run_fn = [=]() {
|
||||
ORT_TRY {
|
||||
Status status;
|
||||
if (run_options) {
|
||||
status = Run(*run_options, feed_names, feeds, fetch_names, fetches);
|
||||
} else {
|
||||
RunOptions default_run_options;
|
||||
status = Run(default_run_options, feed_names, feeds, fetch_names, fetches);
|
||||
}
|
||||
if (status.IsOK()) {
|
||||
callback(user_data, fetches.data(), num_fetches, ToOrtStatus(status));
|
||||
} else {
|
||||
callback(user_data, {}, 0, ToOrtStatus(status));
|
||||
}
|
||||
}
|
||||
ORT_CATCH(const std::exception& ex) {
|
||||
ORT_HANDLE_EXCEPTION([=]() {
|
||||
callback(user_data, {}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, ex.what())));
|
||||
});
|
||||
}
|
||||
ORT_CATCH(...) {
|
||||
callback(user_data, {}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "unknown exception")));
|
||||
}
|
||||
}; // run_fn
|
||||
concurrency::ThreadPool::Schedule(thread_pool_.get(), run_fn);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
common::Status InferenceSession::Run(const NameMLValMap& feeds, gsl::span<const std::string> output_names,
|
||||
std::vector<OrtValue>* p_fetches) {
|
||||
return Run(RunOptions(), feeds, output_names, p_fetches);
|
||||
|
|
|
|||
|
|
@ -305,6 +305,20 @@ class InferenceSession {
|
|||
std::vector<OrtValue>* p_fetches,
|
||||
const std::vector<OrtDevice>* p_fetches_device_info = nullptr);
|
||||
|
||||
[[nodiscard]] common::Status Run(const RunOptions& run_options,
|
||||
gsl::span<const char* const> feed_names,
|
||||
gsl::span<const OrtValue* const> feeds,
|
||||
gsl::span<const char* const> fetch_names,
|
||||
gsl::span<OrtValue*> fetches);
|
||||
|
||||
[[nodiscard]] common::Status RunAsync(const RunOptions* run_options,
|
||||
gsl::span<const char* const> feed_names,
|
||||
gsl::span<const OrtValue* const> feeds,
|
||||
gsl::span<const char* const> fetch_names,
|
||||
gsl::span<OrtValue*> fetches,
|
||||
RunAsyncCallbackFn callback,
|
||||
void* user_data = nullptr);
|
||||
|
||||
/**
|
||||
* Run a pre-loaded and pre-intialized model.
|
||||
* Multiple threads are allowed to run this function; hence its thread-safe.
|
||||
|
|
|
|||
|
|
@ -817,81 +817,56 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In
|
|||
ORT_API_STATUS_IMPL(OrtApis::Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
|
||||
_In_reads_(input_len) const char* const* input_names,
|
||||
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
|
||||
_In_reads_(output_names_len) const char* const* output_names1, size_t output_names_len,
|
||||
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
|
||||
_Inout_updates_all_(output_names_len) OrtValue** output) {
|
||||
API_IMPL_BEGIN
|
||||
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);
|
||||
|
||||
InlinedVector<std::string> feed_names;
|
||||
feed_names.reserve(input_len);
|
||||
InlinedVector<OrtValue> feeds;
|
||||
feeds.reserve(input_len);
|
||||
|
||||
for (size_t i = 0; i != input_len; ++i) {
|
||||
if (input_names[i] == nullptr || input_names[i][0] == '\0') {
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "input name cannot be empty");
|
||||
}
|
||||
|
||||
if (!input[i]) {
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
|
||||
MakeString("NULL input supplied for input ", input_names[i]).c_str());
|
||||
}
|
||||
|
||||
feed_names.emplace_back(input_names[i]);
|
||||
feeds.emplace_back(*input[i]);
|
||||
}
|
||||
|
||||
// Create output feed
|
||||
InlinedVector<std::string> output_names;
|
||||
output_names.reserve(output_names_len);
|
||||
for (size_t i = 0; i != output_names_len; ++i) {
|
||||
if (output_names1[i] == nullptr || output_names1[i][0] == '\0') {
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "output name cannot be empty");
|
||||
}
|
||||
output_names.emplace_back(output_names1[i]);
|
||||
}
|
||||
|
||||
std::vector<OrtValue> fetches;
|
||||
fetches.reserve(output_names_len);
|
||||
for (size_t i = 0; i != output_names_len; ++i) {
|
||||
if (output[i] != nullptr) {
|
||||
fetches.emplace_back(*output[i]);
|
||||
} else {
|
||||
fetches.emplace_back();
|
||||
}
|
||||
}
|
||||
gsl::span<const char* const> input_names_span(input_names, input_len);
|
||||
gsl::span<const OrtValue* const> input_span(input, input_len);
|
||||
gsl::span<const char* const> output_name_span(output_names, output_names_len);
|
||||
gsl::span<OrtValue*> output_span(output, output_names_len);
|
||||
|
||||
Status status;
|
||||
if (run_options == nullptr) {
|
||||
OrtRunOptions op;
|
||||
status = session->Run(op, feed_names, feeds, output_names, &fetches, nullptr);
|
||||
if (run_options) {
|
||||
status = session->Run(*run_options,
|
||||
input_names_span,
|
||||
input_span,
|
||||
output_name_span,
|
||||
output_span);
|
||||
} else {
|
||||
status = session->Run(*run_options, feed_names, feeds, output_names, &fetches, nullptr);
|
||||
const RunOptions default_run_options;
|
||||
status = session->Run(default_run_options,
|
||||
input_names_span,
|
||||
input_span,
|
||||
output_name_span,
|
||||
output_span);
|
||||
}
|
||||
return ToOrtStatus(status);
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
if (!status.IsOK())
|
||||
return ToOrtStatus(status);
|
||||
ORT_API_STATUS_IMPL(OrtApis::RunAsync, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
|
||||
_In_reads_(input_len) const char* const* input_names,
|
||||
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
|
||||
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
|
||||
_Inout_updates_all_(output_names_len) OrtValue** output,
|
||||
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data) {
|
||||
API_IMPL_BEGIN
|
||||
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);
|
||||
|
||||
// We do it in two loops to make sure copy __ctors does not throw
|
||||
InlinedVector<std::unique_ptr<OrtValue>> output_unique_ptrs;
|
||||
output_unique_ptrs.reserve(output_names_len);
|
||||
for (size_t i = 0; i != output_names_len; ++i) {
|
||||
if (output[i] == nullptr) {
|
||||
output_unique_ptrs.emplace_back(std::make_unique<OrtValue>(fetches[i]));
|
||||
} else {
|
||||
output_unique_ptrs.emplace_back();
|
||||
}
|
||||
}
|
||||
gsl::span<const char* const> input_names_span(input_names, input_len);
|
||||
gsl::span<const OrtValue* const> input_span(input, input_len);
|
||||
gsl::span<const char* const> output_name_span(output_names, output_names_len);
|
||||
gsl::span<OrtValue*> output_span(output, output_names_len);
|
||||
|
||||
assert(output_unique_ptrs.size() == output_names_len);
|
||||
|
||||
for (size_t i = 0; i != output_names_len; ++i) {
|
||||
if (output[i] == nullptr) {
|
||||
assert(output_unique_ptrs[i] != nullptr);
|
||||
output[i] = output_unique_ptrs[i].release();
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
return ToOrtStatus(session->RunAsync(run_options,
|
||||
input_names_span,
|
||||
input_span,
|
||||
output_name_span,
|
||||
output_span,
|
||||
run_async_callback,
|
||||
user_data));
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
|
|
@ -2735,6 +2710,7 @@ static constexpr OrtApi ort_api_1_to_16 = {
|
|||
&OrtApis::GetROCMProviderOptionsAsString,
|
||||
&OrtApis::ReleaseROCMProviderOptions,
|
||||
&OrtApis::CreateAndRegisterAllocatorV2,
|
||||
&OrtApis::RunAsync,
|
||||
};
|
||||
|
||||
// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
|
||||
|
|
|
|||
|
|
@ -478,4 +478,11 @@ ORT_API(void, ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProviderOptions
|
|||
|
||||
ORT_API_STATUS_IMPL(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg,
|
||||
_In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys);
|
||||
|
||||
ORT_API_STATUS_IMPL(RunAsync, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
|
||||
_In_reads_(input_len) const char* const* input_names,
|
||||
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
|
||||
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
|
||||
_Inout_updates_all_(output_names_len) OrtValue** outputs,
|
||||
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data);
|
||||
} // namespace OrtApis
|
||||
|
|
|
|||
|
|
@ -3267,8 +3267,8 @@ TEST(MultiKernelSingleSchemaTest, valid) {
|
|||
Ort::Value::CreateTensor<float>(memory_info, x_value, 10, x_dim, 1),
|
||||
};
|
||||
|
||||
Ort::RunOptions run_optoins;
|
||||
auto output_tensors = session.Run(run_optoins, input_names, input_tensors, 1, output_names, 2);
|
||||
Ort::RunOptions run_options;
|
||||
auto output_tensors = session.Run(run_options, input_names, input_tensors, 1, output_names, 2);
|
||||
ASSERT_TRUE(*output_tensors[1].GetTensorData<int32_t>() == 72);
|
||||
}
|
||||
|
||||
|
|
@ -3346,3 +3346,80 @@ TEST(MultiKernelSingleSchemaTest, DuplicateKernel) {
|
|||
}
|
||||
|
||||
#endif
|
||||
|
||||
static std::thread::id caller_tid = std::this_thread::get_id();
|
||||
static std::atomic_bool atomic_wait{false};
|
||||
|
||||
void CallbackSucceed(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status_ptr) {
|
||||
auto callee_tid = std::this_thread::get_id();
|
||||
EXPECT_NE(*(reinterpret_cast<std::thread::id*>(user_data)), callee_tid);
|
||||
Ort::Status status(status_ptr);
|
||||
EXPECT_TRUE(status.IsOK());
|
||||
EXPECT_EQ(num_outputs, 1UL);
|
||||
Ort::Value output_value(outputs[0]);
|
||||
EXPECT_EQ(output_value.At<float>({1, 0}), 9.f);
|
||||
output_value.release();
|
||||
atomic_wait.store(true);
|
||||
}
|
||||
|
||||
TEST(CApiTest, RunAsync) {
|
||||
Ort::SessionOptions session_options;
|
||||
session_options.SetIntraOpNumThreads(2);
|
||||
Ort::Session session(*ort_env, MODEL_URI, session_options);
|
||||
|
||||
const char* input_names[] = {"X"};
|
||||
float x_value[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
int64_t x_dim[] = {3, 2};
|
||||
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
||||
|
||||
Ort::Value input_tensors[1] = {
|
||||
Ort::Value::CreateTensor<float>(memory_info, x_value, 6, x_dim, 2),
|
||||
};
|
||||
|
||||
const char* output_names[] = {"Y"};
|
||||
Ort::RunOptions run_options;
|
||||
Ort::Value output_values[1] = {Ort::Value{nullptr}};
|
||||
|
||||
EXPECT_NO_THROW(session.RunAsync(run_options,
|
||||
input_names,
|
||||
input_tensors,
|
||||
1,
|
||||
output_names,
|
||||
output_values,
|
||||
1,
|
||||
CallbackSucceed,
|
||||
&caller_tid));
|
||||
|
||||
std::chrono::duration<double, std::milli> dur{100};
|
||||
// timeout in about 10 secs
|
||||
for (int i = 0; i < 100 && !atomic_wait.load(); ++i) {
|
||||
std::this_thread::sleep_for(dur);
|
||||
}
|
||||
|
||||
EXPECT_EQ(atomic_wait.load(), true);
|
||||
}
|
||||
|
||||
void CallbackFail(void*, OrtValue**, size_t, OrtStatusPtr) {
|
||||
EXPECT_TRUE(false); // the callback is not supposed to be invoked
|
||||
}
|
||||
|
||||
TEST(CApiTest, RunAsyncFail) {
|
||||
Ort::SessionOptions session_options;
|
||||
session_options.SetIntraOpNumThreads(1); // This will cause RunAsync fail
|
||||
Ort::Session session(*ort_env, MODEL_URI, session_options);
|
||||
|
||||
const char* input_names[] = {"X"};
|
||||
float x_value[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
int64_t x_dim[] = {3, 2};
|
||||
|
||||
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
||||
|
||||
Ort::Value input_tensors[1] = {
|
||||
Ort::Value::CreateTensor<float>(memory_info, x_value, 6, x_dim, 2),
|
||||
};
|
||||
Ort::Value output_values[1] = {Ort::Value{nullptr}};
|
||||
const char* output_names[] = {"Y"};
|
||||
|
||||
Ort::RunOptions run_options;
|
||||
EXPECT_THROW(session.RunAsync(run_options, input_names, input_tensors, 1, output_names, output_values, 1, CallbackFail, nullptr), std::exception);
|
||||
}
|
||||
Loading…
Reference in a new issue