onnxruntime/onnxruntime/core/session/onnxruntime_c_api.cc
2019-01-25 19:41:10 -08:00

638 lines
25 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/graph/onnx_protobuf.h" //TODO: remove this
#include "core/session/onnxruntime_c_api.h"
#include "core/session/allocator_impl.h"
#include "core/framework/error_code_helper.h"
#include "core/framework/execution_provider.h"
#include <cassert>
#include <cstring>
#include <sstream>
#include "core/common/logging/logging.h"
#include "core/common/logging/sinks/clog_sink.h"
#include "core/common/status.h"
#include "core/graph/graph.h"
#include "core/framework/allocator.h"
#include "core/framework/tensor.h"
#include "core/framework/ml_value.h"
#include "core/framework/environment.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/onnxruntime_typeinfo.h"
#include "core/session/inference_session.h"
#include "abi_session_options_impl.h"
using namespace onnxruntime::logging;
using onnxruntime::BFloat16;
using onnxruntime::DataTypeImpl;
using onnxruntime::Environment;
using onnxruntime::IAllocator;
using onnxruntime::InputDefList;
using onnxruntime::MLFloat16;
using onnxruntime::MLStatus;
using onnxruntime::MLValue;
using onnxruntime::OutputDefList;
using onnxruntime::Tensor;
using onnxruntime::ToOrtStatus;
using onnxruntime::common::Status;
#define ORT_API_RETURN_IF_ERROR(expr) \
do { \
auto _status = (expr); \
if (_status) return _status; \
} while (0)
struct OrtEnv {
public:
Environment* value;
LoggingManager* loggingManager;
OrtEnv(Environment* value1, LoggingManager* loggingManager1) : value(value1), loggingManager(loggingManager1) {
}
/**
* This function will call ::google::protobuf::ShutdownProtobufLibrary
*/
~OrtEnv() {
delete loggingManager;
delete value;
}
ORT_DISALLOW_COPY_AND_ASSIGNMENT(OrtEnv);
};
#define API_IMPL_BEGIN try {
#define API_IMPL_END \
} \
catch (std::exception & ex) { \
return OrtCreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); \
}
#define TENSOR_READ_API_BEGIN \
API_IMPL_BEGIN \
auto v = reinterpret_cast<const ::onnxruntime::MLValue*>(value); \
auto& tensor = v->Get<onnxruntime::Tensor>();
#define TENSOR_READWRITE_API_BEGIN \
API_IMPL_BEGIN \
auto v = reinterpret_cast<::onnxruntime::MLValue*>(value); \
auto tensor = v->GetMutable<onnxruntime::Tensor>();
class LoggingWrapper : public ISink {
public:
LoggingWrapper(OrtLoggingFunction logging_function, void* logger_param)
: logging_function_{logging_function}, logger_param_{logger_param} {
}
void SendImpl(const Timestamp& /*timestamp*/ /*timestamp*/, const std::string& logger_id,
const Capture& message) override {
std::string s = message.Location().ToString();
logging_function_(logger_param_, static_cast<OrtLoggingLevel>(message.Severity()), message.Category(),
logger_id.c_str(), s.c_str(), message.Message().c_str());
}
private:
OrtLoggingFunction logging_function_;
void* logger_param_;
};
ORT_API_STATUS_IMPL(OrtInitializeWithCustomLogger, OrtLoggingFunction logging_function,
_In_opt_ void* logger_param, OrtLoggingLevel default_warning_level, _In_ const char* logid,
_Out_ OrtEnv** out) {
API_IMPL_BEGIN
std::string name = logid;
std::unique_ptr<ISink> logger = std::make_unique<LoggingWrapper>(logging_function, logger_param);
auto default_logging_manager = std::make_unique<LoggingManager>(std::move(logger),
static_cast<Severity>(default_warning_level), false,
LoggingManager::InstanceType::Default,
&name);
std::unique_ptr<Environment> env;
Status status = Environment::Create(env);
if (status.IsOK())
*out = new OrtEnv(env.release(), default_logging_manager.release());
return ToOrtStatus(status);
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtInitialize, OrtLoggingLevel default_warning_level,
_In_ const char* logid, _Out_ OrtEnv** out) {
API_IMPL_BEGIN
std::string name = logid;
auto default_logging_manager = std::make_unique<LoggingManager>(std::unique_ptr<ISink>{new CLogSink{}},
static_cast<Severity>(default_warning_level), false,
LoggingManager::InstanceType::Default,
&name);
std::unique_ptr<Environment> env;
Status status = Environment::Create(env);
if (status.IsOK())
*out = new OrtEnv(env.release(), default_logging_manager.release());
return ToOrtStatus(status);
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtGetStringTensorDataLength, _In_ const OrtValue* value, _Out_ size_t* out) {
TENSOR_READ_API_BEGIN
const auto* src = tensor.Data<std::string>();
int64_t len = tensor.Shape().Size();
if (len >= 0) {
size_t ret = 0;
for (int64_t i = 0; i != len; ++i) {
ret += src[i].size();
}
*out = ret;
} else
return OrtCreateStatus(ORT_INVALID_ARGUMENT, "shape is invalid");
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtFillStringTensor, _In_ OrtValue* value, _In_ const char* const* s, size_t s_len) {
TENSOR_READWRITE_API_BEGIN
auto* dst = tensor->MutableData<std::string>();
auto len = static_cast<size_t>(tensor->Shape().Size());
if (s_len < len) {
return OrtCreateStatus(ORT_INVALID_ARGUMENT, "input array is too short");
}
for (size_t i = 0; i != len; ++i) {
//allocate and copy
dst[i] = s[i];
}
return nullptr;
API_IMPL_END
}
template <typename T>
OrtStatus* CreateTensorImpl(const size_t* shape, size_t shape_len, OrtAllocator* allocator,
std::unique_ptr<Tensor>* out) {
size_t elem_count = 1;
std::vector<int64_t> shapes(shape_len);
for (size_t i = 0; i != shape_len; ++i) {
elem_count *= shape[i];
shapes[i] = shape[i];
}
size_t size_to_allocate;
if (!IAllocator::CalcMemSizeForArray(sizeof(T), elem_count, &size_to_allocate)) {
return OrtCreateStatus(ORT_FAIL, "not enough memory");
}
void* p_data = allocator->Alloc(allocator, size_to_allocate);
if (p_data == nullptr)
return OrtCreateStatus(ORT_FAIL, "size overflow");
*out = std::make_unique<Tensor>(DataTypeImpl::GetType<T>(),
onnxruntime::TensorShape(shapes),
static_cast<void*>(p_data),
*allocator->Info(allocator),
std::make_shared<onnxruntime::AllocatorWrapper>(allocator));
return nullptr;
}
/**
*
* this function will create a copy of the allocator info
*/
template <typename T>
OrtStatus* CreateTensorImpl(const size_t* shape, size_t shape_len, const OrtAllocatorInfo* info,
void* p_data, size_t p_data_len, std::unique_ptr<Tensor>* out) {
size_t elem_count = 1;
std::vector<int64_t> shapes(shape_len);
for (size_t i = 0; i != shape_len; ++i) {
elem_count *= shape[i];
shapes[i] = shape[i];
}
size_t size_to_allocate;
if (!IAllocator::CalcMemSizeForArray(sizeof(T), elem_count, &size_to_allocate)) {
return OrtCreateStatus(ORT_INVALID_ARGUMENT, "size overflow");
}
if (size_to_allocate > p_data_len) {
std::ostringstream oss;
oss << "not enough space: expected " << size_to_allocate << ", got " << p_data_len;
return OrtCreateStatus(ORT_INVALID_ARGUMENT, oss.str().c_str());
}
*out = std::make_unique<Tensor>(DataTypeImpl::GetType<T>(),
onnxruntime::TensorShape(shapes),
p_data,
*info,
nullptr);
return nullptr;
}
/**
* this function will create a copy of the allocator info
*/
ORT_API_STATUS_IMPL(OrtCreateTensorWithDataAsOrtValue, _In_ const OrtAllocatorInfo* info,
_In_ void* p_data, size_t p_data_len, _In_ const size_t* shape, size_t shape_len,
ONNXTensorElementDataType type, _Out_ OrtValue** out) {
API_IMPL_BEGIN
std::unique_ptr<Tensor> tensor;
switch (type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<float>(shape, shape_len, info, p_data, p_data_len, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<uint8_t>(shape, shape_len, info, p_data, p_data_len, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<int8_t>(shape, shape_len, info, p_data, p_data_len, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<uint16_t>(shape, shape_len, info, p_data, p_data_len, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<int16_t>(shape, shape_len, info, p_data, p_data_len, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<int32_t>(shape, shape_len, info, p_data, p_data_len, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<int64_t>(shape, shape_len, info, p_data, p_data_len, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<std::string>(shape, shape_len, info, p_data, p_data_len, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<bool>(shape, shape_len, info, p_data, p_data_len, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<MLFloat16>(shape, shape_len, info, p_data, p_data_len, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<BFloat16>(shape, shape_len, info, p_data, p_data_len, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<double>(shape, shape_len, info, p_data, p_data_len, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<uint32_t>(shape, shape_len, info, p_data, p_data_len, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<uint64_t>(shape, shape_len, info, p_data, p_data_len, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64:
case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128:
default: {
std::ostringstream oss;
oss << "type " << type << " is not supported in this function";
std::string errmsg = oss.str();
return OrtCreateStatus(ORT_NOT_IMPLEMENTED, errmsg.c_str());
}
}
std::unique_ptr<MLValue> value = std::make_unique<MLValue>();
value->Init(tensor.release(),
DataTypeImpl::GetType<Tensor>(),
DataTypeImpl::GetType<Tensor>()->GetDeleteFunc());
*out = reinterpret_cast<OrtValue*>(value.release());
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtCreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator,
_In_ const size_t* shape, size_t shape_len, ONNXTensorElementDataType type,
_Out_ OrtValue** out) {
API_IMPL_BEGIN
std::unique_ptr<Tensor> tensor;
switch (type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<float>(shape, shape_len, allocator, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<uint8_t>(shape, shape_len, allocator, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<int8_t>(shape, shape_len, allocator, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<uint16_t>(shape, shape_len, allocator, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<int16_t>(shape, shape_len, allocator, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<int32_t>(shape, shape_len, allocator, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<int64_t>(shape, shape_len, allocator, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<std::string>(shape, shape_len, allocator, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<bool>(shape, shape_len, allocator, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<MLFloat16>(shape, shape_len, allocator, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<BFloat16>(shape, shape_len, allocator, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<double>(shape, shape_len, allocator, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<uint32_t>(shape, shape_len, allocator, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
ORT_API_RETURN_IF_ERROR(CreateTensorImpl<uint64_t>(shape, shape_len, allocator, &tensor));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64:
case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128:
default: {
std::ostringstream oss;
oss << "type " << type << " is not supported in this function";
std::string errmsg = oss.str();
return OrtCreateStatus(ORT_NOT_IMPLEMENTED, errmsg.c_str());
}
}
std::unique_ptr<MLValue> value = std::make_unique<MLValue>();
value->Init(tensor.release(),
DataTypeImpl::GetType<Tensor>(),
DataTypeImpl::GetType<Tensor>()->GetDeleteFunc());
*out = reinterpret_cast<OrtValue*>(value.release());
return nullptr;
API_IMPL_END
}
template <typename T>
static OrtStatus* CreateSessionImpl(_In_ OrtEnv* env, _In_ T model_path,
_In_ const OrtSessionOptions* options,
_Out_ OrtSession** out) {
API_IMPL_BEGIN
auto sess = std::make_unique<::onnxruntime::InferenceSession>(options == nullptr ? onnxruntime::SessionOptions() : options->value, env->loggingManager);
Status status;
if (options != nullptr && !options->custom_op_paths.empty()) {
status = sess->LoadCustomOps(options->custom_op_paths);
if (!status.IsOK())
return ToOrtStatus(status);
}
if (options != nullptr)
for (auto& factory : options->provider_factories) {
auto provider = factory->CreateProvider();
if (provider)
sess->RegisterExecutionProvider(std::move(provider));
}
status = sess->Load(model_path);
if (!status.IsOK())
return ToOrtStatus(status);
status = sess->Initialize();
if (!status.IsOK())
return ToOrtStatus(status);
*out = reinterpret_cast<OrtSession*>(sess.release());
return nullptr;
API_IMPL_END
}
#ifdef _WIN32
ORT_API_STATUS_IMPL(OrtCreateSession, _In_ OrtEnv* env, _In_ const wchar_t* model_path,
_In_ const OrtSessionOptions* options, _Out_ OrtSession** out) {
API_IMPL_BEGIN
return CreateSessionImpl(env, model_path, options, out);
API_IMPL_END
}
#else
ORT_API_STATUS_IMPL(OrtCreateSession, _In_ OrtEnv* env, _In_ const char* model_path,
_In_ const OrtSessionOptions* options, _Out_ OrtSession** out) {
API_IMPL_BEGIN
return CreateSessionImpl(env, model_path, options, out);
API_IMPL_END
}
#endif
ORT_API_STATUS_IMPL(OrtRun, _In_ OrtSession* sess,
_In_ OrtRunOptions* run_options,
_In_ const char* const* input_names, _In_ const OrtValue* const* input, size_t input_len,
_In_ const char* const* output_names1, size_t output_names_len, _Out_ OrtValue** output) {
API_IMPL_BEGIN
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);
::onnxruntime::NameMLValMap in;
const int queue_id = 0;
for (size_t i = 0; i != input_len; ++i) {
auto kvp = in.insert(std::make_pair(std::string(input_names[i]),
*reinterpret_cast<const ::onnxruntime::MLValue*>(input[i])));
if (!kvp.second) {
return OrtCreateStatus(ORT_INVALID_ARGUMENT, "duplicated input name");
}
::onnxruntime::MLValue& value = kvp.first->second;
if (value.Fence())
value.Fence()->BeforeUsingAsInput(onnxruntime::kCpuExecutionProvider, queue_id);
}
// Create output feed
std::vector<std::string> output_names(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output_names1[i] == nullptr || output_names1[i][0] == '\0') {
return OrtCreateStatus(ORT_INVALID_ARGUMENT, "output name cannot be empty");
}
output_names[i] = output_names1[i];
}
std::vector<MLValue> fetches(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] != nullptr) {
::onnxruntime::MLValue& value = *reinterpret_cast<::onnxruntime::MLValue*>(output[i]);
if (value.Fence())
value.Fence()->BeforeUsingAsOutput(onnxruntime::kCpuExecutionProvider, queue_id);
fetches[i] = value;
}
}
Status status;
if (run_options == nullptr) {
OrtRunOptions op;
status = session->Run(op, in, output_names, &fetches);
} else {
status = session->Run(*run_options, in, output_names, &fetches);
}
if (!status.IsOK())
return ToOrtStatus(status);
for (size_t i = 0; i != output_names_len; ++i) {
::onnxruntime::MLValue& value = fetches[i];
if (value.Fence())
value.Fence()->BeforeUsingAsInput(onnxruntime::kCpuExecutionProvider, queue_id);
if (output[i] == nullptr) {
output[i] = reinterpret_cast<OrtValue*>(new MLValue(value));
}
}
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtGetTensorMutableData, _In_ OrtValue* value, _Out_ void** output) {
TENSOR_READWRITE_API_BEGIN
//TODO: test if it's a string tensor
*output = tensor->MutableDataRaw();
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtGetStringTensorContent, _In_ const OrtValue* value,
_Out_ void* s, size_t s_len, _Out_ size_t* offsets, size_t offsets_len) {
TENSOR_READ_API_BEGIN
const auto* input = tensor.Data<std::string>();
auto len = static_cast<size_t>(tensor.Shape().Size());
if (offsets_len < len) {
return OrtCreateStatus(ORT_FAIL, "space is not enough");
}
{
size_t ret = 0;
for (size_t i = 0; i != len; ++i) {
ret += input[i].size();
}
if (s_len < ret) {
return OrtCreateStatus(ORT_FAIL, "space is not enough");
}
}
size_t f = 0;
char* p = static_cast<char*>(s);
for (size_t i = 0; i != offsets_len; ++i, ++offsets) {
memcpy(p, input[i].data(), input[i].size());
p += input[i].size();
*offsets = f;
f += input[i].size();
}
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtTensorProtoToOrtValue, _Inout_ OrtAllocator* allocator,
const void* input, int input_len, _Out_ OrtValue** out) {
API_IMPL_BEGIN
std::shared_ptr<onnxruntime::IAllocator> allocator_ = std::make_shared<onnxruntime::AllocatorWrapper>(allocator);
::ONNX_NAMESPACE::TensorProto proto;
if (!proto.ParseFromArray(input, input_len)) {
return OrtCreateStatus(ORT_FAIL, "parse input tensor proto failed");
}
std::unique_ptr<MLValue> value = std::make_unique<MLValue>();
Status st = onnxruntime::utils::TensorProtoToMLValue(proto, allocator_, nullptr, 0, *value);
if (!st.IsOK())
return ToOrtStatus(st);
*out = reinterpret_cast<OrtValue*>(value.release());
return nullptr;
API_IMPL_END
}
#define DEFINE_RELEASE_ORT_OBJECT_FUNCTION(INPUT_TYPE, REAL_TYPE) \
ORT_API(void, OrtRelease##INPUT_TYPE, Ort##INPUT_TYPE* value) { \
delete reinterpret_cast<REAL_TYPE*>(value); \
}
#define DEFINE_RELEASE_ORT_OBJECT_FUNCTION_FOR_ARRAY(INPUT_TYPE, REAL_TYPE) \
ORT_API(void, OrtRelease##INPUT_TYPE, Ort##INPUT_TYPE* value) { \
delete[] reinterpret_cast<REAL_TYPE*>(value); \
}
ORT_API_STATUS_IMPL(OrtSessionGetInputCount, _In_ const OrtSession* sess, _Out_ size_t* out) {
API_IMPL_BEGIN
auto session = reinterpret_cast<const ::onnxruntime::InferenceSession*>(sess);
std::pair<Status, const InputDefList*> p = session->GetModelInputs();
if (!p.first.IsOK())
return ToOrtStatus(p.first);
*out = p.second->size();
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtSessionGetOutputCount, _In_ const OrtSession* sess, _Out_ size_t* out) {
API_IMPL_BEGIN
auto session = reinterpret_cast<const ::onnxruntime::InferenceSession*>(sess);
std::pair<Status, const InputDefList*> p = session->GetModelOutputs();
if (!p.first.IsOK())
return ToOrtStatus(p.first);
*out = p.second->size();
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtSessionGetInputTypeInfo, _In_ const OrtSession* sess, size_t index, _Out_ struct OrtTypeInfo** out) {
API_IMPL_BEGIN
auto session = reinterpret_cast<const ::onnxruntime::InferenceSession*>(sess);
std::pair<Status, const InputDefList*> p = session->GetModelInputs();
if (!p.first.IsOK())
return ToOrtStatus(p.first);
if (p.second->size() <= index)
return OrtCreateStatus(ORT_FAIL, "out of index");
const ONNX_NAMESPACE::TypeProto* type_proto = (*p.second)[index]->TypeAsProto();
return OrtTypeInfo::FromDataTypeImpl(type_proto, out);
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtSessionGetOutputTypeInfo, _In_ const OrtSession* sess, size_t index, _Out_ struct OrtTypeInfo** out) {
API_IMPL_BEGIN
auto session = reinterpret_cast<const ::onnxruntime::InferenceSession*>(sess);
std::pair<Status, const InputDefList*> p = session->GetModelOutputs();
if (!p.first.IsOK())
return ToOrtStatus(p.first);
if (p.second->size() <= index)
return OrtCreateStatus(ORT_FAIL, "out of index");
const ONNX_NAMESPACE::TypeProto* type_proto = (*p.second)[index]->TypeAsProto();
return OrtTypeInfo::FromDataTypeImpl(type_proto, out);
API_IMPL_END
}
static char* StrDup(const std::string& str, OrtAllocator* allocator) {
char* output_string = reinterpret_cast<char*>(allocator->Alloc(allocator, str.size() + 1));
memcpy(output_string, str.c_str(), str.size());
output_string[str.size()] = '\0';
return output_string;
}
static OrtStatus* GetInputOutputNameImpl(_In_ const OrtSession* sess, size_t index,
_Inout_ OrtAllocator* allocator, bool is_input,
_Out_ char** output) {
auto session = reinterpret_cast<const ::onnxruntime::InferenceSession*>(sess);
std::pair<Status, const InputDefList*> p = is_input ? session->GetModelInputs() : session->GetModelOutputs();
if (!p.first.IsOK())
return ToOrtStatus(p.first);
if (p.second == nullptr)
return OrtCreateStatus(ORT_FAIL, "internal error");
const InputDefList& defs = *p.second;
if (index >= defs.size())
return OrtCreateStatus(ORT_FAIL, "index out of range");
*output = StrDup(defs[index]->Name(), allocator);
return nullptr;
}
ORT_API(int, OrtIsTensor, _In_ const OrtValue* value) {
auto v = reinterpret_cast<const ::onnxruntime::MLValue*>(value);
return v->IsTensor() ? 1 : 0;
}
ORT_API(void*, OrtAllocatorAlloc, _Inout_ OrtAllocator* ptr, size_t size) {
try {
return ptr->Alloc(ptr, size);
} catch (std::exception&) {
return nullptr;
}
}
ORT_API(void, OrtAllocatorFree, _Inout_ OrtAllocator* ptr, void* p) {
try {
ptr->Free(ptr, p);
} catch (std::exception&) {
}
}
ORT_API(const struct OrtAllocatorInfo*, OrtAllocatorGetInfo, _In_ const OrtAllocator* ptr) {
try {
return ptr->Info(ptr);
} catch (std::exception&) {
return nullptr;
}
}
ORT_API_STATUS_IMPL(OrtSessionGetInputName, _In_ const OrtSession* sess, size_t index,
_Inout_ OrtAllocator* allocator, _Out_ char** output) {
API_IMPL_BEGIN
return GetInputOutputNameImpl(sess, index, allocator, true, output);
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtSessionGetOutputName, _In_ const OrtSession* sess, size_t index,
_Inout_ OrtAllocator* allocator, _Out_ char** output) {
API_IMPL_BEGIN
return GetInputOutputNameImpl(sess, index, allocator, false, output);
API_IMPL_END
}
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Env, OrtEnv)
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Value, MLValue)
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(RunOptions, OrtRunOptions)
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Session, ::onnxruntime::InferenceSession)
DEFINE_RELEASE_ORT_OBJECT_FUNCTION_FOR_ARRAY(Status, char)