C API - Remove reference counting (#344)

This commit is contained in:
Ryan Hill 2019-01-25 19:41:10 -08:00 committed by GitHub
parent 6349114583
commit d875ab2acd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
36 changed files with 263 additions and 661 deletions

View file

@ -1,106 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Runtime.InteropServices;
namespace Microsoft.ML.OnnxRuntime
{
internal class CpuExecutionProviderFactory: NativeOnnxObjectHandle
{
protected static readonly Lazy<CpuExecutionProviderFactory> _default = new Lazy<CpuExecutionProviderFactory>(() => new CpuExecutionProviderFactory());
public CpuExecutionProviderFactory(bool useArena=true)
:base(IntPtr.Zero)
{
int useArenaInt = useArena ? 1 : 0;
try
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateCpuExecutionProviderFactory(useArenaInt, out handle));
}
catch(OnnxRuntimeException e)
{
if (IsInvalid)
{
ReleaseHandle();
handle = IntPtr.Zero;
}
throw e;
}
}
public static CpuExecutionProviderFactory Default
{
get
{
return _default.Value;
}
}
}
internal class MklDnnExecutionProviderFactory : NativeOnnxObjectHandle
{
protected static readonly Lazy<MklDnnExecutionProviderFactory> _default = new Lazy<MklDnnExecutionProviderFactory>(() => new MklDnnExecutionProviderFactory());
public MklDnnExecutionProviderFactory(bool useArena = true)
:base(IntPtr.Zero)
{
int useArenaInt = useArena ? 1 : 0;
try
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateMkldnnExecutionProviderFactory(useArenaInt, out handle));
}
catch (OnnxRuntimeException e)
{
if (IsInvalid)
{
ReleaseHandle();
handle = IntPtr.Zero;
}
throw e;
}
}
public static MklDnnExecutionProviderFactory Default
{
get
{
return _default.Value;
}
}
}
internal class CudaExecutionProviderFactory : NativeOnnxObjectHandle
{
protected static readonly Lazy<CudaExecutionProviderFactory> _default = new Lazy<CudaExecutionProviderFactory>(() => new CudaExecutionProviderFactory());
public CudaExecutionProviderFactory(int deviceId = 0)
: base(IntPtr.Zero)
{
try
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateCUDAExecutionProviderFactory(deviceId, out handle));
}
catch (OnnxRuntimeException e)
{
if (IsInvalid)
{
ReleaseHandle();
handle = IntPtr.Zero;
}
throw e;
}
}
public static CudaExecutionProviderFactory Default
{
get
{
return _default.Value;
}
}
}
}

View file

@ -45,9 +45,9 @@ namespace Microsoft.ML.OnnxRuntime
try
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.Unicode.GetBytes(modelPath), options.NativeHandle, out _nativeHandle));
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.Unicode.GetBytes(modelPath), options._nativePtr, out _nativeHandle));
else
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.UTF8.GetBytes(modelPath), options.NativeHandle, out _nativeHandle));
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.UTF8.GetBytes(modelPath), options._nativePtr, out _nativeHandle));
// Initialize input/output metadata
_inputMetadata = new Dictionary<string, NodeMetadata>();
@ -275,7 +275,7 @@ namespace Microsoft.ML.OnnxRuntime
{
if (typeInfo != IntPtr.Zero)
{
NativeMethods.OrtReleaseObject(typeInfo);
NativeMethods.OrtReleaseTypeInfo(typeInfo);
}
}
}
@ -292,7 +292,7 @@ namespace Microsoft.ML.OnnxRuntime
{
if (typeInfo != IntPtr.Zero)
{
NativeMethods.OrtReleaseObject(typeInfo);
NativeMethods.OrtReleaseTypeInfo(typeInfo);
}
}
}

View file

@ -198,7 +198,7 @@ namespace Microsoft.ML.OnnxRuntime
{
if (typeAndShape != IntPtr.Zero)
{
NativeMethods.OrtReleaseObject(typeAndShape);
NativeMethods.OrtReleaseTensorTypeAndShapeInfo(typeAndShape);
}
}

View file

@ -141,7 +141,7 @@ namespace Microsoft.ML.OnnxRuntime
protected static void Delete(IntPtr allocator)
{
NativeMethods.OrtReleaseObject(allocator);
NativeMethods.OrtReleaseAllocator(allocator);
}
protected override bool ReleaseHandle()

View file

@ -90,20 +90,22 @@ namespace Microsoft.ML.OnnxRuntime
IntPtr /*(OrtAllocator*)*/ allocator,
out IntPtr /*(char**)*/name);
// release the typeinfo using OrtReleaseObject
// release the typeinfo using OrtReleaseTypeInfo
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/OrtSessionGetInputTypeInfo(
IntPtr /*(const OrtSession*)*/ session,
ulong index, //TODO: port for size_t
out IntPtr /*(struct OrtTypeInfo**)*/ typeInfo);
// release the typeinfo using OrtReleaseObject
// release the typeinfo using OrtReleaseTypeInfo
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/OrtSessionGetOutputTypeInfo(
IntPtr /*(const OrtSession*)*/ session,
ulong index, //TODO: port for size_t
out IntPtr /* (struct OrtTypeInfo**)*/ typeInfo);
[DllImport(nativeLib, CharSet = charSet)]
public static extern void OrtReleaseTypeInfo(IntPtr /*(OrtTypeInfo*)*/session);
[DllImport(nativeLib, CharSet = charSet)]
public static extern void OrtReleaseSession(IntPtr /*(OrtSession*)*/session);
@ -112,11 +114,12 @@ namespace Microsoft.ML.OnnxRuntime
#region SessionOptions API
//Release using OrtReleaseObject
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*OrtSessionOptions* */ OrtCreateSessionOptions();
[DllImport(nativeLib, CharSet = charSet)]
public static extern void OrtReleaseSessionOptions(IntPtr /*(OrtSessionOptions*)*/session);
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtSessionOptions*)*/OrtCloneSessionOptions(IntPtr /*(OrtSessionOptions*)*/ sessionOptions);
@ -153,22 +156,20 @@ namespace Microsoft.ML.OnnxRuntime
[DllImport(nativeLib, CharSet = charSet)]
public static extern int OrtSetSessionThreadPoolSize(IntPtr /* OrtSessionOptions* */ options, int sessionThreadPoolSize);
///**
// * The order of invocation indicates the preference order as well. In other words call this method
// * on your most preferred execution provider first followed by the less preferred ones.
// * Calling this API is optional in which case onnxruntime will use its internal CPU execution provider.
// */
[DllImport(nativeLib, CharSet = charSet)]
public static extern void OrtSessionOptionsAppendExecutionProvider(IntPtr /*(OrtSessionOptions*)*/ options, IntPtr /* (OrtProviderFactoryPtr*)*/ factory);
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_CPU(IntPtr /*(OrtSessionOptions*) */ options, int use_arena);
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtCreateCpuExecutionProviderFactory(int use_arena, out IntPtr /*(OrtProviderFactoryPtr*)*/ factory);
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Mkldnn(IntPtr /*(OrtSessionOptions*) */ options, int use_arena);
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtCreateMkldnnExecutionProviderFactory(int use_arena, out IntPtr /*(OrtProviderFactoryPtr**)*/ factory);
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtCreateCUDAExecutionProviderFactory(int device_id, out IntPtr /*(OrtProviderFactoryPtr**)*/ factory);
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_CUDA(IntPtr /*(OrtSessionOptions*) */ options, int device_id);
//[DllImport(nativeLib, CharSet = charSet)]
//public static extern IntPtr /*(OrtStatus*)*/ OrtCreateNupharExecutionProviderFactory(int device_id, string target_str, out IntPtr /*(OrtProviderFactoryPtr**)*/ factory);
@ -220,13 +221,8 @@ namespace Microsoft.ML.OnnxRuntime
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/OrtCreateDefaultAllocator(out IntPtr /*(OrtAllocator**)*/ allocator);
/// <summary>
/// Releases/Unrefs any object, including the Allocator
/// </summary>
/// <param name="ptr"></param>
/// <returns>remaining ref count</returns>
[DllImport(nativeLib, CharSet = charSet)]
public static extern uint /*remaining ref count*/ OrtReleaseObject(IntPtr /*(void*)*/ ptr);
public static extern void OrtReleaseAllocator(IntPtr /*(OrtAllocator*)*/ allocator);
/// <summary>
/// Release any object allocated by an allocator
@ -265,6 +261,10 @@ namespace Microsoft.ML.OnnxRuntime
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtGetTensorShapeAndType(IntPtr /*(OrtValue*)*/ value, out IntPtr /*(struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo);
[DllImport(nativeLib, CharSet = charSet)]
public static extern void OrtReleaseTensorTypeAndShapeInfo(IntPtr /*(OrtTensorTypeAndShapeInfo*)*/ value);
[DllImport(nativeLib, CharSet = charSet)]
public static extern TensorElementType OrtGetTensorElementType(IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo);

