onnxruntime/onnxruntime/core/session/onnxruntime_c_api.cc
2023-02-24 21:59:34 -08:00

2702 lines
105 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/session/onnxruntime_c_api.h"
#include "core/session/allocator_adapters.h"
#include "core/session/inference_session_utils.h"
#include "core/session/IOBinding.h"
#include "core/framework/allocator.h"
#include "core/framework/error_code_helper.h"
#include "core/framework/execution_provider.h"
#include "core/framework/tensor_type_and_shape.h"
#include "core/framework/utils.h"
#include <cassert>
#include <cstring>
#include <functional>
#include <sstream>
#include "core/common/common.h"
#include "core/common/logging/logging.h"
#include "core/common/narrow.h"
#include "core/common/status.h"
#include "core/common/safeint.h"
#include "core/graph/constants.h"
#include "core/graph/graph.h"
#include "core/framework/allocator.h"
#include "core/framework/tensor.h"
#include "core/framework/ort_value.h"
#include "core/providers/get_execution_providers.h"
#include "core/session/environment.h"
#include "core/framework/callback.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/onnxruntime_typeinfo.h"
#include "core/session/inference_session.h"
#include "core/session/ort_apis.h"
#include "core/session/ort_env.h"
#include "core/framework/data_types.h"
#include "abi_session_options_impl.h"
#include "core/framework/TensorSeq.h"
#include "core/platform/ort_mutex.h"
#include "core/common/string_helper.h"
#ifdef USE_CUDA
#include "core/providers/cuda/cuda_provider_factory.h"
#include "core/providers/cuda/cuda_execution_provider_info.h"
namespace onnxruntime {
ProviderInfo_CUDA* TryGetProviderInfo_CUDA();
}
#endif
#ifdef ENABLE_TRAINING_APIS
#include "orttraining/training_api/include/onnxruntime_training_c_api.h"
#include "orttraining/training_api/ort_training_apis.h"
#endif
#ifdef USE_CANN
#include "core/providers/cann/cann_provider_factory.h"
#include "core/providers/cann/cann_execution_provider_info.h"
namespace onnxruntime {
ProviderInfo_CANN* TryGetProviderInfo_CANN();
}
#endif
#ifdef USE_DML
#include "core/providers/dml/dml_provider_factory.h"
const OrtDmlApi* GetOrtDmlApi(_In_ uint32_t version) NO_EXCEPTION;
#endif
#ifdef ENABLE_EXTENSION_CUSTOM_OPS
#include "onnxruntime_extensions.h"
#endif
#if defined(_MSC_VER) && !defined(__clang__)
// The warning is: "Do not assign the result of an allocation or a function call with an owner<T> return value to a raw pointer, use owner<T> instead(i .11)."
// But this file is for C API. It can't use unique_ptr/shared_ptr in function signature.
#pragma warning(disable : 26400)
#endif
using namespace onnxruntime::logging;
using onnxruntime::BFloat16;
using onnxruntime::DataTypeImpl;
using onnxruntime::Environment;
using onnxruntime::IAllocator;
using onnxruntime::InputDefList;
using onnxruntime::MLFloat16;
using onnxruntime::narrow;
using onnxruntime::OutputDefList;
using onnxruntime::Tensor;
using onnxruntime::ToOrtStatus;
using onnxruntime::common::Status;
using namespace onnxruntime;
#ifndef ORT_STATUS_PTR
#ifdef _WIN32
#define ORT_STATUS_PTR _Check_return_ _Ret_maybenull_ OrtStatusPtr
#else
#define ORT_STATUS_PTR OrtStatus*
#endif
#endif
#define TENSOR_READ_API_BEGIN \
API_IMPL_BEGIN \
auto v = reinterpret_cast<const ::OrtValue*>(value); \
auto& tensor = v->Get<onnxruntime::Tensor>();
#define TENSOR_READWRITE_API_BEGIN \
API_IMPL_BEGIN \
auto v = (value); \
auto tensor = v->GetMutable<onnxruntime::Tensor>();
ORT_API_STATUS_IMPL(OrtApis::CreateEnvWithCustomLogger, OrtLoggingFunction logging_function,
_In_opt_ void* logger_param, OrtLoggingLevel logging_level, _In_ const char* logid,
_Outptr_ OrtEnv** out) {
API_IMPL_BEGIN
OrtEnv::LoggingManagerConstructionInfo lm_info{logging_function, logger_param, logging_level, logid};
Status status;
*out = OrtEnv::GetInstance(lm_info, status);
return ToOrtStatus(status);
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::CreateEnv, OrtLoggingLevel logging_level,
_In_ const char* logid, _Outptr_ OrtEnv** out) {
API_IMPL_BEGIN
OrtEnv::LoggingManagerConstructionInfo lm_info{nullptr, nullptr, logging_level, logid};
Status status;
*out = OrtEnv::GetInstance(lm_info, status);
return ToOrtStatus(status);
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::CreateEnvWithGlobalThreadPools, OrtLoggingLevel logging_level,
_In_ const char* logid, _In_ const struct OrtThreadingOptions* tp_options, _Outptr_ OrtEnv** out) {
API_IMPL_BEGIN
OrtEnv::LoggingManagerConstructionInfo lm_info{nullptr, nullptr, logging_level, logid};
Status status;
*out = OrtEnv::GetInstance(lm_info, status, tp_options);
return ToOrtStatus(status);
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::CreateEnvWithCustomLoggerAndGlobalThreadPools, OrtLoggingFunction logging_function, _In_opt_ void* logger_param,
OrtLoggingLevel logging_level, _In_ const char* logid, _In_ const struct OrtThreadingOptions* tp_options,
_Outptr_ OrtEnv** out) {
API_IMPL_BEGIN
OrtEnv::LoggingManagerConstructionInfo lm_info{logging_function, logger_param, logging_level, logid};
Status status;
*out = OrtEnv::GetInstance(lm_info, status, tp_options);
return ToOrtStatus(status);
API_IMPL_END
}
// enable platform telemetry
ORT_API_STATUS_IMPL(OrtApis::EnableTelemetryEvents, _In_ const OrtEnv* ort_env) {
API_IMPL_BEGIN
ORT_UNUSED_PARAMETER(ort_env);
// note telemetry is controlled via the platform Env object, not the OrtEnv object instance
const Env& env = Env::Default();
env.GetTelemetryProvider().EnableTelemetryEvents();
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::DisableTelemetryEvents, _In_ const OrtEnv* ort_env) {
API_IMPL_BEGIN
ORT_UNUSED_PARAMETER(ort_env);
// note telemetry is controlled via the platform Env object, not the OrtEnv object instance
const Env& env = Env::Default();
env.GetTelemetryProvider().DisableTelemetryEvents();
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::UpdateEnvWithCustomLogLevel, _In_ OrtEnv* ort_env,
OrtLoggingLevel log_severity_level) {
API_IMPL_BEGIN
LoggingManager* default_logging_manager = ort_env->GetLoggingManager();
int severity_level = static_cast<int>(log_severity_level);
default_logging_manager->SetDefaultLoggerSeverity(static_cast<logging::Severity>(severity_level));
return nullptr;
API_IMPL_END
}
ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t shape_len,
_Inout_ OrtAllocator* allocator, OrtValue& value) {
TensorShape tensor_shape(shape, shape_len);
AllocatorPtr alloc_ptr = std::make_shared<onnxruntime::IAllocatorImplWrappingOrtAllocator>(allocator);
Tensor::InitOrtValue(ml_type, tensor_shape, std::move(alloc_ptr), value);
return nullptr;
}
ORT_STATUS_PTR CreateTensorImplForSeq(MLDataType elem_type, const int64_t* shape, size_t shape_len, Tensor& out) {
OrtAllocator* allocator;
// TODO(pranav): what allocator should be used to create the tensor here?
// for the sake of simplicity of the API using the default one here
ORT_API_RETURN_IF_ERROR(OrtApis::GetAllocatorWithDefaultOptions(&allocator));
AllocatorPtr alloc_ptr = std::make_shared<onnxruntime::IAllocatorImplWrappingOrtAllocator>(allocator);
TensorShape tensor_shape(shape, shape_len);
out = Tensor(elem_type, tensor_shape, std::move(alloc_ptr));
return nullptr;
}
/**
*
* this function will create a copy of the allocator info
*/
ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t shape_len, const OrtMemoryInfo* info,
void* p_data, size_t p_data_len, OrtValue& ort_value) {
TensorShape tensor_shape(shape, shape_len);
if (std::any_of(tensor_shape.GetDims().begin(), tensor_shape.GetDims().end(), [](int64_t v) { return v < 0; })) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tried creating tensor with negative value in shape");
}
auto elem_count = narrow<size_t>(tensor_shape.Size());
size_t size_to_allocate;
if (!IAllocator::CalcMemSizeForArray(ml_type->Size(), elem_count, &size_to_allocate)) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "size overflow");
}
if (size_to_allocate > p_data_len) {
std::ostringstream oss;
oss << "not enough space: expected " << size_to_allocate << ", got " << p_data_len;
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, oss.str().c_str());
}
Tensor::InitOrtValue(ml_type, tensor_shape, p_data, *info, ort_value);
return nullptr;
}
ORT_API_STATUS_IMPL(OrtApis::CreateTensorWithDataAsOrtValue, _In_ const OrtMemoryInfo* info,
_Inout_ void* p_data, size_t p_data_len, _In_ const int64_t* shape, size_t shape_len,
ONNXTensorElementDataType type, _Outptr_ OrtValue** out) {
API_IMPL_BEGIN
auto ml_type = DataTypeImpl::TensorTypeFromONNXEnum(type)->GetElementType();
auto value = std::make_unique<OrtValue>();
ORT_API_RETURN_IF_ERROR(CreateTensorImpl(ml_type, shape, shape_len, info, p_data, p_data_len, *value));
*out = value.release();
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::CreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator,
_In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type,
_Outptr_ OrtValue** out) {
API_IMPL_BEGIN
auto ml_type = DataTypeImpl::TensorTypeFromONNXEnum(type)->GetElementType();
auto value = std::make_unique<OrtValue>();
ORT_API_RETURN_IF_ERROR(CreateTensorImpl(ml_type, shape, shape_len, allocator, *value));
*out = value.release();
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::CreateSparseTensorAsOrtValue, _Inout_ OrtAllocator* allocator, _In_ const int64_t* dense_shape,
size_t dense_shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out) {
API_IMPL_BEGIN
#if !defined(DISABLE_SPARSE_TENSORS)
auto sparse_tensor_type = DataTypeImpl::SparseTensorTypeFromONNXEnum(type);
auto element_type = sparse_tensor_type->GetElementType();
assert(element_type->AsPrimitiveDataType() != nullptr);
TensorShape shape(dense_shape, dense_shape_len);
if (std::any_of(shape.GetDims().begin(), shape.GetDims().end(),
[](int64_t v) { return v < 0; })) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tried creating tensor with negative value in shape");
}
auto alloc_ptr = std::make_shared<onnxruntime::IAllocatorImplWrappingOrtAllocator>(allocator);
auto value = std::make_unique<OrtValue>();
SparseTensor::InitOrtValue(element_type, shape, std::move(alloc_ptr), *value);
*out = value.release();
return nullptr;
#else
ORT_UNUSED_PARAMETER(allocator);
ORT_UNUSED_PARAMETER(dense_shape);
ORT_UNUSED_PARAMETER(dense_shape_len);
ORT_UNUSED_PARAMETER(type);
ORT_UNUSED_PARAMETER(out);
return OrtApis::CreateStatus(ORT_FAIL, "SparseTensor is not supported in this build.");
#endif
API_IMPL_END
}
namespace {
#if !defined(DISABLE_SPARSE_TENSORS)
std::unique_ptr<IDataTransfer> GetDataTransfer(const OrtDevice& src_device, const OrtDevice& dst_device) {
if (src_device.Type() == OrtDevice::CPU && dst_device.Type() == OrtDevice::CPU) {
return std::make_unique<CPUDataTransfer>();
}
#ifdef USE_CUDA
if (src_device.Type() == OrtDevice::GPU || dst_device.Type() == OrtDevice::GPU) {
if (auto* provider_info = TryGetProviderInfo_CUDA()) {
return provider_info->CreateGPUDataTransfer();
}
}
#endif
ORT_THROW("Not able to find appropriate IDataTransfer to copy sparse data");
}
SparseTensor& ValidateFillInputArgs(OrtValue* v, const TensorShape& values_shape, const OrtMemoryInfo* data_mem_info) {
auto& sparse_tensor = SparseTensor::GetSparseTensorFromOrtValue(*v);
if (sparse_tensor.IsDataTypeString()) {
if ((data_mem_info->device.Type() != OrtDevice::CPU) || sparse_tensor.Location().device.Type() != OrtDevice::CPU) {
ORT_THROW("Strings can only reside in CPU memory");
}
}
if (std::any_of(values_shape.GetDims().begin(), values_shape.GetDims().end(),
[](int64_t v) { return v < 0; })) {
ORT_THROW("tried Filling sparse tensor with negative value in values shape");
}
return sparse_tensor;
}
union PtrConvert {
explicit PtrConvert(const void* p_p) : p(p_p) {}
const void* p;
const char** strings;
};
#endif // !defined(DISABLE_SPARSE_TENSORS)
} // namespace
ORT_API_STATUS_IMPL(OrtApis::FillSparseTensorCoo, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info,
_In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values,
_In_ const int64_t* indices_data, size_t indices_num) {
API_IMPL_BEGIN
#if !defined(DISABLE_SPARSE_TENSORS)
TensorShape values_t_shape(values_shape, values_shape_len);
auto& sparse_tensor = ValidateFillInputArgs(ort_value, values_t_shape, data_mem_info);
auto values_size = narrow<size_t>(values_t_shape.Size());
auto indices_span = gsl::make_span(indices_data, indices_num);
if (sparse_tensor.IsDataTypeString()) {
PtrConvert conv(values);
ORT_THROW_IF_ERROR(sparse_tensor.MakeCooStrings(values_size, conv.strings, indices_span));
} else {
auto data_transfer = GetDataTransfer(data_mem_info->device, sparse_tensor.Location().device);
ORT_THROW_IF_ERROR(sparse_tensor.MakeCooData(*data_transfer, *data_mem_info, values_size,
values, indices_span));
}
return nullptr;
#else
ORT_UNUSED_PARAMETER(ort_value);
ORT_UNUSED_PARAMETER(data_mem_info);
ORT_UNUSED_PARAMETER(values_shape);
ORT_UNUSED_PARAMETER(values_shape_len);
ORT_UNUSED_PARAMETER(values);
ORT_UNUSED_PARAMETER(indices_data);
ORT_UNUSED_PARAMETER(indices_num);
return OrtApis::CreateStatus(ORT_FAIL, "SparseTensor is not supported in this build.");
#endif
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::FillSparseTensorCsr, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info,
_In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values,
_In_ const int64_t* inner_indices_data, size_t inner_indices_num,
_In_ const int64_t* outer_indices_data, size_t outer_indices_num) {
API_IMPL_BEGIN
#if !defined(DISABLE_SPARSE_TENSORS)
TensorShape values_t_shape(values_shape, values_shape_len);
auto& sparse_tensor = ValidateFillInputArgs(ort_value, values_t_shape, data_mem_info);
auto values_size = narrow<size_t>(values_t_shape.Size());
auto inner_indices_span = gsl::make_span(inner_indices_data, inner_indices_num);
auto outer_indices_span = gsl::make_span(outer_indices_data, outer_indices_num);
if (sparse_tensor.IsDataTypeString()) {
PtrConvert conv(values);
ORT_THROW_IF_ERROR(sparse_tensor.MakeCsrStrings(values_size, conv.strings, inner_indices_span, outer_indices_span));
} else {
auto data_transfer = GetDataTransfer(data_mem_info->device, sparse_tensor.Location().device);
ORT_THROW_IF_ERROR(sparse_tensor.MakeCsrData(*data_transfer, *data_mem_info, values_size,
values, inner_indices_span, outer_indices_span));
}
return nullptr;
#else
ORT_UNUSED_PARAMETER(ort_value);
ORT_UNUSED_PARAMETER(data_mem_info);
ORT_UNUSED_PARAMETER(values_shape);
ORT_UNUSED_PARAMETER(values_shape_len);
ORT_UNUSED_PARAMETER(values);
ORT_UNUSED_PARAMETER(inner_indices_data);
ORT_UNUSED_PARAMETER(inner_indices_num);
ORT_UNUSED_PARAMETER(outer_indices_data);
ORT_UNUSED_PARAMETER(outer_indices_num);
return OrtApis::CreateStatus(ORT_FAIL, "SparseTensor is not supported in this build.");
#endif
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::FillSparseTensorBlockSparse, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info,
_In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values,
_In_ const int64_t* indices_shape_data, size_t indices_shape_len,
_In_ const int32_t* indices_data) {
API_IMPL_BEGIN
#if !defined(DISABLE_SPARSE_TENSORS)
TensorShape values_t_shape(values_shape, values_shape_len);
auto& sparse_tensor = ValidateFillInputArgs(ort_value, values_t_shape, data_mem_info);
TensorShape indices_t_shape(indices_shape_data, indices_shape_len);
if (std::any_of(indices_t_shape.GetDims().begin(), indices_t_shape.GetDims().end(),
[](int64_t v) { return v < 0; })) {
ORT_THROW("tried Filling sparse tensor with negative value in block sparse indices shape");
}
if (sparse_tensor.IsDataTypeString()) {
PtrConvert conv(values);
ORT_THROW_IF_ERROR(sparse_tensor.MakeBlockSparseStrings(values_t_shape, conv.strings, indices_t_shape, indices_data));
} else {
auto data_transfer = GetDataTransfer(data_mem_info->device, sparse_tensor.Location().device);
ORT_THROW_IF_ERROR(sparse_tensor.MakeBlockSparseData(*data_transfer, *data_mem_info, values_t_shape,
values, indices_t_shape, indices_data));
}
return nullptr;
#else
ORT_UNUSED_PARAMETER(ort_value);
ORT_UNUSED_PARAMETER(data_mem_info);
ORT_UNUSED_PARAMETER(values_shape);
ORT_UNUSED_PARAMETER(values_shape_len);
ORT_UNUSED_PARAMETER(values);
ORT_UNUSED_PARAMETER(indices_shape_data);
ORT_UNUSED_PARAMETER(indices_shape_len);
ORT_UNUSED_PARAMETER(indices_data);
return OrtApis::CreateStatus(ORT_FAIL, "SparseTensor is not supported in this build.");
#endif
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::CreateSparseTensorWithValuesAsOrtValue, _In_ const OrtMemoryInfo* info, _Inout_ void* p_data,
_In_ const int64_t* dense_shape, size_t dense_shape_len,
_In_ const int64_t* values_shape, size_t values_shape_len,
ONNXTensorElementDataType type, _Outptr_ OrtValue** out) {
API_IMPL_BEGIN
#if !defined(DISABLE_SPARSE_TENSORS)
auto sparse_tensor_type = DataTypeImpl::SparseTensorTypeFromONNXEnum(type);
auto element_type = sparse_tensor_type->GetElementType();
assert(element_type->AsPrimitiveDataType() != nullptr);
if (utils::IsDataTypeString(element_type)) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Can not use strings in pre-allocated memory."
" Use CreateSparseTensorAsOrtValue() to allocate memory inside and copy");
}
TensorShape tensor_dense_shape(dense_shape, dense_shape_len);
TensorShape tensor_values_shape(values_shape, values_shape_len);
if (std::any_of(tensor_values_shape.GetDims().begin(), tensor_values_shape.GetDims().end(),
[](int64_t v) { return v < 0; })) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tried creating tensor with negative value in shape");
}
auto value = std::make_unique<OrtValue>();
SparseTensor::InitOrtValue(element_type, tensor_dense_shape, tensor_values_shape, p_data, *info, *value);
*out = value.release();
return nullptr;
#else
ORT_UNUSED_PARAMETER(info);
ORT_UNUSED_PARAMETER(p_data);
ORT_UNUSED_PARAMETER(dense_shape);
ORT_UNUSED_PARAMETER(dense_shape_len);
ORT_UNUSED_PARAMETER(values_shape);
ORT_UNUSED_PARAMETER(values_shape_len);
ORT_UNUSED_PARAMETER(type);
ORT_UNUSED_PARAMETER(out);
return OrtApis::CreateStatus(ORT_FAIL, "SparseTensor is not supported in this build.");
#endif
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::UseCooIndices, _Inout_ OrtValue* ort_value, _Inout_ int64_t* indices_data, size_t indices_num) {
API_IMPL_BEGIN
#if !defined(DISABLE_SPARSE_TENSORS)
auto v = reinterpret_cast<::OrtValue*>(ort_value);
auto& sparse_tensor = SparseTensor::GetSparseTensorFromOrtValue(*v);
auto indices_span = (indices_num == 0 || indices_data == nullptr)
? gsl::span<int64_t>()
: gsl::make_span(indices_data, indices_num);
ORT_THROW_IF_ERROR(sparse_tensor.UseCooIndices(indices_span));
return nullptr;
#else
ORT_UNUSED_PARAMETER(ort_value);
ORT_UNUSED_PARAMETER(indices_data);
ORT_UNUSED_PARAMETER(indices_num);
return OrtApis::CreateStatus(ORT_FAIL, "SparseTensor is not supported in this build.");
#endif
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::UseCsrIndices, _Inout_ OrtValue* ort_value,
_Inout_ int64_t* inner_data, size_t inner_num,
_Inout_ int64_t* outer_data, size_t outer_num) {
API_IMPL_BEGIN
#if !defined(DISABLE_SPARSE_TENSORS)
auto& sparse_tensor = SparseTensor::GetSparseTensorFromOrtValue(*ort_value);
auto inner_span = (inner_num == 0 || inner_data == nullptr)
? gsl::span<int64_t>()
: gsl::make_span(inner_data, inner_num);
auto outer_span = (outer_num == 0 || outer_data == nullptr)
? gsl::span<int64_t>()
: gsl::make_span(outer_data, outer_num);
ORT_THROW_IF_ERROR(sparse_tensor.UseCsrIndices(inner_span, outer_span));
return nullptr;
#else
ORT_UNUSED_PARAMETER(ort_value);
ORT_UNUSED_PARAMETER(inner_data);
ORT_UNUSED_PARAMETER(inner_num);
ORT_UNUSED_PARAMETER(outer_data);
ORT_UNUSED_PARAMETER(outer_num);
return OrtApis::CreateStatus(ORT_FAIL, "SparseTensor is not supported in this build.");
#endif
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::UseBlockSparseIndices, _Inout_ OrtValue* ort_value, const int64_t* indices_shape,
size_t indices_shape_len, _Inout_ int32_t* indices_data) {
API_IMPL_BEGIN
#if !defined(DISABLE_SPARSE_TENSORS)
auto& sparse_tensor = SparseTensor::GetSparseTensorFromOrtValue(*ort_value);
TensorShape ind_shape(indices_shape, indices_shape_len);
ORT_THROW_IF_ERROR(sparse_tensor.UseBlockSparseIndices(ind_shape, indices_data));
return nullptr;
#else
ORT_UNUSED_PARAMETER(ort_value);
ORT_UNUSED_PARAMETER(indices_shape);
ORT_UNUSED_PARAMETER(indices_shape_len);
ORT_UNUSED_PARAMETER(indices_data);
return OrtApis::CreateStatus(ORT_FAIL, "SparseTensor is not supported in this build.");
#endif
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::GetSparseTensorFormat, _In_ const OrtValue* ort_value, _Out_ enum OrtSparseFormat* out) {
API_IMPL_BEGIN
#if !defined(DISABLE_SPARSE_TENSORS)
auto v = reinterpret_cast<const ::OrtValue*>(ort_value);
if (!v->IsAllocated()) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "the ort_value must contain a constructed tensor");
}
const auto& sparse_tensor = v->Get<SparseTensor>();
*out = static_cast<OrtSparseFormat>(sparse_tensor.Format());
return nullptr;
#else
ORT_UNUSED_PARAMETER(ort_value);
ORT_UNUSED_PARAMETER(out);
return OrtApis::CreateStatus(ORT_FAIL, "SparseTensor is not supported in this build.");
#endif
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::GetSparseTensorValues, _In_ const OrtValue* ort_value, _Outptr_ const void** out) {
API_IMPL_BEGIN
#if !defined(DISABLE_SPARSE_TENSORS)
const auto& sparse_tensor = SparseTensor::GetSparseTensorFromOrtValue(*ort_value);
if (sparse_tensor.IsDataTypeString()) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Use GetStringTensor*() API to retrieve strings");
}
const auto& values = sparse_tensor.Values();
*out = values.DataRaw();
return nullptr;
#else
ORT_UNUSED_PARAMETER(ort_value);
ORT_UNUSED_PARAMETER(out);
return OrtApis::CreateStatus(ORT_FAIL, "SparseTensor is not supported in this build.");
#endif
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::CreateCustomOpDomain, _In_ const char* domain, _Outptr_ OrtCustomOpDomain** out) {
API_IMPL_BEGIN
auto custom_op_domain = std::make_unique<OrtCustomOpDomain>();
custom_op_domain->domain_ = domain;
*out = custom_op_domain.release();
return nullptr;
API_IMPL_END
}
ORT_API(void, OrtApis::ReleaseCustomOpDomain, _Frees_ptr_opt_ OrtCustomOpDomain* ptr) {
delete ptr;
}
ORT_API_STATUS_IMPL(OrtApis::CustomOpDomain_Add, _Inout_ OrtCustomOpDomain* custom_op_domain, _In_ const OrtCustomOp* op) {
API_IMPL_BEGIN
custom_op_domain->custom_ops_.emplace_back(op);
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::AddCustomOpDomain, _Inout_ OrtSessionOptions* options,
_In_ OrtCustomOpDomain* custom_op_domain) {
API_IMPL_BEGIN
options->custom_op_domains_.emplace_back(custom_op_domain);
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::RegisterCustomOpsLibrary, _Inout_ OrtSessionOptions* options, _In_ const char* library_path, _Outptr_ void** library_handle) {
API_IMPL_BEGIN
auto path_str = ToPathString(library_path);
ORT_API_RETURN_IF_STATUS_NOT_OK(Env::Default().LoadDynamicLibrary(path_str, false, library_handle));
if (!*library_handle)
return OrtApis::CreateStatus(ORT_FAIL, "RegisterCustomOpsLibrary: Failed to load library");
RegisterCustomOpsFn RegisterCustomOps;
ORT_API_RETURN_IF_STATUS_NOT_OK(Env::Default().GetSymbolFromLibrary(*library_handle, "RegisterCustomOps",
(void**)&RegisterCustomOps));
if (!RegisterCustomOps)
return OrtApis::CreateStatus(ORT_FAIL, "RegisterCustomOpsLibrary: Entry point RegisterCustomOps not found in library");
return RegisterCustomOps(options, OrtGetApiBase());
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::RegisterCustomOpsLibrary_V2, _Inout_ OrtSessionOptions* options,
_In_ const ORTCHAR_T* library_name) {
API_IMPL_BEGIN
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
ORT_API_RETURN_IF_STATUS_NOT_OK(options->RegisterCustomOpsLibrary(library_name));
return nullptr;
#else
ORT_UNUSED_PARAMETER(options);
ORT_UNUSED_PARAMETER(library_name);
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Custom operator libraries are not supported in this build");
#endif
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::RegisterCustomOpsUsingFunction, _Inout_ OrtSessionOptions* options,
_In_ const char* registration_func_name) {
API_IMPL_BEGIN
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
if (!registration_func_name) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"RegisterCustomOpsUsingFunction: Registration function name must be specified.");
}
RegisterCustomOpsFn RegisterCustomOps;
ORT_API_RETURN_IF_STATUS_NOT_OK(Env::Default().GetSymbolFromLibrary(nullptr, registration_func_name,
(void**)&RegisterCustomOps));
if (!RegisterCustomOps) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"RegisterCustomOpsUsingFunction: Registration function was not found");
}
return RegisterCustomOps(options, OrtGetApiBase());
#else
ORT_UNUSED_PARAMETER(options);
ORT_UNUSED_PARAMETER(registration_func_name);
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Custom operator libraries are not supported in this build");
#endif
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::EnableOrtCustomOps, _Inout_ OrtSessionOptions* options) {
API_IMPL_BEGIN
if (options) {
#ifdef ENABLE_EXTENSION_CUSTOM_OPS
return RegisterCustomOps(options, OrtGetApiBase());
#else
return OrtApis::CreateStatus(ORT_FAIL, "EnableOrtCustomOps: Custom operators in onnxruntime-extensions are not enabled");
#endif
}
return nullptr;
API_IMPL_END
}
namespace {
// provider either model_path, or modal_data + model_data_length.
static ORT_STATUS_PTR CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options,
_In_ const OrtEnv* env,
_In_opt_z_ const ORTCHAR_T* model_path,
_In_opt_ const void* model_data,
size_t model_data_length,
std::unique_ptr<onnxruntime::InferenceSession>& sess) {
// quick check here to decide load path. InferenceSession will provide error message for invalid values.
// TODO: Could move to a helper
const Env& os_env = Env::Default(); // OS environment (!= ORT environment)
bool load_config_from_model =
os_env.GetEnvironmentVar(inference_session_utils::kOrtLoadConfigFromModelEnvVar) == "1";
if (load_config_from_model) {
#if !defined(ORT_MINIMAL_BUILD)
if (model_path != nullptr) {
sess = std::make_unique<onnxruntime::InferenceSession>(
options == nullptr ? onnxruntime::SessionOptions() : options->value,
env->GetEnvironment(),
model_path);
} else {
sess = std::make_unique<onnxruntime::InferenceSession>(
options == nullptr ? onnxruntime::SessionOptions() : options->value,
env->GetEnvironment(),
model_data, static_cast<int>(model_data_length));
}
#else
return OrtApis::CreateStatus(ORT_FAIL, "Loading config from ONNX models is not supported in this build.");
#endif
} else {
sess = std::make_unique<onnxruntime::InferenceSession>(
options == nullptr ? onnxruntime::SessionOptions() : options->value,
env->GetEnvironment());
}
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
// Add custom domains
if (options && !options->custom_op_domains_.empty()) {
ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddCustomOpDomains(options->custom_op_domains_));
}
#endif
// Finish load
if (load_config_from_model) {
#if !defined(ORT_MINIMAL_BUILD)
ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load());
#endif
} else {
if (model_path != nullptr) {
ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(model_path));
} else {
ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(model_data, static_cast<int>(model_data_length)));
}
}
return nullptr;
}
static ORT_STATUS_PTR InitializeSession(_In_ const OrtSessionOptions* options,
_In_ std::unique_ptr<::onnxruntime::InferenceSession>& sess,
_Inout_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container = nullptr) {
// we need to disable mem pattern if DML is one of the providers since DML doesn't have the concept of
// byte addressable memory
std::vector<std::unique_ptr<IExecutionProvider>> provider_list;
if (options) {
for (auto& factory : options->provider_factories) {
auto provider = factory->CreateProvider();
provider_list.push_back(std::move(provider));
}
}
// register the providers
for (auto& provider : provider_list) {
if (provider) {
ORT_API_RETURN_IF_STATUS_NOT_OK(sess->RegisterExecutionProvider(std::move(provider)));
}
}
if (prepacked_weights_container != nullptr) {
ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddPrePackedWeightsContainer(
reinterpret_cast<PrepackedWeightsContainer*>(prepacked_weights_container)));
}
ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Initialize());
return nullptr;
}
} // namespace
ORT_API_STATUS_IMPL(OrtApis::CreateSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path,
_In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out) {
API_IMPL_BEGIN
std::unique_ptr<onnxruntime::InferenceSession> sess;
OrtStatus* status = nullptr;
*out = nullptr;
ORT_TRY {
ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, sess));
ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess));
*out = reinterpret_cast<OrtSession*>(sess.release());
}
ORT_CATCH(const std::exception& e) {
ORT_HANDLE_EXCEPTION([&]() {
status = OrtApis::CreateStatus(ORT_FAIL, e.what());
});
}
return status;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In_ const void* model_data,
size_t model_data_length, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out) {
API_IMPL_BEGIN
std::unique_ptr<onnxruntime::InferenceSession> sess;
OrtStatus* status = nullptr;
*out = nullptr;
ORT_TRY {
ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, nullptr, model_data, model_data_length, sess));
ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess));
*out = reinterpret_cast<OrtSession*>(sess.release());
}
ORT_CATCH(const std::exception& e) {
ORT_HANDLE_EXCEPTION([&]() {
status = OrtApis::CreateStatus(ORT_FAIL, e.what());
});
}
return status;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names1, size_t output_names_len,
_Inout_updates_all_(output_names_len) OrtValue** output) {
API_IMPL_BEGIN
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);
std::vector<std::string> feed_names(input_len);
std::vector<OrtValue> feeds(input_len);
for (size_t i = 0; i != input_len; ++i) {
if (input_names[i] == nullptr || input_names[i][0] == '\0') {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "input name cannot be empty");
}
if (!input[i]) {
std::ostringstream ostr;
ostr << "NULL input supplied for input " << input_names[i];
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, ostr.str().c_str());
}
feed_names[i] = input_names[i];
feeds[i] = *reinterpret_cast<const ::OrtValue*>(input[i]);
}
// Create output feed
std::vector<std::string> output_names(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output_names1[i] == nullptr || output_names1[i][0] == '\0') {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "output name cannot be empty");
}
output_names[i] = output_names1[i];
}
std::vector<OrtValue> fetches(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] != nullptr) {
::OrtValue& value = *(output[i]);
fetches[i] = value;
}
}
Status status;
if (run_options == nullptr) {
OrtRunOptions op;
status = session->Run(op, feed_names, feeds, output_names, &fetches, nullptr);
} else {
status = session->Run(*run_options, feed_names, feeds, output_names, &fetches, nullptr);
}
if (!status.IsOK())
return ToOrtStatus(status);
for (size_t i = 0; i != output_names_len; ++i) {
::OrtValue& value = fetches[i];
if (output[i] == nullptr) {
GSL_SUPPRESS(r .11)
output[i] = new OrtValue(value);
}
}
return nullptr;
API_IMPL_END
}
struct OrtIoBinding {
std::unique_ptr<::onnxruntime::IOBinding> binding_;
explicit OrtIoBinding(std::unique_ptr<::onnxruntime::IOBinding>&& binding) : binding_(std::move(binding)) {}
OrtIoBinding(const OrtIoBinding&) = delete;
OrtIoBinding& operator=(const OrtIoBinding&) = delete;
};
ORT_API_STATUS_IMPL(OrtApis::RunWithBinding, _Inout_ OrtSession* sess, _In_ const OrtRunOptions* run_options,
_In_ const OrtIoBinding* binding_ptr) {
API_IMPL_BEGIN
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);
Status status;
if (run_options == nullptr) {
OrtRunOptions default_run_options;
status = session->Run(default_run_options, *binding_ptr->binding_);
} else {
status = session->Run(*run_options, *binding_ptr->binding_);
}
if (!status.IsOK()) {
return ToOrtStatus(status);
}
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::CreateIoBinding, _Inout_ OrtSession* sess, _Outptr_ OrtIoBinding** out) {
API_IMPL_BEGIN
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);
std::unique_ptr<::onnxruntime::IOBinding> binding;
auto status = session->NewIOBinding(&binding);
if (!status.IsOK()) {
return ToOrtStatus(status);
}
GSL_SUPPRESS(r .11)
*out = new OrtIoBinding(std::move(binding));
return nullptr;
API_IMPL_END
}
ORT_API(void, OrtApis::ReleaseIoBinding, _Frees_ptr_opt_ OrtIoBinding* binding_ptr) {
delete binding_ptr;
}
ORT_API_STATUS_IMPL(OrtApis::BindInput, _Inout_ OrtIoBinding* binding_ptr, _In_ const char* name, _In_ const OrtValue* val_ptr) {
API_IMPL_BEGIN
auto st = binding_ptr->binding_->BindInput(name, *val_ptr);
if (!st.IsOK()) {
return ToOrtStatus(st);
}
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::BindOutput, _Inout_ OrtIoBinding* binding_ptr, _In_ const char* name, _In_ const OrtValue* val_ptr) {
API_IMPL_BEGIN
auto st = binding_ptr->binding_->BindOutput(name, *val_ptr);
if (!st.IsOK()) {
return ToOrtStatus(st);
}
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::BindOutputToDevice, _Inout_ OrtIoBinding* binding_ptr, _In_ const char* name, _In_ const OrtMemoryInfo* mem_info_ptr) {
API_IMPL_BEGIN
auto st = binding_ptr->binding_->BindOutput(name, mem_info_ptr->device);
if (!st.IsOK()) {
return ToOrtStatus(st);
}
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::GetBoundOutputNames, _In_ const OrtIoBinding* binding_ptr, _In_ OrtAllocator* allocator,
_Out_ char** buffer, _Outptr_result_maybenull_ size_t** lengths, _Out_ size_t* count) {
API_IMPL_BEGIN
const auto& output_names = binding_ptr->binding_->GetOutputNames();
if (output_names.empty()) {
*buffer = nullptr;
*lengths = nullptr;
*count = 0U;
return nullptr;
}
IAllocatorUniquePtr<size_t> lengths_alloc(reinterpret_cast<size_t*>(allocator->Alloc(allocator, output_names.size() * sizeof(size_t))),
[allocator](size_t* p) { if(p) allocator->Free(allocator, p); });
if (!lengths_alloc) {
return OrtApis::CreateStatus(ORT_FAIL, "lengths allocation failed");
}
size_t total_len = 0;
auto* len_ptr = lengths_alloc.get();
for (const auto& n : output_names) {
auto sz = n.size();
total_len += sz;
*len_ptr++ = sz;
}
IAllocatorUniquePtr<char> buffer_alloc(reinterpret_cast<char*>(allocator->Alloc(allocator, total_len * sizeof(char))),
[allocator](char* p) { if(p) allocator->Free(allocator, p); });
if (!buffer_alloc) {
return OrtApis::CreateStatus(ORT_FAIL, "string buffer allocation failed");
}
char* buf_ptr = buffer_alloc.get();
for (const auto& n : output_names) {
auto sz = n.size();
memcpy(buf_ptr, n.data(), sz);
buf_ptr += sz;
}
*buffer = buffer_alloc.release();
*lengths = lengths_alloc.release();
*count = output_names.size();
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::GetBoundOutputValues, _In_ const OrtIoBinding* binding_ptr, _In_ OrtAllocator* allocator,
_Outptr_result_maybenull_ OrtValue*** output, _Out_ size_t* output_count) {
API_IMPL_BEGIN
const auto& outputs = binding_ptr->binding_->GetOutputs();
if (outputs.empty()) {
*output = nullptr;
*output_count = 0U;
return nullptr;
}
// Used to destroy and de-allocate on exception
size_t created = 0;
IAllocatorUniquePtr<OrtValue*> ortvalues_alloc(reinterpret_cast<OrtValue**>(allocator->Alloc(allocator, outputs.size() * sizeof(OrtValue*))),
[&created, allocator](OrtValue** buffer) {
if (buffer) {
while (created > 0) {
auto p = buffer + --created;
delete (*p);
}
allocator->Free(allocator, buffer);
}
});
if (!ortvalues_alloc) {
return OrtApis::CreateStatus(ORT_FAIL, "Output buffer allocation failed");
}
OrtValue** out_ptr = ortvalues_alloc.get();
for (const auto& out_value : outputs) {
GSL_SUPPRESS(r .11)
*out_ptr = new OrtValue(out_value);
++out_ptr;
++created;
}
assert(created == outputs.size());
*output = ortvalues_alloc.release();
*output_count = created;
return nullptr;
API_IMPL_END
}
ORT_API(void, OrtApis::ClearBoundInputs, _Inout_ OrtIoBinding* binding_ptr) {
binding_ptr->binding_->ClearInputs();
}
ORT_API(void, OrtApis::ClearBoundOutputs, _Inout_ OrtIoBinding* binding_ptr) {
binding_ptr->binding_->ClearOutputs();
}
ORT_API_STATUS_IMPL(OrtApis::SynchronizeBoundInputs, _Inout_ OrtIoBinding* binding_ptr) {
API_IMPL_BEGIN
auto st = binding_ptr->binding_->SynchronizeInputs();
if (!st.IsOK()) {
return ToOrtStatus(st);
}
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::SynchronizeBoundOutputs, _Inout_ OrtIoBinding* binding_ptr) {
API_IMPL_BEGIN
auto st = binding_ptr->binding_->SynchronizeOutputs();
if (!st.IsOK()) {
return ToOrtStatus(st);
}
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::IsTensor, _In_ const OrtValue* value, _Out_ int* out) {
auto v = reinterpret_cast<const ::OrtValue*>(value);
*out = v->IsTensor() ? 1 : 0;
return nullptr;
}
ORT_API_STATUS_IMPL(OrtApis::HasValue, _In_ const OrtValue* value, _Out_ int* out) {
auto v = reinterpret_cast<const ::OrtValue*>(value);
*out = v->IsAllocated() ? 1 : 0;
return nullptr;
}
ORT_API_STATUS_IMPL(OrtApis::IsSparseTensor, _In_ const OrtValue* value, _Out_ int* out) {
#if !defined(DISABLE_SPARSE_TENSORS)
auto v = reinterpret_cast<const ::OrtValue*>(value);
*out = v->IsSparseTensor() ? 1 : 0;
return nullptr;
#else
ORT_UNUSED_PARAMETER(value);
ORT_UNUSED_PARAMETER(out);
return OrtApis::CreateStatus(ORT_FAIL, "SparseTensor is not supported in this build.");
#endif
}
ORT_API_STATUS_IMPL(OrtApis::GetTensorMutableData, _Inout_ OrtValue* value, _Outptr_ void** output) {
TENSOR_READWRITE_API_BEGIN
// Uncomment when WinML fixed their code
// if (tensor->IsDataTypeString()) {
// return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "this API does not support strings");
//}
*output = tensor->MutableDataRaw();
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::FillStringTensor, _Inout_ OrtValue* value, _In_ const char* const* s, size_t s_len) {
TENSOR_READWRITE_API_BEGIN
auto* dst = tensor->MutableData<std::string>();
auto len = static_cast<size_t>(tensor->Shape().Size());
if (s_len != len) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "input array doesn't equal tensor size");
}
for (size_t i = 0; i != len; ++i) {
// allocate and copy
dst[i] = s[i];
}
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::FillStringTensorElement, _Inout_ OrtValue* value, _In_ const char* s, size_t index) {
TENSOR_READWRITE_API_BEGIN
auto* dst = tensor->MutableData<std::string>();
auto len = static_cast<size_t>(tensor->Shape().Size());
if (index >= len) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "element index is out of bounds");
}
dst[index] = s;
return nullptr;
API_IMPL_END
}
namespace {
OrtStatusPtr GetTensorStringSpan(const ::OrtValue& v, gsl::span<const std::string>& span) {
if (!v.IsAllocated()) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtValue should contain a Tensor or a Sparse Tensor");
}
gsl::span<const std::string> str_span;
int64_t items = 0;
// Data type will be enforced on DataAsSpan() call.
if (v.IsTensor()) {
const auto& tensor = v.Get<onnxruntime::Tensor>();
items = tensor.Shape().Size();
if (items >= 0) {
str_span = tensor.DataAsSpan<std::string>();
}
}
#if !defined(DISABLE_SPARSE_TENSORS)
else if (v.IsSparseTensor()) {
const auto& sparse_tensor = v.Get<SparseTensor>();
if (sparse_tensor.Format() == onnxruntime::SparseFormat::kUndefined) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Sparse Tensor does not contain sparse data");
}
items = sparse_tensor.Values().Shape().Size();
if (items >= 0) {
str_span = sparse_tensor.Values().DataAsSpan<std::string>();
}
}
#endif
else {
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API supports Tensors or SparseTensors");
}
if (items < 0) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "shape is invalid");
}
span = str_span;
return nullptr;
}
} // namespace
ORT_API_STATUS_IMPL(OrtApis::GetStringTensorDataLength, _In_ const OrtValue* value, _Out_ size_t* out) {
API_IMPL_BEGIN
gsl::span<const std::string> str_span;
if (auto* status = GetTensorStringSpan(*value, str_span)) {
return status;
}
size_t ret = 0;
for (const auto& s : str_span) {
ret += s.size();
}
*out = ret;
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::GetStringTensorElementLength, _In_ const OrtValue* value, size_t index, _Out_ size_t* out) {
API_IMPL_BEGIN
gsl::span<const std::string> str_span;
if (auto* status = GetTensorStringSpan(*value, str_span)) {
return status;
}
if (index < str_span.size()) {
*out = str_span[index].size();
} else {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "index is out of bounds");
}
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::GetStringTensorContent, _In_ const OrtValue* value, _Out_writes_bytes_all_(s_len) void* s,
size_t s_len, _Out_writes_all_(offsets_len) size_t* offsets, size_t offsets_len) {
API_IMPL_BEGIN
gsl::span<const std::string> str_span;
if (auto* status = GetTensorStringSpan(*value, str_span)) {
return status;
}
if (offsets_len != str_span.size()) {
return OrtApis::CreateStatus(ORT_FAIL, "offsets buffer is not equal to tensor size");
}
size_t total_size = 0;
for (const auto& str : str_span) {
total_size += str.size();
}
if (s_len < total_size) {
return OrtApis::CreateStatus(ORT_FAIL, "output buffer is too small. Use GetStringTensorDataLength.");
}
size_t f = 0;
char* p = static_cast<char*>(s);
for (const auto& str : str_span) {
memcpy(p, str.data(), str.size());
p += str.size();
*offsets++ = f;
f += str.size();
}
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::GetStringTensorElement, _In_ const OrtValue* value,
size_t s_len, size_t index, _Out_writes_bytes_all_(s_len) void* s) {
API_IMPL_BEGIN
gsl::span<const std::string> str_span;
if (auto* status = GetTensorStringSpan(*value, str_span)) {
return status;
}
if (index < str_span.size()) {
const auto& str = str_span[index];
if (s_len < str.size()) {
return OrtApis::CreateStatus(ORT_FAIL, "buffer size is too small for string element");
}
memcpy(s, str.data(), str.size());
} else {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "element index is out of bounds");
}
return nullptr;
API_IMPL_END
}
#define ORT_C_API_RETURN_IF_ERROR(expr) \
do { \
auto _status = (expr); \
if ((!_status.IsOK())) return ToOrtStatus(_status); \
} while (0)
#define DEFINE_RELEASE_ORT_OBJECT_FUNCTION(INPUT_TYPE, REAL_TYPE) \
ORT_API(void, OrtApis::Release##INPUT_TYPE, _Frees_ptr_opt_ Ort##INPUT_TYPE* value) { \
delete reinterpret_cast<REAL_TYPE*>(value); \
}
using DefListResult = std::pair<Status, const InputDefList*>;
using GetDefListFn = DefListResult (*)(const ::onnxruntime::InferenceSession*);
const auto get_inputs_fn = [](const ::onnxruntime::InferenceSession* session) -> DefListResult { return session->GetModelInputs(); };
const auto get_outputs_fn = [](const ::onnxruntime::InferenceSession* session) -> DefListResult { return session->GetModelOutputs(); };
const auto get_overridable_initializers_fn = [](const ::onnxruntime::InferenceSession* session) -> DefListResult { return session->GetOverridableInitializers(); };
static ORT_STATUS_PTR GetNodeDefListCountHelper(const OrtSession* sess, GetDefListFn get_fn, size_t* out) {
API_IMPL_BEGIN
auto session = reinterpret_cast<const ::onnxruntime::InferenceSession*>(sess);
std::pair<Status, const InputDefList*> p = get_fn(session);
if (!p.first.IsOK())
return ToOrtStatus(p.first);
*out = p.second->size();
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::SessionGetInputCount, _In_ const OrtSession* sess, _Out_ size_t* out) {
return GetNodeDefListCountHelper(sess, get_inputs_fn, out);
}
ORT_API_STATUS_IMPL(OrtApis::SessionGetOutputCount, _In_ const OrtSession* sess, _Out_ size_t* out) {
return GetNodeDefListCountHelper(sess, get_outputs_fn, out);
}
ORT_API_STATUS_IMPL(OrtApis::SessionGetOverridableInitializerCount, _In_ const OrtSession* sess, _Out_ size_t* out) {
return GetNodeDefListCountHelper(sess, get_overridable_initializers_fn, out);
}
static ORT_STATUS_PTR GetNodeDefTypeInfoHelper(const OrtSession* sess, GetDefListFn get_fn, size_t index,
_Outptr_ struct OrtTypeInfo** out) {
API_IMPL_BEGIN
auto session = reinterpret_cast<const ::onnxruntime::InferenceSession*>(sess);
std::pair<Status, const InputDefList*> p = get_fn(session);
if (!p.first.IsOK())
return ToOrtStatus(p.first);
if (p.second->size() <= index)
return OrtApis::CreateStatus(ORT_FAIL, "out of index");
const ONNX_NAMESPACE::TypeProto* type_proto = (*p.second)[index]->TypeAsProto();
return OrtTypeInfo::FromTypeProto(type_proto, out);
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::SessionGetInputTypeInfo, _In_ const OrtSession* sess, size_t index, _Outptr_ struct OrtTypeInfo** out) {
return GetNodeDefTypeInfoHelper(sess, get_inputs_fn, index, out);
}
ORT_API_STATUS_IMPL(OrtApis::SessionGetOutputTypeInfo, _In_ const OrtSession* sess, size_t index, _Outptr_ struct OrtTypeInfo** out) {
return GetNodeDefTypeInfoHelper(sess, get_outputs_fn, index, out);
}
ORT_API_STATUS_IMPL(OrtApis::SessionGetOverridableInitializerTypeInfo, _In_ const OrtSession* sess, size_t index, _Outptr_ struct OrtTypeInfo** out) {
return GetNodeDefTypeInfoHelper(sess, get_overridable_initializers_fn, index, out);
}
char* onnxruntime::StrDup(const std::string& str, OrtAllocator* allocator) {
char* output_string = reinterpret_cast<char*>(allocator->Alloc(allocator, str.size() + 1));
memcpy(output_string, str.c_str(), str.size());
output_string[str.size()] = '\0';
return output_string;
}
static ORT_STATUS_PTR GetNodeDefNameImpl(_In_ const OrtSession* sess, size_t index, _Inout_ OrtAllocator* allocator,
GetDefListFn get_fn, _Outptr_ char** output) {
auto session = reinterpret_cast<const ::onnxruntime::InferenceSession*>(sess);
std::pair<Status, const InputDefList*> p = get_fn(session);
if (!p.first.IsOK())
return ToOrtStatus(p.first);
if (p.second == nullptr)
return OrtApis::CreateStatus(ORT_FAIL, "internal error");
const InputDefList& defs = *p.second;
if (index >= defs.size())
return OrtApis::CreateStatus(ORT_FAIL, "index out of range");
*output = StrDup(defs[index]->Name(), allocator);
return nullptr;
}
ORT_API_STATUS_IMPL(OrtApis::SessionEndProfiling, _In_ OrtSession* sess, _Inout_ OrtAllocator* allocator,
_Outptr_ char** out) {
API_IMPL_BEGIN
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);
auto profile_file_name = session->EndProfiling();
*out = StrDup(profile_file_name, allocator);
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::SessionGetModelMetadata, _In_ const OrtSession* sess,
_Outptr_ OrtModelMetadata** out) {
API_IMPL_BEGIN
auto session = reinterpret_cast<const ::onnxruntime::InferenceSession*>(sess);
auto p = session->GetModelMetadata();
if (!p.first.IsOK())
return ToOrtStatus(p.first);
GSL_SUPPRESS(r .11)
*out = reinterpret_cast<OrtModelMetadata*>(new ModelMetadata(*p.second));
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetProducerName,
_In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _Outptr_ char** value) {
API_IMPL_BEGIN
auto producer_name = reinterpret_cast<const ::onnxruntime::ModelMetadata*>(model_metadata)->producer_name;
*value = StrDup(producer_name, allocator);
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetGraphName,
_In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _Outptr_ char** value) {
API_IMPL_BEGIN
auto graph_name = reinterpret_cast<const ::onnxruntime::ModelMetadata*>(model_metadata)->graph_name;
*value = StrDup(graph_name, allocator);
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetDomain,
_In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _Outptr_ char** value) {
API_IMPL_BEGIN
auto domain = reinterpret_cast<const ::onnxruntime::ModelMetadata*>(model_metadata)->domain;
*value = StrDup(domain, allocator);
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetDescription,
_In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _Outptr_ char** value) {
API_IMPL_BEGIN
auto description = reinterpret_cast<const ::onnxruntime::ModelMetadata*>(model_metadata)->description;
*value = StrDup(description, allocator);
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetGraphDescription,
_In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _Outptr_ char** value) {
API_IMPL_BEGIN
auto description = reinterpret_cast<const ::onnxruntime::ModelMetadata*>(model_metadata)->graph_description;
*value = StrDup(description, allocator);
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::ModelMetadataLookupCustomMetadataMap, _In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _In_ const char* key, _Outptr_result_maybenull_ char** value) {
API_IMPL_BEGIN
auto custom_metadata_map =
reinterpret_cast<const ::onnxruntime::ModelMetadata*>(model_metadata)->custom_metadata_map;
std::string temp(key);
auto iter = custom_metadata_map.find(temp);
if (iter == custom_metadata_map.end()) {
*value = nullptr;
} else {
*value = StrDup(iter->second, allocator);
}
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetCustomMetadataMapKeys,
_In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _Outptr_result_buffer_maybenull_(*num_keys) char*** keys, _Out_ int64_t* num_keys) {
API_IMPL_BEGIN
const auto& custom_metadata_map =
reinterpret_cast<const ::onnxruntime::ModelMetadata*>(model_metadata)->custom_metadata_map;
auto count = custom_metadata_map.size();
if (count == 0) {
*keys = nullptr;
} else {
// To guard against overflow in the next step where we compute bytes to allocate
SafeInt<size_t> alloc_count(count);
InlinedVector<Ort::AllocatedStringPtr> string_holders;
string_holders.reserve(count);
auto deletor = Ort::detail::AllocatedFree(allocator);
// alloc_count * sizeof(...) will throw if there was an overflow which will be caught in API_IMPL_END
// and be returned to the user as a status
char** p = reinterpret_cast<char**>(allocator->Alloc(allocator, alloc_count * sizeof(char*)));
assert(p != nullptr);
// StrDup may throw
std::unique_ptr<void, decltype(deletor)> array_guard(p, deletor);
int64_t i = 0;
for (const auto& e : custom_metadata_map) {
auto* s = StrDup(e.first, allocator);
string_holders.push_back(Ort::AllocatedStringPtr(s, deletor));
p[i++] = s;
}
for (auto& s : string_holders) {
s.release();
}
*keys = p;
array_guard.release();
}
*num_keys = static_cast<int64_t>(count);
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetVersion,
_In_ const OrtModelMetadata* model_metadata,
_Out_ int64_t* value) {
API_IMPL_BEGIN
*value = reinterpret_cast<const ::onnxruntime::ModelMetadata*>(model_metadata)->version;
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::SessionGetInputName, _In_ const OrtSession* sess, size_t index,
_Inout_ OrtAllocator* allocator, _Outptr_ char** output) {
API_IMPL_BEGIN
return GetNodeDefNameImpl(sess, index, allocator, get_inputs_fn, output);
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::SessionGetOutputName, _In_ const OrtSession* sess, size_t index,
_Inout_ OrtAllocator* allocator, _Outptr_ char** output) {
API_IMPL_BEGIN
return GetNodeDefNameImpl(sess, index, allocator, get_outputs_fn, output);
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::SessionGetOverridableInitializerName, _In_ const OrtSession* sess, size_t index,
_Inout_ OrtAllocator* allocator, _Outptr_ char** output) {
API_IMPL_BEGIN
return GetNodeDefNameImpl(sess, index, allocator, get_overridable_initializers_fn, output);
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::AllocatorAlloc, _Inout_ OrtAllocator* ptr, size_t size, _Outptr_ void** out) {
API_IMPL_BEGIN
*out = ptr->Alloc(ptr, size);
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::AllocatorFree, _Inout_ OrtAllocator* ptr, void* p) {
API_IMPL_BEGIN
ptr->Free(ptr, p);
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::AllocatorGetInfo, _In_ const OrtAllocator* ptr, _Outptr_ const struct OrtMemoryInfo** out) {
API_IMPL_BEGIN
*out = ptr->Info(ptr);
return nullptr;
API_IMPL_END
}
template <typename T>
ORT_STATUS_PTR OrtGetNumSequenceElements(const OrtValue* p_ml_value, size_t* out) {
auto& data = p_ml_value->Get<T>();
*out = data.size();
return nullptr;
}
#if !defined(DISABLE_ML_OPS)
static constexpr int NUM_MAP_INDICES = 2;
#endif
static ORT_STATUS_PTR OrtGetValueCountImpl(const OrtValue* value, size_t* out) {
ONNXType value_type;
if (auto status = OrtApis::GetValueType(value, &value_type))
return status;
if (value_type == ONNX_TYPE_MAP) {
#if !defined(DISABLE_ML_OPS)
*out = NUM_MAP_INDICES;
return nullptr;
#else
return OrtApis::CreateStatus(ORT_FAIL, "Map type is not supported in this build.");
#endif
}
if (value_type == ONNX_TYPE_SEQUENCE) {
// Note: keep these in sync with the registered types in data_types.h
if (value->IsTensorSequence()) {
*out = value->Get<TensorSeq>().Size();
return nullptr;
} else {
#if !defined(DISABLE_ML_OPS)
utils::ContainerChecker c_checker(value->Type());
if (c_checker.IsSequenceOf<std::map<std::string, float>>()) {
return OrtGetNumSequenceElements<VectorMapStringToFloat>(value, out);
} else if (c_checker.IsSequenceOf<std::map<int64_t, float>>()) {
return OrtGetNumSequenceElements<VectorMapInt64ToFloat>(value, out);
} else {
return OrtApis::CreateStatus(ORT_FAIL, "Input is not of one of the supported sequence types.");
}
#else
return OrtApis::CreateStatus(ORT_FAIL, "Map type is not supported in this build.");
#endif
}
} else {
return OrtApis::CreateStatus(ORT_FAIL, "Input is not of type sequence or map.");
}
}
ORT_API_STATUS_IMPL(OrtApis::GetValueCount, _In_ const OrtValue* value, _Out_ size_t* out) {
API_IMPL_BEGIN
return OrtGetValueCountImpl(value, out);
API_IMPL_END
}
namespace c_api_internal {
#if !defined(DISABLE_ML_OPS)
///////////////////
// OrtGetValueImplSeqOfMap
template <typename T>
static ORT_STATUS_PTR OrtGetValueImplSeqOfMap(const OrtValue* p_ml_value, int index, _Outptr_ OrtValue** out) {
using TKey = typename T::value_type::key_type;
using TVal = typename T::value_type::mapped_type;
using MapType = std::map<TKey, TVal>;
auto& data_vec = p_ml_value->Get<T>();
auto& data_elem = data_vec.at(index);
auto copy_data_elem = std::make_unique<MapType>(data_elem);
auto value = std::make_unique<OrtValue>();
auto ml_type = DataTypeImpl::GetType<MapType>();
value->Init(copy_data_elem.release(),
ml_type,
ml_type->GetDeleteFunc());
*out = value.release();
return nullptr;
}
#endif
ORT_STATUS_PTR PopulateTensorWithData(Tensor& tensor, bool is_string, _In_ const void* data_elem, size_t num_elems,
size_t elem_size) {
auto len = narrow<size_t>(tensor.Shape().Size());
if (num_elems < len) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "input array is too short");
}
if (!is_string) {
memcpy(tensor.MutableDataRaw(), data_elem, elem_size * num_elems);
} else {
const std::string* strings = reinterpret_cast<const std::string*>(data_elem);
auto str_span = gsl::make_span(strings, num_elems);
auto* dst = tensor.MutableData<std::string>();
std::copy(str_span.begin(), str_span.end(), dst);
}
return nullptr;
}
ORT_STATUS_PTR CreateTensorAndPopulate(MLDataType element_type, const int64_t* shape, size_t shape_len,
const void* data, size_t num_elements, _Inout_ OrtAllocator* allocator, OrtValue& result) {
ORT_API_RETURN_IF_ERROR(CreateTensorImpl(element_type, shape, shape_len, allocator, result));
ORT_API_RETURN_IF_ERROR(PopulateTensorWithData(*result.GetMutable<Tensor>(), utils::IsDataTypeString(element_type),
data, num_elements, element_type->Size()));
return nullptr;
}
} // namespace c_api_internal
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable : 6101)
#endif
static ORT_STATUS_PTR OrtGetValueImplSeqOfTensors(_In_ const OrtValue* p_ml_value, int index, _Inout_ OrtAllocator* allocator,
_Outptr_ OrtValue** out) {
const auto& data = p_ml_value->Get<TensorSeq>();
const auto& one_tensor = data.Get(index);
const auto& tensor_shape = one_tensor.Shape();
auto result = std::make_unique<OrtValue>();
ORT_API_RETURN_IF_ERROR(c_api_internal::CreateTensorAndPopulate(one_tensor.DataType(), tensor_shape.GetDims().data(),
tensor_shape.NumDimensions(), one_tensor.DataRaw(),
narrow<size_t>(one_tensor.Shape().Size()),
allocator, *result));
*out = result.release();
return nullptr;
}
#ifdef _MSC_VER
#pragma warning(pop)
#endif
static ORT_STATUS_PTR OrtGetValueImplSeq(_In_ const OrtValue* value, int index, _Inout_ OrtAllocator* allocator,
_Outptr_ OrtValue** out) {
// Note: keep these in sync with the registered types in data_types.h
if (value->IsTensorSequence()) {
return OrtGetValueImplSeqOfTensors(value, index, allocator, out);
} else {
#if !defined(DISABLE_ML_OPS)
utils::ContainerChecker c_checker(value->Type());
if (c_checker.IsSequenceOf<std::map<std::string, float>>()) {
return c_api_internal::OrtGetValueImplSeqOfMap<VectorMapStringToFloat>(value, index, out);
} else if (c_checker.IsSequenceOf<std::map<int64_t, float>>()) {
return c_api_internal::OrtGetValueImplSeqOfMap<VectorMapInt64ToFloat>(value, index, out);
} else {
return OrtApis::CreateStatus(ORT_FAIL, "Input is not of one of the supported sequence types.");
}
#else
return OrtApis::CreateStatus(ORT_FAIL, "Map type is not supported in this build.");
#endif
}
}
#if !defined(DISABLE_ML_OPS)
template <typename T>
static ORT_STATUS_PTR OrtGetValueImplMapHelper(_In_ const OrtValue* p_ml_value, int index,
_Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out) {
using namespace onnxruntime::utils;
using TKey = typename T::key_type;
using TVal = typename T::mapped_type;
auto& data = p_ml_value->Get<T>();
int64_t num_kv_pairs = data.size();
#if defined(_WIN32) && !defined(_M_AMD64)
ORT_ENFORCE(static_cast<uint64_t>(num_kv_pairs) < std::numeric_limits<size_t>::max());
#endif
const std::vector<int64_t> dims{num_kv_pairs};
auto result = std::make_unique<OrtValue>();
std::vector<TKey> vec_keys;
std::vector<TVal> vec_vals;
const void* data_ptr;
size_t data_size;
MLDataType element_type;
switch (index) {
case 0: { // user is requesting keys
element_type = DataTypeImpl::TensorTypeFromONNXEnum(GetONNXTensorElementDataType<TKey>())->GetElementType();
vec_keys.reserve(static_cast<size_t>(num_kv_pairs));
std::transform(data.cbegin(), data.cend(), std::back_inserter(vec_keys), [](const auto& k) { return k.first; });
data_ptr = vec_keys.data();
data_size = vec_keys.size();
} break;
case 1: { // user is requesting values
element_type = DataTypeImpl::TensorTypeFromONNXEnum(GetONNXTensorElementDataType<TVal>())->GetElementType();
vec_vals.reserve(static_cast<size_t>(num_kv_pairs));
std::transform(data.cbegin(), data.cend(), std::back_inserter(vec_vals), [](const auto& k) { return k.second; });
data_ptr = vec_vals.data();
data_size = vec_vals.size();
} break;
default:
return OrtApis::CreateStatus(ORT_FAIL, "Invalid index requested for map type.");
}
ORT_API_RETURN_IF_ERROR(c_api_internal::CreateTensorAndPopulate(element_type, dims.data(), dims.size(), data_ptr,
data_size, allocator, *result));
*out = result.release();
return nullptr;
}
static ORT_STATUS_PTR OrtGetValueImplMap(_In_ const OrtValue* value, int index, _Inout_ OrtAllocator* allocator,
_Outptr_ OrtValue** out) {
auto p_ml_value = reinterpret_cast<const OrtValue*>(value);
auto type = p_ml_value->Type();
// Note: keep these in sync with the registered types in data_types.h
utils::ContainerChecker c_checker(type);
if (c_checker.IsMap()) {
if (c_checker.IsMapOf<std::string, std::string>()) {
return OrtGetValueImplMapHelper<MapStringToString>(p_ml_value, index, allocator, out);
} else if (c_checker.IsMapOf<std::string, int64_t>()) {
return OrtGetValueImplMapHelper<MapStringToInt64>(p_ml_value, index, allocator, out);
} else if (c_checker.IsMapOf<std::string, float>()) {
return OrtGetValueImplMapHelper<MapStringToFloat>(p_ml_value, index, allocator, out);
} else if (c_checker.IsMapOf<std::string, double>()) {
return OrtGetValueImplMapHelper<MapStringToDouble>(p_ml_value, index, allocator, out);
} else if (c_checker.IsMapOf<int64_t, std::string>()) {
return OrtGetValueImplMapHelper<MapInt64ToString>(p_ml_value, index, allocator, out);
} else if (c_checker.IsMapOf<int64_t, int64_t>()) {
return OrtGetValueImplMapHelper<MapInt64ToInt64>(p_ml_value, index, allocator, out);
} else if (c_checker.IsMapOf<int64_t, float>()) {
return OrtGetValueImplMapHelper<MapInt64ToFloat>(p_ml_value, index, allocator, out);
} else if (c_checker.IsMapOf<int64_t, double>()) {
return OrtGetValueImplMapHelper<MapInt64ToDouble>(p_ml_value, index, allocator, out);
}
}
return OrtApis::CreateStatus(ORT_FAIL, "Input is not of one of the supported map types.");
}
#endif
static ORT_STATUS_PTR OrtGetValueImpl(_In_ const OrtValue* value, int index, _Inout_ OrtAllocator* allocator,
_Outptr_ OrtValue** out) {
ONNXType value_type;
if (auto status = OrtApis::GetValueType(value, &value_type))
return status;
if (value_type == ONNX_TYPE_MAP) {
#if !defined(DISABLE_ML_OPS)
return OrtGetValueImplMap(value, index, allocator, out);
#else
return OrtApis::CreateStatus(ORT_FAIL, "Map type is not supported in this build.");
#endif
}
if (value_type == ONNX_TYPE_SEQUENCE) {
return OrtGetValueImplSeq(value, index, allocator, out);
} else {
return OrtApis::CreateStatus(ORT_FAIL, "Input is not of type sequence or map.");
}
}
ORT_API_STATUS_IMPL(OrtApis::GetValue, _In_ const OrtValue* value, int index, _Inout_ OrtAllocator* allocator,
_Outptr_ OrtValue** out) {
API_IMPL_BEGIN
return OrtGetValueImpl(value, index, allocator, out);
API_IMPL_END
}
///////////////////
// OrtCreateValue
#if !defined(DISABLE_ML_OPS)
template <typename T>
static ORT_STATUS_PTR OrtCreateValueImplSeqHelperMap(const OrtValue* const* in, size_t num_values,
_Outptr_ OrtValue** out) {
using SeqType = std::vector<T>;
auto seq_ptr = std::make_unique<SeqType>();
seq_ptr->reserve(num_values);
for (size_t idx = 0; idx < num_values; ++idx) {
auto& m = reinterpret_cast<const OrtValue*>(in[idx])->Get<T>();
seq_ptr->push_back(m);
}
// create OrtValue with this vector
auto value = std::make_unique<OrtValue>();
auto ml_type = DataTypeImpl::GetType<SeqType>();
value->Init(seq_ptr.release(),
ml_type,
ml_type->GetDeleteFunc());
*out = value.release();
return nullptr;
}
#endif
static ORT_STATUS_PTR OrtCreateValueImplSeqHelper(const OrtValue* const* in, size_t num_values,
_Outptr_ OrtValue** out) {
using namespace c_api_internal;
auto dtype = in[0]->Get<Tensor>().DataType();
auto seq_ptr = std::make_unique<TensorSeq>(dtype);
seq_ptr->Reserve(num_values);
for (size_t idx = 0; idx < num_values; ++idx) {
ORT_ENFORCE(in[idx]->IsTensor(), "Expecting all elements to be tensors. Got: ", DataTypeImpl::ToString(in[idx]->Type()));
auto tensor_elem_type = in[idx]->Get<Tensor>().DataType();
// sequences must have tensors of the same data type
if (tensor_elem_type != dtype) {
return OrtApis::CreateStatus(ORT_FAIL,
"Sequences must have tensors of the same data type. There was at least one tensor in the input that was different.");
}
seq_ptr->Add(*in[idx]);
}
// create OrtValue with this vector
auto value = std::make_unique<OrtValue>();
auto ml_type = DataTypeImpl::GetType<TensorSeq>();
value->Init(seq_ptr.release(),
ml_type,
ml_type->GetDeleteFunc());
*out = value.release();
return nullptr;
}
static ORT_STATUS_PTR OrtCreateValueImplSeq(_In_reads_(num_values) const OrtValue* const* in, size_t num_values,
_Outptr_ OrtValue** out) {
// We only support limited sequence types. For the sake of simplicity the type of the first
// OrtValue* in OrtValue** will determine the type of the vector used to create the output OrtValue
// this type should be either a tensor of limited types or map of limited types
const OrtValue* ovfirst = in[0];
ONNXType first_value_type;
if (auto status = OrtApis::GetValueType(ovfirst, &first_value_type))
return status;
// in onnxruntime type registrations we can support only a fixed vector types
// this check ensures that the input conforms to that
if (!(first_value_type == ONNX_TYPE_TENSOR || first_value_type == ONNX_TYPE_MAP)) {
return OrtApis::CreateStatus(ORT_FAIL, "Each element of the sequence should be either tensor or map.");
}
// check if all OrtValues in the input array are of the same type
// this is because even though the ONNX spec and this API spec supports heterogenous sequences,
// only a fixed types are registered in onnxruntime
for (size_t i = 0; i < num_values; ++i) {
const OrtValue* ov = in[i];
ONNXType ov_type;
if (auto status = OrtApis::GetValueType(ov, &ov_type))
return status;
if (ov_type != first_value_type) {
return OrtApis::CreateStatus(ORT_FAIL,
"At least one element in the sequence is of a type different from others.");
}
}
// finally create the output vector/MLValue
auto first_mlvalue = reinterpret_cast<const OrtValue*>(ovfirst);
if (first_value_type == ONNX_TYPE_TENSOR) {
return OrtCreateValueImplSeqHelper(in, num_values, out);
} else if (first_value_type == ONNX_TYPE_MAP) {
#if !defined(DISABLE_ML_OPS)
auto map_type = first_mlvalue->Type();
utils::ContainerChecker c_checker(map_type);
if (c_checker.IsMapOf<std::string, float>()) {
return OrtCreateValueImplSeqHelperMap<MapStringToFloat>(in, num_values, out);
}
if (c_checker.IsMapOf<int64_t, float>()) {
return OrtCreateValueImplSeqHelperMap<MapInt64ToFloat>(in, num_values, out);
} else {
return OrtApis::CreateStatus(ORT_FAIL, "Input is not of one of the supported map types.");
}
#else
ORT_UNUSED_PARAMETER(first_mlvalue);
return OrtApis::CreateStatus(ORT_FAIL, "Map type is not supported in this build.");
#endif
} else {
return OrtApis::CreateStatus(ORT_FAIL, "Unsupported input type");
}
}
#if !defined(DISABLE_ML_OPS)
template <typename KeyType, typename ValueType>
static OrtStatus* OrtCreateMapMLValue(const Tensor& key_tensor, const Tensor& value_tensor, _Outptr_ OrtValue** out) {
using MapType = std::map<KeyType, ValueType>;
auto map_ptr = std::make_unique<MapType>();
// iterate through the key and value tensors and populate map
auto key_data = key_tensor.Data<KeyType>();
auto value_data = value_tensor.Data<ValueType>();
auto len = key_tensor.Shape().Size();
ORT_ENFORCE(len >= 0 && static_cast<uint64_t>(len) < std::numeric_limits<size_t>::max());
size_t num_kv_pairs = static_cast<size_t>(key_tensor.Shape().Size());
for (size_t n = 0; n < num_kv_pairs; ++n, ++key_data, ++value_data) {
map_ptr->insert({*key_data, *value_data});
}
// create ort_value with this map
auto value = std::make_unique<OrtValue>();
auto ml_type = DataTypeImpl::GetType<MapType>();
value->Init(map_ptr.release(),
ml_type,
ml_type->GetDeleteFunc());
*out = value.release();
return nullptr;
}
template <typename KeyType>
static ORT_STATUS_PTR OrtCreateValueImplMapHelper(const Tensor& key_tensor, const Tensor& value_tensor,
_Outptr_ OrtValue** out) {
auto value_type = value_tensor.DataType()->AsPrimitiveDataType();
ORT_ENFORCE(value_type != nullptr, "Tensor must always contain primitive types. Found: ",
DataTypeImpl::ToString(value_tensor.DataType()));
switch (value_type->GetDataType()) {
case ONNX_NAMESPACE::TensorProto_DataType_STRING:
return OrtCreateMapMLValue<KeyType, std::string>(key_tensor, value_tensor, out);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
return OrtCreateMapMLValue<KeyType, int64_t>(key_tensor, value_tensor, out);
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
return OrtCreateMapMLValue<KeyType, float>(key_tensor, value_tensor, out);
break;
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
return OrtCreateMapMLValue<KeyType, double>(key_tensor, value_tensor, out);
break;
default:
break;
}
std::string msg("Value type is not supported yet: ");
msg += DataTypeImpl::ToString(value_tensor.DataType());
return OrtApis::CreateStatus(ORT_FAIL, msg.c_str());
}
static ORT_STATUS_PTR OrtCreateValueImplMap(const OrtValue* const* in, size_t num_values, _Outptr_ OrtValue** out) {
if (num_values != NUM_MAP_INDICES) {
return OrtApis::CreateStatus(ORT_FAIL, "For map type num_values MUST be 2");
}
const OrtValue* ort_keys = in[0];
auto p_key_ml_value = reinterpret_cast<const OrtValue*>(ort_keys);
auto& key_tensor = p_key_ml_value->Get<Tensor>();
const OrtValue* ort_values = in[1];
auto p_value_ml_value = reinterpret_cast<const OrtValue*>(ort_values);
auto& value_tensor = p_value_ml_value->Get<Tensor>();
// as per data_types.h, we only support maps of primitive data types.
if (key_tensor.Shape().NumDimensions() > 1 || value_tensor.Shape().NumDimensions() > 1) {
return OrtApis::CreateStatus(ORT_FAIL, "Either the key tensor or the value tensor has NumDimensions > 1");
}
// since maps are represented by key and value tensors, their sizes have to be the same.
if (key_tensor.Shape().Size() != value_tensor.Shape().Size()) {
return OrtApis::CreateStatus(ORT_FAIL, "Key and value tensors have unequal number of elements.");
}
if (key_tensor.IsDataTypeString()) {
return OrtCreateValueImplMapHelper<std::string>(key_tensor, value_tensor, out);
}
if (key_tensor.IsDataType<int64_t>()) {
return OrtCreateValueImplMapHelper<int64_t>(key_tensor, value_tensor, out);
}
return OrtApis::CreateStatus(ORT_FAIL, "Key type is not supported yet.");
}
#endif
static ORT_STATUS_PTR OrtCreateValueImpl(_In_reads_(num_values) const OrtValue* const* in, size_t num_values,
enum ONNXType value_type, _Outptr_ OrtValue** out) {
if (num_values <= 0) {
return OrtApis::CreateStatus(ORT_FAIL, "Number of values should be at least 1.");
}
if (value_type == ONNX_TYPE_MAP) {
#if !defined(DISABLE_ML_OPS)
return OrtCreateValueImplMap(in, num_values, out);
#else
return OrtApis::CreateStatus(ORT_FAIL, "Map type is not supported in this build.");
#endif
}
if (value_type == ONNX_TYPE_SEQUENCE) {
return OrtCreateValueImplSeq(in, num_values, out);
}
return OrtApis::CreateStatus(ORT_FAIL, "Input is not of type sequence or map.");
}
ORT_API_STATUS_IMPL(OrtApis::CreateValue, _In_reads_(num_values) const OrtValue* const* in, size_t num_values,
enum ONNXType value_type, _Outptr_ OrtValue** out) {
API_IMPL_BEGIN
return OrtCreateValueImpl(in, num_values, value_type, out);
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::CreateOpaqueValue, _In_z_ const char* domain_name, _In_z_ const char* type_name,
_In_ const void* data_container, size_t data_container_size, _Outptr_ OrtValue** out) {
API_IMPL_BEGIN
std::string dtype("opaque(");
dtype.append(domain_name).append(",").append(type_name).append(")");
MLDataType ml_type = DataTypeImpl::GetDataType(dtype);
ORT_ENFORCE(ml_type != nullptr,
"Specified domain and type names combination does not refer to a registered opaque type");
const auto* non_tensor_base = ml_type->AsNonTensorType();
ORT_ENFORCE(non_tensor_base != nullptr, "Opaque type is not a non_tensor type!!!");
std::unique_ptr<OrtValue> ort_val = std::make_unique<OrtValue>();
non_tensor_base->FromDataContainer(data_container, data_container_size, *ort_val);
*out = ort_val.release();
API_IMPL_END
return nullptr;
}
ORT_API_STATUS_IMPL(OrtApis::GetOpaqueValue, _In_ const char* domain_name, _In_ const char* type_name,
_In_ const OrtValue* in, _Out_ void* data_container, size_t data_container_size) {
API_IMPL_BEGIN
std::string dtype("opaque(");
dtype.append(domain_name).append(",").append(type_name).append(")");
MLDataType ml_type = DataTypeImpl::GetDataType(dtype);
ORT_ENFORCE(ml_type != nullptr,
"Specified domain and type names combination does not refer to a registered opaque type");
const auto* non_tensor_base = ml_type->AsNonTensorType();
ORT_ENFORCE(non_tensor_base != nullptr, "Opaque type is not a non_tensor type!!!");
non_tensor_base->ToDataContainer(*in, data_container_size, data_container);
API_IMPL_END
return nullptr;
}
GSL_SUPPRESS(r .11)
ORT_API_STATUS_IMPL(OrtApis::GetAvailableProviders, _Outptr_ char*** out_ptr,
_In_ int* providers_length) {
API_IMPL_BEGIN
// TODO: there is no need to manually malloc/free these memory, it is insecure
// and inefficient. Instead, the implementation could scan the array twice,
// and use a single string object to hold all the names.
constexpr size_t MAX_LEN = 30;
const auto& available_providers = GetAvailableExecutionProviderNames();
const int available_count = narrow<int>(available_providers.size());
GSL_SUPPRESS(r .11)
char** const out = new char*[available_count];
if (out) {
for (int i = 0; i < available_count; i++) {
GSL_SUPPRESS(r .11)
out[i] = new char[MAX_LEN + 1];
#ifdef _MSC_VER
strncpy_s(out[i], MAX_LEN, available_providers[i].c_str(), MAX_LEN);
out[i][MAX_LEN] = '\0';
#elif defined(__APPLE__)
strlcpy(out[i], available_providers[i].c_str(), MAX_LEN);
#else
strncpy(out[i], available_providers[i].c_str(), MAX_LEN);
out[i][MAX_LEN] = '\0';
#endif
}
}
*providers_length = available_count;
*out_ptr = out;
API_IMPL_END
return nullptr;
}
// TODO: we don't really need the second parameter
ORT_API_STATUS_IMPL(OrtApis::ReleaseAvailableProviders, _In_ char** ptr,
_In_ int providers_length) {
API_IMPL_BEGIN
if (ptr) {
for (int i = 0; i < providers_length; i++) {
GSL_SUPPRESS(r .11)
delete[] ptr[i];
}
GSL_SUPPRESS(r .11)
delete[] ptr;
}
API_IMPL_END
return NULL;
}
ORT_API_STATUS_IMPL(OrtApis::GetExecutionProviderApi,
[[maybe_unused]] _In_ const char* provider_name,
[[maybe_unused]] _In_ uint32_t version,
_Outptr_ const void** provider_api) {
API_IMPL_BEGIN
*provider_api = nullptr;
#ifdef USE_DML
if (strcmp(provider_name, "DML") == 0) {
*provider_api = GetOrtDmlApi(version);
if (*provider_api == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified version is not supported for the DirectML provider.");
}
return NULL;
}
#endif
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified provider is not supported.");
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::TensorAt, _Inout_ OrtValue* value, const int64_t* location_values, size_t location_values_count,
_Outptr_ void** out) {
TENSOR_READWRITE_API_BEGIN
if (tensor->IsDataTypeString()) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "this API does not support strings");
}
const auto& tensor_shape = tensor->Shape();
const auto num_dimensions = tensor_shape.NumDimensions();
if (location_values_count != num_dimensions) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "location dimensions do not match shape size");
}
for (size_t i = 0; i < location_values_count; i++) {
if (location_values[i] >= tensor_shape[i] || location_values[i] < 0) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "invalid location range");
}
}
// compute strides
// TensorPitches p;
std::vector<int64_t> strides(num_dimensions);
{
int64_t stride = 1;
for (size_t dim = num_dimensions; dim > 0; --dim) {
strides[dim - 1] = stride;
stride *= tensor_shape[dim - 1];
}
}
// For Scalers the offset would always be zero
int64_t offset = 0;
for (size_t i = 0; i < num_dimensions; i++) {
offset += location_values[i] * strides[i];
}
auto data = reinterpret_cast<char*>(tensor->MutableDataRaw()) + tensor->DataType()->Size() * offset;
*out = data;
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::SetLanguageProjection, _In_ const OrtEnv* ort_env, _In_ OrtLanguageProjection projection) {
API_IMPL_BEGIN
ORT_UNUSED_PARAMETER(ort_env);
// note telemetry is controlled via the platform Env object, not the OrtEnv object instance
const Env& env = Env::Default();
env.GetTelemetryProvider().SetLanguageProjection(static_cast<uint32_t>(projection));
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::SessionGetProfilingStartTimeNs, _In_ const OrtSession* sess, _Out_ uint64_t* out) {
API_IMPL_BEGIN
const auto* session = reinterpret_cast<const ::onnxruntime::InferenceSession*>(sess);
auto profiling_start_time = session->GetProfiling().GetStartTimeNs();
*out = static_cast<uint64_t>(profiling_start_time);
return nullptr;
API_IMPL_END
}
// End support for non-tensor types
ORT_API_STATUS_IMPL(OrtApis::CreateArenaCfg, _In_ size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes,
int max_dead_bytes_per_chunk, _Outptr_ OrtArenaCfg** out) {
API_IMPL_BEGIN
GSL_SUPPRESS(r .11)
*out = new OrtArenaCfg();
(*out)->max_mem = max_mem;
(*out)->arena_extend_strategy = arena_extend_strategy;
(*out)->initial_chunk_size_bytes = initial_chunk_size_bytes;
(*out)->max_dead_bytes_per_chunk = max_dead_bytes_per_chunk;
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::CreateArenaCfgV2, _In_reads_(num_keys) const char* const* arena_config_keys, _In_reads_(num_keys) const size_t* arena_config_values,
_In_ size_t num_keys, _Outptr_ OrtArenaCfg** out) {
API_IMPL_BEGIN
auto cfg = std::make_unique<OrtArenaCfg>();
for (size_t i = 0; i < num_keys; ++i) {
if (strcmp(arena_config_keys[i], "max_mem") == 0) {
cfg->max_mem = arena_config_values[i];
} else if (strcmp(arena_config_keys[i], "arena_extend_strategy") == 0) {
cfg->arena_extend_strategy = static_cast<int>(arena_config_values[i]);
} else if (strcmp(arena_config_keys[i], "initial_chunk_size_bytes") == 0) {
cfg->initial_chunk_size_bytes = static_cast<int>(arena_config_values[i]);
} else if (strcmp(arena_config_keys[i], "max_dead_bytes_per_chunk") == 0) {
cfg->max_dead_bytes_per_chunk = static_cast<int>(arena_config_values[i]);
} else if (strcmp(arena_config_keys[i], "initial_growth_chunk_size_bytes") == 0) {
cfg->initial_growth_chunk_size_bytes = static_cast<int>(arena_config_values[i]);
} else {
std::ostringstream oss;
oss << "Invalid key found: " << arena_config_keys[i];
return CreateStatus(ORT_INVALID_ARGUMENT, oss.str().c_str());
}
}
*out = cfg.release();
return nullptr;
API_IMPL_END
}
// Allow using raw new/delete because this is for C.
GSL_SUPPRESS(r .11)
ORT_API(void, OrtApis::ReleaseArenaCfg, _Frees_ptr_opt_ OrtArenaCfg* ptr) {
delete ptr;
}
ORT_API_STATUS_IMPL(OrtApis::CreatePrepackedWeightsContainer, _Outptr_ OrtPrepackedWeightsContainer** out) {
API_IMPL_BEGIN
std::unique_ptr<PrepackedWeightsContainer> container = std::make_unique<PrepackedWeightsContainer>();
*out = reinterpret_cast<OrtPrepackedWeightsContainer*>(container.release());
return nullptr;
API_IMPL_END
}
ORT_API(void, OrtApis::ReleasePrepackedWeightsContainer, _Frees_ptr_opt_ OrtPrepackedWeightsContainer* ptr) {
delete reinterpret_cast<PrepackedWeightsContainer*>(ptr);
}
ORT_API_STATUS_IMPL(OrtApis::CreateSessionWithPrepackedWeightsContainer, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path,
_In_ const OrtSessionOptions* options, _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container,
_Outptr_ OrtSession** out) {
API_IMPL_BEGIN
std::unique_ptr<onnxruntime::InferenceSession> sess;
OrtStatus* status = nullptr;
*out = nullptr;
ORT_TRY {
ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, sess));
ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess, prepacked_weights_container));
*out = reinterpret_cast<OrtSession*>(sess.release());
}
ORT_CATCH(const std::exception& e) {
ORT_HANDLE_EXCEPTION([&]() {
status = OrtApis::CreateStatus(ORT_FAIL, e.what());
});
}
return status;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArrayWithPrepackedWeightsContainer, _In_ const OrtEnv* env,
_In_ const void* model_data, size_t model_data_length,
_In_ const OrtSessionOptions* options, _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container,
_Outptr_ OrtSession** out) {
API_IMPL_BEGIN
std::unique_ptr<onnxruntime::InferenceSession> sess;
OrtStatus* status = nullptr;
*out = nullptr;
ORT_TRY {
ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, nullptr, model_data,
model_data_length, sess));
ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess, prepacked_weights_container));
*out = reinterpret_cast<OrtSession*>(sess.release());
}
ORT_CATCH(const std::exception& e) {
ORT_HANDLE_EXCEPTION([&]() {
status = OrtApis::CreateStatus(ORT_FAIL, e.what());
});
}
return status;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::GetTensorMemoryInfo, _In_ const OrtValue* value, _Outptr_ const OrtMemoryInfo** memory_info) {
TENSOR_READ_API_BEGIN
*memory_info = &tensor.Location();
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetCustomCreateThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
API_IMPL_BEGIN
options->value.custom_create_thread_fn = ort_custom_create_thread_fn;
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetCustomThreadCreationOptions, _Inout_ OrtSessionOptions* options, _In_ void* ort_custom_thread_creation_options) {
API_IMPL_BEGIN
options->value.custom_thread_creation_options = ort_custom_thread_creation_options;
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
API_IMPL_BEGIN
options->value.custom_join_thread_fn = ort_custom_join_thread_fn;
return nullptr;
API_IMPL_END
}
ORT_API(const OrtTrainingApi*, OrtApis::GetTrainingApi, uint32_t version) {
#ifdef ENABLE_TRAINING_APIS
return OrtTrainingApis::GetTrainingApi(version);
#else
ORT_UNUSED_PARAMETER(version);
fprintf(stderr,
"Training APIs are not supported with this build. Please build onnxruntime "
"from source with the build flags enable_training_apis to retrieve the training APIs.\n");
return nullptr;
#endif
}
static constexpr OrtApiBase ort_api_base = {
&OrtApis::GetApi,
&OrtApis::GetVersionString,
};
/* Rules on how to add a new Ort API version
In general, NEVER remove or rearrange the members in this structure unless a new version is being created. The
goal is for newer shared libraries of the Onnx Runtime to work with binaries targeting the previous versions.
In order to do that we need to ensure older binaries get the older interfaces they are expecting.
If the next version of the OrtApi only adds members, new members can be added at the end of the OrtApi structure
without breaking anything. In this case, rename the ort_api_# structure in a way that shows the range of versions
it supports, for example 'ort_api_1_to_2', and then GetApi can return the same structure for a range of versions.
If methods need to be removed or rearranged, then make a copy of the OrtApi structure and name it 'OrtApi#to#'.
The latest Api should always be named just OrtApi. Then make a copy of the latest ort_api_* structure below and
name it ort_api_# to match the latest version number supported, you'll need to be sure the structure types match
the API they're for (the compiler should complain if this isn't correct).
If there is no desire to have the headers still expose the older APIs (clutter, documentation, etc) then the
definition should be moved to a file included by this file so that it's still defined here for binary compatibility
but isn't visible in public headers.
So for example, if we wanted to just add some new members to the ort_api_1_to_2, we'd take the following steps:
In include\onnxruntime\core\session\onnxruntime_c_api.h we'd just add the members to the end of the structure
In this file, we'd correspondingly add the member values to the end of the ort_api_1_to_2 structure, and also rename
it to ort_api_1_to_3.
Then in GetApi we'd make it return ort_api_1_to_3 for versions 1 through 3.
Second example, if we wanted to add and remove some members, we'd do this:
In include\onnxruntime\core\session\onnxruntime_c_api.h we'd make a copy of the OrtApi structure and name the
old one OrtApi1to2. In the new OrtApi we'd add or remove any members that we desire.
In this file, we'd create a new copy of ort_api_1_to_2 called ort_api_3 and make the corresponding changes that were
made to the new OrtApi.
In GetApi we now make it return ort_api_3 for version 3.
*/
static constexpr OrtApi ort_api_1_to_14 = {
// NOTE: The ordering of these fields MUST not change after that version has shipped since existing binaries depend on this ordering.
// Shipped as version 1 - DO NOT MODIFY (see above text for more information)
&OrtApis::CreateStatus,
&OrtApis::GetErrorCode,
&OrtApis::GetErrorMessage,
&OrtApis::CreateEnv,
&OrtApis::CreateEnvWithCustomLogger,
&OrtApis::EnableTelemetryEvents,
&OrtApis::DisableTelemetryEvents,
&OrtApis::CreateSession,
&OrtApis::CreateSessionFromArray,
&OrtApis::Run,
&OrtApis::CreateSessionOptions,
&OrtApis::SetOptimizedModelFilePath,
&OrtApis::CloneSessionOptions,
&OrtApis::SetSessionExecutionMode,
&OrtApis::EnableProfiling,
&OrtApis::DisableProfiling,
&OrtApis::EnableMemPattern,
&OrtApis::DisableMemPattern,
&OrtApis::EnableCpuMemArena,
&OrtApis::DisableCpuMemArena,
&OrtApis::SetSessionLogId,
&OrtApis::SetSessionLogVerbosityLevel,
&OrtApis::SetSessionLogSeverityLevel,
&OrtApis::SetSessionGraphOptimizationLevel,
&OrtApis::SetIntraOpNumThreads,
&OrtApis::SetInterOpNumThreads,
&OrtApis::CreateCustomOpDomain,
&OrtApis::CustomOpDomain_Add,
&OrtApis::AddCustomOpDomain,
&OrtApis::RegisterCustomOpsLibrary,
&OrtApis::SessionGetInputCount,
&OrtApis::SessionGetOutputCount,
&OrtApis::SessionGetOverridableInitializerCount,
&OrtApis::SessionGetInputTypeInfo,
&OrtApis::SessionGetOutputTypeInfo,
&OrtApis::SessionGetOverridableInitializerTypeInfo,
&OrtApis::SessionGetInputName,
&OrtApis::SessionGetOutputName,
&OrtApis::SessionGetOverridableInitializerName,
&OrtApis::CreateRunOptions,
&OrtApis::RunOptionsSetRunLogVerbosityLevel,
&OrtApis::RunOptionsSetRunLogSeverityLevel,
&OrtApis::RunOptionsSetRunTag,
&OrtApis::RunOptionsGetRunLogVerbosityLevel,
&OrtApis::RunOptionsGetRunLogSeverityLevel,
&OrtApis::RunOptionsGetRunTag,
&OrtApis::RunOptionsSetTerminate,
&OrtApis::RunOptionsUnsetTerminate,
&OrtApis::CreateTensorAsOrtValue,
&OrtApis::CreateTensorWithDataAsOrtValue,
&OrtApis::IsTensor,
&OrtApis::GetTensorMutableData,
&OrtApis::FillStringTensor,
&OrtApis::GetStringTensorDataLength,
&OrtApis::GetStringTensorContent,
&OrtApis::CastTypeInfoToTensorInfo,
&OrtApis::GetOnnxTypeFromTypeInfo,
&OrtApis::CreateTensorTypeAndShapeInfo,
&OrtApis::SetTensorElementType,
&OrtApis::SetDimensions,
&OrtApis::GetTensorElementType,
&OrtApis::GetDimensionsCount,
&OrtApis::GetDimensions,
&OrtApis::GetSymbolicDimensions,
&OrtApis::GetTensorShapeElementCount,
&OrtApis::GetTensorTypeAndShape,
&OrtApis::GetTypeInfo,
&OrtApis::GetValueType,
&OrtApis::CreateMemoryInfo,
&OrtApis::CreateCpuMemoryInfo,
&OrtApis::CompareMemoryInfo,
&OrtApis::MemoryInfoGetName,
&OrtApis::MemoryInfoGetId,
&OrtApis::MemoryInfoGetMemType,
&OrtApis::MemoryInfoGetType,
&OrtApis::AllocatorAlloc,
&OrtApis::AllocatorFree,
&OrtApis::AllocatorGetInfo,
&OrtApis::GetAllocatorWithDefaultOptions,
&OrtApis::AddFreeDimensionOverride,
&OrtApis::GetValue,
&OrtApis::GetValueCount,
&OrtApis::CreateValue,
&OrtApis::CreateOpaqueValue,
&OrtApis::GetOpaqueValue,
&OrtApis::KernelInfoGetAttribute_float,
&OrtApis::KernelInfoGetAttribute_int64,
&OrtApis::KernelInfoGetAttribute_string,
&OrtApis::KernelContext_GetInputCount,
&OrtApis::KernelContext_GetOutputCount,
&OrtApis::KernelContext_GetInput,
&OrtApis::KernelContext_GetOutput,
&OrtApis::ReleaseEnv,
&OrtApis::ReleaseStatus,
&OrtApis::ReleaseMemoryInfo,
&OrtApis::ReleaseSession,
&OrtApis::ReleaseValue,
&OrtApis::ReleaseRunOptions,
&OrtApis::ReleaseTypeInfo,
&OrtApis::ReleaseTensorTypeAndShapeInfo,
&OrtApis::ReleaseSessionOptions,
&OrtApis::ReleaseCustomOpDomain,
// End of Version 1 - DO NOT MODIFY ABOVE (see above text for more information)
&OrtApis::GetDenotationFromTypeInfo,
&OrtApis::CastTypeInfoToMapTypeInfo,
&OrtApis::CastTypeInfoToSequenceTypeInfo,
&OrtApis::GetMapKeyType,
&OrtApis::GetMapValueType,
&OrtApis::GetSequenceElementType,
&OrtApis::ReleaseMapTypeInfo,
&OrtApis::ReleaseSequenceTypeInfo,
&OrtApis::SessionEndProfiling,
&OrtApis::SessionGetModelMetadata,
&OrtApis::ModelMetadataGetProducerName,
&OrtApis::ModelMetadataGetGraphName,
&OrtApis::ModelMetadataGetDomain,
&OrtApis::ModelMetadataGetDescription,
&OrtApis::ModelMetadataLookupCustomMetadataMap,
&OrtApis::ModelMetadataGetVersion,
&OrtApis::ReleaseModelMetadata,
// End of Version 2 - DO NOT MODIFY ABOVE (see above text for more information)
&OrtApis::CreateEnvWithGlobalThreadPools,
&OrtApis::DisablePerSessionThreads,
&OrtApis::CreateThreadingOptions,
&OrtApis::ReleaseThreadingOptions,
&OrtApis::ModelMetadataGetCustomMetadataMapKeys,
&OrtApis::AddFreeDimensionOverrideByName,
// End of Version 3 - DO NOT MODIFY ABOVE (see above text for more information)
&OrtApis::GetAvailableProviders,
&OrtApis::ReleaseAvailableProviders,
// End of Version 4 - DO NOT MODIFY ABOVE (see above text for more information)
&OrtApis::GetStringTensorElementLength,
&OrtApis::GetStringTensorElement,
&OrtApis::FillStringTensorElement,
&OrtApis::AddSessionConfigEntry,
// IoBinding and above are propagated in the same order to C# API
// Do not move
&OrtApis::CreateAllocator,
&OrtApis::ReleaseAllocator,
&OrtApis::RunWithBinding,
&OrtApis::CreateIoBinding,
&OrtApis::ReleaseIoBinding,
&OrtApis::BindInput,
&OrtApis::BindOutput,
&OrtApis::BindOutputToDevice,
&OrtApis::GetBoundOutputNames,
&OrtApis::GetBoundOutputValues,
&OrtApis::ClearBoundInputs,
&OrtApis::ClearBoundOutputs,
&OrtApis::TensorAt,
&OrtApis::CreateAndRegisterAllocator,
&OrtApis::SetLanguageProjection,
&OrtApis::SessionGetProfilingStartTimeNs,
&OrtApis::SetGlobalIntraOpNumThreads,
&OrtApis::SetGlobalInterOpNumThreads,
&OrtApis::SetGlobalSpinControl,
// End of Version 5 - DO NOT MODIFY ABOVE (see above text for more information)
&OrtApis::AddInitializer,
&OrtApis::CreateEnvWithCustomLoggerAndGlobalThreadPools,
&OrtApis::SessionOptionsAppendExecutionProvider_CUDA,
&OrtApis::SessionOptionsAppendExecutionProvider_ROCM,
&OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO,
&OrtApis::SetGlobalDenormalAsZero,
&OrtApis::CreateArenaCfg,
&OrtApis::ReleaseArenaCfg,
// End of Version 6 - DO NOT MODIFY ABOVE (see above text for more information)
&OrtApis::ModelMetadataGetGraphDescription,
&OrtApis::SessionOptionsAppendExecutionProvider_TensorRT,
&OrtApis::SetCurrentGpuDeviceId,
&OrtApis::GetCurrentGpuDeviceId,
// End of Version 7 - DO NOT MODIFY ABOVE (see above text for more information)
&OrtApis::KernelInfoGetAttributeArray_float,
&OrtApis::KernelInfoGetAttributeArray_int64,
&OrtApis::CreateArenaCfgV2,
&OrtApis::AddRunConfigEntry,
&OrtApis::CreatePrepackedWeightsContainer,
&OrtApis::ReleasePrepackedWeightsContainer,
&OrtApis::CreateSessionWithPrepackedWeightsContainer,
&OrtApis::CreateSessionFromArrayWithPrepackedWeightsContainer,
// End of Version 8 - DO NOT MODIFY ABOVE (see above text for more information)
&OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2,
&OrtApis::CreateTensorRTProviderOptions,
&OrtApis::UpdateTensorRTProviderOptions,
&OrtApis::GetTensorRTProviderOptionsAsString,
&OrtApis::ReleaseTensorRTProviderOptions,
&OrtApis::EnableOrtCustomOps,
&OrtApis::RegisterAllocator,
&OrtApis::UnregisterAllocator,
&OrtApis::IsSparseTensor,
&OrtApis::CreateSparseTensorAsOrtValue,
&OrtApis::FillSparseTensorCoo,
&OrtApis::FillSparseTensorCsr,
&OrtApis::FillSparseTensorBlockSparse,
&OrtApis::CreateSparseTensorWithValuesAsOrtValue,
&OrtApis::UseCooIndices,
&OrtApis::UseCsrIndices,
&OrtApis::UseBlockSparseIndices,
&OrtApis::GetSparseTensorFormat,
&OrtApis::GetSparseTensorValuesTypeAndShape,
&OrtApis::GetSparseTensorValues,
&OrtApis::GetSparseTensorIndicesTypeShape,
&OrtApis::GetSparseTensorIndices,
// End of Version 9 - DO NOT MODIFY ABOVE (see above text for more information)
&OrtApis::HasValue,
&OrtApis::KernelContext_GetGPUComputeStream,
&OrtApis::GetTensorMemoryInfo,
&OrtApis::GetExecutionProviderApi,
&OrtApis::SessionOptionsSetCustomCreateThreadFn,
&OrtApis::SessionOptionsSetCustomThreadCreationOptions,
&OrtApis::SessionOptionsSetCustomJoinThreadFn,
&OrtApis::SetGlobalCustomCreateThreadFn,
&OrtApis::SetGlobalCustomThreadCreationOptions,
&OrtApis::SetGlobalCustomJoinThreadFn,
&OrtApis::SynchronizeBoundInputs,
&OrtApis::SynchronizeBoundOutputs,
// End of Version 10 - DO NOT MODIFY ABOVE (see above text for more information)
&OrtApis::SessionOptionsAppendExecutionProvider_CUDA_V2,
&OrtApis::CreateCUDAProviderOptions,
&OrtApis::UpdateCUDAProviderOptions,
&OrtApis::GetCUDAProviderOptionsAsString,
&OrtApis::ReleaseCUDAProviderOptions,
&OrtApis::SessionOptionsAppendExecutionProvider_MIGraphX,
// End of Version 11 - DO NOT MODIFY ABOVE (see above text for more information)
&OrtApis::AddExternalInitializers,
&OrtApis::CreateOpAttr,
&OrtApis::ReleaseOpAttr,
&OrtApis::CreateOp,
&OrtApis::InvokeOp,
&OrtApis::ReleaseOp,
&OrtApis::SessionOptionsAppendExecutionProvider,
&OrtApis::CopyKernelInfo,
&OrtApis::ReleaseKernelInfo,
// End of Version 12 - DO NOT MODIFY ABOVE (see above text for more information)
// Start of Version 13 API in progress, safe to modify/rename/rearrange until we ship
&OrtApis::GetTrainingApi,
&OrtApis::SessionOptionsAppendExecutionProvider_CANN,
&OrtApis::CreateCANNProviderOptions,
&OrtApis::UpdateCANNProviderOptions,
&OrtApis::GetCANNProviderOptionsAsString,
&OrtApis::ReleaseCANNProviderOptions,
// End of Version 13 - DO NOT MODIFY ABOVE (see above text for more information)
// Start of Version 14 API in progress, safe to modify/rename/rearrange until we ship
&OrtApis::MemoryInfoGetDeviceType,
&OrtApis::UpdateEnvWithCustomLogLevel,
&OrtApis::SetGlobalIntraOpThreadAffinity,
&OrtApis::RegisterCustomOpsLibrary_V2,
&OrtApis::RegisterCustomOpsUsingFunction,
&OrtApis::KernelInfo_GetInputCount,
&OrtApis::KernelInfo_GetOutputCount,
&OrtApis::KernelInfo_GetInputName,
&OrtApis::KernelInfo_GetOutputName,
&OrtApis::KernelInfo_GetInputTypeInfo,
&OrtApis::KernelInfo_GetOutputTypeInfo,
&OrtApis::KernelInfoGetAttribute_tensor,
&OrtApis::HasSessionConfigEntry,
&OrtApis::GetSessionConfigEntry,
};
// Asserts to do a some checks to ensure older Versions of the OrtApi never change (will detect an addition or deletion but not if they cancel out each other)
// If any of these asserts hit, read the above 'Rules on how to add a new Ort API version'
static_assert(offsetof(OrtApi, ReleaseCustomOpDomain) / sizeof(void*) == 101, "Size of version 1 API cannot change");
static_assert(offsetof(OrtApi, ReleaseModelMetadata) / sizeof(void*) == 118, "Size of version 2 API cannot change");
static_assert(offsetof(OrtApi, AddFreeDimensionOverrideByName) / sizeof(void*) == 124,
"Size of version 3 API cannot change");
static_assert(offsetof(OrtApi, ReleaseAvailableProviders) / sizeof(void*) == 126,
"Size of version 4 API cannot change");
static_assert(offsetof(OrtApi, SetGlobalSpinControl) / sizeof(void*) == 149, "Size of version 5 API cannot change");
static_assert(offsetof(OrtApi, ReleaseArenaCfg) / sizeof(void*) == 157, "Size of version 6 API cannot change");
static_assert(offsetof(OrtApi, GetCurrentGpuDeviceId) / sizeof(void*) == 161, "Size of version 7 API cannot change");
static_assert(offsetof(OrtApi, CreateSessionFromArrayWithPrepackedWeightsContainer) / sizeof(void*) == 169, "Size of version 8 API cannot change");
static_assert(offsetof(OrtApi, GetSparseTensorIndices) / sizeof(void*) == 191, "Size of version 9 API cannot change");
static_assert(offsetof(OrtApi, SynchronizeBoundOutputs) / sizeof(void*) == 203, "Size of version 10 API cannot change");
static_assert(offsetof(OrtApi, SessionOptionsAppendExecutionProvider_MIGraphX) / sizeof(void*) == 209, "Size of version 11 API cannot change");
static_assert(offsetof(OrtApi, ReleaseKernelInfo) / sizeof(void*) == 218, "Size of version 12 API cannot change");
static_assert(offsetof(OrtApi, ReleaseCANNProviderOptions) / sizeof(void*) == 224, "Size of version 13 API cannot change");
// So that nobody forgets to finish an API version, this check will serve as a reminder:
static_assert(std::string_view(ORT_VERSION) == "1.14.1",
"ORT_Version change detected, please follow below steps to ensure OrtApi is updated properly");
// 1. Update the hardcoded version string in above static_assert to silence it
// 2. If there were any APIs added to ort_api_1_to_14 above:
// a. Add the 'End of version #' markers (pattern above should be obvious)
// b. Add a static_assert in the directly above list of version sizes to ensure nobody adds any more functions to the just shipped API version
ORT_API(const OrtApi*, OrtApis::GetApi, uint32_t version) {
if (version >= 1 && version <= ORT_API_VERSION)
return &ort_api_1_to_14;
fprintf(stderr, "The given version [%u] is not supported, only version 1 to %u is supported in this build.\n",
version, ORT_API_VERSION);
return nullptr; // Unsupported version
}
ORT_API(const char*, OrtApis::GetVersionString) {
return ORT_VERSION;
}
const OrtApiBase* ORT_API_CALL OrtGetApiBase(void) NO_EXCEPTION {
return &ort_api_base;
}
ORT_API(void, OrtApis::ReleaseEnv, OrtEnv* value) {
OrtEnv::Release(value);
}
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Value, OrtValue)
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(RunOptions, OrtRunOptions)
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Session, ::onnxruntime::InferenceSession)
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(ModelMetadata, ::onnxruntime::ModelMetadata)