diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 674aa4f4c8..f343847f30 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1,6 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// Summary: The Ort C++ API is a header only wrapper around the Ort C API. +// +// The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors +// and automatically releasing resources in the destructors. +// +// Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers. +// To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). +// +// Only move assignment between objects is allowed, there are no copy constructors. Some objects have explicit 'Clone' +// methods for this purpose. + #pragma once #include "onnxruntime_c_api.h" #include @@ -14,6 +25,7 @@ namespace Ort { using std::nullptr_t; +// All C++ methods that can fail will throw an exception of this type struct Exception : std::exception { Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {} @@ -25,14 +37,7 @@ struct Exception : std::exception { OrtErrorCode code_; }; -#define ORT_THROW_ON_ERROR(expr) \ - if (OrtStatus* onnx_status = (expr)) { \ - std::string ort_error_message = OrtGetErrorMessage(onnx_status); \ - OrtErrorCode ort_error_code = OrtGetErrorCode(onnx_status); \ - OrtReleaseStatus(onnx_status); \ - throw Ort::Exception(std::move(ort_error_message), ort_error_code); \ - } - +// This Macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type #define ORT_DEFINE_RELEASE(NAME) \ inline void OrtRelease(Ort##NAME* ptr) { OrtRelease##NAME(ptr); } @@ -47,29 +52,6 @@ ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo); ORT_DEFINE_RELEASE(TypeInfo); ORT_DEFINE_RELEASE(Value); -template -struct TypeToTensorType; -template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; -template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; }; -template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; }; -template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; }; -template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; }; -template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; }; -template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; }; -template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; }; -template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; }; -template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; }; - template struct Base { Base() = default; @@ -261,415 +243,39 @@ struct AllocatorInfo : Base { explicit AllocatorInfo(OrtAllocatorInfo* p) : Base{p} {} }; -} // namespace Ort +// +// Custom OPs (only needed to implement custom OPs) +// -namespace Ort { - -inline Allocator Allocator::CreateDefault() { - OrtAllocator* p; - ORT_THROW_ON_ERROR(OrtCreateDefaultAllocator(&p)); - return Allocator(p); -} - -inline void* Allocator::Alloc(size_t size) { - return OrtAllocatorAlloc(p_, size); -} - -inline void Allocator::Free(void* p) { - OrtAllocatorFree(p_, p); -} - -inline const OrtAllocatorInfo* Allocator::GetInfo() const { - return OrtAllocatorGetInfo(p_); -} - -inline AllocatorInfo AllocatorInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_type) { - OrtAllocatorInfo* p; - ORT_THROW_ON_ERROR(OrtCreateCpuAllocatorInfo(type, mem_type, &p)); - return AllocatorInfo(p); -} - -inline AllocatorInfo::AllocatorInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) { - ORT_THROW_ON_ERROR(OrtCreateAllocatorInfo(name, type, id, mem_type, &p_)); -} - -inline Env::Env(OrtLoggingLevel default_warning_level, _In_ const char* logid) { - ORT_THROW_ON_ERROR(OrtCreateEnv(default_warning_level, logid, &p_)); -} - -inline Env::Env(OrtLoggingLevel default_warning_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) { - ORT_THROW_ON_ERROR(OrtCreateEnvWithCustomLogger(logging_function, logger_param, default_warning_level, logid, &p_)); -} - -inline CustomOpDomain::CustomOpDomain(const char* domain) - : Base{OrtCreateCustomOpDomain(domain)} { -} - -inline void CustomOpDomain::Add(OrtCustomOp* op) { - ORT_THROW_ON_ERROR(OrtCustomOpDomain_Add(p_, op)); -} - -inline RunOptions::RunOptions() : Base{OrtCreateRunOptions()} {} - -inline RunOptions& RunOptions::SetRunLogVerbosityLevel(unsigned int level) { - ORT_THROW_ON_ERROR(OrtRunOptionsSetRunLogVerbosityLevel(p_, level)); - return *this; -} - -inline unsigned int RunOptions::GetRunLogVerbosityLevel() const { - return OrtRunOptionsGetRunLogVerbosityLevel(p_); -} - -inline RunOptions& RunOptions::SetRunTag(const char* run_tag) { - ORT_THROW_ON_ERROR(OrtRunOptionsSetRunTag(p_, run_tag)); - return *this; -} - -inline const char* RunOptions::GetRunTag() const { - return OrtRunOptionsGetRunTag(p_); -} - -inline RunOptions& RunOptions::SetTerminate(bool flag) { - OrtRunOptionsSetTerminate(p_, flag ? 1 : 0); - return *this; -} - -inline SessionOptions::SessionOptions() : Base{OrtCreateSessionOptions()} { -} - -inline SessionOptions SessionOptions::Clone() const { - return SessionOptions{OrtCloneSessionOptions(p_)}; -} - -inline SessionOptions& SessionOptions::SetThreadPoolSize(int session_thread_pool_size) { - if (OrtSetSessionThreadPoolSize(p_, session_thread_pool_size) == -1) - throw Exception("Error calling SessionOptions::SetThreadPoolSize", ORT_FAIL); - return *this; -} - -inline SessionOptions& SessionOptions::SetGraphOptimizationLevel(uint32_t graph_optimization_level) { - if (OrtSetSessionGraphOptimizationLevel(p_, graph_optimization_level) == -1) - throw Exception("Error calling SessionOptions::SetGraphOptimizationLevel", ORT_FAIL); - return *this; -} - -inline SessionOptions& SessionOptions::EnableProfiling(const ORTCHAR_T* profile_file_prefix) { - OrtEnableProfiling(p_, profile_file_prefix); - return *this; -} - -inline SessionOptions& SessionOptions::DisableProfiling() { - OrtDisableProfiling(p_); - return *this; -} - -inline SessionOptions& SessionOptions::EnableMemPattern() { - OrtEnableMemPattern(p_); - return *this; -} - -inline SessionOptions& SessionOptions::DisableMemPattern() { - OrtDisableMemPattern(p_); - return *this; -} - -inline SessionOptions& SessionOptions::EnableCpuMemArena() { - OrtEnableCpuMemArena(p_); - return *this; -} - -inline SessionOptions& SessionOptions::DisableCpuMemArena() { - OrtDisableCpuMemArena(p_); - return *this; -} - -inline SessionOptions& SessionOptions::EnableSequentialExecution() { - OrtEnableSequentialExecution(p_); - return *this; -} - -inline SessionOptions& SessionOptions::DisableSequentialExecution() { - OrtDisableSequentialExecution(p_); - return *this; -} - -inline SessionOptions& SessionOptions::SetLogId(const char* logid) { - OrtSetSessionLogId(p_, logid); - return *this; -} -inline SessionOptions& SessionOptions::Add(OrtCustomOpDomain* custom_op_domain) { - ORT_THROW_ON_ERROR(OrtAddCustomOpDomain(p_, custom_op_domain)); - return *this; -} - -inline Session::Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) { - ORT_THROW_ON_ERROR(OrtCreateSession(env, model_path, options, &p_)); -} - -inline Session::Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) { - ORT_THROW_ON_ERROR(OrtCreateSessionFromArray(env, model_data, model_data_length, options, &p_)); -} - -inline std::vector Session::Run(const RunOptions& run_options, const char* const* input_names, Value* input_values, size_t input_count, - const char* const* output_names, size_t output_names_count) { - std::vector output_values; - for (size_t i = 0; i < output_names_count; i++) - output_values.emplace_back(nullptr); - Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_names_count); - return output_values; -} - -inline void Session::Run(const RunOptions& run_options, const char* const* input_names, Value* input_values, size_t input_count, - const char* const* output_names, Value* output_values, size_t output_count) { - static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely"); - auto ort_input_values = reinterpret_cast(input_values); - auto ort_output_values = reinterpret_cast(output_values); - ORT_THROW_ON_ERROR(OrtRun(p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values)); -} - -inline size_t Session::GetInputCount() const { - size_t out; - ORT_THROW_ON_ERROR(OrtSessionGetInputCount(p_, &out)); - return out; -} - -inline size_t Session::GetOutputCount() const { - size_t out; - ORT_THROW_ON_ERROR(OrtSessionGetOutputCount(p_, &out)); - return out; -} - -inline char* Session::GetInputName(size_t index, OrtAllocator* allocator) const { - char* out; - ORT_THROW_ON_ERROR(OrtSessionGetInputName(p_, index, allocator, &out)); - return out; -} - -inline char* Session::GetOutputName(size_t index, OrtAllocator* allocator) const { - char* out; - ORT_THROW_ON_ERROR(OrtSessionGetOutputName(p_, index, allocator, &out)); - return out; -} - -inline TypeInfo Session::GetInputTypeInfo(size_t index) const { - OrtTypeInfo* out; - ORT_THROW_ON_ERROR(OrtSessionGetInputTypeInfo(p_, index, &out)); - return TypeInfo{out}; -} - -inline TypeInfo Session::GetOutputTypeInfo(size_t index) const { - OrtTypeInfo* out; - ORT_THROW_ON_ERROR(OrtSessionGetOutputTypeInfo(p_, index, &out)); - return TypeInfo{out}; -} - -inline ONNXTensorElementDataType TensorTypeAndShapeInfo::GetElementType() const { - return OrtGetTensorElementType(p_); -} - -inline size_t TensorTypeAndShapeInfo::GetElementCount() const { - return static_cast(OrtGetTensorShapeElementCount(p_)); -} - -inline size_t TensorTypeAndShapeInfo::GetDimensionsCount() const { - return OrtGetDimensionsCount(p_); -} - -inline void TensorTypeAndShapeInfo::GetDimensions(int64_t* values, size_t values_count) const { - OrtGetDimensions(p_, values, values_count); -} - -inline std::vector TensorTypeAndShapeInfo::GetShape() const { - std::vector out(GetDimensionsCount(), 0); - GetDimensions(out.data(), out.size()); - return out; -} - -inline Unowned TypeInfo::GetTensorTypeAndShapeInfo() const { - return Unowned{const_cast(OrtCastTypeInfoToTensorInfo(p_))}; -} - -inline ONNXType TypeInfo::GetONNXType() const { - return OrtOnnxTypeFromTypeInfo(p_); -} - -template -inline Value Value::CreateTensor(const OrtAllocatorInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) { - return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType::type); -} - -inline Value Value::CreateTensor(const OrtAllocatorInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, - ONNXTensorElementDataType type) { - OrtValue* out; - ORT_THROW_ON_ERROR(OrtCreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out)); - return Value{out}; -} - -template -inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) { - return CreateTensor(allocator, shape, shape_len, TypeToTensorType::type); -} - -inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) { - OrtValue* out; - ORT_THROW_ON_ERROR(OrtCreateTensorAsOrtValue(allocator, shape, shape_len, type, &out)); - return Value{out}; -} - -ORT_API_STATUS(OrtCreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator, - _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, - _Out_ OrtValue** out); - -inline Value Value::CreateMap(Value& keys, Value& values) { - OrtValue* out; - OrtValue* inputs[2] = {keys, values}; - ORT_THROW_ON_ERROR(OrtCreateValue(inputs, 2, ONNX_TYPE_MAP, &out)); - return Value{out}; -} - -inline Value Value::CreateSequence(std::vector& values) { - OrtValue* out; - std::vector values_ort{values.data(), values.data() + values.size()}; - ORT_THROW_ON_ERROR(OrtCreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out)); - return Value{out}; -} - -inline bool Value::IsTensor() const { - return OrtIsTensor(p_) != 0; -} - -inline size_t Value::GetCount() const { - size_t out; - ORT_THROW_ON_ERROR(OrtGetValueCount(p_, &out)); - return out; -} - -inline Value Value::GetValue(int index, OrtAllocator* allocator) const { - OrtValue* out; - ORT_THROW_ON_ERROR(OrtGetValue(p_, index, allocator, &out)); - return Value{out}; -} - -inline size_t Value::GetStringTensorDataLength() const { - size_t out; - ORT_THROW_ON_ERROR(OrtGetStringTensorDataLength(p_, &out)); - return out; -} - -inline void Value::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const { - ORT_THROW_ON_ERROR(OrtGetStringTensorContent(p_, buffer, buffer_length, offsets, offsets_count)); -} - -template -T* Value::GetTensorMutableData() { - T* out; - ORT_THROW_ON_ERROR(OrtGetTensorMutableData(p_, (void**)&out)); - return out; -} - -inline TypeInfo Value::GetTypeInfo() const { - OrtTypeInfo* output; - ORT_THROW_ON_ERROR(OrtGetTypeInfo(p_, &output)); - return TypeInfo{output}; -} - -inline TensorTypeAndShapeInfo Value::GetTensorTypeAndShapeInfo() const { - OrtTensorTypeAndShapeInfo* output; - ORT_THROW_ON_ERROR(OrtGetTensorTypeAndShape(p_, &output)); - return TensorTypeAndShapeInfo{output}; -} - -} // namespace Ort - -namespace Ort { struct CustomOpApi { CustomOpApi(const OrtCustomOpApi& api) : api_(api) {} - template + template // T is only implemented for float and int64_t T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name); - OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value) { - OrtTensorTypeAndShapeInfo* out; - ORT_THROW_ON_ERROR(api_.GetTensorTypeAndShape(value, &out)); - return out; - } - - int64_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) { - return api_.GetTensorShapeElementCount(info); - } - - ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) { - return api_.GetTensorElementType(info); - } - - size_t GetDimensionCount(_In_ const OrtTensorTypeAndShapeInfo* info) { - return api_.GetDimensionCount(info); - } - - void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) { - api_.GetDimensions(info, dim_values, dim_values_length); - } - - void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) { - api_.SetDimensions(info, dim_values, dim_count); - } + OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value); + int64_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info); + ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info); + size_t GetDimensionCount(_In_ const OrtTensorTypeAndShapeInfo* info); + void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length); + void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count); template - T* GetTensorMutableData(_Inout_ OrtValue* value) { - T* data; - ORT_THROW_ON_ERROR(api_.GetTensorMutableData(value, reinterpret_cast(&data))); - return data; - } - + T* GetTensorMutableData(_Inout_ OrtValue* value); template - const T* GetTensorData(_Inout_ const OrtValue* value) { - return GetTensorMutableData(const_cast(value)); - } + const T* GetTensorData(_Inout_ const OrtValue* value); - std::vector GetTensorShape(const OrtTensorTypeAndShapeInfo* info) { - std::vector output(GetDimensionCount(info)); - GetDimensions(info, output.data(), output.size()); - return output; - } - - void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) { - api_.ReleaseTensorTypeAndShapeInfo(input); - } - - size_t KernelContext_GetInputCount(const OrtKernelContext* context) { - return api_.KernelContext_GetInputCount(context); - } - - const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) { - return api_.KernelContext_GetInput(context, index); - } - - size_t KernelContext_GetOutputCount(const OrtKernelContext* context) { - return api_.KernelContext_GetOutputCount(context); - } - - OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count) { - return api_.KernelContext_GetOutput(context, index, dim_values, dim_count); - } + std::vector GetTensorShape(const OrtTensorTypeAndShapeInfo* info); + void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input); + size_t KernelContext_GetInputCount(const OrtKernelContext* context); + const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index); + size_t KernelContext_GetOutputCount(const OrtKernelContext* context); + OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count); private: const OrtCustomOpApi& api_; }; -template <> -inline float CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) { - float out; - ORT_THROW_ON_ERROR(api_.KernelInfoGetAttribute_float(info, name, &out)); - return out; -} - -template <> -inline int64_t CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) { - int64_t out; - ORT_THROW_ON_ERROR(api_.KernelInfoGetAttribute_int64(info, name, &out)); - return out; -} - template struct CustomOpBase : OrtCustomOp { CustomOpBase() { @@ -690,4 +296,4 @@ struct CustomOpBase : OrtCustomOp { } // namespace Ort -#undef ORT_REDIRECT_SIMPLE_FUNCTION_CALL +#include "onnxruntime_cxx_inline.h" diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h new file mode 100644 index 0000000000..6a6dc758c1 --- /dev/null +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -0,0 +1,437 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter +// the main C++ file with implementation details. + +#define ORT_THROW_ON_ERROR(expr) \ + if (OrtStatus* onnx_status = (expr)) { \ + std::string ort_error_message = OrtGetErrorMessage(onnx_status); \ + OrtErrorCode ort_error_code = OrtGetErrorCode(onnx_status); \ + OrtReleaseStatus(onnx_status); \ + throw Ort::Exception(std::move(ort_error_message), ort_error_code); \ + } + +namespace Ort { + +// This template converts a C++ type into it's ONNXTensorElementDataType +template +struct TypeToTensorType; +template <> +struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; +template <> +struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; }; +template <> +struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; }; +template <> +struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; }; +template <> +struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; }; +template <> +struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; }; +template <> +struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; }; +template <> +struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; }; +template <> +struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; }; +template <> +struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; }; + +inline Allocator Allocator::CreateDefault() { + OrtAllocator* p; + ORT_THROW_ON_ERROR(OrtCreateDefaultAllocator(&p)); + return Allocator(p); +} + +inline void* Allocator::Alloc(size_t size) { + return OrtAllocatorAlloc(p_, size); +} + +inline void Allocator::Free(void* p) { + OrtAllocatorFree(p_, p); +} + +inline const OrtAllocatorInfo* Allocator::GetInfo() const { + return OrtAllocatorGetInfo(p_); +} + +inline AllocatorInfo AllocatorInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_type) { + OrtAllocatorInfo* p; + ORT_THROW_ON_ERROR(OrtCreateCpuAllocatorInfo(type, mem_type, &p)); + return AllocatorInfo(p); +} + +inline AllocatorInfo::AllocatorInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) { + ORT_THROW_ON_ERROR(OrtCreateAllocatorInfo(name, type, id, mem_type, &p_)); +} + +inline Env::Env(OrtLoggingLevel default_warning_level, _In_ const char* logid) { + ORT_THROW_ON_ERROR(OrtCreateEnv(default_warning_level, logid, &p_)); +} + +inline Env::Env(OrtLoggingLevel default_warning_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) { + ORT_THROW_ON_ERROR(OrtCreateEnvWithCustomLogger(logging_function, logger_param, default_warning_level, logid, &p_)); +} + +inline CustomOpDomain::CustomOpDomain(const char* domain) + : Base{OrtCreateCustomOpDomain(domain)} { +} + +inline void CustomOpDomain::Add(OrtCustomOp* op) { + ORT_THROW_ON_ERROR(OrtCustomOpDomain_Add(p_, op)); +} + +inline RunOptions::RunOptions() : Base{OrtCreateRunOptions()} {} + +inline RunOptions& RunOptions::SetRunLogVerbosityLevel(unsigned int level) { + ORT_THROW_ON_ERROR(OrtRunOptionsSetRunLogVerbosityLevel(p_, level)); + return *this; +} + +inline unsigned int RunOptions::GetRunLogVerbosityLevel() const { + return OrtRunOptionsGetRunLogVerbosityLevel(p_); +} + +inline RunOptions& RunOptions::SetRunTag(const char* run_tag) { + ORT_THROW_ON_ERROR(OrtRunOptionsSetRunTag(p_, run_tag)); + return *this; +} + +inline const char* RunOptions::GetRunTag() const { + return OrtRunOptionsGetRunTag(p_); +} + +inline RunOptions& RunOptions::SetTerminate(bool flag) { + OrtRunOptionsSetTerminate(p_, flag ? 1 : 0); + return *this; +} + +inline SessionOptions::SessionOptions() : Base{OrtCreateSessionOptions()} { +} + +inline SessionOptions SessionOptions::Clone() const { + return SessionOptions{OrtCloneSessionOptions(p_)}; +} + +inline SessionOptions& SessionOptions::SetThreadPoolSize(int session_thread_pool_size) { + if (OrtSetSessionThreadPoolSize(p_, session_thread_pool_size) == -1) + throw Exception("Error calling SessionOptions::SetThreadPoolSize", ORT_FAIL); + return *this; +} + +inline SessionOptions& SessionOptions::SetGraphOptimizationLevel(uint32_t graph_optimization_level) { + if (OrtSetSessionGraphOptimizationLevel(p_, graph_optimization_level) == -1) + throw Exception("Error calling SessionOptions::SetGraphOptimizationLevel", ORT_FAIL); + return *this; +} + +inline SessionOptions& SessionOptions::EnableProfiling(const ORTCHAR_T* profile_file_prefix) { + OrtEnableProfiling(p_, profile_file_prefix); + return *this; +} + +inline SessionOptions& SessionOptions::DisableProfiling() { + OrtDisableProfiling(p_); + return *this; +} + +inline SessionOptions& SessionOptions::EnableMemPattern() { + OrtEnableMemPattern(p_); + return *this; +} + +inline SessionOptions& SessionOptions::DisableMemPattern() { + OrtDisableMemPattern(p_); + return *this; +} + +inline SessionOptions& SessionOptions::EnableCpuMemArena() { + OrtEnableCpuMemArena(p_); + return *this; +} + +inline SessionOptions& SessionOptions::DisableCpuMemArena() { + OrtDisableCpuMemArena(p_); + return *this; +} + +inline SessionOptions& SessionOptions::EnableSequentialExecution() { + OrtEnableSequentialExecution(p_); + return *this; +} + +inline SessionOptions& SessionOptions::DisableSequentialExecution() { + OrtDisableSequentialExecution(p_); + return *this; +} + +inline SessionOptions& SessionOptions::SetLogId(const char* logid) { + OrtSetSessionLogId(p_, logid); + return *this; +} +inline SessionOptions& SessionOptions::Add(OrtCustomOpDomain* custom_op_domain) { + ORT_THROW_ON_ERROR(OrtAddCustomOpDomain(p_, custom_op_domain)); + return *this; +} + +inline Session::Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) { + ORT_THROW_ON_ERROR(OrtCreateSession(env, model_path, options, &p_)); +} + +inline Session::Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) { + ORT_THROW_ON_ERROR(OrtCreateSessionFromArray(env, model_data, model_data_length, options, &p_)); +} + +inline std::vector Session::Run(const RunOptions& run_options, const char* const* input_names, Value* input_values, size_t input_count, + const char* const* output_names, size_t output_names_count) { + std::vector output_values; + for (size_t i = 0; i < output_names_count; i++) + output_values.emplace_back(nullptr); + Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_names_count); + return output_values; +} + +inline void Session::Run(const RunOptions& run_options, const char* const* input_names, Value* input_values, size_t input_count, + const char* const* output_names, Value* output_values, size_t output_count) { + static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely"); + auto ort_input_values = reinterpret_cast(input_values); + auto ort_output_values = reinterpret_cast(output_values); + ORT_THROW_ON_ERROR(OrtRun(p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values)); +} + +inline size_t Session::GetInputCount() const { + size_t out; + ORT_THROW_ON_ERROR(OrtSessionGetInputCount(p_, &out)); + return out; +} + +inline size_t Session::GetOutputCount() const { + size_t out; + ORT_THROW_ON_ERROR(OrtSessionGetOutputCount(p_, &out)); + return out; +} + +inline char* Session::GetInputName(size_t index, OrtAllocator* allocator) const { + char* out; + ORT_THROW_ON_ERROR(OrtSessionGetInputName(p_, index, allocator, &out)); + return out; +} + +inline char* Session::GetOutputName(size_t index, OrtAllocator* allocator) const { + char* out; + ORT_THROW_ON_ERROR(OrtSessionGetOutputName(p_, index, allocator, &out)); + return out; +} + +inline TypeInfo Session::GetInputTypeInfo(size_t index) const { + OrtTypeInfo* out; + ORT_THROW_ON_ERROR(OrtSessionGetInputTypeInfo(p_, index, &out)); + return TypeInfo{out}; +} + +inline TypeInfo Session::GetOutputTypeInfo(size_t index) const { + OrtTypeInfo* out; + ORT_THROW_ON_ERROR(OrtSessionGetOutputTypeInfo(p_, index, &out)); + return TypeInfo{out}; +} + +inline ONNXTensorElementDataType TensorTypeAndShapeInfo::GetElementType() const { + return OrtGetTensorElementType(p_); +} + +inline size_t TensorTypeAndShapeInfo::GetElementCount() const { + return static_cast(OrtGetTensorShapeElementCount(p_)); +} + +inline size_t TensorTypeAndShapeInfo::GetDimensionsCount() const { + return OrtGetDimensionsCount(p_); +} + +inline void TensorTypeAndShapeInfo::GetDimensions(int64_t* values, size_t values_count) const { + OrtGetDimensions(p_, values, values_count); +} + +inline std::vector TensorTypeAndShapeInfo::GetShape() const { + std::vector out(GetDimensionsCount(), 0); + GetDimensions(out.data(), out.size()); + return out; +} + +inline Unowned TypeInfo::GetTensorTypeAndShapeInfo() const { + return Unowned{const_cast(OrtCastTypeInfoToTensorInfo(p_))}; +} + +inline ONNXType TypeInfo::GetONNXType() const { + return OrtOnnxTypeFromTypeInfo(p_); +} + +template +inline Value Value::CreateTensor(const OrtAllocatorInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) { + return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType::type); +} + +inline Value Value::CreateTensor(const OrtAllocatorInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type) { + OrtValue* out; + ORT_THROW_ON_ERROR(OrtCreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out)); + return Value{out}; +} + +template +inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) { + return CreateTensor(allocator, shape, shape_len, TypeToTensorType::type); +} + +inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) { + OrtValue* out; + ORT_THROW_ON_ERROR(OrtCreateTensorAsOrtValue(allocator, shape, shape_len, type, &out)); + return Value{out}; +} + +ORT_API_STATUS(OrtCreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator, + _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, + _Out_ OrtValue** out); + +inline Value Value::CreateMap(Value& keys, Value& values) { + OrtValue* out; + OrtValue* inputs[2] = {keys, values}; + ORT_THROW_ON_ERROR(OrtCreateValue(inputs, 2, ONNX_TYPE_MAP, &out)); + return Value{out}; +} + +inline Value Value::CreateSequence(std::vector& values) { + OrtValue* out; + std::vector values_ort{values.data(), values.data() + values.size()}; + ORT_THROW_ON_ERROR(OrtCreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out)); + return Value{out}; +} + +inline bool Value::IsTensor() const { + return OrtIsTensor(p_) != 0; +} + +inline size_t Value::GetCount() const { + size_t out; + ORT_THROW_ON_ERROR(OrtGetValueCount(p_, &out)); + return out; +} + +inline Value Value::GetValue(int index, OrtAllocator* allocator) const { + OrtValue* out; + ORT_THROW_ON_ERROR(OrtGetValue(p_, index, allocator, &out)); + return Value{out}; +} + +inline size_t Value::GetStringTensorDataLength() const { + size_t out; + ORT_THROW_ON_ERROR(OrtGetStringTensorDataLength(p_, &out)); + return out; +} + +inline void Value::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const { + ORT_THROW_ON_ERROR(OrtGetStringTensorContent(p_, buffer, buffer_length, offsets, offsets_count)); +} + +template +T* Value::GetTensorMutableData() { + T* out; + ORT_THROW_ON_ERROR(OrtGetTensorMutableData(p_, (void**)&out)); + return out; +} + +inline TypeInfo Value::GetTypeInfo() const { + OrtTypeInfo* output; + ORT_THROW_ON_ERROR(OrtGetTypeInfo(p_, &output)); + return TypeInfo{output}; +} + +inline TensorTypeAndShapeInfo Value::GetTensorTypeAndShapeInfo() const { + OrtTensorTypeAndShapeInfo* output; + ORT_THROW_ON_ERROR(OrtGetTensorTypeAndShape(p_, &output)); + return TensorTypeAndShapeInfo{output}; +} + +// +// Custom OP API Inlines +// + +template <> +inline float CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) { + float out; + ORT_THROW_ON_ERROR(api_.KernelInfoGetAttribute_float(info, name, &out)); + return out; +} + +template <> +inline int64_t CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) { + int64_t out; + ORT_THROW_ON_ERROR(api_.KernelInfoGetAttribute_int64(info, name, &out)); + return out; +} + +inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) { + OrtTensorTypeAndShapeInfo* out; + ORT_THROW_ON_ERROR(api_.GetTensorTypeAndShape(value, &out)); + return out; +} + +inline int64_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) { + return api_.GetTensorShapeElementCount(info); +} + +inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) { + return api_.GetTensorElementType(info); +} + +inline size_t CustomOpApi::GetDimensionCount(_In_ const OrtTensorTypeAndShapeInfo* info) { + return api_.GetDimensionCount(info); +} + +inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) { + api_.GetDimensions(info, dim_values, dim_values_length); +} + +inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) { + api_.SetDimensions(info, dim_values, dim_count); +} + +template +inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) { + T* data; + ORT_THROW_ON_ERROR(api_.GetTensorMutableData(value, reinterpret_cast(&data))); + return data; +} + +template +inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) { + return GetTensorMutableData(const_cast(value)); +} + +inline std::vector CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) { + std::vector output(GetDimensionCount(info)); + GetDimensions(info, output.data(), output.size()); + return output; +} + +inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) { + api_.ReleaseTensorTypeAndShapeInfo(input); +} + +inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) { + return api_.KernelContext_GetInputCount(context); +} + +inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) { + return api_.KernelContext_GetInput(context, index); +} + +inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) { + return api_.KernelContext_GetOutputCount(context); +} + +inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count) { + return api_.KernelContext_GetOutput(context, index, dim_values, dim_count); +} + +} // namespace Ort