View file

@ -1,28 +0,0 @@
using System;
using System.Runtime.InteropServices;
namespace Microsoft.ML.OnnxRuntime
{
internal class NativeOnnxObjectHandle : SafeHandle
{
public NativeOnnxObjectHandle(IntPtr ptr)
: base(IntPtr.Zero, true)
{
handle = ptr;
}
public override bool IsInvalid
{
get
{
return (handle == IntPtr.Zero);
}
}
protected override bool ReleaseHandle()
{
NativeMethods.OrtReleaseObject(handle);
return true;
}
}
}

View file

@ -70,7 +70,7 @@ namespace Microsoft.ML.OnnxRuntime
{
if (typeAndShape != IntPtr.Zero)
{
NativeMethods.OrtReleaseObject(typeAndShape);
NativeMethods.OrtReleaseTensorTypeAndShapeInfo(typeAndShape);
}
}
}

View file

@ -23,7 +23,7 @@ namespace Microsoft.ML.OnnxRuntime
/// </summary>
public class SessionOptions:IDisposable
{
protected SafeHandle _nativeOption;
public IntPtr _nativePtr;
protected static readonly Lazy<SessionOptions> _default = new Lazy<SessionOptions>(MakeSessionOptionWithMklDnnProvider);
private static string[] cudaDelayLoadedLibs = { "cublas64_100.dll", "cudnn64_7.dll" };
@ -32,7 +32,7 @@ namespace Microsoft.ML.OnnxRuntime
/// </summary>
public SessionOptions()
{
_nativeOption = new NativeOnnxObjectHandle(NativeMethods.OrtCreateSessionOptions());
_nativePtr = NativeMethods.OrtCreateSessionOptions();
}
/// <summary>
@ -46,33 +46,11 @@ namespace Microsoft.ML.OnnxRuntime
}
}
/// <summary>
/// Append an execution propvider. When any operator is evaluated, it is executed on the first execution provider that provides it
/// </summary>
/// <param name="provider"></param>
public void AppendExecutionProvider(ExecutionProvider provider)
{
switch (provider)
{
case ExecutionProvider.Cpu:
AppendExecutionProvider(CpuExecutionProviderFactory.Default);
break;
case ExecutionProvider.MklDnn:
AppendExecutionProvider(MklDnnExecutionProviderFactory.Default);
break;
case ExecutionProvider.Cuda:
AppendExecutionProvider(CudaExecutionProviderFactory.Default);
break;
default:
break;
}
}
private static SessionOptions MakeSessionOptionWithMklDnnProvider()
{
SessionOptions options = new SessionOptions();
options.AppendExecutionProvider(MklDnnExecutionProviderFactory.Default);
options.AppendExecutionProvider(CpuExecutionProviderFactory.Default);
// NativeMethods.OrtSessionOptionsAppendExecutionProvider_Mkldnn(_nativePtr, 1);
NativeMethods.OrtSessionOptionsAppendExecutionProvider_CPU(options._nativePtr, 1);
return options;
}
@ -94,38 +72,12 @@ namespace Microsoft.ML.OnnxRuntime
{
CheckCudaExecutionProviderDLLs();
SessionOptions options = new SessionOptions();
if (deviceId == 0) //default value
options.AppendExecutionProvider(CudaExecutionProviderFactory.Default);
else
options.AppendExecutionProvider(new CudaExecutionProviderFactory(deviceId));
options.AppendExecutionProvider(MklDnnExecutionProviderFactory.Default);
options.AppendExecutionProvider(CpuExecutionProviderFactory.Default);
NativeMethods.OrtSessionOptionsAppendExecutionProvider_CUDA(options._nativePtr, deviceId);
NativeMethods.OrtSessionOptionsAppendExecutionProvider_Mkldnn(options._nativePtr, 1);
NativeMethods.OrtSessionOptionsAppendExecutionProvider_CPU(options._nativePtr, 1);
return options;
}
internal IntPtr NativeHandle
{
get
{
return _nativeOption.DangerousGetHandle(); //Note: this is unsafe, and not ref counted, use with caution
}
}
private void AppendExecutionProvider(NativeOnnxObjectHandle providerFactory)
{
unsafe
{
bool success = false;
providerFactory.DangerousAddRef(ref success);
if (success)
{
NativeMethods.OrtSessionOptionsAppendExecutionProvider(_nativeOption.DangerousGetHandle(), providerFactory.DangerousGetHandle());
providerFactory.DangerousRelease();
}
}
}
// Declared, but called only if OS = Windows.
[DllImport("kernel32.dll")]
private static extern IntPtr LoadLibrary(string dllToLoad);
@ -172,7 +124,7 @@ namespace Microsoft.ML.OnnxRuntime
{
// cleanup managed resources
}
_nativeOption.Dispose();
NativeMethods.OrtReleaseSessionOptions(_nativePtr);
}
#endregion

View file

@ -522,9 +522,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests
"OrtSessionGetOutputTypeInfo","OrtReleaseSession","OrtCreateSessionOptions","OrtCloneSessionOptions",
"OrtEnableSequentialExecution","OrtDisableSequentialExecution","OrtEnableProfiling","OrtDisableProfiling",
"OrtEnableMemPattern","OrtDisableMemPattern","OrtEnableCpuMemArena","OrtDisableCpuMemArena",
"OrtSetSessionLogId","OrtSetSessionLogVerbosityLevel","OrtSetSessionThreadPoolSize","OrtSessionOptionsAppendExecutionProvider",
"OrtCreateCpuExecutionProviderFactory","OrtCreateAllocatorInfo","OrtCreateCpuAllocatorInfo",
"OrtCreateDefaultAllocator","OrtReleaseObject","OrtAllocatorFree","OrtAllocatorGetInfo",
"OrtSetSessionLogId","OrtSetSessionLogVerbosityLevel","OrtSetSessionThreadPoolSize","OrtSessionOptionsAppendExecutionProvider_CPU",
"OrtCreateAllocatorInfo","OrtCreateCpuAllocatorInfo",
"OrtCreateDefaultAllocator","OrtAllocatorFree","OrtAllocatorGetInfo",
"OrtCreateTensorWithDataAsOrtValue","OrtGetTensorMutableData", "OrtReleaseAllocatorInfo",
"OrtCastTypeInfoToTensorInfo","OrtGetTensorShapeAndType","OrtGetTensorElementType","OrtGetNumOfDimensions",
"OrtGetDimensions","OrtGetTensorShapeElementCount","OrtReleaseValue"};

View file

@ -1,47 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/session/onnxruntime_c_api.h"
#include <atomic>
namespace onnxruntime {
/**
* Even it's designed to be inherited, this class doesn't have a virtual destructor.
* No vtable is allowed in this class and its subclasses.
* \tparam T subclass type name
*/
template <typename T>
class ObjectBase {
private:
static OrtObject static_cls;
protected:
const OrtObject* const ORT_ATTRIBUTE_UNUSED cls_;
std::atomic_int ref_count;
ObjectBase() : cls_(&static_cls), ref_count(1) {
}
static uint32_t ORT_API_CALL OrtReleaseImpl(void* this_) {
T* this_ptr = reinterpret_cast<T*>(this_);
if (--this_ptr->ref_count == 0)
delete this_ptr;
return 0;
}
static uint32_t ORT_API_CALL OrtAddRefImpl(void* this_) {
T* this_ptr = reinterpret_cast<T*>(this_);
++this_ptr->ref_count;
return 0;
}
};
template <typename T>
OrtObject ObjectBase<T>::static_cls = {ObjectBase<T>::OrtAddRefImpl, ObjectBase<T>::OrtReleaseImpl};
} // namespace onnxruntime
#define ORT_CHECK_C_OBJECT_LAYOUT \
{ assert((char*)&ref_count == (char*)this + sizeof(this)); }

View file

