onnxruntime/samples/c_cxx/imagenet/sync_api_posix.cc
Changming Sun c2aa2056b5
Sample for imagenet and batch prediction (#1372)
* Sample for imagenet and batch prediction
(Will add a readme later)
2019-07-16 14:23:45 -07:00

109 lines
4.1 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "sync_api.h"
#include <mutex>
#include <unsupported/Eigen/CXX11/ThreadPool>
#include <core/common/common.h>
#include <core/common/logging/logging.h>
#include "simple_thread_pool.h"
#include "onnxruntime_event.h"
using onnxruntime::common::Status;
// this can be passed to one of the following functions:
// OnnxRuntimeSetEventWhenCallbackReturns
class OnnxRuntimeCallbackInstance {
private:
std::vector<ONNXRUNTIME_EVENT> events_to_signal_;
public:
void AddEvent(ONNXRUNTIME_EVENT event);
onnxruntime::common::Status SignalAllEvents();
};
Status WaitAndCloseEvent(ONNXRUNTIME_EVENT finish_event) {
if (finish_event == nullptr)
return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::INVALID_ARGUMENT, "");
pthread_mutex_lock(&finish_event->finish_event_mutex);
while (!finish_event->finished) {
pthread_cond_wait(&finish_event->finish_event_data, &finish_event->finish_event_mutex);
}
pthread_mutex_unlock(&finish_event->finish_event_mutex);
delete finish_event;
return Status::OK();
}
Status CreateAndSubmitThreadpoolWork(ONNXRUNTIME_CALLBACK_FUNCTION callback, void* data, PThreadPool pool) {
if (callback == nullptr)
return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::INVALID_ARGUMENT, "callback cannot be NULL");
if (pool == nullptr)
return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::INVALID_ARGUMENT, "pool cannot be NULL");
pool->Schedule([=]() {
OnnxRuntimeCallbackInstance instance;
callback(&instance, data, nullptr);
Status st = instance.SignalAllEvents();
if (!st.IsOK()) {
LOGF_DEFAULT(ERROR, "SignalAllEvents failed:%s. aborting...\n", st.ErrorMessage().c_str());
abort();
}
});
return Status::OK();
}
using DefaultThreadPoolType = onnxruntime::SimpleThreadPoolTempl<onnxruntime::Env>;
static std::unique_ptr<DefaultThreadPoolType> default_pool;
static std::once_flag default_pool_init;
PThreadPool GetDefaultThreadPool(const onnxruntime::Env& env) {
std::call_once(default_pool_init, [&env] {
int core_num = env.GetNumCpuCores();
default_pool.reset(new DefaultThreadPoolType(core_num, env));
});
return default_pool.get();
}
Status OnnxRuntimeSetEventWhenCallbackReturns(ONNXRUNTIME_CALLBACK_INSTANCE pci, ONNXRUNTIME_EVENT finish_event) {
if (finish_event == nullptr)
return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::INVALID_ARGUMENT, "");
if (pci == nullptr) {
if (pthread_mutex_lock(&finish_event->finish_event_mutex)) {
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "lock failed");
}
finish_event->finished = true;
if (pthread_mutex_unlock(&finish_event->finish_event_mutex))
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "unlock failed");
if (!pthread_cond_broadcast(&finish_event->finish_event_data))
return Status::OK();
else
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "pthread_cond_broadcast failed");
} else {
pci->AddEvent(finish_event);
return Status::OK();
}
}
void OnnxRuntimeCallbackInstance::AddEvent(ONNXRUNTIME_EVENT event) { events_to_signal_.push_back(event); }
Status OnnxRuntimeCallbackInstance::SignalAllEvents() {
for (ONNXRUNTIME_EVENT finish_event : events_to_signal_) {
if (pthread_mutex_lock(&finish_event->finish_event_mutex)) {
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "lock failed");
}
finish_event->finished = true;
if (pthread_mutex_unlock(&finish_event->finish_event_mutex))
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "unlock failed");
if (pthread_cond_broadcast(&finish_event->finish_event_data))
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "pthread_cond_broadcast failed");
}
return Status::OK();
}
Status CreateOnnxRuntimeEvent(ONNXRUNTIME_EVENT* out) {
if (out == nullptr) return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::INVALID_ARGUMENT, "");
*out = new OnnxRuntimeEvent();
return Status::OK();
}
void ONNXRuntimeCloseEvent(ONNXRUNTIME_EVENT finish_event) { delete finish_event; }