mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-20 02:07:56 +00:00
C API - Remove reference counting (#344)
This commit is contained in:
parent
6349114583
commit
d875ab2acd
36 changed files with 263 additions and 661 deletions
|
|
@ -1,106 +0,0 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Text;
|
||||
using System.Runtime.InteropServices;
|
||||
|
||||
namespace Microsoft.ML.OnnxRuntime
|
||||
{
|
||||
|
||||
internal class CpuExecutionProviderFactory: NativeOnnxObjectHandle
|
||||
{
|
||||
protected static readonly Lazy<CpuExecutionProviderFactory> _default = new Lazy<CpuExecutionProviderFactory>(() => new CpuExecutionProviderFactory());
|
||||
|
||||
public CpuExecutionProviderFactory(bool useArena=true)
|
||||
:base(IntPtr.Zero)
|
||||
{
|
||||
int useArenaInt = useArena ? 1 : 0;
|
||||
try
|
||||
{
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateCpuExecutionProviderFactory(useArenaInt, out handle));
|
||||
}
|
||||
catch(OnnxRuntimeException e)
|
||||
{
|
||||
if (IsInvalid)
|
||||
{
|
||||
ReleaseHandle();
|
||||
handle = IntPtr.Zero;
|
||||
}
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
public static CpuExecutionProviderFactory Default
|
||||
{
|
||||
get
|
||||
{
|
||||
return _default.Value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal class MklDnnExecutionProviderFactory : NativeOnnxObjectHandle
|
||||
{
|
||||
protected static readonly Lazy<MklDnnExecutionProviderFactory> _default = new Lazy<MklDnnExecutionProviderFactory>(() => new MklDnnExecutionProviderFactory());
|
||||
|
||||
public MklDnnExecutionProviderFactory(bool useArena = true)
|
||||
:base(IntPtr.Zero)
|
||||
{
|
||||
int useArenaInt = useArena ? 1 : 0;
|
||||
try
|
||||
{
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateMkldnnExecutionProviderFactory(useArenaInt, out handle));
|
||||
}
|
||||
catch (OnnxRuntimeException e)
|
||||
{
|
||||
if (IsInvalid)
|
||||
{
|
||||
ReleaseHandle();
|
||||
handle = IntPtr.Zero;
|
||||
}
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
public static MklDnnExecutionProviderFactory Default
|
||||
{
|
||||
get
|
||||
{
|
||||
return _default.Value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal class CudaExecutionProviderFactory : NativeOnnxObjectHandle
|
||||
{
|
||||
protected static readonly Lazy<CudaExecutionProviderFactory> _default = new Lazy<CudaExecutionProviderFactory>(() => new CudaExecutionProviderFactory());
|
||||
|
||||
public CudaExecutionProviderFactory(int deviceId = 0)
|
||||
: base(IntPtr.Zero)
|
||||
{
|
||||
try
|
||||
{
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateCUDAExecutionProviderFactory(deviceId, out handle));
|
||||
}
|
||||
catch (OnnxRuntimeException e)
|
||||
{
|
||||
if (IsInvalid)
|
||||
{
|
||||
ReleaseHandle();
|
||||
handle = IntPtr.Zero;
|
||||
}
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
public static CudaExecutionProviderFactory Default
|
||||
{
|
||||
get
|
||||
{
|
||||
return _default.Value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
|
@ -45,9 +45,9 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
try
|
||||
{
|
||||
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.Unicode.GetBytes(modelPath), options.NativeHandle, out _nativeHandle));
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.Unicode.GetBytes(modelPath), options._nativePtr, out _nativeHandle));
|
||||
else
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.UTF8.GetBytes(modelPath), options.NativeHandle, out _nativeHandle));
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.UTF8.GetBytes(modelPath), options._nativePtr, out _nativeHandle));
|
||||
|
||||
// Initialize input/output metadata
|
||||
_inputMetadata = new Dictionary<string, NodeMetadata>();
|
||||
|
|
@ -275,7 +275,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
{
|
||||
if (typeInfo != IntPtr.Zero)
|
||||
{
|
||||
NativeMethods.OrtReleaseObject(typeInfo);
|
||||
NativeMethods.OrtReleaseTypeInfo(typeInfo);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -292,7 +292,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
{
|
||||
if (typeInfo != IntPtr.Zero)
|
||||
{
|
||||
NativeMethods.OrtReleaseObject(typeInfo);
|
||||
NativeMethods.OrtReleaseTypeInfo(typeInfo);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -198,7 +198,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
{
|
||||
if (typeAndShape != IntPtr.Zero)
|
||||
{
|
||||
NativeMethods.OrtReleaseObject(typeAndShape);
|
||||
NativeMethods.OrtReleaseTensorTypeAndShapeInfo(typeAndShape);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -141,7 +141,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
|
||||
protected static void Delete(IntPtr allocator)
|
||||
{
|
||||
NativeMethods.OrtReleaseObject(allocator);
|
||||
NativeMethods.OrtReleaseAllocator(allocator);
|
||||
}
|
||||
|
||||
protected override bool ReleaseHandle()
|
||||
|
|
|
|||
|
|
@ -90,20 +90,22 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
IntPtr /*(OrtAllocator*)*/ allocator,
|
||||
out IntPtr /*(char**)*/name);
|
||||
|
||||
// release the typeinfo using OrtReleaseObject
|
||||
// release the typeinfo using OrtReleaseTypeInfo
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern IntPtr /*(OrtStatus*)*/OrtSessionGetInputTypeInfo(
|
||||
IntPtr /*(const OrtSession*)*/ session,
|
||||
ulong index, //TODO: port for size_t
|
||||
out IntPtr /*(struct OrtTypeInfo**)*/ typeInfo);
|
||||
|
||||
// release the typeinfo using OrtReleaseObject
|
||||
// release the typeinfo using OrtReleaseTypeInfo
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern IntPtr /*(OrtStatus*)*/OrtSessionGetOutputTypeInfo(
|
||||
IntPtr /*(const OrtSession*)*/ session,
|
||||
ulong index, //TODO: port for size_t
|
||||
out IntPtr /* (struct OrtTypeInfo**)*/ typeInfo);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern void OrtReleaseTypeInfo(IntPtr /*(OrtTypeInfo*)*/session);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern void OrtReleaseSession(IntPtr /*(OrtSession*)*/session);
|
||||
|
|
@ -112,11 +114,12 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
|
||||
#region SessionOptions API
|
||||
|
||||
//Release using OrtReleaseObject
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern IntPtr /*OrtSessionOptions* */ OrtCreateSessionOptions();
|
||||
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern void OrtReleaseSessionOptions(IntPtr /*(OrtSessionOptions*)*/session);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern IntPtr /*(OrtSessionOptions*)*/OrtCloneSessionOptions(IntPtr /*(OrtSessionOptions*)*/ sessionOptions);
|
||||
|
||||
|
|
@ -153,22 +156,20 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern int OrtSetSessionThreadPoolSize(IntPtr /* OrtSessionOptions* */ options, int sessionThreadPoolSize);
|
||||
|
||||
|
||||
///**
|
||||
// * The order of invocation indicates the preference order as well. In other words call this method
|
||||
// * on your most preferred execution provider first followed by the less preferred ones.
|
||||
// * Calling this API is optional in which case onnxruntime will use its internal CPU execution provider.
|
||||
// */
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern void OrtSessionOptionsAppendExecutionProvider(IntPtr /*(OrtSessionOptions*)*/ options, IntPtr /* (OrtProviderFactoryPtr*)*/ factory);
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_CPU(IntPtr /*(OrtSessionOptions*) */ options, int use_arena);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtCreateCpuExecutionProviderFactory(int use_arena, out IntPtr /*(OrtProviderFactoryPtr*)*/ factory);
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Mkldnn(IntPtr /*(OrtSessionOptions*) */ options, int use_arena);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtCreateMkldnnExecutionProviderFactory(int use_arena, out IntPtr /*(OrtProviderFactoryPtr**)*/ factory);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtCreateCUDAExecutionProviderFactory(int device_id, out IntPtr /*(OrtProviderFactoryPtr**)*/ factory);
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_CUDA(IntPtr /*(OrtSessionOptions*) */ options, int device_id);
|
||||
|
||||
//[DllImport(nativeLib, CharSet = charSet)]
|
||||
//public static extern IntPtr /*(OrtStatus*)*/ OrtCreateNupharExecutionProviderFactory(int device_id, string target_str, out IntPtr /*(OrtProviderFactoryPtr**)*/ factory);
|
||||
|
|
@ -220,13 +221,8 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern IntPtr /*(OrtStatus*)*/OrtCreateDefaultAllocator(out IntPtr /*(OrtAllocator**)*/ allocator);
|
||||
|
||||
/// <summary>
|
||||
/// Releases/Unrefs any object, including the Allocator
|
||||
/// </summary>
|
||||
/// <param name="ptr"></param>
|
||||
/// <returns>remaining ref count</returns>
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern uint /*remaining ref count*/ OrtReleaseObject(IntPtr /*(void*)*/ ptr);
|
||||
public static extern void OrtReleaseAllocator(IntPtr /*(OrtAllocator*)*/ allocator);
|
||||
|
||||
/// <summary>
|
||||
/// Release any object allocated by an allocator
|
||||
|
|
@ -265,6 +261,10 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtGetTensorShapeAndType(IntPtr /*(OrtValue*)*/ value, out IntPtr /*(struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo);
|
||||
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern void OrtReleaseTensorTypeAndShapeInfo(IntPtr /*(OrtTensorTypeAndShapeInfo*)*/ value);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern TensorElementType OrtGetTensorElementType(IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,28 +0,0 @@
|
|||
using System;
|
||||
using System.Runtime.InteropServices;
|
||||
|
||||
namespace Microsoft.ML.OnnxRuntime
|
||||
{
|
||||
|
||||
internal class NativeOnnxObjectHandle : SafeHandle
|
||||
{
|
||||
public NativeOnnxObjectHandle(IntPtr ptr)
|
||||
: base(IntPtr.Zero, true)
|
||||
{
|
||||
handle = ptr;
|
||||
}
|
||||
public override bool IsInvalid
|
||||
{
|
||||
get
|
||||
{
|
||||
return (handle == IntPtr.Zero);
|
||||
}
|
||||
}
|
||||
|
||||
protected override bool ReleaseHandle()
|
||||
{
|
||||
NativeMethods.OrtReleaseObject(handle);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -70,7 +70,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
{
|
||||
if (typeAndShape != IntPtr.Zero)
|
||||
{
|
||||
NativeMethods.OrtReleaseObject(typeAndShape);
|
||||
NativeMethods.OrtReleaseTensorTypeAndShapeInfo(typeAndShape);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
/// </summary>
|
||||
public class SessionOptions:IDisposable
|
||||
{
|
||||
protected SafeHandle _nativeOption;
|
||||
public IntPtr _nativePtr;
|
||||
protected static readonly Lazy<SessionOptions> _default = new Lazy<SessionOptions>(MakeSessionOptionWithMklDnnProvider);
|
||||
private static string[] cudaDelayLoadedLibs = { "cublas64_100.dll", "cudnn64_7.dll" };
|
||||
|
||||
|
|
@ -32,7 +32,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
/// </summary>
|
||||
public SessionOptions()
|
||||
{
|
||||
_nativeOption = new NativeOnnxObjectHandle(NativeMethods.OrtCreateSessionOptions());
|
||||
_nativePtr = NativeMethods.OrtCreateSessionOptions();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
@ -46,33 +46,11 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Append an execution propvider. When any operator is evaluated, it is executed on the first execution provider that provides it
|
||||
/// </summary>
|
||||
/// <param name="provider"></param>
|
||||
public void AppendExecutionProvider(ExecutionProvider provider)
|
||||
{
|
||||
switch (provider)
|
||||
{
|
||||
case ExecutionProvider.Cpu:
|
||||
AppendExecutionProvider(CpuExecutionProviderFactory.Default);
|
||||
break;
|
||||
case ExecutionProvider.MklDnn:
|
||||
AppendExecutionProvider(MklDnnExecutionProviderFactory.Default);
|
||||
break;
|
||||
case ExecutionProvider.Cuda:
|
||||
AppendExecutionProvider(CudaExecutionProviderFactory.Default);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
private static SessionOptions MakeSessionOptionWithMklDnnProvider()
|
||||
{
|
||||
SessionOptions options = new SessionOptions();
|
||||
options.AppendExecutionProvider(MklDnnExecutionProviderFactory.Default);
|
||||
options.AppendExecutionProvider(CpuExecutionProviderFactory.Default);
|
||||
// NativeMethods.OrtSessionOptionsAppendExecutionProvider_Mkldnn(_nativePtr, 1);
|
||||
NativeMethods.OrtSessionOptionsAppendExecutionProvider_CPU(options._nativePtr, 1);
|
||||
return options;
|
||||
}
|
||||
|
||||
|
|
@ -94,38 +72,12 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
{
|
||||
CheckCudaExecutionProviderDLLs();
|
||||
SessionOptions options = new SessionOptions();
|
||||
if (deviceId == 0) //default value
|
||||
options.AppendExecutionProvider(CudaExecutionProviderFactory.Default);
|
||||
else
|
||||
options.AppendExecutionProvider(new CudaExecutionProviderFactory(deviceId));
|
||||
options.AppendExecutionProvider(MklDnnExecutionProviderFactory.Default);
|
||||
options.AppendExecutionProvider(CpuExecutionProviderFactory.Default);
|
||||
NativeMethods.OrtSessionOptionsAppendExecutionProvider_CUDA(options._nativePtr, deviceId);
|
||||
NativeMethods.OrtSessionOptionsAppendExecutionProvider_Mkldnn(options._nativePtr, 1);
|
||||
NativeMethods.OrtSessionOptionsAppendExecutionProvider_CPU(options._nativePtr, 1);
|
||||
return options;
|
||||
}
|
||||
|
||||
internal IntPtr NativeHandle
|
||||
{
|
||||
get
|
||||
{
|
||||
return _nativeOption.DangerousGetHandle(); //Note: this is unsafe, and not ref counted, use with caution
|
||||
}
|
||||
}
|
||||
|
||||
private void AppendExecutionProvider(NativeOnnxObjectHandle providerFactory)
|
||||
{
|
||||
unsafe
|
||||
{
|
||||
bool success = false;
|
||||
providerFactory.DangerousAddRef(ref success);
|
||||
if (success)
|
||||
{
|
||||
NativeMethods.OrtSessionOptionsAppendExecutionProvider(_nativeOption.DangerousGetHandle(), providerFactory.DangerousGetHandle());
|
||||
providerFactory.DangerousRelease();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// Declared, but called only if OS = Windows.
|
||||
[DllImport("kernel32.dll")]
|
||||
private static extern IntPtr LoadLibrary(string dllToLoad);
|
||||
|
|
@ -172,7 +124,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
{
|
||||
// cleanup managed resources
|
||||
}
|
||||
_nativeOption.Dispose();
|
||||
NativeMethods.OrtReleaseSessionOptions(_nativePtr);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
|
|
|||
|
|
@ -522,9 +522,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
"OrtSessionGetOutputTypeInfo","OrtReleaseSession","OrtCreateSessionOptions","OrtCloneSessionOptions",
|
||||
"OrtEnableSequentialExecution","OrtDisableSequentialExecution","OrtEnableProfiling","OrtDisableProfiling",
|
||||
"OrtEnableMemPattern","OrtDisableMemPattern","OrtEnableCpuMemArena","OrtDisableCpuMemArena",
|
||||
"OrtSetSessionLogId","OrtSetSessionLogVerbosityLevel","OrtSetSessionThreadPoolSize","OrtSessionOptionsAppendExecutionProvider",
|
||||
"OrtCreateCpuExecutionProviderFactory","OrtCreateAllocatorInfo","OrtCreateCpuAllocatorInfo",
|
||||
"OrtCreateDefaultAllocator","OrtReleaseObject","OrtAllocatorFree","OrtAllocatorGetInfo",
|
||||
"OrtSetSessionLogId","OrtSetSessionLogVerbosityLevel","OrtSetSessionThreadPoolSize","OrtSessionOptionsAppendExecutionProvider_CPU",
|
||||
"OrtCreateAllocatorInfo","OrtCreateCpuAllocatorInfo",
|
||||
"OrtCreateDefaultAllocator","OrtAllocatorFree","OrtAllocatorGetInfo",
|
||||
"OrtCreateTensorWithDataAsOrtValue","OrtGetTensorMutableData", "OrtReleaseAllocatorInfo",
|
||||
"OrtCastTypeInfoToTensorInfo","OrtGetTensorShapeAndType","OrtGetTensorElementType","OrtGetNumOfDimensions",
|
||||
"OrtGetDimensions","OrtGetTensorShapeElementCount","OrtReleaseValue"};
|
||||
|
|
|
|||
|
|
@ -1,47 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/session/onnxruntime_c_api.h"
|
||||
#include <atomic>
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
* Even it's designed to be inherited, this class doesn't have a virtual destructor.
|
||||
* No vtable is allowed in this class and its subclasses.
|
||||
* \tparam T subclass type name
|
||||
*/
|
||||
template <typename T>
|
||||
class ObjectBase {
|
||||
private:
|
||||
static OrtObject static_cls;
|
||||
|
||||
protected:
|
||||
const OrtObject* const ORT_ATTRIBUTE_UNUSED cls_;
|
||||
std::atomic_int ref_count;
|
||||
ObjectBase() : cls_(&static_cls), ref_count(1) {
|
||||
}
|
||||
|
||||
static uint32_t ORT_API_CALL OrtReleaseImpl(void* this_) {
|
||||
T* this_ptr = reinterpret_cast<T*>(this_);
|
||||
if (--this_ptr->ref_count == 0)
|
||||
delete this_ptr;
|
||||
return 0;
|
||||
}
|
||||
|
||||
static uint32_t ORT_API_CALL OrtAddRefImpl(void* this_) {
|
||||
T* this_ptr = reinterpret_cast<T*>(this_);
|
||||
++this_ptr->ref_count;
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
OrtObject ObjectBase<T>::static_cls = {ObjectBase<T>::OrtAddRefImpl, ObjectBase<T>::OrtReleaseImpl};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
||||
#define ORT_CHECK_C_OBJECT_LAYOUT \
|
||||
{ assert((char*)&ref_count == (char*)this + sizeof(this)); }
|
||||
|
|
@ -7,12 +7,11 @@
|
|||
#include <string>
|
||||
#include <atomic>
|
||||
#include "core/session/onnxruntime_c_api.h"
|
||||
#include "core/framework/onnx_object_cxx.h"
|
||||
|
||||
/**
|
||||
* Configuration information for a single Run.
|
||||
*/
|
||||
struct OrtRunOptions : public onnxruntime::ObjectBase<OrtRunOptions> {
|
||||
struct OrtRunOptions {
|
||||
unsigned run_log_verbosity_level = 0; ///< applies to a particular Run() invocation
|
||||
std::string run_tag; ///< to identify logs generated by a particular Run() invocation
|
||||
|
||||
|
|
|
|||
|
|
@ -9,9 +9,8 @@ extern "C" {
|
|||
|
||||
/**
|
||||
* \param use_arena zero: false. non-zero: true.
|
||||
* \param out Call OrtReleaseObject() method when you no longer need to use it.
|
||||
*/
|
||||
ORT_API_STATUS(OrtCreateCpuExecutionProviderFactory, int use_arena, _Out_ OrtProviderFactoryInterface*** out)
|
||||
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena)
|
||||
ORT_ALL_ARGS_NONNULL;
|
||||
|
||||
ORT_API_STATUS(OrtCreateCpuAllocatorInfo, enum OrtAllocatorType type, enum OrtMemType mem_type1, _Out_ OrtAllocatorInfo** out)
|
||||
|
|
|
|||
|
|
@ -6,11 +6,11 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \param device_id cuda device id, starts from zero.
|
||||
* \param out Call OrtReleaseObject() method when you no longer need to use it.
|
||||
*/
|
||||
ORT_API_STATUS(OrtCreateCUDAExecutionProviderFactory, int device_id, _Out_ OrtProviderFactoryInterface*** out);
|
||||
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, int device_id);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,9 +9,8 @@ extern "C" {
|
|||
|
||||
/**
|
||||
* \param use_arena zero: false. non-zero: true.
|
||||
* \param out Call OrtReleaseObject() method when you no longer need to use it.
|
||||
*/
|
||||
ORT_API_STATUS(OrtCreateMkldnnExecutionProviderFactory, int use_arena, _Out_ OrtProviderFactoryInterface*** out);
|
||||
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Mkldnn, _In_ OrtSessionOptions* options, int use_arena);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
|||
11
include/onnxruntime/core/providers/providers.h
Normal file
11
include/onnxruntime/core/providers/providers.h
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
namespace onnxruntime {
|
||||
class IExecutionProvider;
|
||||
|
||||
struct IExecutionProviderFactory {
|
||||
virtual ~IExecutionProviderFactory() {}
|
||||
virtual std::unique_ptr<IExecutionProvider> CreateProvider() = 0;
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -71,7 +71,7 @@ typedef enum ONNXTensorElementDataType {
|
|||
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, // maps to c type int32_t
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, // maps to c type int64_t
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, // maps to c++ type std::string
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, //
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL,
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, // maps to c type double
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, // maps to c type uint32_t
|
||||
|
|
@ -139,26 +139,10 @@ ORT_RUNTIME_CLASS(AllocatorInfo);
|
|||
ORT_RUNTIME_CLASS(Session);
|
||||
ORT_RUNTIME_CLASS(Value);
|
||||
ORT_RUNTIME_CLASS(ValueList);
|
||||
|
||||
struct OrtTypeInfo;
|
||||
typedef struct OrtTypeInfo OrtTypeInfo;
|
||||
struct OrtTensorTypeAndShapeInfo;
|
||||
typedef struct OrtTensorTypeAndShapeInfo OrtTensorTypeAndShapeInfo;
|
||||
struct OrtRunOptions;
|
||||
typedef struct OrtRunOptions OrtRunOptions;
|
||||
struct OrtSessionOptions;
|
||||
typedef struct OrtSessionOptions OrtSessionOptions;
|
||||
|
||||
/**
|
||||
* Every type inherented from OrtObject should be deleted by OrtReleaseObject(...).
|
||||
*/
|
||||
typedef struct OrtObject {
|
||||
// Returns the new reference count.
|
||||
uint32_t(ORT_API_CALL* AddRef)(void* this_);
|
||||
// Returns the new reference count.
|
||||
uint32_t(ORT_API_CALL* Release)(void* this_);
|
||||
|
||||
} OrtObject;
|
||||
ORT_RUNTIME_CLASS(RunOptions);
|
||||
ORT_RUNTIME_CLASS(TypeInfo);
|
||||
ORT_RUNTIME_CLASS(TensorTypeAndShapeInfo);
|
||||
ORT_RUNTIME_CLASS(SessionOptions);
|
||||
|
||||
// When passing in an allocator to any ORT function, be sure that the allocator object
|
||||
// is not destroyed until the last allocated object using it is freed.
|
||||
|
|
@ -168,12 +152,6 @@ typedef struct OrtAllocator {
|
|||
const struct OrtAllocatorInfo*(ORT_API_CALL* Info)(const struct OrtAllocator* this_);
|
||||
} OrtAllocator;
|
||||
|
||||
// Inherented from OrtObject
|
||||
typedef struct OrtProviderFactoryInterface {
|
||||
OrtObject parent;
|
||||
OrtStatus*(ORT_API_CALL* CreateProvider)(void* this_, OrtProvider** out);
|
||||
} OrtProviderFactoryInterface;
|
||||
|
||||
typedef void(ORT_API_CALL* OrtLoggingFunction)(
|
||||
void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location,
|
||||
const char* message);
|
||||
|
|
@ -187,7 +165,7 @@ ORT_ALL_ARGS_NONNULL;
|
|||
|
||||
/**
|
||||
* OrtEnv is process-wise. For each process, only one OrtEnv can be created. Don't do it multiple times
|
||||
* \param out Should be freed by `OrtReleaseObject` after use
|
||||
* \param out Should be freed by `OrtReleaseEnv` after use
|
||||
*/
|
||||
ORT_API_STATUS(OrtInitializeWithCustomLogger, OrtLoggingFunction logging_function,
|
||||
_In_opt_ void* logger_param, OrtLoggingLevel default_warning_level,
|
||||
|
|
@ -209,7 +187,7 @@ ORT_API_STATUS(OrtRun, _Inout_ OrtSession* sess,
|
|||
_In_ const char* const* output_names, size_t output_names_len, _Out_ OrtValue** output);
|
||||
|
||||
/**
|
||||
* \return A pointer of the newly created object. The pointer should be freed by OrtReleaseObject after use
|
||||
* \return A pointer of the newly created object. The pointer should be freed by OrtReleaseSessionOptions after use
|
||||
*/
|
||||
ORT_API(OrtSessionOptions*, OrtCreateSessionOptions);
|
||||
|
||||
|
|
@ -245,11 +223,15 @@ ORT_API(void, OrtSetSessionLogVerbosityLevel, _In_ OrtSessionOptions* options, u
|
|||
ORT_API(int, OrtSetSessionThreadPoolSize, _In_ OrtSessionOptions* options, int session_thread_pool_size);
|
||||
|
||||
/**
|
||||
* The order of invocation indicates the preference order as well. In other words call this method
|
||||
* To use additional providers, you must build ORT with the extra providers enabled. Then call one of these
|
||||
* functions to enable them in the session:
|
||||
* OrtSessionOptionsAppendExecutionProvider_CPU
|
||||
* OrtSessionOptionsAppendExecutionProvider_CUDA
|
||||
* OrtSessionOptionsAppendExecutionProvider_<remaining providers...>
|
||||
* The order they care called indicates the preference order as well. In other words call this method
|
||||
* on your most preferred execution provider first followed by the less preferred ones.
|
||||
* Calling this API is optional in which case Ort will use its internal CPU execution provider.
|
||||
* If none are called Ort will use its internal CPU execution provider.
|
||||
*/
|
||||
ORT_API(void, OrtSessionOptionsAppendExecutionProvider, _In_ OrtSessionOptions* options, _In_ OrtProviderFactoryInterface** f);
|
||||
|
||||
ORT_API(void, OrtAppendCustomOpLibPath, _In_ OrtSessionOptions* options, const char* lib_path);
|
||||
|
||||
|
|
@ -257,12 +239,12 @@ ORT_API_STATUS(OrtSessionGetInputCount, _In_ const OrtSession* sess, _Out_ size_
|
|||
ORT_API_STATUS(OrtSessionGetOutputCount, _In_ const OrtSession* sess, _Out_ size_t* out);
|
||||
|
||||
/**
|
||||
* \param out should be freed by OrtReleaseObject after use
|
||||
* \param out should be freed by OrtReleaseTypeInfo after use
|
||||
*/
|
||||
ORT_API_STATUS(OrtSessionGetInputTypeInfo, _In_ const OrtSession* sess, size_t index, _Out_ OrtTypeInfo** out);
|
||||
|
||||
/**
|
||||
* \param out should be freed by OrtReleaseObject after use
|
||||
* \param out should be freed by OrtReleaseTypeInfo after use
|
||||
*/
|
||||
ORT_API_STATUS(OrtSessionGetOutputTypeInfo, _In_ const OrtSession* sess, size_t index, _Out_ OrtTypeInfo** out);
|
||||
|
||||
|
|
@ -275,7 +257,7 @@ ORT_API_STATUS(OrtSessionGetOutputName, _In_ const OrtSession* sess, size_t inde
|
|||
_Inout_ OrtAllocator* allocator, _Out_ char** value);
|
||||
|
||||
/**
|
||||
* \return A pointer to the newly created object. The pointer should be freed by OrtReleaseObject after use
|
||||
* \return A pointer to the newly created object. The pointer should be freed by OrtReleaseRunOptions after use
|
||||
*/
|
||||
ORT_API(OrtRunOptions*, OrtCreateRunOptions);
|
||||
|
||||
|
|
@ -345,7 +327,7 @@ ORT_API_STATUS(OrtTensorProtoToOrtValue, _Inout_ OrtAllocator* allocator,
|
|||
ORT_API(const OrtTensorTypeAndShapeInfo*, OrtCastTypeInfoToTensorInfo, _In_ OrtTypeInfo*);
|
||||
|
||||
/**
|
||||
* The retured value should be released by calling OrtReleaseObject
|
||||
* The retured value should be released by calling OrtReleaseTensorTypeAndShapeInfo
|
||||
*/
|
||||
ORT_API(OrtTensorTypeAndShapeInfo*, OrtCreateTensorTypeAndShapeInfo);
|
||||
|
||||
|
|
@ -374,36 +356,19 @@ ORT_API(void, OrtGetDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, _Out
|
|||
ORT_API(int64_t, OrtGetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* info);
|
||||
|
||||
/**
|
||||
* \param out Should be freed by OrtReleaseObject after use
|
||||
* \param out Should be freed by OrtReleaseTensorTypeAndShapeInfo after use
|
||||
*/
|
||||
ORT_API_STATUS(OrtGetTensorShapeAndType, _In_ const OrtValue* value, _Out_ OrtTensorTypeAndShapeInfo** out);
|
||||
|
||||
/**
|
||||
* Get the type information of an OrtValue
|
||||
* \param value
|
||||
* \param out The returned value should be freed by OrtReleaseObject after use
|
||||
* \param out The returned value should be freed by OrtReleaseTypeInfo after use
|
||||
*/
|
||||
ORT_API_STATUS(OrtGetTypeInfo, _In_ const OrtValue* value, OrtTypeInfo** out);
|
||||
|
||||
ORT_API(enum ONNXType, OrtGetValueType, _In_ const OrtValue* value);
|
||||
|
||||
/**
|
||||
* This function is a wrapper to "(*(OrtObject**)ptr)->AddRef(ptr)"
|
||||
* WARNING: There is NO type checking in this function.
|
||||
* Before calling this function, caller should make sure current ref count > 0
|
||||
* \return the new reference count
|
||||
*/
|
||||
ORT_API(uint32_t, OrtAddRefToObject, _In_ void* ptr);
|
||||
|
||||
/**
|
||||
*
|
||||
* A wrapper to "(*(OrtObject**)ptr)->Release(ptr)"
|
||||
* WARNING: There is NO type checking in this function.
|
||||
* \param ptr Can be NULL. If it's NULL, this function will return zero.
|
||||
* \return the new reference count.
|
||||
*/
|
||||
ORT_API(uint32_t, OrtReleaseObject, _Inout_opt_ void* ptr);
|
||||
|
||||
typedef enum OrtAllocatorType {
|
||||
OrtDeviceAllocator = 0,
|
||||
OrtArenaAllocator = 1
|
||||
|
|
|
|||
|
|
@ -38,36 +38,48 @@ struct default_delete<OrtEnv> {
|
|||
OrtReleaseEnv(ptr);
|
||||
}
|
||||
};
|
||||
} // namespace std
|
||||
|
||||
#define DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(TYPE_NAME) \
|
||||
namespace std { \
|
||||
template <> \
|
||||
struct default_delete<Ort##TYPE_NAME> { \
|
||||
void operator()(Ort##TYPE_NAME* ptr) { \
|
||||
(*reinterpret_cast<OrtObject**>(ptr))->Release(ptr); \
|
||||
} \
|
||||
}; \
|
||||
template <>
|
||||
struct default_delete<OrtRunOptions> {
|
||||
void operator()(OrtRunOptions* ptr) {
|
||||
OrtReleaseRunOptions(ptr);
|
||||
}
|
||||
};
|
||||
|
||||
DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(TypeInfo);
|
||||
DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(TensorTypeAndShapeInfo);
|
||||
DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(RunOptions);
|
||||
DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(SessionOptions);
|
||||
DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(ProviderFactoryInterface*);
|
||||
template <>
|
||||
struct default_delete<OrtTypeInfo> {
|
||||
void operator()(OrtTypeInfo* ptr) {
|
||||
OrtReleaseTypeInfo(ptr);
|
||||
}
|
||||
};
|
||||
|
||||
#undef DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT
|
||||
template <>
|
||||
struct default_delete<OrtTensorTypeAndShapeInfo> {
|
||||
void operator()(OrtTensorTypeAndShapeInfo* ptr) {
|
||||
OrtReleaseTensorTypeAndShapeInfo(ptr);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct default_delete<OrtSessionOptions> {
|
||||
void operator()(OrtSessionOptions* ptr) {
|
||||
OrtReleaseSessionOptions(ptr);
|
||||
}
|
||||
};
|
||||
} // namespace std
|
||||
|
||||
namespace onnxruntime {
|
||||
class SessionOptionsWrapper {
|
||||
private:
|
||||
std::unique_ptr<OrtSessionOptions, decltype(&OrtReleaseObject)> value;
|
||||
std::unique_ptr<OrtSessionOptions> value;
|
||||
OrtEnv* env_;
|
||||
SessionOptionsWrapper(_In_ OrtEnv* env, OrtSessionOptions* p) : value(p, OrtReleaseObject), env_(env){};
|
||||
SessionOptionsWrapper(_In_ OrtEnv* env, OrtSessionOptions* p) : value(p), env_(env){};
|
||||
|
||||
public:
|
||||
operator OrtSessionOptions*() { return value.get(); }
|
||||
|
||||
//TODO: for the input arg, should we call addref here?
|
||||
SessionOptionsWrapper(_In_ OrtEnv* env) : value(OrtCreateSessionOptions(), OrtReleaseObject), env_(env){};
|
||||
SessionOptionsWrapper(_In_ OrtEnv* env) : value(OrtCreateSessionOptions()), env_(env){};
|
||||
ORT_REDIRECT_SIMPLE_FUNCTION_CALL(EnableSequentialExecution)
|
||||
ORT_REDIRECT_SIMPLE_FUNCTION_CALL(DisableSequentialExecution)
|
||||
ORT_REDIRECT_SIMPLE_FUNCTION_CALL(DisableProfiling)
|
||||
|
|
@ -89,15 +101,6 @@ class SessionOptionsWrapper {
|
|||
OrtSetSessionThreadPoolSize(value.get(), session_thread_pool_size);
|
||||
}
|
||||
|
||||
/**
|
||||
* The order of invocation indicates the preference order as well. In other words call this method
|
||||
* on your most preferred execution provider first followed by the less preferred ones.
|
||||
* Calling this API is optional in which case onnxruntime will use its internal CPU execution provider.
|
||||
*/
|
||||
void AppendExecutionProvider(_In_ OrtProviderFactoryInterface** f) {
|
||||
OrtSessionOptionsAppendExecutionProvider(value.get(), f);
|
||||
}
|
||||
|
||||
SessionOptionsWrapper clone() const {
|
||||
OrtSessionOptions* p = OrtCloneSessionOptions(value.get());
|
||||
return SessionOptionsWrapper(env_, p);
|
||||
|
|
|
|||
|
|
@ -1,21 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/session/onnxruntime_c_api.h"
|
||||
#include <atomic>
|
||||
|
||||
ORT_API(uint32_t, OrtAddRefToObject, void* ptr) {
|
||||
return (*static_cast<OrtObject**>(ptr))->AddRef(ptr);
|
||||
}
|
||||
|
||||
ORT_API(uint32_t, OrtReleaseObject, void* ptr) {
|
||||
if (ptr == nullptr) return 0;
|
||||
return (*static_cast<OrtObject**>(ptr))->Release(ptr);
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct ObjectImpl {
|
||||
const OrtObject* const cls;
|
||||
std::atomic_int ref_count;
|
||||
};
|
||||
} // namespace
|
||||
|
|
@ -3,8 +3,8 @@
|
|||
|
||||
//this file contains implementations of the C API
|
||||
|
||||
#include "onnxruntime_typeinfo.h"
|
||||
#include <cassert>
|
||||
#include "onnxruntime_typeinfo.h"
|
||||
#include "core/framework/tensor.h"
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
|
||||
|
|
@ -14,16 +14,19 @@ using onnxruntime::MLFloat16;
|
|||
using onnxruntime::Tensor;
|
||||
using onnxruntime::TensorShape;
|
||||
|
||||
OrtTypeInfo::OrtTypeInfo(ONNXType type1, void* data1) noexcept : type(type1), data(data1) {
|
||||
OrtTypeInfo::OrtTypeInfo(ONNXType type1, OrtTensorTypeAndShapeInfo* data1) noexcept : type(type1), data(data1) {
|
||||
}
|
||||
|
||||
OrtTypeInfo::~OrtTypeInfo() {
|
||||
assert(ref_count == 0);
|
||||
OrtReleaseObject(data);
|
||||
OrtReleaseTensorTypeAndShapeInfo(data);
|
||||
}
|
||||
|
||||
ORT_API(const struct OrtTensorTypeAndShapeInfo*, OrtCastTypeInfoToTensorInfo, _In_ struct OrtTypeInfo* input) {
|
||||
return input->type == ONNX_TYPE_TENSOR ? reinterpret_cast<const struct OrtTensorTypeAndShapeInfo*>(input->data) : nullptr;
|
||||
return input->type == ONNX_TYPE_TENSOR ? input->data : nullptr;
|
||||
}
|
||||
|
||||
ORT_API(void, OrtReleaseTypeInfo, OrtTypeInfo* ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
OrtStatus* GetTensorShapeAndType(const TensorShape* shape, const onnxruntime::DataTypeImpl* tensor_data_type, OrtTensorTypeAndShapeInfo** out);
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/framework/onnx_object_cxx.h"
|
||||
#include <atomic>
|
||||
#include "core/session/onnxruntime_c_api.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
class DataTypeImpl;
|
||||
|
|
@ -18,21 +18,21 @@ class TypeProto;
|
|||
* the equivalent of onnx::TypeProto
|
||||
* This class is mainly for the C API
|
||||
*/
|
||||
struct OrtTypeInfo : public onnxruntime::ObjectBase<OrtTypeInfo> {
|
||||
struct OrtTypeInfo {
|
||||
public:
|
||||
friend class onnxruntime::ObjectBase<OrtTypeInfo>;
|
||||
|
||||
ONNXType type = ONNX_TYPE_UNKNOWN;
|
||||
|
||||
~OrtTypeInfo();
|
||||
|
||||
//owned by this
|
||||
void* data = nullptr;
|
||||
OrtTensorTypeAndShapeInfo* data = nullptr;
|
||||
OrtTypeInfo(const OrtTypeInfo& other) = delete;
|
||||
OrtTypeInfo& operator=(const OrtTypeInfo& other) = delete;
|
||||
|
||||
static OrtStatus* FromDataTypeImpl(const onnxruntime::DataTypeImpl* input, const onnxruntime::TensorShape* shape,
|
||||
const onnxruntime::DataTypeImpl* tensor_data_type, OrtTypeInfo** out);
|
||||
const onnxruntime::DataTypeImpl* tensor_data_type, OrtTypeInfo** out);
|
||||
static OrtStatus* FromDataTypeImpl(const onnx::TypeProto*, OrtTypeInfo** out);
|
||||
|
||||
private:
|
||||
OrtTypeInfo(ONNXType type, void* data) noexcept;
|
||||
~OrtTypeInfo();
|
||||
OrtTypeInfo(ONNXType type, OrtTensorTypeAndShapeInfo* data) noexcept;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -15,36 +15,29 @@ using onnxruntime::DataTypeImpl;
|
|||
using onnxruntime::MLFloat16;
|
||||
using onnxruntime::Tensor;
|
||||
|
||||
struct OrtTensorTypeAndShapeInfo : public onnxruntime::ObjectBase<OrtTensorTypeAndShapeInfo> {
|
||||
struct OrtTensorTypeAndShapeInfo {
|
||||
public:
|
||||
friend class onnxruntime::ObjectBase<OrtTensorTypeAndShapeInfo>;
|
||||
|
||||
ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
onnxruntime::TensorShape shape;
|
||||
|
||||
static OrtTensorTypeAndShapeInfo* Create() {
|
||||
return new OrtTensorTypeAndShapeInfo();
|
||||
}
|
||||
|
||||
OrtTensorTypeAndShapeInfo() = default;
|
||||
OrtTensorTypeAndShapeInfo(const OrtTensorTypeAndShapeInfo& other) = delete;
|
||||
OrtTensorTypeAndShapeInfo& operator=(const OrtTensorTypeAndShapeInfo& other) = delete;
|
||||
|
||||
private:
|
||||
OrtTensorTypeAndShapeInfo() = default;
|
||||
~OrtTensorTypeAndShapeInfo() {
|
||||
assert(ref_count == 0);
|
||||
}
|
||||
};
|
||||
|
||||
#define API_IMPL_BEGIN try {
|
||||
#define API_IMPL_END \
|
||||
} \
|
||||
catch (std::exception & ex) { \
|
||||
#define API_IMPL_END \
|
||||
} \
|
||||
catch (std::exception & ex) { \
|
||||
return OrtCreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); \
|
||||
}
|
||||
|
||||
ORT_API(OrtTensorTypeAndShapeInfo*, OrtCreateTensorTypeAndShapeInfo) {
|
||||
return OrtTensorTypeAndShapeInfo::Create();
|
||||
return new OrtTensorTypeAndShapeInfo();
|
||||
}
|
||||
|
||||
ORT_API(void, OrtReleaseTensorTypeAndShapeInfo, OrtTensorTypeAndShapeInfo* ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtSetTensorElementType, _In_ OrtTensorTypeAndShapeInfo* this_ptr, enum ONNXTensorElementDataType type) {
|
||||
|
|
@ -126,13 +119,13 @@ OrtStatus* GetTensorShapeAndType(const onnxruntime::TensorShape* shape, const on
|
|||
OrtTensorTypeAndShapeInfo* ret = OrtCreateTensorTypeAndShapeInfo();
|
||||
auto status = OrtSetTensorElementType(ret, type);
|
||||
if (status != nullptr) {
|
||||
OrtReleaseObject(ret);
|
||||
OrtReleaseTensorTypeAndShapeInfo(ret);
|
||||
return status;
|
||||
}
|
||||
if (shape != nullptr) {
|
||||
status = OrtSetDims(ret, shape->GetDims().data(), shape->GetDims().size());
|
||||
if (status != nullptr) {
|
||||
OrtReleaseObject(ret);
|
||||
OrtReleaseTensorTypeAndShapeInfo(ret);
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
|
@ -160,7 +153,7 @@ ORT_API(enum ONNXType, OrtGetValueType, _In_ const OrtValue* value) {
|
|||
return ONNX_TYPE_UNKNOWN;
|
||||
}
|
||||
ONNXType ret = out->type;
|
||||
OrtReleaseObject(out);
|
||||
OrtReleaseTypeInfo(out);
|
||||
return ret;
|
||||
} catch (std::exception&) {
|
||||
return ONNX_TYPE_UNKNOWN;
|
||||
|
|
@ -170,7 +163,7 @@ ORT_API(enum ONNXType, OrtGetValueType, _In_ const OrtValue* value) {
|
|||
/**
|
||||
* Get the type information of an OrtValue
|
||||
* \param value
|
||||
* \return The returned value should be freed by OrtReleaseObject after use
|
||||
* \return The returned value should be freed by OrtReleaseTypeInfo after use
|
||||
*/
|
||||
ORT_API_STATUS_IMPL(OrtGetTypeInfo, _In_ const OrtValue* value, struct OrtTypeInfo** out) {
|
||||
auto v = reinterpret_cast<const ::onnxruntime::MLValue*>(value);
|
||||
|
|
|
|||
|
|
@ -4,52 +4,33 @@
|
|||
#include "core/providers/cpu/cpu_provider_factory.h"
|
||||
#include <atomic>
|
||||
#include "cpu_execution_provider.h"
|
||||
#include "core/session/abi_session_options_impl.h"
|
||||
|
||||
using namespace onnxruntime;
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace {
|
||||
struct CpuProviderFactory {
|
||||
const OrtProviderFactoryInterface* const cls;
|
||||
std::atomic_int ref_count;
|
||||
bool create_arena;
|
||||
CpuProviderFactory();
|
||||
struct CpuProviderFactory : IExecutionProviderFactory {
|
||||
CpuProviderFactory(bool create_arena) : create_arena_(create_arena) {}
|
||||
~CpuProviderFactory() override {}
|
||||
std::unique_ptr<IExecutionProvider> CreateProvider() override;
|
||||
|
||||
private:
|
||||
bool create_arena_;
|
||||
};
|
||||
|
||||
OrtStatus* ORT_API_CALL CreateCpu(void* this_, OrtProvider** out) {
|
||||
std::unique_ptr<IExecutionProvider> CpuProviderFactory::CreateProvider() {
|
||||
CPUExecutionProviderInfo info;
|
||||
CpuProviderFactory* this_ptr = (CpuProviderFactory*)this_;
|
||||
info.create_arena = this_ptr->create_arena;
|
||||
CPUExecutionProvider* ret = new CPUExecutionProvider(info);
|
||||
*out = (OrtProvider*)ret;
|
||||
return nullptr;
|
||||
info.create_arena = create_arena_;
|
||||
return std::make_unique<CPUExecutionProvider>(info);
|
||||
}
|
||||
|
||||
uint32_t ORT_API_CALL ReleaseCpu(void* this_) {
|
||||
CpuProviderFactory* this_ptr = (CpuProviderFactory*)this_;
|
||||
if (--this_ptr->ref_count == 0)
|
||||
delete this_ptr;
|
||||
return 0;
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CPU(int use_arena) {
|
||||
return std::make_shared<onnxruntime::CpuProviderFactory>(use_arena != 0);
|
||||
}
|
||||
|
||||
uint32_t ORT_API_CALL AddRefCpu(void* this_) {
|
||||
CpuProviderFactory* this_ptr = (CpuProviderFactory*)this_;
|
||||
++this_ptr->ref_count;
|
||||
return 0;
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
||||
constexpr OrtProviderFactoryInterface cpu_cls = {
|
||||
{AddRefCpu,
|
||||
ReleaseCpu},
|
||||
CreateCpu,
|
||||
};
|
||||
|
||||
CpuProviderFactory::CpuProviderFactory() : cls(&cpu_cls), ref_count(1), create_arena(true) {}
|
||||
} // namespace
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtCreateCpuExecutionProviderFactory, int use_arena, _Out_ OrtProviderFactoryInterface*** out) {
|
||||
CpuProviderFactory* ret = new CpuProviderFactory();
|
||||
ret->create_arena = (use_arena != 0);
|
||||
*out = (OrtProviderFactoryInterface**)ret;
|
||||
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena) {
|
||||
options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_CPU(use_arena));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
OrtAddRefToObject
|
||||
OrtAllocatorAlloc
|
||||
OrtAllocatorFree
|
||||
OrtAllocatorGetInfo
|
||||
|
|
@ -12,7 +11,6 @@ OrtCloneSessionOptions
|
|||
OrtCompareAllocatorInfo
|
||||
OrtCreateAllocatorInfo
|
||||
OrtCreateCpuAllocatorInfo
|
||||
OrtCreateCpuExecutionProviderFactory
|
||||
OrtCreateDefaultAllocator
|
||||
OrtCreateRunOptions
|
||||
OrtCreateSession
|
||||
|
|
@ -47,9 +45,12 @@ OrtIsTensor
|
|||
OrtReleaseAllocator
|
||||
OrtReleaseAllocatorInfo
|
||||
OrtReleaseEnv
|
||||
OrtReleaseObject
|
||||
OrtReleaseRunOptions
|
||||
OrtReleaseSession
|
||||
OrtReleaseSessionOptions
|
||||
OrtReleaseStatus
|
||||
OrtReleaseTensorTypeAndShapeInfo
|
||||
OrtReleaseTypeInfo
|
||||
OrtReleaseValue
|
||||
OrtRun
|
||||
OrtRunOptionsGetRunLogVerbosityLevel
|
||||
|
|
@ -63,7 +64,7 @@ OrtSessionGetInputTypeInfo
|
|||
OrtSessionGetOutputCount
|
||||
OrtSessionGetOutputName
|
||||
OrtSessionGetOutputTypeInfo
|
||||
OrtSessionOptionsAppendExecutionProvider
|
||||
OrtSessionOptionsAppendExecutionProvider_CPU
|
||||
OrtSetDims
|
||||
OrtSetSessionLogId
|
||||
OrtSetSessionLogVerbosityLevel
|
||||
|
|
|
|||
|
|
@ -4,51 +4,35 @@
|
|||
#include "core/providers/cuda/cuda_provider_factory.h"
|
||||
#include <atomic>
|
||||
#include "cuda_execution_provider.h"
|
||||
#include "core/session/abi_session_options_impl.h"
|
||||
|
||||
using namespace onnxruntime;
|
||||
|
||||
namespace {
|
||||
struct CUDAProviderFactory {
|
||||
const OrtProviderFactoryInterface* const cls;
|
||||
std::atomic_int ref_count;
|
||||
int device_id;
|
||||
CUDAProviderFactory();
|
||||
namespace onnxruntime {
|
||||
|
||||
struct CUDAProviderFactory : IExecutionProviderFactory {
|
||||
CUDAProviderFactory(int device_id) : device_id_(device_id) {}
|
||||
~CUDAProviderFactory() override {}
|
||||
|
||||
std::unique_ptr<IExecutionProvider> CreateProvider() override;
|
||||
|
||||
private:
|
||||
int device_id_;
|
||||
};
|
||||
|
||||
OrtStatus* ORT_API_CALL CreateCuda(void* this_, OrtProvider** out) {
|
||||
std::unique_ptr<IExecutionProvider> CUDAProviderFactory::CreateProvider() {
|
||||
CUDAExecutionProviderInfo info;
|
||||
CUDAProviderFactory* this_ptr = (CUDAProviderFactory*)this_;
|
||||
info.device_id = this_ptr->device_id;
|
||||
CUDAExecutionProvider* ret = new CUDAExecutionProvider(info);
|
||||
*out = (OrtProvider*)ret;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
uint32_t ORT_API_CALL ReleaseCuda(void* this_) {
|
||||
CUDAProviderFactory* this_ptr = (CUDAProviderFactory*)this_;
|
||||
if (--this_ptr->ref_count == 0)
|
||||
delete this_ptr;
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t ORT_API_CALL AddRefCuda(void* this_) {
|
||||
CUDAProviderFactory* this_ptr = (CUDAProviderFactory*)this_;
|
||||
++this_ptr->ref_count;
|
||||
return 0;
|
||||
}
|
||||
|
||||
constexpr OrtProviderFactoryInterface cuda_cls = {
|
||||
AddRefCuda,
|
||||
ReleaseCuda,
|
||||
CreateCuda,
|
||||
};
|
||||
|
||||
CUDAProviderFactory::CUDAProviderFactory() : cls(&cuda_cls), ref_count(1), device_id(0) {}
|
||||
} // namespace
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtCreateCUDAExecutionProviderFactory, int device_id, _Out_ OrtProviderFactoryInterface*** out) {
|
||||
CUDAProviderFactory* ret = new CUDAProviderFactory();
|
||||
ret->device_id = device_id;
|
||||
*out = (OrtProviderFactoryInterface**)ret;
|
||||
info.device_id = device_id_;
|
||||
return std::make_unique<CUDAExecutionProvider>(info);
|
||||
}
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CUDA(int device_id) {
|
||||
return std::make_shared<onnxruntime::CUDAProviderFactory>(device_id);
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, int device_id) {
|
||||
options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_CUDA(device_id));
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
OrtCreateCUDAExecutionProviderFactory
|
||||
OrtSessionOptionsAppendExecutionProvider_CUDA
|
||||
|
|
@ -4,51 +4,34 @@
|
|||
#include "core/providers/mkldnn/mkldnn_provider_factory.h"
|
||||
#include <atomic>
|
||||
#include "mkldnn_execution_provider.h"
|
||||
#include "core/session/abi_session_options_impl.h"
|
||||
|
||||
using namespace onnxruntime;
|
||||
|
||||
namespace {
|
||||
struct MkldnnProviderFactory {
|
||||
const OrtProviderFactoryInterface* const cls;
|
||||
std::atomic_int ref_count;
|
||||
bool create_arena;
|
||||
MkldnnProviderFactory();
|
||||
namespace onnxruntime {
|
||||
struct MkldnnProviderFactory : IExecutionProviderFactory {
|
||||
MkldnnProviderFactory(bool create_arena) : create_arena_(create_arena) {}
|
||||
~MkldnnProviderFactory() override {}
|
||||
|
||||
std::unique_ptr<IExecutionProvider> CreateProvider() override;
|
||||
|
||||
private:
|
||||
bool create_arena_;
|
||||
};
|
||||
|
||||
OrtStatus* ORT_API_CALL CreateMkldnn(void* this_, OrtProvider** out) {
|
||||
std::unique_ptr<IExecutionProvider> MkldnnProviderFactory::CreateProvider() {
|
||||
MKLDNNExecutionProviderInfo info;
|
||||
MkldnnProviderFactory* this_ptr = (MkldnnProviderFactory*)this_;
|
||||
info.create_arena = this_ptr->create_arena;
|
||||
MKLDNNExecutionProvider* ret = new MKLDNNExecutionProvider(info);
|
||||
*out = (OrtProvider*)ret;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
uint32_t ORT_API_CALL ReleaseMkldnn(void* this_) {
|
||||
MkldnnProviderFactory* this_ptr = (MkldnnProviderFactory*)this_;
|
||||
if (--this_ptr->ref_count == 0)
|
||||
delete this_ptr;
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t ORT_API_CALL AddRefMkldnn(void* this_) {
|
||||
MkldnnProviderFactory* this_ptr = (MkldnnProviderFactory*)this_;
|
||||
++this_ptr->ref_count;
|
||||
return 0;
|
||||
}
|
||||
|
||||
constexpr OrtProviderFactoryInterface mkl_cls = {
|
||||
{AddRefMkldnn,
|
||||
ReleaseMkldnn},
|
||||
CreateMkldnn,
|
||||
};
|
||||
|
||||
MkldnnProviderFactory::MkldnnProviderFactory() : cls(&mkl_cls), ref_count(1), create_arena(true) {}
|
||||
} // namespace
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtCreateMkldnnExecutionProviderFactory, int use_arena, _Out_ OrtProviderFactoryInterface*** out) {
|
||||
MkldnnProviderFactory* ret = new MkldnnProviderFactory();
|
||||
ret->create_arena = (use_arena != 0);
|
||||
*out = (OrtProviderFactoryInterface**)ret;
|
||||
info.create_arena = create_arena_;
|
||||
return std::make_unique<MKLDNNExecutionProvider>(info);
|
||||
}
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Mkldnn(int device_id) {
|
||||
return std::make_shared<onnxruntime::MkldnnProviderFactory>(device_id);
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Mkldnn, _In_ OrtSessionOptions* options, int use_arena) {
|
||||
options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_Mkldnn(use_arena));
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
OrtCreateMkldnnExecutionProviderFactory
|
||||
OrtSessionOptionsAppendExecutionProvider_Mkldnn
|
||||
|
|
@ -8,10 +8,6 @@
|
|||
#include "abi_session_options_impl.h"
|
||||
|
||||
OrtSessionOptions::~OrtSessionOptions() {
|
||||
assert(ref_count == 0);
|
||||
for (OrtProviderFactoryInterface** p : provider_factories) {
|
||||
OrtReleaseObject(p);
|
||||
}
|
||||
}
|
||||
|
||||
OrtSessionOptions& OrtSessionOptions::operator=(const OrtSessionOptions&) {
|
||||
|
|
@ -19,15 +15,17 @@ OrtSessionOptions& OrtSessionOptions::operator=(const OrtSessionOptions&) {
|
|||
}
|
||||
OrtSessionOptions::OrtSessionOptions(const OrtSessionOptions& other)
|
||||
: value(other.value), custom_op_paths(other.custom_op_paths), provider_factories(other.provider_factories) {
|
||||
for (OrtProviderFactoryInterface** p : other.provider_factories) {
|
||||
OrtAddRefToObject(p);
|
||||
}
|
||||
}
|
||||
|
||||
ORT_API(OrtSessionOptions*, OrtCreateSessionOptions) {
|
||||
std::unique_ptr<OrtSessionOptions> options = std::make_unique<OrtSessionOptions>();
|
||||
return options.release();
|
||||
}
|
||||
|
||||
ORT_API(void, OrtReleaseSessionOptions, OrtSessionOptions* ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
ORT_API(OrtSessionOptions*, OrtCloneSessionOptions, OrtSessionOptions* input) {
|
||||
try {
|
||||
return new OrtSessionOptions(*input);
|
||||
|
|
@ -36,11 +34,6 @@ ORT_API(OrtSessionOptions*, OrtCloneSessionOptions, OrtSessionOptions* input) {
|
|||
}
|
||||
}
|
||||
|
||||
ORT_API(void, OrtSessionOptionsAppendExecutionProvider, _In_ OrtSessionOptions* options, _In_ OrtProviderFactoryInterface** f) {
|
||||
OrtAddRefToObject(f);
|
||||
options->provider_factories.push_back(f);
|
||||
}
|
||||
|
||||
ORT_API(void, OrtEnableSequentialExecution, _In_ OrtSessionOptions* options) {
|
||||
options->value.enable_sequential_execution = true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,14 +6,14 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include <atomic>
|
||||
#include "core/framework/onnx_object_cxx.h"
|
||||
#include "core/session/inference_session.h"
|
||||
#include "core/session/onnxruntime_c_api.h"
|
||||
#include "core/providers/providers.h"
|
||||
|
||||
struct OrtSessionOptions : public onnxruntime::ObjectBase<OrtSessionOptions> {
|
||||
struct OrtSessionOptions {
|
||||
onnxruntime::SessionOptions value;
|
||||
std::vector<std::string> custom_op_paths;
|
||||
std::vector<OrtProviderFactoryInterface**> provider_factories;
|
||||
std::vector<std::shared_ptr<onnxruntime::IExecutionProviderFactory>> provider_factories;
|
||||
OrtSessionOptions() = default;
|
||||
~OrtSessionOptions();
|
||||
OrtSessionOptions(const OrtSessionOptions& other);
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@
|
|||
#include "core/framework/environment.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/framework/onnxruntime_typeinfo.h"
|
||||
#include "core/framework/onnx_object_cxx.h"
|
||||
#include "core/session/inference_session.h"
|
||||
|
||||
#include "abi_session_options_impl.h"
|
||||
|
|
@ -49,7 +48,6 @@ struct OrtEnv {
|
|||
public:
|
||||
Environment* value;
|
||||
LoggingManager* loggingManager;
|
||||
friend class onnxruntime::ObjectBase<OrtEnv>;
|
||||
|
||||
OrtEnv(Environment* value1, LoggingManager* loggingManager1) : value(value1), loggingManager(loggingManager1) {
|
||||
}
|
||||
|
|
@ -367,13 +365,10 @@ static OrtStatus* CreateSessionImpl(_In_ OrtEnv* env, _In_ T model_path,
|
|||
return ToOrtStatus(status);
|
||||
}
|
||||
if (options != nullptr)
|
||||
for (OrtProviderFactoryInterface** p : options->provider_factories) {
|
||||
OrtProvider* provider;
|
||||
OrtStatus* error_code = (*p)->CreateProvider(p, &provider);
|
||||
if (error_code)
|
||||
return error_code;
|
||||
sess->RegisterExecutionProvider(std::unique_ptr<onnxruntime::IExecutionProvider>(
|
||||
reinterpret_cast<onnxruntime::IExecutionProvider*>(provider)));
|
||||
for (auto& factory : options->provider_factories) {
|
||||
auto provider = factory->CreateProvider();
|
||||
if (provider)
|
||||
sess->RegisterExecutionProvider(std::move(provider));
|
||||
}
|
||||
status = sess->Load(model_path);
|
||||
if (!status.IsOK())
|
||||
|
|
@ -638,5 +633,6 @@ ORT_API_STATUS_IMPL(OrtSessionGetOutputName, _In_ const OrtSession* sess, size_t
|
|||
|
||||
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Env, OrtEnv)
|
||||
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Value, MLValue)
|
||||
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(RunOptions, OrtRunOptions)
|
||||
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Session, ::onnxruntime::InferenceSession)
|
||||
DEFINE_RELEASE_ORT_OBJECT_FUNCTION_FOR_ARRAY(Status, char)
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@
|
|||
|
||||
#define BACKEND_DEVICE BACKEND_PROC BACKEND_MKLDNN BACKEND_MKLML BACKEND_OPENBLAS
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
#include "core/providers/providers.h"
|
||||
#include "core/providers/cpu/cpu_execution_provider.h"
|
||||
#include "core/providers/cpu/cpu_provider_factory.h"
|
||||
|
||||
|
|
@ -54,6 +55,15 @@
|
|||
#ifdef USE_NUPHAR
|
||||
#include "core/providers/nuphar/nuphar_provider_factory.h"
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CPU(int use_arena);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CUDA(int device_id);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Mkldnn(int use_arena);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nuphar(int device_id, const char*);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_BrainSlice(int id, bool f, const char*, const char*, const char*);
|
||||
} // namespace onnxruntime
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable : 4267 4996 4503 4003)
|
||||
#endif // _MSC_VER
|
||||
|
|
@ -172,45 +182,32 @@ class SessionObjectInitializer {
|
|||
}
|
||||
};
|
||||
|
||||
inline void RegisterExecutionProvider(InferenceSession* sess, OrtProviderFactoryInterface** f) {
|
||||
OrtProvider* p;
|
||||
(*f)->CreateProvider(f, &p);
|
||||
std::unique_ptr<onnxruntime::IExecutionProvider> q((onnxruntime::IExecutionProvider*)p);
|
||||
auto status = sess->RegisterExecutionProvider(std::move(q));
|
||||
inline void RegisterExecutionProvider(InferenceSession* sess, onnxruntime::IExecutionProviderFactory& f) {
|
||||
auto p = f.CreateProvider();
|
||||
auto status = sess->RegisterExecutionProvider(std::move(p));
|
||||
if (!status.IsOK()) {
|
||||
throw std::runtime_error(status.ErrorMessage().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
#define FACTORY_PTR_HOLDER \
|
||||
std::unique_ptr<OrtProviderFactoryInterface*, decltype(&OrtReleaseObject)> ptr_holder_(f, OrtReleaseObject);
|
||||
|
||||
void InitializeSession(InferenceSession* sess) {
|
||||
onnxruntime::common::Status status;
|
||||
|
||||
#ifdef USE_CUDA
|
||||
{
|
||||
OrtProviderFactoryInterface** f;
|
||||
ORT_THROW_ON_ERROR(OrtCreateCUDAExecutionProviderFactory(0, &f));
|
||||
RegisterExecutionProvider(sess, f);
|
||||
FACTORY_PTR_HOLDER;
|
||||
RegisterExecutionProvider(sess, *onnxruntime::CreateExecutionProviderFactory_CUDA(0));
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_MKLDNN
|
||||
{
|
||||
const bool enable_cpu_mem_arena = true;
|
||||
OrtProviderFactoryInterface** f;
|
||||
ORT_THROW_ON_ERROR(OrtCreateMkldnnExecutionProviderFactory(enable_cpu_mem_arena ? 1 : 0, &f));
|
||||
RegisterExecutionProvider(sess, f);
|
||||
FACTORY_PTR_HOLDER;
|
||||
RegisterExecutionProvider(sess, *onnxruntime::CreateExecutionProviderFactory_Mkldnn(enable_cpu_mem_arena ? 1 : 0));
|
||||
}
|
||||
#endif
|
||||
#if 0 //USE_NUPHAR
|
||||
{
|
||||
OrtProviderFactoryInterface** f;
|
||||
ORT_THROW_ON_ERROR(OrtCreateNupharExecutionProviderFactory(0, "", &f));
|
||||
RegisterExecutionProvider(sess, f);
|
||||
FACTORY_PTR_HOLDER;
|
||||
RegisterExecutionProvider(sess, *onnxruntime::CreateExecutionProviderFactory_Nuphar(0, ""));
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -203,10 +203,7 @@ int real_main(int argc, char* argv[]) {
|
|||
sf.DisableSequentialExecution();
|
||||
if (enable_cuda) {
|
||||
#ifdef USE_CUDA
|
||||
OrtProviderFactoryInterface** f;
|
||||
ORT_THROW_ON_ERROR(OrtCreateCUDAExecutionProviderFactory(0, &f));
|
||||
sf.AppendExecutionProvider(f);
|
||||
OrtReleaseObject(f);
|
||||
ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_CUDA(sf, 0));
|
||||
#else
|
||||
fprintf(stderr, "CUDA is not supported in this build");
|
||||
return -1;
|
||||
|
|
@ -214,10 +211,7 @@ int real_main(int argc, char* argv[]) {
|
|||
}
|
||||
if (enable_nuphar) {
|
||||
#ifdef USE_NUPHAR
|
||||
OrtProviderFactoryInterface** f;
|
||||
ORT_THROW_ON_ERROR(OrtCreateNupharExecutionProviderFactory(0, "", &f));
|
||||
sf.AppendExecutionProvider(f);
|
||||
OrtReleaseObject(f);
|
||||
ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Nuphar(sf, 0, ""));
|
||||
#else
|
||||
fprintf(stderr, "Nuphar is not supported in this build");
|
||||
return -1;
|
||||
|
|
@ -225,10 +219,7 @@ int real_main(int argc, char* argv[]) {
|
|||
}
|
||||
if (enable_mkl) {
|
||||
#ifdef USE_MKLDNN
|
||||
OrtProviderFactoryInterface** f;
|
||||
ORT_THROW_ON_ERROR(OrtCreateMkldnnExecutionProviderFactory(enable_cpu_mem_arena ? 1 : 0, &f));
|
||||
sf.AppendExecutionProvider(f);
|
||||
OrtReleaseObject(f);
|
||||
ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Mkldnn(sf, enable_cpu_mem_arena ? 1 : 0));
|
||||
#else
|
||||
fprintf(stderr, "MKL-DNN is not supported in this build");
|
||||
return -1;
|
||||
|
|
|
|||
|
|
@ -12,21 +12,17 @@
|
|||
#include <filesystem>
|
||||
#endif
|
||||
#include "providers.h"
|
||||
#include "default_providers.h"
|
||||
|
||||
using namespace std::experimental::filesystem::v1;
|
||||
using onnxruntime::Status;
|
||||
|
||||
inline void RegisterExecutionProvider(onnxruntime::InferenceSession* sess, OrtProviderFactoryInterface** f) {
|
||||
OrtProvider* p;
|
||||
(*f)->CreateProvider(f, &p);
|
||||
std::unique_ptr<onnxruntime::IExecutionProvider> q((onnxruntime::IExecutionProvider*)p);
|
||||
auto status = sess->RegisterExecutionProvider(std::move(q));
|
||||
inline void RegisterExecutionProvider(onnxruntime::InferenceSession* sess, std::unique_ptr<onnxruntime::IExecutionProvider>&& f) {
|
||||
auto status = sess->RegisterExecutionProvider(std::move(f));
|
||||
if (!status.IsOK()) {
|
||||
throw std::runtime_error(status.ErrorMessage().c_str());
|
||||
}
|
||||
}
|
||||
#define FACTORY_PTR_HOLDER \
|
||||
std::unique_ptr<OrtProviderFactoryInterface*, decltype(&OrtReleaseObject)> ptr_holder_(f, OrtReleaseObject);
|
||||
|
||||
Status SessionFactory::create(std::shared_ptr<::onnxruntime::InferenceSession>& sess, const path& model_url, const std::string& logid) const {
|
||||
::onnxruntime::SessionOptions so;
|
||||
|
|
@ -41,37 +37,25 @@ Status SessionFactory::create(std::shared_ptr<::onnxruntime::InferenceSession>&
|
|||
for (const std::string& provider : providers_) {
|
||||
if (provider == onnxruntime::kCudaExecutionProvider) {
|
||||
#ifdef USE_CUDA
|
||||
OrtProviderFactoryInterface** f;
|
||||
ORT_THROW_ON_ERROR(OrtCreateCUDAExecutionProviderFactory(0, &f));
|
||||
FACTORY_PTR_HOLDER;
|
||||
RegisterExecutionProvider(sess.get(), f);
|
||||
RegisterExecutionProvider(sess.get(), onnxruntime::test::DefaultCudaExecutionProvider());
|
||||
#else
|
||||
ORT_THROW("CUDA is not supported in this build");
|
||||
#endif
|
||||
} else if (provider == onnxruntime::kMklDnnExecutionProvider) {
|
||||
#ifdef USE_MKLDNN
|
||||
OrtProviderFactoryInterface** f;
|
||||
ORT_THROW_ON_ERROR(OrtCreateMkldnnExecutionProviderFactory(enable_cpu_mem_arena_ ? 1 : 0, &f));
|
||||
FACTORY_PTR_HOLDER;
|
||||
RegisterExecutionProvider(sess.get(), f);
|
||||
RegisterExecutionProvider(sess.get(), onnxruntime::test::DefaultMkldnnExecutionProvider(enable_cpu_mem_arena_ ? 1 : 0));
|
||||
#else
|
||||
ORT_THROW("CUDA is not supported in this build");
|
||||
#endif
|
||||
} else if (provider == onnxruntime::kNupharExecutionProvider) {
|
||||
#ifdef USE_NUPHAR
|
||||
OrtProviderFactoryInterface** f;
|
||||
ORT_THROW_ON_ERROR(OrtCreateNupharExecutionProviderFactory(0, "", &f));
|
||||
RegisterExecutionProvider(sess.get(), f);
|
||||
FACTORY_PTR_HOLDER;
|
||||
RegisterExecutionProvider(sess.get(), onnxruntime::test::DefaultNupharExecutionProvider());
|
||||
#else
|
||||
ORT_THROW("CUDA is not supported in this build");
|
||||
#endif
|
||||
} else if (provider == onnxruntime::kBrainSliceExecutionProvider) {
|
||||
#if USE_BRAINSLICE
|
||||
OrtProviderFactoryInterface** f;
|
||||
ORT_THROW_ON_ERROR(OrtCreateBrainSliceExecutionProviderFactory(0, true, "testdata/firmwares/onnx_rnns/instructions.bin", "testdata/firmwares/onnx_rnns/data.bin", "testdata/firmwares/onnx_rnns/schema.bin", &f));
|
||||
RegisterExecutionProvider(sess.get(), f);
|
||||
FACTORY_PTR_HOLDER;
|
||||
RegisterExecutionProvider(sess.get(), onnxruntime::test::DefaultBrainsliceExecutionProvider());
|
||||
#else
|
||||
ORT_THROW("This executable was not built with BrainSlice");
|
||||
#endif
|
||||
|
|
@ -85,7 +69,7 @@ Status SessionFactory::create(std::shared_ptr<::onnxruntime::InferenceSession>&
|
|||
ORT_THROW("TensorRT is not supported in this build");
|
||||
#endif
|
||||
}
|
||||
//TODO: add more
|
||||
// TODO: add more
|
||||
}
|
||||
|
||||
status = sess->Load(model_url.string());
|
||||
|
|
|
|||
|
|
@ -182,10 +182,7 @@ void verify_input_output_count(OrtSession* session) {
|
|||
|
||||
#ifdef USE_CUDA
|
||||
void enable_cuda(OrtSessionOptions* session_option) {
|
||||
OrtProviderFactoryInterface** factory;
|
||||
ORT_ABORT_ON_ERROR(OrtCreateCUDAExecutionProviderFactory(0, &factory));
|
||||
OrtSessionOptionsAppendExecutionProvider(session_option, factory);
|
||||
OrtReleaseObject(factory);
|
||||
ORT_ABORT_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_CUDA(session_option, 0));
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -207,9 +204,9 @@ int main(int argc, char* argv[]) {
|
|||
ORT_ABORT_ON_ERROR(OrtCreateSession(env, model_path, session_option, &session));
|
||||
verify_input_output_count(session);
|
||||
int ret = run_inference(session, input_file, output_file);
|
||||
OrtReleaseObject(session_option);
|
||||
OrtReleaseSessionOptions(session_option);
|
||||
OrtReleaseSession(session);
|
||||
OrtReleaseObject(env);
|
||||
OrtReleaseEnv(env);
|
||||
if (ret != 0) {
|
||||
fprintf(stderr, "fail\n");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -64,30 +64,21 @@ void TestInference(OrtEnv* env, T model_uri,
|
|||
|
||||
if (provider_type == 1) {
|
||||
#ifdef USE_CUDA
|
||||
OrtProviderFactoryInterface** f;
|
||||
ORT_THROW_ON_ERROR(OrtCreateCUDAExecutionProviderFactory(0, &f));
|
||||
sf.AppendExecutionProvider(f);
|
||||
OrtReleaseObject(f);
|
||||
ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_CUDA(sf, 0));
|
||||
std::cout << "Running simple inference with cuda provider" << std::endl;
|
||||
#else
|
||||
return;
|
||||
#endif
|
||||
} else if (provider_type == 2) {
|
||||
#ifdef USE_MKLDNN
|
||||
OrtProviderFactoryInterface** f;
|
||||
ORT_THROW_ON_ERROR(OrtCreateMkldnnExecutionProviderFactory(1, &f));
|
||||
sf.AppendExecutionProvider(f);
|
||||
OrtReleaseObject(f);
|
||||
ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Mkldnn(sf, 1));
|
||||
std::cout << "Running simple inference with mkldnn provider" << std::endl;
|
||||
#else
|
||||
return;
|
||||
#endif
|
||||
} else if (provider_type == 3) {
|
||||
#ifdef USE_NUPHAR
|
||||
OrtProviderFactoryInterface** f;
|
||||
ORT_THROW_ON_ERROR(OrtCreateNupharExecutionProviderFactory(0, "", &f));
|
||||
sf.AppendExecutionProvider(f);
|
||||
OrtReleaseObject(f);
|
||||
ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Nuphar(sf, 0, ""));
|
||||
std::cout << "Running simple inference with nuphar provider" << std::endl;
|
||||
#else
|
||||
return;
|
||||
|
|
@ -196,7 +187,7 @@ TEST_F(CApiTest, create_tensor_with_data) {
|
|||
const struct OrtTensorTypeAndShapeInfo* tensor_info = OrtCastTypeInfoToTensorInfo(type_info);
|
||||
ASSERT_NE(tensor_info, nullptr);
|
||||
ASSERT_EQ(1, OrtGetNumOfDimensions(tensor_info));
|
||||
OrtReleaseObject(type_info);
|
||||
OrtReleaseTypeInfo(type_info);
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
|
|
|
|||
|
|
@ -4,28 +4,25 @@
|
|||
#include "default_providers.h"
|
||||
#include "providers.h"
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
#define FACTORY_PTR_HOLDER \
|
||||
std::unique_ptr<OrtProviderFactoryInterface*> ptr_holder_(f);
|
||||
#include "core/providers/providers.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CPU(int use_arena);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CUDA(int device_id);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Mkldnn(int use_arena);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nuphar(int device_id, const char*);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_BrainSlice(int id, bool f, const char*, const char*, const char*);
|
||||
|
||||
namespace test {
|
||||
|
||||
std::unique_ptr<IExecutionProvider> DefaultCpuExecutionProvider(bool enable_arena) {
|
||||
OrtProviderFactoryInterface** f;
|
||||
ORT_THROW_ON_ERROR(OrtCreateCpuExecutionProviderFactory(enable_arena ? 1 : 0, &f));
|
||||
FACTORY_PTR_HOLDER;
|
||||
OrtProvider* out;
|
||||
ORT_THROW_ON_ERROR((*f)->CreateProvider(f, &out));
|
||||
return std::unique_ptr<IExecutionProvider>((IExecutionProvider*)out);
|
||||
return CreateExecutionProviderFactory_CPU(enable_arena)->CreateProvider();
|
||||
}
|
||||
|
||||
std::unique_ptr<IExecutionProvider> DefaultCudaExecutionProvider() {
|
||||
#ifdef USE_CUDA
|
||||
OrtProviderFactoryInterface** f;
|
||||
ORT_THROW_ON_ERROR(OrtCreateCUDAExecutionProviderFactory(0, &f));
|
||||
FACTORY_PTR_HOLDER;
|
||||
OrtProvider* out;
|
||||
ORT_THROW_ON_ERROR((*f)->CreateProvider(f, &out));
|
||||
return std::unique_ptr<IExecutionProvider>((IExecutionProvider*)out);
|
||||
return CreateExecutionProviderFactory_CUDA(0)->CreateProvider();
|
||||
#else
|
||||
return nullptr;
|
||||
#endif
|
||||
|
|
@ -33,12 +30,7 @@ std::unique_ptr<IExecutionProvider> DefaultCudaExecutionProvider() {
|
|||
|
||||
std::unique_ptr<IExecutionProvider> DefaultMkldnnExecutionProvider(bool enable_arena) {
|
||||
#ifdef USE_MKLDNN
|
||||
OrtProviderFactoryInterface** f;
|
||||
ORT_THROW_ON_ERROR(OrtCreateMkldnnExecutionProviderFactory(enable_arena ? 1 : 0, &f));
|
||||
FACTORY_PTR_HOLDER;
|
||||
OrtProvider* out;
|
||||
ORT_THROW_ON_ERROR((*f)->CreateProvider(f, &out));
|
||||
return std::unique_ptr<IExecutionProvider>((IExecutionProvider*)out);
|
||||
return CreateExecutionProviderFactory_Mkldnn(enable_arena ? 1 : 0)->CreateProvider();
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(enable_arena);
|
||||
return nullptr;
|
||||
|
|
@ -47,12 +39,7 @@ std::unique_ptr<IExecutionProvider> DefaultMkldnnExecutionProvider(bool enable_a
|
|||
|
||||
std::unique_ptr<IExecutionProvider> DefaultNupharExecutionProvider() {
|
||||
#ifdef USE_NUPHAR
|
||||
OrtProviderFactoryInterface** f;
|
||||
ORT_THROW_ON_ERROR(OrtCreateNupharExecutionProviderFactory(0, "", &f));
|
||||
FACTORY_PTR_HOLDER;
|
||||
OrtProvider* out;
|
||||
ORT_THROW_ON_ERROR((*f)->CreateProvider(f, &out));
|
||||
return std::unique_ptr<IExecutionProvider>((IExecutionProvider*)out);
|
||||
return CreateExecutionProviderFactory_Nuphar(0, "")->CreateProvider();
|
||||
#else
|
||||
return nullptr;
|
||||
#endif
|
||||
|
|
@ -60,12 +47,7 @@ std::unique_ptr<IExecutionProvider> DefaultNupharExecutionProvider() {
|
|||
|
||||
std::unique_ptr<IExecutionProvider> DefaultBrainSliceExecutionProvider() {
|
||||
#ifdef USE_BRAINSLICE
|
||||
OrtProviderFactoryInterface** f;
|
||||
ORT_THROW_ON_ERROR(OrtCreateBrainSliceExecutionProviderFactory(0, true, "testdata/firmwares/onnx_rnns/instructions.bin", "testdata/firmwares/onnx_rnns/data.bin", "testdata/firmwares/onnx_rnns/schema.bin", &f));
|
||||
FACTORY_PTR_HOLDER;
|
||||
OrtProvider* out;
|
||||
ORT_THROW_ON_ERROR((*f)->CreateProvider(f, &out));
|
||||
return std::unique_ptr<IExecutionProvider>((IExecutionProvider*)out);
|
||||
return CreateExecutionProviderFactory_BrainSlice(0, true, "testdata/firmwares/onnx_rnns/instructions.bin", "testdata/firmwares/onnx_rnns/data.bin", "testdata/firmwares/onnx_rnns/schema.bin", &f));
|
||||
#else
|
||||
return nullptr;
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Reference in a new issue