@ -7,12 +7,11 @@
#include <string>
#include <atomic>
#include "core/session/onnxruntime_c_api.h"
#include "core/framework/onnx_object_cxx.h"
/**
* Configuration information for a single Run.
*/
struct OrtRunOptions : public onnxruntime::ObjectBase<OrtRunOptions> {
struct OrtRunOptions {
unsigned run_log_verbosity_level = 0; ///< applies to a particular Run() invocation
std::string run_tag; ///< to identify logs generated by a particular Run() invocation

View file

@ -9,9 +9,8 @@ extern "C" {
/**
* \param use_arena zero: false. non-zero: true.
* \param out Call OrtReleaseObject() method when you no longer need to use it.
*/
ORT_API_STATUS(OrtCreateCpuExecutionProviderFactory, int use_arena, _Out_ OrtProviderFactoryInterface*** out)
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena)
ORT_ALL_ARGS_NONNULL;
ORT_API_STATUS(OrtCreateCpuAllocatorInfo, enum OrtAllocatorType type, enum OrtMemType mem_type1, _Out_ OrtAllocatorInfo** out)

View file

@ -6,11 +6,11 @@
#ifdef __cplusplus
extern "C" {
#endif
/**
* \param device_id cuda device id, starts from zero.
* \param out Call OrtReleaseObject() method when you no longer need to use it.
*/
ORT_API_STATUS(OrtCreateCUDAExecutionProviderFactory, int device_id, _Out_ OrtProviderFactoryInterface*** out);
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, int device_id);
#ifdef __cplusplus
}

View file

@ -9,9 +9,8 @@ extern "C" {
/**
* \param use_arena zero: false. non-zero: true.
* \param out Call OrtReleaseObject() method when you no longer need to use it.
*/
ORT_API_STATUS(OrtCreateMkldnnExecutionProviderFactory, int use_arena, _Out_ OrtProviderFactoryInterface*** out);
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Mkldnn, _In_ OrtSessionOptions* options, int use_arena);
#ifdef __cplusplus
}

View file

@ -0,0 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
namespace onnxruntime {
class IExecutionProvider;
struct IExecutionProviderFactory {
virtual ~IExecutionProviderFactory() {}
virtual std::unique_ptr<IExecutionProvider> CreateProvider() = 0;
};
} // namespace onnxruntime

View file

@ -71,7 +71,7 @@ typedef enum ONNXTensorElementDataType {
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, // maps to c type int32_t
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, // maps to c type int64_t
ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, // maps to c++ type std::string
ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, //
ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL,
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, // maps to c type double
ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, // maps to c type uint32_t
@ -139,26 +139,10 @@ ORT_RUNTIME_CLASS(AllocatorInfo);
ORT_RUNTIME_CLASS(Session);
ORT_RUNTIME_CLASS(Value);
ORT_RUNTIME_CLASS(ValueList);
struct OrtTypeInfo;
typedef struct OrtTypeInfo OrtTypeInfo;
struct OrtTensorTypeAndShapeInfo;
typedef struct OrtTensorTypeAndShapeInfo OrtTensorTypeAndShapeInfo;
struct OrtRunOptions;
typedef struct OrtRunOptions OrtRunOptions;
struct OrtSessionOptions;
typedef struct OrtSessionOptions OrtSessionOptions;
/**
* Every type inherented from OrtObject should be deleted by OrtReleaseObject(...).
*/
typedef struct OrtObject {
// Returns the new reference count.
uint32_t(ORT_API_CALL* AddRef)(void* this_);
// Returns the new reference count.
uint32_t(ORT_API_CALL* Release)(void* this_);
} OrtObject;
ORT_RUNTIME_CLASS(RunOptions);
ORT_RUNTIME_CLASS(TypeInfo);
ORT_RUNTIME_CLASS(TensorTypeAndShapeInfo);
ORT_RUNTIME_CLASS(SessionOptions);
// When passing in an allocator to any ORT function, be sure that the allocator object
// is not destroyed until the last allocated object using it is freed.
@ -168,12 +152,6 @@ typedef struct OrtAllocator {
const struct OrtAllocatorInfo*(ORT_API_CALL* Info)(const struct OrtAllocator* this_);
} OrtAllocator;
// Inherented from OrtObject
typedef struct OrtProviderFactoryInterface {
OrtObject parent;
OrtStatus*(ORT_API_CALL* CreateProvider)(void* this_, OrtProvider** out);
} OrtProviderFactoryInterface;
typedef void(ORT_API_CALL* OrtLoggingFunction)(
void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location,
const char* message);
@ -187,7 +165,7 @@ ORT_ALL_ARGS_NONNULL;
/**
* OrtEnv is process-wise. For each process, only one OrtEnv can be created. Don't do it multiple times
* \param out Should be freed by `OrtReleaseObject` after use
* \param out Should be freed by `OrtReleaseEnv` after use
*/
ORT_API_STATUS(OrtInitializeWithCustomLogger, OrtLoggingFunction logging_function,
_In_opt_ void* logger_param, OrtLoggingLevel default_warning_level,
@ -209,7 +187,7 @@ ORT_API_STATUS(OrtRun, _Inout_ OrtSession* sess,
_In_ const char* const* output_names, size_t output_names_len, _Out_ OrtValue** output);
/**
* \return A pointer of the newly created object. The pointer should be freed by OrtReleaseObject after use
* \return A pointer of the newly created object. The pointer should be freed by OrtReleaseSessionOptions after use
*/
ORT_API(OrtSessionOptions*, OrtCreateSessionOptions);
@ -245,11 +223,15 @@ ORT_API(void, OrtSetSessionLogVerbosityLevel, _In_ OrtSessionOptions* options, u
ORT_API(int, OrtSetSessionThreadPoolSize, _In_ OrtSessionOptions* options, int session_thread_pool_size);
/**
* The order of invocation indicates the preference order as well. In other words call this method
* To use additional providers, you must build ORT with the extra providers enabled. Then call one of these
* functions to enable them in the session:
* OrtSessionOptionsAppendExecutionProvider_CPU
* OrtSessionOptionsAppendExecutionProvider_CUDA
* OrtSessionOptionsAppendExecutionProvider_<remaining providers...>
* The order they care called indicates the preference order as well. In other words call this method
* on your most preferred execution provider first followed by the less preferred ones.
* Calling this API is optional in which case Ort will use its internal CPU execution provider.
* If none are called Ort will use its internal CPU execution provider.
*/
ORT_API(void, OrtSessionOptionsAppendExecutionProvider, _In_ OrtSessionOptions* options, _In_ OrtProviderFactoryInterface** f);
ORT_API(void, OrtAppendCustomOpLibPath, _In_ OrtSessionOptions* options, const char* lib_path);
@ -257,12 +239,12 @@ ORT_API_STATUS(OrtSessionGetInputCount, _In_ const OrtSession* sess, _Out_ size_
ORT_API_STATUS(OrtSessionGetOutputCount, _In_ const OrtSession* sess, _Out_ size_t* out);
/**
* \param out should be freed by OrtReleaseObject after use
* \param out should be freed by OrtReleaseTypeInfo after use
*/
ORT_API_STATUS(OrtSessionGetInputTypeInfo, _In_ const OrtSession* sess, size_t index, _Out_ OrtTypeInfo** out);
/**
* \param out should be freed by OrtReleaseObject after use
* \param out should be freed by OrtReleaseTypeInfo after use
*/
ORT_API_STATUS(OrtSessionGetOutputTypeInfo, _In_ const OrtSession* sess, size_t index, _Out_ OrtTypeInfo** out);
@ -275,7 +257,7 @@ ORT_API_STATUS(OrtSessionGetOutputName, _In_ const OrtSession* sess, size_t inde
_Inout_ OrtAllocator* allocator, _Out_ char** value);
/**
* \return A pointer to the newly created object. The pointer should be freed by OrtReleaseObject after use
* \return A pointer to the newly created object. The pointer should be freed by OrtReleaseRunOptions after use
*/
ORT_API(OrtRunOptions*, OrtCreateRunOptions);
@ -345,7 +327,7 @@ ORT_API_STATUS(OrtTensorProtoToOrtValue, _Inout_ OrtAllocator* allocator,
ORT_API(const OrtTensorTypeAndShapeInfo*, OrtCastTypeInfoToTensorInfo, _In_ OrtTypeInfo*);
/**
* The retured value should be released by calling OrtReleaseObject
* The retured value should be released by calling OrtReleaseTensorTypeAndShapeInfo
*/
ORT_API(OrtTensorTypeAndShapeInfo*, OrtCreateTensorTypeAndShapeInfo);
@ -374,36 +356,19 @@ ORT_API(void, OrtGetDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, _Out
ORT_API(int64_t, OrtGetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* info);
/**
* \param out Should be freed by OrtReleaseObject after use
* \param out Should be freed by OrtReleaseTensorTypeAndShapeInfo after use
*/
ORT_API_STATUS(OrtGetTensorShapeAndType, _In_ const OrtValue* value, _Out_ OrtTensorTypeAndShapeInfo** out);
/**
* Get the type information of an OrtValue
* \param value
* \param out The returned value should be freed by OrtReleaseObject after use
* \param out The returned value should be freed by OrtReleaseTypeInfo after use
*/
ORT_API_STATUS(OrtGetTypeInfo, _In_ const OrtValue* value, OrtTypeInfo** out);
ORT_API(enum ONNXType, OrtGetValueType, _In_ const OrtValue* value);
/**
* This function is a wrapper to "(*(OrtObject**)ptr)->AddRef(ptr)"
* WARNING: There is NO type checking in this function.
* Before calling this function, caller should make sure current ref count > 0
* \return the new reference count
*/
ORT_API(uint32_t, OrtAddRefToObject, _In_ void* ptr);
/**
*
* A wrapper to "(*(OrtObject**)ptr)->Release(ptr)"
* WARNING: There is NO type checking in this function.
* \param ptr Can be NULL. If it's NULL, this function will return zero.
* \return the new reference count.
*/
ORT_API(uint32_t, OrtReleaseObject, _Inout_opt_ void* ptr);
typedef enum OrtAllocatorType {
OrtDeviceAllocator = 0,
OrtArenaAllocator = 1

View file

@ -38,36 +38,48 @@ struct default_delete<OrtEnv> {
OrtReleaseEnv(ptr);
}
};
} // namespace std
#define DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(TYPE_NAME) \
namespace std { \
template <> \
struct default_delete<Ort##TYPE_NAME> { \
void operator()(Ort##TYPE_NAME* ptr) { \
(*reinterpret_cast<OrtObject**>(ptr))->Release(ptr); \
} \
}; \
template <>
struct default_delete<OrtRunOptions> {
void operator()(OrtRunOptions* ptr) {
OrtReleaseRunOptions(ptr);
}
};
DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(TypeInfo);
DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(TensorTypeAndShapeInfo);
DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(RunOptions);
DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(SessionOptions);
DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(ProviderFactoryInterface*);
template <>
struct default_delete<OrtTypeInfo> {
void operator()(OrtTypeInfo* ptr) {
OrtReleaseTypeInfo(ptr);
}
};
#undef DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT
template <>
struct default_delete<OrtTensorTypeAndShapeInfo> {
void operator()(OrtTensorTypeAndShapeInfo* ptr) {
OrtReleaseTensorTypeAndShapeInfo(ptr);
}
};
template <>
struct default_delete<OrtSessionOptions> {
void operator()(OrtSessionOptions* ptr) {
OrtReleaseSessionOptions(ptr);
}
};
} // namespace std
namespace onnxruntime {
class SessionOptionsWrapper {
private:
std::unique_ptr<OrtSessionOptions, decltype(&OrtReleaseObject)> value;
std::unique_ptr<OrtSessionOptions> value;
OrtEnv* env_;
SessionOptionsWrapper(_In_ OrtEnv* env, OrtSessionOptions* p) : value(p, OrtReleaseObject), env_(env){};
SessionOptionsWrapper(_In_ OrtEnv* env, OrtSessionOptions* p) : value(p), env_(env){};
public:
operator OrtSessionOptions*() { return value.get(); }
//TODO: for the input arg, should we call addref here?
SessionOptionsWrapper(_In_ OrtEnv* env) : value(OrtCreateSessionOptions(), OrtReleaseObject), env_(env){};
SessionOptionsWrapper(_In_ OrtEnv* env) : value(OrtCreateSessionOptions()), env_(env){};
ORT_REDIRECT_SIMPLE_FUNCTION_CALL(EnableSequentialExecution)
ORT_REDIRECT_SIMPLE_FUNCTION_CALL(DisableSequentialExecution)
ORT_REDIRECT_SIMPLE_FUNCTION_CALL(DisableProfiling)
@ -89,15 +101,6 @@ class SessionOptionsWrapper {
OrtSetSessionThreadPoolSize(value.get(), session_thread_pool_size);
}
/**
* The order of invocation indicates the preference order as well. In other words call this method
* on your most preferred execution provider first followed by the less preferred ones.
* Calling this API is optional in which case onnxruntime will use its internal CPU execution provider.
*/
void AppendExecutionProvider(_In_ OrtProviderFactoryInterface** f) {
OrtSessionOptionsAppendExecutionProvider(value.get(), f);
}
SessionOptionsWrapper clone() const {
OrtSessionOptions* p = OrtCloneSessionOptions(value.get());
return SessionOptionsWrapper(env_, p);

View file

@ -1,21 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/session/onnxruntime_c_api.h"
#include <atomic>
ORT_API(uint32_t, OrtAddRefToObject, void* ptr) {
return (*static_cast<OrtObject**>(ptr))->AddRef(ptr);
}
ORT_API(uint32_t, OrtReleaseObject, void* ptr) {
if (ptr == nullptr) return 0;
return (*static_cast<OrtObject**>(ptr))->Release(ptr);
}
namespace {
struct ObjectImpl {
const OrtObject* const cls;
std::atomic_int ref_count;
};
} // namespace

View file

@ -3,8 +3,8 @@
//this file contains implementations of the C API
#include "onnxruntime_typeinfo.h"
#include <cassert>
#include "onnxruntime_typeinfo.h"
#include "core/framework/tensor.h"
#include "core/graph/onnx_protobuf.h"
@ -14,16 +14,19 @@ using onnxruntime::MLFloat16;
using onnxruntime::Tensor;
using onnxruntime::TensorShape;
OrtTypeInfo::OrtTypeInfo(ONNXType type1, void* data1) noexcept : type(type1), data(data1) {
OrtTypeInfo::OrtTypeInfo(ONNXType type1, OrtTensorTypeAndShapeInfo* data1) noexcept : type(type1), data(data1) {
}
OrtTypeInfo::~OrtTypeInfo() {
assert(ref_count == 0);
OrtReleaseObject(data);
OrtReleaseTensorTypeAndShapeInfo(data);
}
ORT_API(const struct OrtTensorTypeAndShapeInfo*, OrtCastTypeInfoToTensorInfo, _In_ struct OrtTypeInfo* input) {
return input->type == ONNX_TYPE_TENSOR ? reinterpret_cast<const struct OrtTensorTypeAndShapeInfo*>(input->data) : nullptr;
return input->type == ONNX_TYPE_TENSOR ? input->data : nullptr;
}
ORT_API(void, OrtReleaseTypeInfo, OrtTypeInfo* ptr) {
delete ptr;
}
OrtStatus* GetTensorShapeAndType(const TensorShape* shape, const onnxruntime::DataTypeImpl* tensor_data_type, OrtTensorTypeAndShapeInfo** out);

View file

@ -2,8 +2,8 @@
// Licensed under the MIT License.
#pragma once
#include "core/framework/onnx_object_cxx.h"
#include <atomic>
#include "core/session/onnxruntime_c_api.h"
namespace onnxruntime {
class DataTypeImpl;
@ -18,21 +18,21 @@ class TypeProto;
* the equivalent of onnx::TypeProto
* This class is mainly for the C API
*/
struct OrtTypeInfo : public onnxruntime::ObjectBase<OrtTypeInfo> {
struct OrtTypeInfo {
public:
friend class onnxruntime::ObjectBase<OrtTypeInfo>;
ONNXType type = ONNX_TYPE_UNKNOWN;
~OrtTypeInfo();
//owned by this
void* data = nullptr;
OrtTensorTypeAndShapeInfo* data = nullptr;
OrtTypeInfo(const OrtTypeInfo& other) = delete;
OrtTypeInfo& operator=(const OrtTypeInfo& other) = delete;
static OrtStatus* FromDataTypeImpl(const onnxruntime::DataTypeImpl* input, const onnxruntime::TensorShape* shape,
const onnxruntime::DataTypeImpl* tensor_data_type, OrtTypeInfo** out);
const onnxruntime::DataTypeImpl* tensor_data_type, OrtTypeInfo** out);
static OrtStatus* FromDataTypeImpl(const onnx::TypeProto*, OrtTypeInfo** out);
private:
OrtTypeInfo(ONNXType type, void* data) noexcept;
~OrtTypeInfo();
OrtTypeInfo(ONNXType type, OrtTensorTypeAndShapeInfo* data) noexcept;
};

View file

@ -15,36 +15,29 @@ using onnxruntime::DataTypeImpl;
using onnxruntime::MLFloat16;
using onnxruntime::Tensor;
struct OrtTensorTypeAndShapeInfo : public onnxruntime::ObjectBase<OrtTensorTypeAndShapeInfo> {
struct OrtTensorTypeAndShapeInfo {
public:
friend class onnxruntime::ObjectBase<OrtTensorTypeAndShapeInfo>;
ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
onnxruntime::TensorShape shape;
static OrtTensorTypeAndShapeInfo* Create() {
return new OrtTensorTypeAndShapeInfo();
}
OrtTensorTypeAndShapeInfo() = default;
OrtTensorTypeAndShapeInfo(const OrtTensorTypeAndShapeInfo& other) = delete;
OrtTensorTypeAndShapeInfo& operator=(const OrtTensorTypeAndShapeInfo& other) = delete;
private:
OrtTensorTypeAndShapeInfo() = default;
~OrtTensorTypeAndShapeInfo() {
assert(ref_count == 0);
}
};
#define API_IMPL_BEGIN try {
#define API_IMPL_END \
} \
catch (std::exception & ex) { \
#define API_IMPL_END \
} \
catch (std::exception & ex) { \
return OrtCreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); \
}
ORT_API(OrtTensorTypeAndShapeInfo*, OrtCreateTensorTypeAndShapeInfo) {
return OrtTensorTypeAndShapeInfo::Create();
return new OrtTensorTypeAndShapeInfo();
}
ORT_API(void, OrtReleaseTensorTypeAndShapeInfo, OrtTensorTypeAndShapeInfo* ptr) {
delete ptr;
}
ORT_API_STATUS_IMPL(OrtSetTensorElementType, _In_ OrtTensorTypeAndShapeInfo* this_ptr, enum ONNXTensorElementDataType type) {
@ -126,13 +119,13 @@ OrtStatus* GetTensorShapeAndType(const onnxruntime::TensorShape* shape, const on
OrtTensorTypeAndShapeInfo* ret = OrtCreateTensorTypeAndShapeInfo();
auto status = OrtSetTensorElementType(ret, type);
if (status != nullptr) {
OrtReleaseObject(ret);
OrtReleaseTensorTypeAndShapeInfo(ret);
return status;
}
if (shape != nullptr) {
status = OrtSetDims(ret, shape->GetDims().data(), shape->GetDims().size());
if (status != nullptr) {
OrtReleaseObject(ret);
OrtReleaseTensorTypeAndShapeInfo(ret);
return status;
}
}
@ -160,7 +153,7 @@ ORT_API(enum ONNXType, OrtGetValueType, _In_ const OrtValue* value) {
return ONNX_TYPE_UNKNOWN;
}
ONNXType ret = out->type;
OrtReleaseObject(out);
OrtReleaseTypeInfo(out);
return ret;
} catch (std::exception&) {
return ONNX_TYPE_UNKNOWN;
@ -170,7 +163,7 @@ ORT_API(enum ONNXType, OrtGetValueType, _In_ const OrtValue* value) {
/**
* Get the type information of an OrtValue
* \param value
* \return The returned value should be freed by OrtReleaseObject after use
* \return The returned value should be freed by OrtReleaseTypeInfo after use
*/
ORT_API_STATUS_IMPL(OrtGetTypeInfo, _In_ const OrtValue* value, struct OrtTypeInfo** out) {
auto v = reinterpret_cast<const ::onnxruntime::MLValue*>(value);

View file

@ -4,52 +4,33 @@
#include "core/providers/cpu/cpu_provider_factory.h"
#include <atomic>
#include "cpu_execution_provider.h"
#include "core/session/abi_session_options_impl.h"
using namespace onnxruntime;
namespace onnxruntime {
namespace {
struct CpuProviderFactory {
const OrtProviderFactoryInterface* const cls;
std::atomic_int ref_count;
bool create_arena;
CpuProviderFactory();
struct CpuProviderFactory : IExecutionProviderFactory {
CpuProviderFactory(bool create_arena) : create_arena_(create_arena) {}
~CpuProviderFactory() override {}
std::unique_ptr<IExecutionProvider> CreateProvider() override;
private:
bool create_arena_;
};
OrtStatus* ORT_API_CALL CreateCpu(void* this_, OrtProvider** out) {
std::unique_ptr<IExecutionProvider> CpuProviderFactory::CreateProvider() {
CPUExecutionProviderInfo info;
CpuProviderFactory* this_ptr = (CpuProviderFactory*)this_;
info.create_arena = this_ptr->create_arena;
CPUExecutionProvider* ret = new CPUExecutionProvider(info);
*out = (OrtProvider*)ret;
return nullptr;
info.create_arena = create_arena_;
return std::make_unique<CPUExecutionProvider>(info);
}
uint32_t ORT_API_CALL ReleaseCpu(void* this_) {
CpuProviderFactory* this_ptr = (CpuProviderFactory*)this_;
if (--this_ptr->ref_count == 0)
delete this_ptr;
return 0;
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CPU(int use_arena) {
return std::make_shared<onnxruntime::CpuProviderFactory>(use_arena != 0);
}
uint32_t ORT_API_CALL AddRefCpu(void* this_) {
CpuProviderFactory* this_ptr = (CpuProviderFactory*)this_;
++this_ptr->ref_count;
return 0;
}
} // namespace onnxruntime
constexpr OrtProviderFactoryInterface cpu_cls = {
{AddRefCpu,
ReleaseCpu},
CreateCpu,
};
CpuProviderFactory::CpuProviderFactory() : cls(&cpu_cls), ref_count(1), create_arena(true) {}
} // namespace
ORT_API_STATUS_IMPL(OrtCreateCpuExecutionProviderFactory, int use_arena, _Out_ OrtProviderFactoryInterface*** out) {
CpuProviderFactory* ret = new CpuProviderFactory();
ret->create_arena = (use_arena != 0);
*out = (OrtProviderFactoryInterface**)ret;
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena) {
options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_CPU(use_arena));
return nullptr;
}

View file

@ -1,4 +1,3 @@
OrtAddRefToObject
OrtAllocatorAlloc
OrtAllocatorFree
OrtAllocatorGetInfo
@ -12,7 +11,6 @@ OrtCloneSessionOptions
OrtCompareAllocatorInfo
OrtCreateAllocatorInfo
OrtCreateCpuAllocatorInfo
OrtCreateCpuExecutionProviderFactory
OrtCreateDefaultAllocator
OrtCreateRunOptions
OrtCreateSession
@ -47,9 +45,12 @@ OrtIsTensor
OrtReleaseAllocator
OrtReleaseAllocatorInfo
OrtReleaseEnv
OrtReleaseObject
OrtReleaseRunOptions
OrtReleaseSession
OrtReleaseSessionOptions
OrtReleaseStatus
OrtReleaseTensorTypeAndShapeInfo
OrtReleaseTypeInfo
OrtReleaseValue
OrtRun
OrtRunOptionsGetRunLogVerbosityLevel
@ -63,7 +64,7 @@ OrtSessionGetInputTypeInfo
OrtSessionGetOutputCount
OrtSessionGetOutputName
OrtSessionGetOutputTypeInfo
OrtSessionOptionsAppendExecutionProvider
OrtSessionOptionsAppendExecutionProvider_CPU
OrtSetDims
OrtSetSessionLogId
OrtSetSessionLogVerbosityLevel

View file

@ -4,51 +4,35 @@
#include "core/providers/cuda/cuda_provider_factory.h"
#include <atomic>
#include "cuda_execution_provider.h"
#include "core/session/abi_session_options_impl.h"
using namespace onnxruntime;
namespace {
struct CUDAProviderFactory {
const OrtProviderFactoryInterface* const cls;
std::atomic_int ref_count;
int device_id;
CUDAProviderFactory();
namespace onnxruntime {
struct CUDAProviderFactory : IExecutionProviderFactory {
CUDAProviderFactory(int device_id) : device_id_(device_id) {}
~CUDAProviderFactory() override {}
std::unique_ptr<IExecutionProvider> CreateProvider() override;
private:
int device_id_;
};
OrtStatus* ORT_API_CALL CreateCuda(void* this_, OrtProvider** out) {
std::unique_ptr<IExecutionProvider> CUDAProviderFactory::CreateProvider() {
CUDAExecutionProviderInfo info;
CUDAProviderFactory* this_ptr = (CUDAProviderFactory*)this_;
info.device_id = this_ptr->device_id;
CUDAExecutionProvider* ret = new CUDAExecutionProvider(info);
*out = (OrtProvider*)ret;
return nullptr;
}
uint32_t ORT_API_CALL ReleaseCuda(void* this_) {
CUDAProviderFactory* this_ptr = (CUDAProviderFactory*)this_;
if (--this_ptr->ref_count == 0)
delete this_ptr;
return 0;
}
uint32_t ORT_API_CALL AddRefCuda(void* this_) {
CUDAProviderFactory* this_ptr = (CUDAProviderFactory*)this_;
++this_ptr->ref_count;
return 0;
}
constexpr OrtProviderFactoryInterface cuda_cls = {
AddRefCuda,
ReleaseCuda,
CreateCuda,
};
CUDAProviderFactory::CUDAProviderFactory() : cls(&cuda_cls), ref_count(1), device_id(0) {}
} // namespace
ORT_API_STATUS_IMPL(OrtCreateCUDAExecutionProviderFactory, int device_id, _Out_ OrtProviderFactoryInterface*** out) {
CUDAProviderFactory* ret = new CUDAProviderFactory();
ret->device_id = device_id;
*out = (OrtProviderFactoryInterface**)ret;
info.device_id = device_id_;
return std::make_unique<CUDAExecutionProvider>(info);
}
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CUDA(int device_id) {
return std::make_shared<onnxruntime::CUDAProviderFactory>(device_id);
}
} // namespace onnxruntime
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, int device_id) {
options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_CUDA(device_id));
return nullptr;
}

View file

@ -1 +1 @@
OrtCreateCUDAExecutionProviderFactory
OrtSessionOptionsAppendExecutionProvider_CUDA

View file

@ -4,51 +4,34 @@
#include "core/providers/mkldnn/mkldnn_provider_factory.h"
#include <atomic>
#include "mkldnn_execution_provider.h"
#include "core/session/abi_session_options_impl.h"
using namespace onnxruntime;
namespace {
struct MkldnnProviderFactory {
const OrtProviderFactoryInterface* const cls;
std::atomic_int ref_count;
bool create_arena;
MkldnnProviderFactory();
namespace onnxruntime {
struct MkldnnProviderFactory : IExecutionProviderFactory {
MkldnnProviderFactory(bool create_arena) : create_arena_(create_arena) {}
~MkldnnProviderFactory() override {}
std::unique_ptr<IExecutionProvider> CreateProvider() override;
private:
bool create_arena_;
};
OrtStatus* ORT_API_CALL CreateMkldnn(void* this_, OrtProvider** out) {
std::unique_ptr<IExecutionProvider> MkldnnProviderFactory::CreateProvider() {
MKLDNNExecutionProviderInfo info;
MkldnnProviderFactory* this_ptr = (MkldnnProviderFactory*)this_;
info.create_arena = this_ptr->create_arena;
MKLDNNExecutionProvider* ret = new MKLDNNExecutionProvider(info);
*out = (OrtProvider*)ret;
return nullptr;
}
uint32_t ORT_API_CALL ReleaseMkldnn(void* this_) {
MkldnnProviderFactory* this_ptr = (MkldnnProviderFactory*)this_;
if (--this_ptr->ref_count == 0)
delete this_ptr;
return 0;
}
uint32_t ORT_API_CALL AddRefMkldnn(void* this_) {
MkldnnProviderFactory* this_ptr = (MkldnnProviderFactory*)this_;
++this_ptr->ref_count;
return 0;
}
constexpr OrtProviderFactoryInterface mkl_cls = {
{AddRefMkldnn,
ReleaseMkldnn},
CreateMkldnn,
};
MkldnnProviderFactory::MkldnnProviderFactory() : cls(&mkl_cls), ref_count(1), create_arena(true) {}
} // namespace
ORT_API_STATUS_IMPL(OrtCreateMkldnnExecutionProviderFactory, int use_arena, _Out_ OrtProviderFactoryInterface*** out) {
MkldnnProviderFactory* ret = new MkldnnProviderFactory();
ret->create_arena = (use_arena != 0);
*out = (OrtProviderFactoryInterface**)ret;
info.create_arena = create_arena_;
return std::make_unique<MKLDNNExecutionProvider>(info);
}
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Mkldnn(int device_id) {
return std::make_shared<onnxruntime::MkldnnProviderFactory>(device_id);
}
} // namespace onnxruntime
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Mkldnn, _In_ OrtSessionOptions* options, int use_arena) {
options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_Mkldnn(use_arena));
return nullptr;
}

View file

@ -1 +1 @@
OrtCreateMkldnnExecutionProviderFactory
OrtSessionOptionsAppendExecutionProvider_Mkldnn

View file

@ -8,10 +8,6 @@
#include "abi_session_options_impl.h"
OrtSessionOptions::~OrtSessionOptions() {
assert(ref_count == 0);
for (OrtProviderFactoryInterface** p : provider_factories) {
OrtReleaseObject(p);
}
}
OrtSessionOptions& OrtSessionOptions::operator=(const OrtSessionOptions&) {
@ -19,15 +15,17 @@ OrtSessionOptions& OrtSessionOptions::operator=(const OrtSessionOptions&) {
}
OrtSessionOptions::OrtSessionOptions(const OrtSessionOptions& other)
: value(other.value), custom_op_paths(other.custom_op_paths), provider_factories(other.provider_factories) {
for (OrtProviderFactoryInterface** p : other.provider_factories) {
OrtAddRefToObject(p);
}
}
ORT_API(OrtSessionOptions*, OrtCreateSessionOptions) {
std::unique_ptr<OrtSessionOptions> options = std::make_unique<OrtSessionOptions>();
return options.release();
}
ORT_API(void, OrtReleaseSessionOptions, OrtSessionOptions* ptr) {
delete ptr;
}
ORT_API(OrtSessionOptions*, OrtCloneSessionOptions, OrtSessionOptions* input) {
try {
return new OrtSessionOptions(*input);
@ -36,11 +34,6 @@ ORT_API(OrtSessionOptions*, OrtCloneSessionOptions, OrtSessionOptions* input) {
}
}
ORT_API(void, OrtSessionOptionsAppendExecutionProvider, _In_ OrtSessionOptions* options, _In_ OrtProviderFactoryInterface** f) {
OrtAddRefToObject(f);
options->provider_factories.push_back(f);
}
ORT_API(void, OrtEnableSequentialExecution, _In_ OrtSessionOptions* options) {
options->value.enable_sequential_execution = true;
}

View file

@ -6,14 +6,14 @@
#include <string>
#include <vector>
#include <atomic>
#include "core/framework/onnx_object_cxx.h"
#include "core/session/inference_session.h"
#include "core/session/onnxruntime_c_api.h"
#include "core/providers/providers.h"
struct OrtSessionOptions : public onnxruntime::ObjectBase<OrtSessionOptions> {
struct OrtSessionOptions {
onnxruntime::SessionOptions value;
std::vector<std::string> custom_op_paths;
std::vector<OrtProviderFactoryInterface**> provider_factories;
std::vector<std::shared_ptr<onnxruntime::IExecutionProviderFactory>> provider_factories;
OrtSessionOptions() = default;
~OrtSessionOptions();
OrtSessionOptions(const OrtSessionOptions& other);

View file

@ -20,7 +20,6 @@
#include "core/framework/environment.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/onnxruntime_typeinfo.h"
#include "core/framework/onnx_object_cxx.h"
#include "core/session/inference_session.h"
#include "abi_session_options_impl.h"
@ -49,7 +48,6 @@ struct OrtEnv {
public:
Environment* value;
LoggingManager* loggingManager;
friend class onnxruntime::ObjectBase<OrtEnv>;
OrtEnv(Environment* value1, LoggingManager* loggingManager1) : value(value1), loggingManager(loggingManager1) {
}
@ -367,13 +365,10 @@ static OrtStatus* CreateSessionImpl(_In_ OrtEnv* env, _In_ T model_path,
return ToOrtStatus(status);
}
if (options != nullptr)
for (OrtProviderFactoryInterface** p : options->provider_factories) {
OrtProvider* provider;
OrtStatus* error_code = (*p)->CreateProvider(p, &provider);
if (error_code)
return error_code;
sess->RegisterExecutionProvider(std::unique_ptr<onnxruntime::IExecutionProvider>(
reinterpret_cast<onnxruntime::IExecutionProvider*>(provider)));
for (auto& factory : options->provider_factories) {
auto provider = factory->CreateProvider();
if (provider)
sess->RegisterExecutionProvider(std::move(provider));
}
status = sess->Load(model_path);
if (!status.IsOK())
@ -638,5 +633,6 @@ ORT_API_STATUS_IMPL(OrtSessionGetOutputName, _In_ const OrtSession* sess, size_t
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Env, OrtEnv)
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Value, MLValue)
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(RunOptions, OrtRunOptions)
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Session, ::onnxruntime::InferenceSession)
DEFINE_RELEASE_ORT_OBJECT_FUNCTION_FOR_ARRAY(Status, char)

View file

@ -42,6 +42,7 @@
#define BACKEND_DEVICE BACKEND_PROC BACKEND_MKLDNN BACKEND_MKLML BACKEND_OPENBLAS
#include "core/session/onnxruntime_cxx_api.h"
#include "core/providers/providers.h"
#include "core/providers/cpu/cpu_execution_provider.h"
#include "core/providers/cpu/cpu_provider_factory.h"
@ -54,6 +55,15 @@
#ifdef USE_NUPHAR
#include "core/providers/nuphar/nuphar_provider_factory.h"
#endif
namespace onnxruntime {
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CPU(int use_arena);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CUDA(int device_id);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Mkldnn(int use_arena);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nuphar(int device_id, const char*);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_BrainSlice(int id, bool f, const char*, const char*, const char*);
} // namespace onnxruntime
#if defined(_MSC_VER)
#pragma warning(disable : 4267 4996 4503 4003)
#endif // _MSC_VER
@ -172,45 +182,32 @@ class SessionObjectInitializer {
}
};
inline void RegisterExecutionProvider(InferenceSession* sess, OrtProviderFactoryInterface** f) {
OrtProvider* p;
(*f)->CreateProvider(f, &p);
std::unique_ptr<onnxruntime::IExecutionProvider> q((onnxruntime::IExecutionProvider*)p);
auto status = sess->RegisterExecutionProvider(std::move(q));
inline void RegisterExecutionProvider(InferenceSession* sess, onnxruntime::IExecutionProviderFactory& f) {
auto p = f.CreateProvider();
auto status = sess->RegisterExecutionProvider(std::move(p));
if (!status.IsOK()) {
throw std::runtime_error(status.ErrorMessage().c_str());
}
}
#define FACTORY_PTR_HOLDER \
std::unique_ptr<OrtProviderFactoryInterface*, decltype(&OrtReleaseObject)> ptr_holder_(f, OrtReleaseObject);
void InitializeSession(InferenceSession* sess) {
onnxruntime::common::Status status;
#ifdef USE_CUDA
{
OrtProviderFactoryInterface** f;
ORT_THROW_ON_ERROR(OrtCreateCUDAExecutionProviderFactory(0, &f));
RegisterExecutionProvider(sess, f);
FACTORY_PTR_HOLDER;
RegisterExecutionProvider(sess, *onnxruntime::CreateExecutionProviderFactory_CUDA(0));
}
#endif
#ifdef USE_MKLDNN
{
const bool enable_cpu_mem_arena = true;
OrtProviderFactoryInterface** f;
ORT_THROW_ON_ERROR(OrtCreateMkldnnExecutionProviderFactory(enable_cpu_mem_arena ? 1 : 0, &f));
RegisterExecutionProvider(sess, f);
FACTORY_PTR_HOLDER;
RegisterExecutionProvider(sess, *onnxruntime::CreateExecutionProviderFactory_Mkldnn(enable_cpu_mem_arena ? 1 : 0));
}
#endif
#if 0 //USE_NUPHAR
{
OrtProviderFactoryInterface** f;
ORT_THROW_ON_ERROR(OrtCreateNupharExecutionProviderFactory(0, "", &f));
RegisterExecutionProvider(sess, f);
FACTORY_PTR_HOLDER;
RegisterExecutionProvider(sess, *onnxruntime::CreateExecutionProviderFactory_Nuphar(0, ""));
}
#endif

View file

@ -203,10 +203,7 @@ int real_main(int argc, char* argv[]) {
sf.DisableSequentialExecution();
if (enable_cuda) {
#ifdef USE_CUDA
OrtProviderFactoryInterface** f;
ORT_THROW_ON_ERROR(OrtCreateCUDAExecutionProviderFactory(0, &f));
sf.AppendExecutionProvider(f);
OrtReleaseObject(f);
ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_CUDA(sf, 0));
#else
fprintf(stderr, "CUDA is not supported in this build");
return -1;
@ -214,10 +211,7 @@ int real_main(int argc, char* argv[]) {
}
if (enable_nuphar) {
#ifdef USE_NUPHAR
OrtProviderFactoryInterface** f;
ORT_THROW_ON_ERROR(OrtCreateNupharExecutionProviderFactory(0, "", &f));
sf.AppendExecutionProvider(f);
OrtReleaseObject(f);
ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Nuphar(sf, 0, ""));
#else
fprintf(stderr, "Nuphar is not supported in this build");
return -1;
@ -225,10 +219,7 @@ int real_main(int argc, char* argv[]) {
}
if (enable_mkl) {
#ifdef USE_MKLDNN
OrtProviderFactoryInterface** f;
ORT_THROW_ON_ERROR(OrtCreateMkldnnExecutionProviderFactory(enable_cpu_mem_arena ? 1 : 0, &f));
sf.AppendExecutionProvider(f);
OrtReleaseObject(f);
ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Mkldnn(sf, enable_cpu_mem_arena ? 1 : 0));
#else
fprintf(stderr, "MKL-DNN is not supported in this build");
return -1;

View file

@ -12,21 +12,17 @@
#include <filesystem>
#endif
#include "providers.h"
#include "default_providers.h"
using namespace std::experimental::filesystem::v1;
using onnxruntime::Status;
inline void RegisterExecutionProvider(onnxruntime::InferenceSession* sess, OrtProviderFactoryInterface** f) {
OrtProvider* p;
(*f)->CreateProvider(f, &p);
std::unique_ptr<onnxruntime::IExecutionProvider> q((onnxruntime::IExecutionProvider*)p);
auto status = sess->RegisterExecutionProvider(std::move(q));
inline void RegisterExecutionProvider(onnxruntime::InferenceSession* sess, std::unique_ptr<onnxruntime::IExecutionProvider>&& f) {
auto status = sess->RegisterExecutionProvider(std::move(f));
if (!status.IsOK()) {
throw std::runtime_error(status.ErrorMessage().c_str());
}
}
#define FACTORY_PTR_HOLDER \
std::unique_ptr<OrtProviderFactoryInterface*, decltype(&OrtReleaseObject)> ptr_holder_(f, OrtReleaseObject);
Status SessionFactory::create(std::shared_ptr<::onnxruntime::InferenceSession>& sess, const path& model_url, const std::string& logid) const {
::onnxruntime::SessionOptions so;
@ -41,37 +37,25 @@ Status SessionFactory::create(std::shared_ptr<::onnxruntime::InferenceSession>&
for (const std::string& provider : providers_) {
if (provider == onnxruntime::kCudaExecutionProvider) {
#ifdef USE_CUDA
OrtProviderFactoryInterface** f;
ORT_THROW_ON_ERROR(OrtCreateCUDAExecutionProviderFactory(0, &f));
FACTORY_PTR_HOLDER;
RegisterExecutionProvider(sess.get(), f);
RegisterExecutionProvider(sess.get(), onnxruntime::test::DefaultCudaExecutionProvider());
#else
ORT_THROW("CUDA is not supported in this build");
#endif
} else if (provider == onnxruntime::kMklDnnExecutionProvider) {
#ifdef USE_MKLDNN
OrtProviderFactoryInterface** f;
ORT_THROW_ON_ERROR(OrtCreateMkldnnExecutionProviderFactory(enable_cpu_mem_arena_ ? 1 : 0, &f));
FACTORY_PTR_HOLDER;
RegisterExecutionProvider(sess.get(), f);
RegisterExecutionProvider(sess.get(), onnxruntime::test::DefaultMkldnnExecutionProvider(enable_cpu_mem_arena_ ? 1 : 0));
#else
ORT_THROW("CUDA is not supported in this build");
#endif
} else if (provider == onnxruntime::kNupharExecutionProvider) {
#ifdef USE_NUPHAR
OrtProviderFactoryInterface** f;
ORT_THROW_ON_ERROR(OrtCreateNupharExecutionProviderFactory(0, "", &f));
RegisterExecutionProvider(sess.get(), f);
FACTORY_PTR_HOLDER;
RegisterExecutionProvider(sess.get(), onnxruntime::test::DefaultNupharExecutionProvider());
#else
ORT_THROW("CUDA is not supported in this build");
#endif
} else if (provider == onnxruntime::kBrainSliceExecutionProvider) {
#if USE_BRAINSLICE
OrtProviderFactoryInterface** f;
ORT_THROW_ON_ERROR(OrtCreateBrainSliceExecutionProviderFactory(0, true, "testdata/firmwares/onnx_rnns/instructions.bin", "testdata/firmwares/onnx_rnns/data.bin", "testdata/firmwares/onnx_rnns/schema.bin", &f));
RegisterExecutionProvider(sess.get(), f);
FACTORY_PTR_HOLDER;
RegisterExecutionProvider(sess.get(), onnxruntime::test::DefaultBrainsliceExecutionProvider());
#else
ORT_THROW("This executable was not built with BrainSlice");
#endif
@ -85,7 +69,7 @@ Status SessionFactory::create(std::shared_ptr<::onnxruntime::InferenceSession>&
ORT_THROW("TensorRT is not supported in this build");
#endif
}
//TODO: add more
// TODO: add more
}
status = sess->Load(model_url.string());

View file

@ -182,10 +182,7 @@ void verify_input_output_count(OrtSession* session) {
#ifdef USE_CUDA
void enable_cuda(OrtSessionOptions* session_option) {
OrtProviderFactoryInterface** factory;
ORT_ABORT_ON_ERROR(OrtCreateCUDAExecutionProviderFactory(0, &factory));
OrtSessionOptionsAppendExecutionProvider(session_option, factory);
OrtReleaseObject(factory);
ORT_ABORT_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_CUDA(session_option, 0));
}
#endif
@ -207,9 +204,9 @@ int main(int argc, char* argv[]) {
ORT_ABORT_ON_ERROR(OrtCreateSession(env, model_path, session_option, &session));
verify_input_output_count(session);
int ret = run_inference(session, input_file, output_file);
OrtReleaseObject(session_option);
OrtReleaseSessionOptions(session_option);
OrtReleaseSession(session);
OrtReleaseObject(env);
OrtReleaseEnv(env);
if (ret != 0) {
fprintf(stderr, "fail\n");
}

View file

@ -64,30 +64,21 @@ void TestInference(OrtEnv* env, T model_uri,
if (provider_type == 1) {
#ifdef USE_CUDA
OrtProviderFactoryInterface** f;
ORT_THROW_ON_ERROR(OrtCreateCUDAExecutionProviderFactory(0, &f));
sf.AppendExecutionProvider(f);
OrtReleaseObject(f);
ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_CUDA(sf, 0));
std::cout << "Running simple inference with cuda provider" << std::endl;
#else
return;
#endif
} else if (provider_type == 2) {
#ifdef USE_MKLDNN
OrtProviderFactoryInterface** f;
ORT_THROW_ON_ERROR(OrtCreateMkldnnExecutionProviderFactory(1, &f));
sf.AppendExecutionProvider(f);
OrtReleaseObject(f);
ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Mkldnn(sf, 1));
std::cout << "Running simple inference with mkldnn provider" << std::endl;
#else
return;
#endif
} else if (provider_type == 3) {
#ifdef USE_NUPHAR
OrtProviderFactoryInterface** f;
ORT_THROW_ON_ERROR(OrtCreateNupharExecutionProviderFactory(0, "", &f));
sf.AppendExecutionProvider(f);
OrtReleaseObject(f);
ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Nuphar(sf, 0, ""));
std::cout << "Running simple inference with nuphar provider" << std::endl;
#else
return;
@ -196,7 +187,7 @@ TEST_F(CApiTest, create_tensor_with_data) {
const struct OrtTensorTypeAndShapeInfo* tensor_info = OrtCastTypeInfoToTensorInfo(type_info);
ASSERT_NE(tensor_info, nullptr);
ASSERT_EQ(1, OrtGetNumOfDimensions(tensor_info));
OrtReleaseObject(type_info);
OrtReleaseTypeInfo(type_info);
}
int main(int argc, char** argv) {

View file

@ -4,28 +4,25 @@
#include "default_providers.h"
#include "providers.h"
#include "core/session/onnxruntime_cxx_api.h"
#define FACTORY_PTR_HOLDER \
std::unique_ptr<OrtProviderFactoryInterface*> ptr_holder_(f);
#include "core/providers/providers.h"
namespace onnxruntime {
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CPU(int use_arena);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CUDA(int device_id);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Mkldnn(int use_arena);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nuphar(int device_id, const char*);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_BrainSlice(int id, bool f, const char*, const char*, const char*);
namespace test {
std::unique_ptr<IExecutionProvider> DefaultCpuExecutionProvider(bool enable_arena) {
OrtProviderFactoryInterface** f;
ORT_THROW_ON_ERROR(OrtCreateCpuExecutionProviderFactory(enable_arena ? 1 : 0, &f));
FACTORY_PTR_HOLDER;
OrtProvider* out;
ORT_THROW_ON_ERROR((*f)->CreateProvider(f, &out));
return std::unique_ptr<IExecutionProvider>((IExecutionProvider*)out);
return CreateExecutionProviderFactory_CPU(enable_arena)->CreateProvider();
}
std::unique_ptr<IExecutionProvider> DefaultCudaExecutionProvider() {
#ifdef USE_CUDA
OrtProviderFactoryInterface** f;
ORT_THROW_ON_ERROR(OrtCreateCUDAExecutionProviderFactory(0, &f));
FACTORY_PTR_HOLDER;
OrtProvider* out;
ORT_THROW_ON_ERROR((*f)->CreateProvider(f, &out));
return std::unique_ptr<IExecutionProvider>((IExecutionProvider*)out);
return CreateExecutionProviderFactory_CUDA(0)->CreateProvider();
#else
return nullptr;
#endif
@ -33,12 +30,7 @@ std::unique_ptr<IExecutionProvider> DefaultCudaExecutionProvider() {
std::unique_ptr<IExecutionProvider> DefaultMkldnnExecutionProvider(bool enable_arena) {
#ifdef USE_MKLDNN
OrtProviderFactoryInterface** f;
ORT_THROW_ON_ERROR(OrtCreateMkldnnExecutionProviderFactory(enable_arena ? 1 : 0, &f));
FACTORY_PTR_HOLDER;
OrtProvider* out;
ORT_THROW_ON_ERROR((*f)->CreateProvider(f, &out));
return std::unique_ptr<IExecutionProvider>((IExecutionProvider*)out);
return CreateExecutionProviderFactory_Mkldnn(enable_arena ? 1 : 0)->CreateProvider();
#else
ORT_UNUSED_PARAMETER(enable_arena);
return nullptr;
@ -47,12 +39,7 @@ std::unique_ptr<IExecutionProvider> DefaultMkldnnExecutionProvider(bool enable_a
std::unique_ptr<IExecutionProvider> DefaultNupharExecutionProvider() {
#ifdef USE_NUPHAR
OrtProviderFactoryInterface** f;
ORT_THROW_ON_ERROR(OrtCreateNupharExecutionProviderFactory(0, "", &f));
FACTORY_PTR_HOLDER;
OrtProvider* out;
ORT_THROW_ON_ERROR((*f)->CreateProvider(f, &out));
return std::unique_ptr<IExecutionProvider>((IExecutionProvider*)out);
return CreateExecutionProviderFactory_Nuphar(0, "")->CreateProvider();
#else
return nullptr;
#endif
@ -60,12 +47,7 @@ std::unique_ptr<IExecutionProvider> DefaultNupharExecutionProvider() {
std::unique_ptr<IExecutionProvider> DefaultBrainSliceExecutionProvider() {
#ifdef USE_BRAINSLICE
OrtProviderFactoryInterface** f;
ORT_THROW_ON_ERROR(OrtCreateBrainSliceExecutionProviderFactory(0, true, "testdata/firmwares/onnx_rnns/instructions.bin", "testdata/firmwares/onnx_rnns/data.bin", "testdata/firmwares/onnx_rnns/schema.bin", &f));
FACTORY_PTR_HOLDER;
OrtProvider* out;
ORT_THROW_ON_ERROR((*f)->CreateProvider(f, &out));
return std::unique_ptr<IExecutionProvider>((IExecutionProvider*)out);
return CreateExecutionProviderFactory_BrainSlice(0, true, "testdata/firmwares/onnx_rnns/instructions.bin", "testdata/firmwares/onnx_rnns/data.bin", "testdata/firmwares/onnx_rnns/schema.bin", &f));
#else
return nullptr;
#endif