From d875ab2acdafe415a9eccbef68954fb608a76e9f Mon Sep 17 00:00:00 2001 From: Ryan Hill <38674843+RyanUnderhill@users.noreply.github.com> Date: Fri, 25 Jan 2019 19:41:10 -0800 Subject: [PATCH] C API - Remove reference counting (#344) --- .../ExecutionProviderFactory.cs | 106 ------------------ .../InferenceSession.cs | 8 +- .../NamedOnnxValue.cs | 2 +- .../NativeMemoryAllocator.cs | 2 +- .../Microsoft.ML.OnnxRuntime/NativeMethods.cs | 32 +++--- .../NativeOnnxObjectHandle.cs | 28 ----- .../NativeOnnxTensorMemory.cs | 2 +- .../SessionOptions.cs | 64 ++--------- .../InferenceTest.cs | 6 +- .../core/framework/onnx_object_cxx.h | 47 -------- .../onnxruntime/core/framework/run_options.h | 3 +- .../core/providers/cpu/cpu_provider_factory.h | 3 +- .../providers/cuda/cuda_provider_factory.h | 4 +- .../mkldnn/mkldnn_provider_factory.h | 3 +- .../onnxruntime/core/providers/providers.h | 11 ++ .../core/session/onnxruntime_c_api.h | 75 ++++--------- .../core/session/onnxruntime_cxx_api.h | 57 +++++----- onnxruntime/core/framework/onnx_object.cc | 21 ---- .../core/framework/onnxruntime_typeinfo.cc | 13 ++- .../core/framework/onnxruntime_typeinfo.h | 16 +-- .../core/framework/tensor_type_and_shape.cc | 35 +++--- .../providers/cpu/cpu_provider_factory.cc | 53 +++------ onnxruntime/core/providers/cpu/symbols.txt | 9 +- .../providers/cuda/cuda_provider_factory.cc | 64 ++++------- onnxruntime/core/providers/cuda/symbols.txt | 2 +- .../mkldnn/mkldnn_provider_factory.cc | 63 ++++------- onnxruntime/core/providers/mkldnn/symbols.txt | 2 +- .../core/session/abi_session_options.cc | 17 +-- .../core/session/abi_session_options_impl.h | 6 +- onnxruntime/core/session/onnxruntime_c_api.cc | 14 +-- .../python/onnxruntime_pybind_state.cc | 37 +++--- onnxruntime/test/onnx/main.cc | 15 +-- onnxruntime/test/perftest/testenv.cc | 32 ++---- .../shared_lib/fns_candy_style_transfer.c | 9 +- onnxruntime/test/shared_lib/test_inference.cc | 17 +-- onnxruntime/test/util/default_providers.cc | 46 +++----- 36 files changed, 263 insertions(+), 661 deletions(-) delete mode 100644 csharp/src/Microsoft.ML.OnnxRuntime/ExecutionProviderFactory.cs delete mode 100644 csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxObjectHandle.cs delete mode 100644 include/onnxruntime/core/framework/onnx_object_cxx.h create mode 100644 include/onnxruntime/core/providers/providers.h delete mode 100644 onnxruntime/core/framework/onnx_object.cc diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/ExecutionProviderFactory.cs b/csharp/src/Microsoft.ML.OnnxRuntime/ExecutionProviderFactory.cs deleted file mode 100644 index 49eed94acf..0000000000 --- a/csharp/src/Microsoft.ML.OnnxRuntime/ExecutionProviderFactory.cs +++ /dev/null @@ -1,106 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; -using System.Runtime.InteropServices; - -namespace Microsoft.ML.OnnxRuntime -{ - - internal class CpuExecutionProviderFactory: NativeOnnxObjectHandle - { - protected static readonly Lazy _default = new Lazy(() => new CpuExecutionProviderFactory()); - - public CpuExecutionProviderFactory(bool useArena=true) - :base(IntPtr.Zero) - { - int useArenaInt = useArena ? 1 : 0; - try - { - NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateCpuExecutionProviderFactory(useArenaInt, out handle)); - } - catch(OnnxRuntimeException e) - { - if (IsInvalid) - { - ReleaseHandle(); - handle = IntPtr.Zero; - } - throw e; - } - } - - public static CpuExecutionProviderFactory Default - { - get - { - return _default.Value; - } - } - } - - internal class MklDnnExecutionProviderFactory : NativeOnnxObjectHandle - { - protected static readonly Lazy _default = new Lazy(() => new MklDnnExecutionProviderFactory()); - - public MklDnnExecutionProviderFactory(bool useArena = true) - :base(IntPtr.Zero) - { - int useArenaInt = useArena ? 1 : 0; - try - { - NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateMkldnnExecutionProviderFactory(useArenaInt, out handle)); - } - catch (OnnxRuntimeException e) - { - if (IsInvalid) - { - ReleaseHandle(); - handle = IntPtr.Zero; - } - throw e; - } - } - - public static MklDnnExecutionProviderFactory Default - { - get - { - return _default.Value; - } - } - } - - internal class CudaExecutionProviderFactory : NativeOnnxObjectHandle - { - protected static readonly Lazy _default = new Lazy(() => new CudaExecutionProviderFactory()); - - public CudaExecutionProviderFactory(int deviceId = 0) - : base(IntPtr.Zero) - { - try - { - NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateCUDAExecutionProviderFactory(deviceId, out handle)); - } - catch (OnnxRuntimeException e) - { - if (IsInvalid) - { - ReleaseHandle(); - handle = IntPtr.Zero; - } - throw e; - } - } - - public static CudaExecutionProviderFactory Default - { - get - { - return _default.Value; - } - } - } - - - -} diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs index 16bcbc83f9..5836837fda 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs @@ -45,9 +45,9 @@ namespace Microsoft.ML.OnnxRuntime try { if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.Unicode.GetBytes(modelPath), options.NativeHandle, out _nativeHandle)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.Unicode.GetBytes(modelPath), options._nativePtr, out _nativeHandle)); else - NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.UTF8.GetBytes(modelPath), options.NativeHandle, out _nativeHandle)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.UTF8.GetBytes(modelPath), options._nativePtr, out _nativeHandle)); // Initialize input/output metadata _inputMetadata = new Dictionary(); @@ -275,7 +275,7 @@ namespace Microsoft.ML.OnnxRuntime { if (typeInfo != IntPtr.Zero) { - NativeMethods.OrtReleaseObject(typeInfo); + NativeMethods.OrtReleaseTypeInfo(typeInfo); } } } @@ -292,7 +292,7 @@ namespace Microsoft.ML.OnnxRuntime { if (typeInfo != IntPtr.Zero) { - NativeMethods.OrtReleaseObject(typeInfo); + NativeMethods.OrtReleaseTypeInfo(typeInfo); } } } diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.cs index 19eb93e662..cf144eb256 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.cs @@ -198,7 +198,7 @@ namespace Microsoft.ML.OnnxRuntime { if (typeAndShape != IntPtr.Zero) { - NativeMethods.OrtReleaseObject(typeAndShape); + NativeMethods.OrtReleaseTensorTypeAndShapeInfo(typeAndShape); } } diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMemoryAllocator.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMemoryAllocator.cs index 0ff438e857..a9b4e60f5a 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMemoryAllocator.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMemoryAllocator.cs @@ -141,7 +141,7 @@ namespace Microsoft.ML.OnnxRuntime protected static void Delete(IntPtr allocator) { - NativeMethods.OrtReleaseObject(allocator); + NativeMethods.OrtReleaseAllocator(allocator); } protected override bool ReleaseHandle() diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs index 4cf0ee85f0..43833783fb 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs @@ -90,20 +90,22 @@ namespace Microsoft.ML.OnnxRuntime IntPtr /*(OrtAllocator*)*/ allocator, out IntPtr /*(char**)*/name); - // release the typeinfo using OrtReleaseObject + // release the typeinfo using OrtReleaseTypeInfo [DllImport(nativeLib, CharSet = charSet)] public static extern IntPtr /*(OrtStatus*)*/OrtSessionGetInputTypeInfo( IntPtr /*(const OrtSession*)*/ session, ulong index, //TODO: port for size_t out IntPtr /*(struct OrtTypeInfo**)*/ typeInfo); - // release the typeinfo using OrtReleaseObject + // release the typeinfo using OrtReleaseTypeInfo [DllImport(nativeLib, CharSet = charSet)] public static extern IntPtr /*(OrtStatus*)*/OrtSessionGetOutputTypeInfo( IntPtr /*(const OrtSession*)*/ session, ulong index, //TODO: port for size_t out IntPtr /* (struct OrtTypeInfo**)*/ typeInfo); + [DllImport(nativeLib, CharSet = charSet)] + public static extern void OrtReleaseTypeInfo(IntPtr /*(OrtTypeInfo*)*/session); [DllImport(nativeLib, CharSet = charSet)] public static extern void OrtReleaseSession(IntPtr /*(OrtSession*)*/session); @@ -112,11 +114,12 @@ namespace Microsoft.ML.OnnxRuntime #region SessionOptions API - //Release using OrtReleaseObject [DllImport(nativeLib, CharSet = charSet)] public static extern IntPtr /*OrtSessionOptions* */ OrtCreateSessionOptions(); - + [DllImport(nativeLib, CharSet = charSet)] + public static extern void OrtReleaseSessionOptions(IntPtr /*(OrtSessionOptions*)*/session); + [DllImport(nativeLib, CharSet = charSet)] public static extern IntPtr /*(OrtSessionOptions*)*/OrtCloneSessionOptions(IntPtr /*(OrtSessionOptions*)*/ sessionOptions); @@ -153,22 +156,20 @@ namespace Microsoft.ML.OnnxRuntime [DllImport(nativeLib, CharSet = charSet)] public static extern int OrtSetSessionThreadPoolSize(IntPtr /* OrtSessionOptions* */ options, int sessionThreadPoolSize); + ///** // * The order of invocation indicates the preference order as well. In other words call this method // * on your most preferred execution provider first followed by the less preferred ones. // * Calling this API is optional in which case onnxruntime will use its internal CPU execution provider. // */ [DllImport(nativeLib, CharSet = charSet)] - public static extern void OrtSessionOptionsAppendExecutionProvider(IntPtr /*(OrtSessionOptions*)*/ options, IntPtr /* (OrtProviderFactoryPtr*)*/ factory); + public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_CPU(IntPtr /*(OrtSessionOptions*) */ options, int use_arena); [DllImport(nativeLib, CharSet = charSet)] - public static extern IntPtr /*(OrtStatus*)*/ OrtCreateCpuExecutionProviderFactory(int use_arena, out IntPtr /*(OrtProviderFactoryPtr*)*/ factory); + public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Mkldnn(IntPtr /*(OrtSessionOptions*) */ options, int use_arena); [DllImport(nativeLib, CharSet = charSet)] - public static extern IntPtr /*(OrtStatus*)*/ OrtCreateMkldnnExecutionProviderFactory(int use_arena, out IntPtr /*(OrtProviderFactoryPtr**)*/ factory); - - [DllImport(nativeLib, CharSet = charSet)] - public static extern IntPtr /*(OrtStatus*)*/ OrtCreateCUDAExecutionProviderFactory(int device_id, out IntPtr /*(OrtProviderFactoryPtr**)*/ factory); + public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_CUDA(IntPtr /*(OrtSessionOptions*) */ options, int device_id); //[DllImport(nativeLib, CharSet = charSet)] //public static extern IntPtr /*(OrtStatus*)*/ OrtCreateNupharExecutionProviderFactory(int device_id, string target_str, out IntPtr /*(OrtProviderFactoryPtr**)*/ factory); @@ -220,13 +221,8 @@ namespace Microsoft.ML.OnnxRuntime [DllImport(nativeLib, CharSet = charSet)] public static extern IntPtr /*(OrtStatus*)*/OrtCreateDefaultAllocator(out IntPtr /*(OrtAllocator**)*/ allocator); - /// - /// Releases/Unrefs any object, including the Allocator - /// - /// - /// remaining ref count [DllImport(nativeLib, CharSet = charSet)] - public static extern uint /*remaining ref count*/ OrtReleaseObject(IntPtr /*(void*)*/ ptr); + public static extern void OrtReleaseAllocator(IntPtr /*(OrtAllocator*)*/ allocator); /// /// Release any object allocated by an allocator @@ -265,6 +261,10 @@ namespace Microsoft.ML.OnnxRuntime [DllImport(nativeLib, CharSet = charSet)] public static extern IntPtr /*(OrtStatus*)*/ OrtGetTensorShapeAndType(IntPtr /*(OrtValue*)*/ value, out IntPtr /*(struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo); + + [DllImport(nativeLib, CharSet = charSet)] + public static extern void OrtReleaseTensorTypeAndShapeInfo(IntPtr /*(OrtTensorTypeAndShapeInfo*)*/ value); + [DllImport(nativeLib, CharSet = charSet)] public static extern TensorElementType OrtGetTensorElementType(IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo); diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxObjectHandle.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxObjectHandle.cs deleted file mode 100644 index b2fa4c8f68..0000000000 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxObjectHandle.cs +++ /dev/null @@ -1,28 +0,0 @@ -using System; -using System.Runtime.InteropServices; - -namespace Microsoft.ML.OnnxRuntime -{ - - internal class NativeOnnxObjectHandle : SafeHandle - { - public NativeOnnxObjectHandle(IntPtr ptr) - : base(IntPtr.Zero, true) - { - handle = ptr; - } - public override bool IsInvalid - { - get - { - return (handle == IntPtr.Zero); - } - } - - protected override bool ReleaseHandle() - { - NativeMethods.OrtReleaseObject(handle); - return true; - } - } -} diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxTensorMemory.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxTensorMemory.cs index 6b4e22fb71..31e354ab04 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxTensorMemory.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxTensorMemory.cs @@ -70,7 +70,7 @@ namespace Microsoft.ML.OnnxRuntime { if (typeAndShape != IntPtr.Zero) { - NativeMethods.OrtReleaseObject(typeAndShape); + NativeMethods.OrtReleaseTensorTypeAndShapeInfo(typeAndShape); } } } diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs index 0e8f88c6da..093e347836 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs @@ -23,7 +23,7 @@ namespace Microsoft.ML.OnnxRuntime /// public class SessionOptions:IDisposable { - protected SafeHandle _nativeOption; + public IntPtr _nativePtr; protected static readonly Lazy _default = new Lazy(MakeSessionOptionWithMklDnnProvider); private static string[] cudaDelayLoadedLibs = { "cublas64_100.dll", "cudnn64_7.dll" }; @@ -32,7 +32,7 @@ namespace Microsoft.ML.OnnxRuntime /// public SessionOptions() { - _nativeOption = new NativeOnnxObjectHandle(NativeMethods.OrtCreateSessionOptions()); + _nativePtr = NativeMethods.OrtCreateSessionOptions(); } /// @@ -46,33 +46,11 @@ namespace Microsoft.ML.OnnxRuntime } } - /// - /// Append an execution propvider. When any operator is evaluated, it is executed on the first execution provider that provides it - /// - /// - public void AppendExecutionProvider(ExecutionProvider provider) - { - switch (provider) - { - case ExecutionProvider.Cpu: - AppendExecutionProvider(CpuExecutionProviderFactory.Default); - break; - case ExecutionProvider.MklDnn: - AppendExecutionProvider(MklDnnExecutionProviderFactory.Default); - break; - case ExecutionProvider.Cuda: - AppendExecutionProvider(CudaExecutionProviderFactory.Default); - break; - default: - break; - } - } - private static SessionOptions MakeSessionOptionWithMklDnnProvider() { SessionOptions options = new SessionOptions(); - options.AppendExecutionProvider(MklDnnExecutionProviderFactory.Default); - options.AppendExecutionProvider(CpuExecutionProviderFactory.Default); + // NativeMethods.OrtSessionOptionsAppendExecutionProvider_Mkldnn(_nativePtr, 1); + NativeMethods.OrtSessionOptionsAppendExecutionProvider_CPU(options._nativePtr, 1); return options; } @@ -94,38 +72,12 @@ namespace Microsoft.ML.OnnxRuntime { CheckCudaExecutionProviderDLLs(); SessionOptions options = new SessionOptions(); - if (deviceId == 0) //default value - options.AppendExecutionProvider(CudaExecutionProviderFactory.Default); - else - options.AppendExecutionProvider(new CudaExecutionProviderFactory(deviceId)); - options.AppendExecutionProvider(MklDnnExecutionProviderFactory.Default); - options.AppendExecutionProvider(CpuExecutionProviderFactory.Default); + NativeMethods.OrtSessionOptionsAppendExecutionProvider_CUDA(options._nativePtr, deviceId); + NativeMethods.OrtSessionOptionsAppendExecutionProvider_Mkldnn(options._nativePtr, 1); + NativeMethods.OrtSessionOptionsAppendExecutionProvider_CPU(options._nativePtr, 1); return options; } - internal IntPtr NativeHandle - { - get - { - return _nativeOption.DangerousGetHandle(); //Note: this is unsafe, and not ref counted, use with caution - } - } - - private void AppendExecutionProvider(NativeOnnxObjectHandle providerFactory) - { - unsafe - { - bool success = false; - providerFactory.DangerousAddRef(ref success); - if (success) - { - NativeMethods.OrtSessionOptionsAppendExecutionProvider(_nativeOption.DangerousGetHandle(), providerFactory.DangerousGetHandle()); - providerFactory.DangerousRelease(); - } - - } - } - // Declared, but called only if OS = Windows. [DllImport("kernel32.dll")] private static extern IntPtr LoadLibrary(string dllToLoad); @@ -172,7 +124,7 @@ namespace Microsoft.ML.OnnxRuntime { // cleanup managed resources } - _nativeOption.Dispose(); + NativeMethods.OrtReleaseSessionOptions(_nativePtr); } #endregion diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs index d53392e5c5..9fdac5b3c6 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs @@ -522,9 +522,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests "OrtSessionGetOutputTypeInfo","OrtReleaseSession","OrtCreateSessionOptions","OrtCloneSessionOptions", "OrtEnableSequentialExecution","OrtDisableSequentialExecution","OrtEnableProfiling","OrtDisableProfiling", "OrtEnableMemPattern","OrtDisableMemPattern","OrtEnableCpuMemArena","OrtDisableCpuMemArena", - "OrtSetSessionLogId","OrtSetSessionLogVerbosityLevel","OrtSetSessionThreadPoolSize","OrtSessionOptionsAppendExecutionProvider", - "OrtCreateCpuExecutionProviderFactory","OrtCreateAllocatorInfo","OrtCreateCpuAllocatorInfo", - "OrtCreateDefaultAllocator","OrtReleaseObject","OrtAllocatorFree","OrtAllocatorGetInfo", + "OrtSetSessionLogId","OrtSetSessionLogVerbosityLevel","OrtSetSessionThreadPoolSize","OrtSessionOptionsAppendExecutionProvider_CPU", + "OrtCreateAllocatorInfo","OrtCreateCpuAllocatorInfo", + "OrtCreateDefaultAllocator","OrtAllocatorFree","OrtAllocatorGetInfo", "OrtCreateTensorWithDataAsOrtValue","OrtGetTensorMutableData", "OrtReleaseAllocatorInfo", "OrtCastTypeInfoToTensorInfo","OrtGetTensorShapeAndType","OrtGetTensorElementType","OrtGetNumOfDimensions", "OrtGetDimensions","OrtGetTensorShapeElementCount","OrtReleaseValue"}; diff --git a/include/onnxruntime/core/framework/onnx_object_cxx.h b/include/onnxruntime/core/framework/onnx_object_cxx.h deleted file mode 100644 index 9677f8bdda..0000000000 --- a/include/onnxruntime/core/framework/onnx_object_cxx.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/common/common.h" -#include "core/session/onnxruntime_c_api.h" -#include - -namespace onnxruntime { - -/** - * Even it's designed to be inherited, this class doesn't have a virtual destructor. - * No vtable is allowed in this class and its subclasses. - * \tparam T subclass type name - */ -template -class ObjectBase { - private: - static OrtObject static_cls; - - protected: - const OrtObject* const ORT_ATTRIBUTE_UNUSED cls_; - std::atomic_int ref_count; - ObjectBase() : cls_(&static_cls), ref_count(1) { - } - - static uint32_t ORT_API_CALL OrtReleaseImpl(void* this_) { - T* this_ptr = reinterpret_cast(this_); - if (--this_ptr->ref_count == 0) - delete this_ptr; - return 0; - } - - static uint32_t ORT_API_CALL OrtAddRefImpl(void* this_) { - T* this_ptr = reinterpret_cast(this_); - ++this_ptr->ref_count; - return 0; - } -}; - -template -OrtObject ObjectBase::static_cls = {ObjectBase::OrtAddRefImpl, ObjectBase::OrtReleaseImpl}; - -} // namespace onnxruntime - -#define ORT_CHECK_C_OBJECT_LAYOUT \ - { assert((char*)&ref_count == (char*)this + sizeof(this)); } diff --git a/include/onnxruntime/core/framework/run_options.h b/include/onnxruntime/core/framework/run_options.h index 9c6670a3b0..f10f161ebe 100644 --- a/include/onnxruntime/core/framework/run_options.h +++ b/include/onnxruntime/core/framework/run_options.h @@ -7,12 +7,11 @@ #include #include #include "core/session/onnxruntime_c_api.h" -#include "core/framework/onnx_object_cxx.h" /** * Configuration information for a single Run. */ -struct OrtRunOptions : public onnxruntime::ObjectBase { +struct OrtRunOptions { unsigned run_log_verbosity_level = 0; ///< applies to a particular Run() invocation std::string run_tag; ///< to identify logs generated by a particular Run() invocation diff --git a/include/onnxruntime/core/providers/cpu/cpu_provider_factory.h b/include/onnxruntime/core/providers/cpu/cpu_provider_factory.h index cda1766ad3..32289eb5bb 100644 --- a/include/onnxruntime/core/providers/cpu/cpu_provider_factory.h +++ b/include/onnxruntime/core/providers/cpu/cpu_provider_factory.h @@ -9,9 +9,8 @@ extern "C" { /** * \param use_arena zero: false. non-zero: true. - * \param out Call OrtReleaseObject() method when you no longer need to use it. */ -ORT_API_STATUS(OrtCreateCpuExecutionProviderFactory, int use_arena, _Out_ OrtProviderFactoryInterface*** out) +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena) ORT_ALL_ARGS_NONNULL; ORT_API_STATUS(OrtCreateCpuAllocatorInfo, enum OrtAllocatorType type, enum OrtMemType mem_type1, _Out_ OrtAllocatorInfo** out) diff --git a/include/onnxruntime/core/providers/cuda/cuda_provider_factory.h b/include/onnxruntime/core/providers/cuda/cuda_provider_factory.h index 39f0fbc776..3fc4b7b51f 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_provider_factory.h +++ b/include/onnxruntime/core/providers/cuda/cuda_provider_factory.h @@ -6,11 +6,11 @@ #ifdef __cplusplus extern "C" { #endif + /** * \param device_id cuda device id, starts from zero. - * \param out Call OrtReleaseObject() method when you no longer need to use it. */ -ORT_API_STATUS(OrtCreateCUDAExecutionProviderFactory, int device_id, _Out_ OrtProviderFactoryInterface*** out); +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, int device_id); #ifdef __cplusplus } diff --git a/include/onnxruntime/core/providers/mkldnn/mkldnn_provider_factory.h b/include/onnxruntime/core/providers/mkldnn/mkldnn_provider_factory.h index 25603f121c..03ef1158ee 100644 --- a/include/onnxruntime/core/providers/mkldnn/mkldnn_provider_factory.h +++ b/include/onnxruntime/core/providers/mkldnn/mkldnn_provider_factory.h @@ -9,9 +9,8 @@ extern "C" { /** * \param use_arena zero: false. non-zero: true. - * \param out Call OrtReleaseObject() method when you no longer need to use it. */ -ORT_API_STATUS(OrtCreateMkldnnExecutionProviderFactory, int use_arena, _Out_ OrtProviderFactoryInterface*** out); +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Mkldnn, _In_ OrtSessionOptions* options, int use_arena); #ifdef __cplusplus } diff --git a/include/onnxruntime/core/providers/providers.h b/include/onnxruntime/core/providers/providers.h new file mode 100644 index 0000000000..fc16812417 --- /dev/null +++ b/include/onnxruntime/core/providers/providers.h @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace onnxruntime { +class IExecutionProvider; + +struct IExecutionProviderFactory { + virtual ~IExecutionProviderFactory() {} + virtual std::unique_ptr CreateProvider() = 0; +}; +} // namespace onnxruntime diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 07fdca7cd0..5c936af3a0 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -71,7 +71,7 @@ typedef enum ONNXTensorElementDataType { ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, // maps to c type int32_t ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, // maps to c type int64_t ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, // maps to c++ type std::string - ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, // + ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, // maps to c type double ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, // maps to c type uint32_t @@ -139,26 +139,10 @@ ORT_RUNTIME_CLASS(AllocatorInfo); ORT_RUNTIME_CLASS(Session); ORT_RUNTIME_CLASS(Value); ORT_RUNTIME_CLASS(ValueList); - -struct OrtTypeInfo; -typedef struct OrtTypeInfo OrtTypeInfo; -struct OrtTensorTypeAndShapeInfo; -typedef struct OrtTensorTypeAndShapeInfo OrtTensorTypeAndShapeInfo; -struct OrtRunOptions; -typedef struct OrtRunOptions OrtRunOptions; -struct OrtSessionOptions; -typedef struct OrtSessionOptions OrtSessionOptions; - -/** - * Every type inherented from OrtObject should be deleted by OrtReleaseObject(...). - */ -typedef struct OrtObject { - // Returns the new reference count. - uint32_t(ORT_API_CALL* AddRef)(void* this_); - // Returns the new reference count. - uint32_t(ORT_API_CALL* Release)(void* this_); - -} OrtObject; +ORT_RUNTIME_CLASS(RunOptions); +ORT_RUNTIME_CLASS(TypeInfo); +ORT_RUNTIME_CLASS(TensorTypeAndShapeInfo); +ORT_RUNTIME_CLASS(SessionOptions); // When passing in an allocator to any ORT function, be sure that the allocator object // is not destroyed until the last allocated object using it is freed. @@ -168,12 +152,6 @@ typedef struct OrtAllocator { const struct OrtAllocatorInfo*(ORT_API_CALL* Info)(const struct OrtAllocator* this_); } OrtAllocator; -// Inherented from OrtObject -typedef struct OrtProviderFactoryInterface { - OrtObject parent; - OrtStatus*(ORT_API_CALL* CreateProvider)(void* this_, OrtProvider** out); -} OrtProviderFactoryInterface; - typedef void(ORT_API_CALL* OrtLoggingFunction)( void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, const char* message); @@ -187,7 +165,7 @@ ORT_ALL_ARGS_NONNULL; /** * OrtEnv is process-wise. For each process, only one OrtEnv can be created. Don't do it multiple times - * \param out Should be freed by `OrtReleaseObject` after use + * \param out Should be freed by `OrtReleaseEnv` after use */ ORT_API_STATUS(OrtInitializeWithCustomLogger, OrtLoggingFunction logging_function, _In_opt_ void* logger_param, OrtLoggingLevel default_warning_level, @@ -209,7 +187,7 @@ ORT_API_STATUS(OrtRun, _Inout_ OrtSession* sess, _In_ const char* const* output_names, size_t output_names_len, _Out_ OrtValue** output); /** - * \return A pointer of the newly created object. The pointer should be freed by OrtReleaseObject after use + * \return A pointer of the newly created object. The pointer should be freed by OrtReleaseSessionOptions after use */ ORT_API(OrtSessionOptions*, OrtCreateSessionOptions); @@ -245,11 +223,15 @@ ORT_API(void, OrtSetSessionLogVerbosityLevel, _In_ OrtSessionOptions* options, u ORT_API(int, OrtSetSessionThreadPoolSize, _In_ OrtSessionOptions* options, int session_thread_pool_size); /** - * The order of invocation indicates the preference order as well. In other words call this method + * To use additional providers, you must build ORT with the extra providers enabled. Then call one of these + * functions to enable them in the session: + * OrtSessionOptionsAppendExecutionProvider_CPU + * OrtSessionOptionsAppendExecutionProvider_CUDA + * OrtSessionOptionsAppendExecutionProvider_ + * The order they care called indicates the preference order as well. In other words call this method * on your most preferred execution provider first followed by the less preferred ones. - * Calling this API is optional in which case Ort will use its internal CPU execution provider. + * If none are called Ort will use its internal CPU execution provider. */ -ORT_API(void, OrtSessionOptionsAppendExecutionProvider, _In_ OrtSessionOptions* options, _In_ OrtProviderFactoryInterface** f); ORT_API(void, OrtAppendCustomOpLibPath, _In_ OrtSessionOptions* options, const char* lib_path); @@ -257,12 +239,12 @@ ORT_API_STATUS(OrtSessionGetInputCount, _In_ const OrtSession* sess, _Out_ size_ ORT_API_STATUS(OrtSessionGetOutputCount, _In_ const OrtSession* sess, _Out_ size_t* out); /** - * \param out should be freed by OrtReleaseObject after use + * \param out should be freed by OrtReleaseTypeInfo after use */ ORT_API_STATUS(OrtSessionGetInputTypeInfo, _In_ const OrtSession* sess, size_t index, _Out_ OrtTypeInfo** out); /** - * \param out should be freed by OrtReleaseObject after use + * \param out should be freed by OrtReleaseTypeInfo after use */ ORT_API_STATUS(OrtSessionGetOutputTypeInfo, _In_ const OrtSession* sess, size_t index, _Out_ OrtTypeInfo** out); @@ -275,7 +257,7 @@ ORT_API_STATUS(OrtSessionGetOutputName, _In_ const OrtSession* sess, size_t inde _Inout_ OrtAllocator* allocator, _Out_ char** value); /** - * \return A pointer to the newly created object. The pointer should be freed by OrtReleaseObject after use + * \return A pointer to the newly created object. The pointer should be freed by OrtReleaseRunOptions after use */ ORT_API(OrtRunOptions*, OrtCreateRunOptions); @@ -345,7 +327,7 @@ ORT_API_STATUS(OrtTensorProtoToOrtValue, _Inout_ OrtAllocator* allocator, ORT_API(const OrtTensorTypeAndShapeInfo*, OrtCastTypeInfoToTensorInfo, _In_ OrtTypeInfo*); /** - * The retured value should be released by calling OrtReleaseObject + * The retured value should be released by calling OrtReleaseTensorTypeAndShapeInfo */ ORT_API(OrtTensorTypeAndShapeInfo*, OrtCreateTensorTypeAndShapeInfo); @@ -374,36 +356,19 @@ ORT_API(void, OrtGetDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, _Out ORT_API(int64_t, OrtGetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* info); /** - * \param out Should be freed by OrtReleaseObject after use + * \param out Should be freed by OrtReleaseTensorTypeAndShapeInfo after use */ ORT_API_STATUS(OrtGetTensorShapeAndType, _In_ const OrtValue* value, _Out_ OrtTensorTypeAndShapeInfo** out); /** * Get the type information of an OrtValue * \param value - * \param out The returned value should be freed by OrtReleaseObject after use + * \param out The returned value should be freed by OrtReleaseTypeInfo after use */ ORT_API_STATUS(OrtGetTypeInfo, _In_ const OrtValue* value, OrtTypeInfo** out); ORT_API(enum ONNXType, OrtGetValueType, _In_ const OrtValue* value); -/** - * This function is a wrapper to "(*(OrtObject**)ptr)->AddRef(ptr)" - * WARNING: There is NO type checking in this function. - * Before calling this function, caller should make sure current ref count > 0 - * \return the new reference count - */ -ORT_API(uint32_t, OrtAddRefToObject, _In_ void* ptr); - -/** - * - * A wrapper to "(*(OrtObject**)ptr)->Release(ptr)" - * WARNING: There is NO type checking in this function. - * \param ptr Can be NULL. If it's NULL, this function will return zero. - * \return the new reference count. - */ -ORT_API(uint32_t, OrtReleaseObject, _Inout_opt_ void* ptr); - typedef enum OrtAllocatorType { OrtDeviceAllocator = 0, OrtArenaAllocator = 1 diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 5f7bb2cb5c..ab916f5771 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -38,36 +38,48 @@ struct default_delete { OrtReleaseEnv(ptr); } }; -} // namespace std -#define DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(TYPE_NAME) \ - namespace std { \ - template <> \ - struct default_delete { \ - void operator()(Ort##TYPE_NAME* ptr) { \ - (*reinterpret_cast(ptr))->Release(ptr); \ - } \ - }; \ +template <> +struct default_delete { + void operator()(OrtRunOptions* ptr) { + OrtReleaseRunOptions(ptr); } +}; -DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(TypeInfo); -DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(TensorTypeAndShapeInfo); -DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(RunOptions); -DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(SessionOptions); -DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT(ProviderFactoryInterface*); +template <> +struct default_delete { + void operator()(OrtTypeInfo* ptr) { + OrtReleaseTypeInfo(ptr); + } +}; -#undef DECLARE_DEFAULT_DELETER_FOR_ONNX_OBJECT +template <> +struct default_delete { + void operator()(OrtTensorTypeAndShapeInfo* ptr) { + OrtReleaseTensorTypeAndShapeInfo(ptr); + } +}; + +template <> +struct default_delete { + void operator()(OrtSessionOptions* ptr) { + OrtReleaseSessionOptions(ptr); + } +}; +} // namespace std namespace onnxruntime { class SessionOptionsWrapper { private: - std::unique_ptr value; + std::unique_ptr value; OrtEnv* env_; - SessionOptionsWrapper(_In_ OrtEnv* env, OrtSessionOptions* p) : value(p, OrtReleaseObject), env_(env){}; + SessionOptionsWrapper(_In_ OrtEnv* env, OrtSessionOptions* p) : value(p), env_(env){}; public: + operator OrtSessionOptions*() { return value.get(); } + //TODO: for the input arg, should we call addref here? - SessionOptionsWrapper(_In_ OrtEnv* env) : value(OrtCreateSessionOptions(), OrtReleaseObject), env_(env){}; + SessionOptionsWrapper(_In_ OrtEnv* env) : value(OrtCreateSessionOptions()), env_(env){}; ORT_REDIRECT_SIMPLE_FUNCTION_CALL(EnableSequentialExecution) ORT_REDIRECT_SIMPLE_FUNCTION_CALL(DisableSequentialExecution) ORT_REDIRECT_SIMPLE_FUNCTION_CALL(DisableProfiling) @@ -89,15 +101,6 @@ class SessionOptionsWrapper { OrtSetSessionThreadPoolSize(value.get(), session_thread_pool_size); } - /** - * The order of invocation indicates the preference order as well. In other words call this method - * on your most preferred execution provider first followed by the less preferred ones. - * Calling this API is optional in which case onnxruntime will use its internal CPU execution provider. - */ - void AppendExecutionProvider(_In_ OrtProviderFactoryInterface** f) { - OrtSessionOptionsAppendExecutionProvider(value.get(), f); - } - SessionOptionsWrapper clone() const { OrtSessionOptions* p = OrtCloneSessionOptions(value.get()); return SessionOptionsWrapper(env_, p); diff --git a/onnxruntime/core/framework/onnx_object.cc b/onnxruntime/core/framework/onnx_object.cc deleted file mode 100644 index 199b68dc58..0000000000 --- a/onnxruntime/core/framework/onnx_object.cc +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/session/onnxruntime_c_api.h" -#include - -ORT_API(uint32_t, OrtAddRefToObject, void* ptr) { - return (*static_cast(ptr))->AddRef(ptr); -} - -ORT_API(uint32_t, OrtReleaseObject, void* ptr) { - if (ptr == nullptr) return 0; - return (*static_cast(ptr))->Release(ptr); -} - -namespace { -struct ObjectImpl { - const OrtObject* const cls; - std::atomic_int ref_count; -}; -} // namespace diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.cc b/onnxruntime/core/framework/onnxruntime_typeinfo.cc index 0c45349982..57ff6ca54c 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.cc +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.cc @@ -3,8 +3,8 @@ //this file contains implementations of the C API -#include "onnxruntime_typeinfo.h" #include +#include "onnxruntime_typeinfo.h" #include "core/framework/tensor.h" #include "core/graph/onnx_protobuf.h" @@ -14,16 +14,19 @@ using onnxruntime::MLFloat16; using onnxruntime::Tensor; using onnxruntime::TensorShape; -OrtTypeInfo::OrtTypeInfo(ONNXType type1, void* data1) noexcept : type(type1), data(data1) { +OrtTypeInfo::OrtTypeInfo(ONNXType type1, OrtTensorTypeAndShapeInfo* data1) noexcept : type(type1), data(data1) { } OrtTypeInfo::~OrtTypeInfo() { - assert(ref_count == 0); - OrtReleaseObject(data); + OrtReleaseTensorTypeAndShapeInfo(data); } ORT_API(const struct OrtTensorTypeAndShapeInfo*, OrtCastTypeInfoToTensorInfo, _In_ struct OrtTypeInfo* input) { - return input->type == ONNX_TYPE_TENSOR ? reinterpret_cast(input->data) : nullptr; + return input->type == ONNX_TYPE_TENSOR ? input->data : nullptr; +} + +ORT_API(void, OrtReleaseTypeInfo, OrtTypeInfo* ptr) { + delete ptr; } OrtStatus* GetTensorShapeAndType(const TensorShape* shape, const onnxruntime::DataTypeImpl* tensor_data_type, OrtTensorTypeAndShapeInfo** out); diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.h b/onnxruntime/core/framework/onnxruntime_typeinfo.h index 9cbef90d24..7f13d3bdab 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.h +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.h @@ -2,8 +2,8 @@ // Licensed under the MIT License. #pragma once -#include "core/framework/onnx_object_cxx.h" #include +#include "core/session/onnxruntime_c_api.h" namespace onnxruntime { class DataTypeImpl; @@ -18,21 +18,21 @@ class TypeProto; * the equivalent of onnx::TypeProto * This class is mainly for the C API */ -struct OrtTypeInfo : public onnxruntime::ObjectBase { +struct OrtTypeInfo { public: - friend class onnxruntime::ObjectBase; - ONNXType type = ONNX_TYPE_UNKNOWN; + + ~OrtTypeInfo(); + //owned by this - void* data = nullptr; + OrtTensorTypeAndShapeInfo* data = nullptr; OrtTypeInfo(const OrtTypeInfo& other) = delete; OrtTypeInfo& operator=(const OrtTypeInfo& other) = delete; static OrtStatus* FromDataTypeImpl(const onnxruntime::DataTypeImpl* input, const onnxruntime::TensorShape* shape, - const onnxruntime::DataTypeImpl* tensor_data_type, OrtTypeInfo** out); + const onnxruntime::DataTypeImpl* tensor_data_type, OrtTypeInfo** out); static OrtStatus* FromDataTypeImpl(const onnx::TypeProto*, OrtTypeInfo** out); private: - OrtTypeInfo(ONNXType type, void* data) noexcept; - ~OrtTypeInfo(); + OrtTypeInfo(ONNXType type, OrtTensorTypeAndShapeInfo* data) noexcept; }; diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index da355705ec..6d26e91ad0 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -15,36 +15,29 @@ using onnxruntime::DataTypeImpl; using onnxruntime::MLFloat16; using onnxruntime::Tensor; -struct OrtTensorTypeAndShapeInfo : public onnxruntime::ObjectBase { +struct OrtTensorTypeAndShapeInfo { public: - friend class onnxruntime::ObjectBase; - ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; onnxruntime::TensorShape shape; - static OrtTensorTypeAndShapeInfo* Create() { - return new OrtTensorTypeAndShapeInfo(); - } - + OrtTensorTypeAndShapeInfo() = default; OrtTensorTypeAndShapeInfo(const OrtTensorTypeAndShapeInfo& other) = delete; OrtTensorTypeAndShapeInfo& operator=(const OrtTensorTypeAndShapeInfo& other) = delete; - - private: - OrtTensorTypeAndShapeInfo() = default; - ~OrtTensorTypeAndShapeInfo() { - assert(ref_count == 0); - } }; #define API_IMPL_BEGIN try { -#define API_IMPL_END \ - } \ - catch (std::exception & ex) { \ +#define API_IMPL_END \ + } \ + catch (std::exception & ex) { \ return OrtCreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); \ } ORT_API(OrtTensorTypeAndShapeInfo*, OrtCreateTensorTypeAndShapeInfo) { - return OrtTensorTypeAndShapeInfo::Create(); + return new OrtTensorTypeAndShapeInfo(); +} + +ORT_API(void, OrtReleaseTensorTypeAndShapeInfo, OrtTensorTypeAndShapeInfo* ptr) { + delete ptr; } ORT_API_STATUS_IMPL(OrtSetTensorElementType, _In_ OrtTensorTypeAndShapeInfo* this_ptr, enum ONNXTensorElementDataType type) { @@ -126,13 +119,13 @@ OrtStatus* GetTensorShapeAndType(const onnxruntime::TensorShape* shape, const on OrtTensorTypeAndShapeInfo* ret = OrtCreateTensorTypeAndShapeInfo(); auto status = OrtSetTensorElementType(ret, type); if (status != nullptr) { - OrtReleaseObject(ret); + OrtReleaseTensorTypeAndShapeInfo(ret); return status; } if (shape != nullptr) { status = OrtSetDims(ret, shape->GetDims().data(), shape->GetDims().size()); if (status != nullptr) { - OrtReleaseObject(ret); + OrtReleaseTensorTypeAndShapeInfo(ret); return status; } } @@ -160,7 +153,7 @@ ORT_API(enum ONNXType, OrtGetValueType, _In_ const OrtValue* value) { return ONNX_TYPE_UNKNOWN; } ONNXType ret = out->type; - OrtReleaseObject(out); + OrtReleaseTypeInfo(out); return ret; } catch (std::exception&) { return ONNX_TYPE_UNKNOWN; @@ -170,7 +163,7 @@ ORT_API(enum ONNXType, OrtGetValueType, _In_ const OrtValue* value) { /** * Get the type information of an OrtValue * \param value - * \return The returned value should be freed by OrtReleaseObject after use + * \return The returned value should be freed by OrtReleaseTypeInfo after use */ ORT_API_STATUS_IMPL(OrtGetTypeInfo, _In_ const OrtValue* value, struct OrtTypeInfo** out) { auto v = reinterpret_cast(value); diff --git a/onnxruntime/core/providers/cpu/cpu_provider_factory.cc b/onnxruntime/core/providers/cpu/cpu_provider_factory.cc index dc8f9532bd..be905a8531 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_factory.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_factory.cc @@ -4,52 +4,33 @@ #include "core/providers/cpu/cpu_provider_factory.h" #include #include "cpu_execution_provider.h" +#include "core/session/abi_session_options_impl.h" -using namespace onnxruntime; +namespace onnxruntime { -namespace { -struct CpuProviderFactory { - const OrtProviderFactoryInterface* const cls; - std::atomic_int ref_count; - bool create_arena; - CpuProviderFactory(); +struct CpuProviderFactory : IExecutionProviderFactory { + CpuProviderFactory(bool create_arena) : create_arena_(create_arena) {} + ~CpuProviderFactory() override {} + std::unique_ptr CreateProvider() override; + + private: + bool create_arena_; }; -OrtStatus* ORT_API_CALL CreateCpu(void* this_, OrtProvider** out) { +std::unique_ptr CpuProviderFactory::CreateProvider() { CPUExecutionProviderInfo info; - CpuProviderFactory* this_ptr = (CpuProviderFactory*)this_; - info.create_arena = this_ptr->create_arena; - CPUExecutionProvider* ret = new CPUExecutionProvider(info); - *out = (OrtProvider*)ret; - return nullptr; + info.create_arena = create_arena_; + return std::make_unique(info); } -uint32_t ORT_API_CALL ReleaseCpu(void* this_) { - CpuProviderFactory* this_ptr = (CpuProviderFactory*)this_; - if (--this_ptr->ref_count == 0) - delete this_ptr; - return 0; +std::shared_ptr CreateExecutionProviderFactory_CPU(int use_arena) { + return std::make_shared(use_arena != 0); } -uint32_t ORT_API_CALL AddRefCpu(void* this_) { - CpuProviderFactory* this_ptr = (CpuProviderFactory*)this_; - ++this_ptr->ref_count; - return 0; -} +} // namespace onnxruntime -constexpr OrtProviderFactoryInterface cpu_cls = { - {AddRefCpu, - ReleaseCpu}, - CreateCpu, -}; - -CpuProviderFactory::CpuProviderFactory() : cls(&cpu_cls), ref_count(1), create_arena(true) {} -} // namespace - -ORT_API_STATUS_IMPL(OrtCreateCpuExecutionProviderFactory, int use_arena, _Out_ OrtProviderFactoryInterface*** out) { - CpuProviderFactory* ret = new CpuProviderFactory(); - ret->create_arena = (use_arena != 0); - *out = (OrtProviderFactoryInterface**)ret; +ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena) { + options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_CPU(use_arena)); return nullptr; } diff --git a/onnxruntime/core/providers/cpu/symbols.txt b/onnxruntime/core/providers/cpu/symbols.txt index 1940e8d441..fd64791936 100644 --- a/onnxruntime/core/providers/cpu/symbols.txt +++ b/onnxruntime/core/providers/cpu/symbols.txt @@ -1,4 +1,3 @@ -OrtAddRefToObject OrtAllocatorAlloc OrtAllocatorFree OrtAllocatorGetInfo @@ -12,7 +11,6 @@ OrtCloneSessionOptions OrtCompareAllocatorInfo OrtCreateAllocatorInfo OrtCreateCpuAllocatorInfo -OrtCreateCpuExecutionProviderFactory OrtCreateDefaultAllocator OrtCreateRunOptions OrtCreateSession @@ -47,9 +45,12 @@ OrtIsTensor OrtReleaseAllocator OrtReleaseAllocatorInfo OrtReleaseEnv -OrtReleaseObject +OrtReleaseRunOptions OrtReleaseSession +OrtReleaseSessionOptions OrtReleaseStatus +OrtReleaseTensorTypeAndShapeInfo +OrtReleaseTypeInfo OrtReleaseValue OrtRun OrtRunOptionsGetRunLogVerbosityLevel @@ -63,7 +64,7 @@ OrtSessionGetInputTypeInfo OrtSessionGetOutputCount OrtSessionGetOutputName OrtSessionGetOutputTypeInfo -OrtSessionOptionsAppendExecutionProvider +OrtSessionOptionsAppendExecutionProvider_CPU OrtSetDims OrtSetSessionLogId OrtSetSessionLogVerbosityLevel diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index fb419ac922..3dbf8ea83d 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -4,51 +4,35 @@ #include "core/providers/cuda/cuda_provider_factory.h" #include #include "cuda_execution_provider.h" +#include "core/session/abi_session_options_impl.h" using namespace onnxruntime; -namespace { -struct CUDAProviderFactory { - const OrtProviderFactoryInterface* const cls; - std::atomic_int ref_count; - int device_id; - CUDAProviderFactory(); +namespace onnxruntime { + +struct CUDAProviderFactory : IExecutionProviderFactory { + CUDAProviderFactory(int device_id) : device_id_(device_id) {} + ~CUDAProviderFactory() override {} + + std::unique_ptr CreateProvider() override; + + private: + int device_id_; }; -OrtStatus* ORT_API_CALL CreateCuda(void* this_, OrtProvider** out) { +std::unique_ptr CUDAProviderFactory::CreateProvider() { CUDAExecutionProviderInfo info; - CUDAProviderFactory* this_ptr = (CUDAProviderFactory*)this_; - info.device_id = this_ptr->device_id; - CUDAExecutionProvider* ret = new CUDAExecutionProvider(info); - *out = (OrtProvider*)ret; - return nullptr; -} - -uint32_t ORT_API_CALL ReleaseCuda(void* this_) { - CUDAProviderFactory* this_ptr = (CUDAProviderFactory*)this_; - if (--this_ptr->ref_count == 0) - delete this_ptr; - return 0; -} - -uint32_t ORT_API_CALL AddRefCuda(void* this_) { - CUDAProviderFactory* this_ptr = (CUDAProviderFactory*)this_; - ++this_ptr->ref_count; - return 0; -} - -constexpr OrtProviderFactoryInterface cuda_cls = { - AddRefCuda, - ReleaseCuda, - CreateCuda, -}; - -CUDAProviderFactory::CUDAProviderFactory() : cls(&cuda_cls), ref_count(1), device_id(0) {} -} // namespace - -ORT_API_STATUS_IMPL(OrtCreateCUDAExecutionProviderFactory, int device_id, _Out_ OrtProviderFactoryInterface*** out) { - CUDAProviderFactory* ret = new CUDAProviderFactory(); - ret->device_id = device_id; - *out = (OrtProviderFactoryInterface**)ret; + info.device_id = device_id_; + return std::make_unique(info); +} + +std::shared_ptr CreateExecutionProviderFactory_CUDA(int device_id) { + return std::make_shared(device_id); +} + +} // namespace onnxruntime + +ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, int device_id) { + options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_CUDA(device_id)); return nullptr; } diff --git a/onnxruntime/core/providers/cuda/symbols.txt b/onnxruntime/core/providers/cuda/symbols.txt index 30d625edc0..157b3c3ae5 100644 --- a/onnxruntime/core/providers/cuda/symbols.txt +++ b/onnxruntime/core/providers/cuda/symbols.txt @@ -1 +1 @@ -OrtCreateCUDAExecutionProviderFactory +OrtSessionOptionsAppendExecutionProvider_CUDA \ No newline at end of file diff --git a/onnxruntime/core/providers/mkldnn/mkldnn_provider_factory.cc b/onnxruntime/core/providers/mkldnn/mkldnn_provider_factory.cc index f9a7b0f063..c94060d5b2 100644 --- a/onnxruntime/core/providers/mkldnn/mkldnn_provider_factory.cc +++ b/onnxruntime/core/providers/mkldnn/mkldnn_provider_factory.cc @@ -4,51 +4,34 @@ #include "core/providers/mkldnn/mkldnn_provider_factory.h" #include #include "mkldnn_execution_provider.h" +#include "core/session/abi_session_options_impl.h" using namespace onnxruntime; -namespace { -struct MkldnnProviderFactory { - const OrtProviderFactoryInterface* const cls; - std::atomic_int ref_count; - bool create_arena; - MkldnnProviderFactory(); +namespace onnxruntime { +struct MkldnnProviderFactory : IExecutionProviderFactory { + MkldnnProviderFactory(bool create_arena) : create_arena_(create_arena) {} + ~MkldnnProviderFactory() override {} + + std::unique_ptr CreateProvider() override; + + private: + bool create_arena_; }; -OrtStatus* ORT_API_CALL CreateMkldnn(void* this_, OrtProvider** out) { +std::unique_ptr MkldnnProviderFactory::CreateProvider() { MKLDNNExecutionProviderInfo info; - MkldnnProviderFactory* this_ptr = (MkldnnProviderFactory*)this_; - info.create_arena = this_ptr->create_arena; - MKLDNNExecutionProvider* ret = new MKLDNNExecutionProvider(info); - *out = (OrtProvider*)ret; - return nullptr; -} - -uint32_t ORT_API_CALL ReleaseMkldnn(void* this_) { - MkldnnProviderFactory* this_ptr = (MkldnnProviderFactory*)this_; - if (--this_ptr->ref_count == 0) - delete this_ptr; - return 0; -} - -uint32_t ORT_API_CALL AddRefMkldnn(void* this_) { - MkldnnProviderFactory* this_ptr = (MkldnnProviderFactory*)this_; - ++this_ptr->ref_count; - return 0; -} - -constexpr OrtProviderFactoryInterface mkl_cls = { - {AddRefMkldnn, - ReleaseMkldnn}, - CreateMkldnn, -}; - -MkldnnProviderFactory::MkldnnProviderFactory() : cls(&mkl_cls), ref_count(1), create_arena(true) {} -} // namespace - -ORT_API_STATUS_IMPL(OrtCreateMkldnnExecutionProviderFactory, int use_arena, _Out_ OrtProviderFactoryInterface*** out) { - MkldnnProviderFactory* ret = new MkldnnProviderFactory(); - ret->create_arena = (use_arena != 0); - *out = (OrtProviderFactoryInterface**)ret; + info.create_arena = create_arena_; + return std::make_unique(info); +} + +std::shared_ptr CreateExecutionProviderFactory_Mkldnn(int device_id) { + return std::make_shared(device_id); +} + +} // namespace onnxruntime + +ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Mkldnn, _In_ OrtSessionOptions* options, int use_arena) { + options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_Mkldnn(use_arena)); return nullptr; } diff --git a/onnxruntime/core/providers/mkldnn/symbols.txt b/onnxruntime/core/providers/mkldnn/symbols.txt index 4cc61114a7..a4ded157c4 100644 --- a/onnxruntime/core/providers/mkldnn/symbols.txt +++ b/onnxruntime/core/providers/mkldnn/symbols.txt @@ -1 +1 @@ -OrtCreateMkldnnExecutionProviderFactory +OrtSessionOptionsAppendExecutionProvider_Mkldnn \ No newline at end of file diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index 0b30c7fcb7..43805d9615 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -8,10 +8,6 @@ #include "abi_session_options_impl.h" OrtSessionOptions::~OrtSessionOptions() { - assert(ref_count == 0); - for (OrtProviderFactoryInterface** p : provider_factories) { - OrtReleaseObject(p); - } } OrtSessionOptions& OrtSessionOptions::operator=(const OrtSessionOptions&) { @@ -19,15 +15,17 @@ OrtSessionOptions& OrtSessionOptions::operator=(const OrtSessionOptions&) { } OrtSessionOptions::OrtSessionOptions(const OrtSessionOptions& other) : value(other.value), custom_op_paths(other.custom_op_paths), provider_factories(other.provider_factories) { - for (OrtProviderFactoryInterface** p : other.provider_factories) { - OrtAddRefToObject(p); - } } + ORT_API(OrtSessionOptions*, OrtCreateSessionOptions) { std::unique_ptr options = std::make_unique(); return options.release(); } +ORT_API(void, OrtReleaseSessionOptions, OrtSessionOptions* ptr) { + delete ptr; +} + ORT_API(OrtSessionOptions*, OrtCloneSessionOptions, OrtSessionOptions* input) { try { return new OrtSessionOptions(*input); @@ -36,11 +34,6 @@ ORT_API(OrtSessionOptions*, OrtCloneSessionOptions, OrtSessionOptions* input) { } } -ORT_API(void, OrtSessionOptionsAppendExecutionProvider, _In_ OrtSessionOptions* options, _In_ OrtProviderFactoryInterface** f) { - OrtAddRefToObject(f); - options->provider_factories.push_back(f); -} - ORT_API(void, OrtEnableSequentialExecution, _In_ OrtSessionOptions* options) { options->value.enable_sequential_execution = true; } diff --git a/onnxruntime/core/session/abi_session_options_impl.h b/onnxruntime/core/session/abi_session_options_impl.h index 3d1aa52f01..1af9e5f268 100644 --- a/onnxruntime/core/session/abi_session_options_impl.h +++ b/onnxruntime/core/session/abi_session_options_impl.h @@ -6,14 +6,14 @@ #include #include #include -#include "core/framework/onnx_object_cxx.h" #include "core/session/inference_session.h" #include "core/session/onnxruntime_c_api.h" +#include "core/providers/providers.h" -struct OrtSessionOptions : public onnxruntime::ObjectBase { +struct OrtSessionOptions { onnxruntime::SessionOptions value; std::vector custom_op_paths; - std::vector provider_factories; + std::vector> provider_factories; OrtSessionOptions() = default; ~OrtSessionOptions(); OrtSessionOptions(const OrtSessionOptions& other); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 05911b00d1..0544c004a0 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -20,7 +20,6 @@ #include "core/framework/environment.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/onnxruntime_typeinfo.h" -#include "core/framework/onnx_object_cxx.h" #include "core/session/inference_session.h" #include "abi_session_options_impl.h" @@ -49,7 +48,6 @@ struct OrtEnv { public: Environment* value; LoggingManager* loggingManager; - friend class onnxruntime::ObjectBase; OrtEnv(Environment* value1, LoggingManager* loggingManager1) : value(value1), loggingManager(loggingManager1) { } @@ -367,13 +365,10 @@ static OrtStatus* CreateSessionImpl(_In_ OrtEnv* env, _In_ T model_path, return ToOrtStatus(status); } if (options != nullptr) - for (OrtProviderFactoryInterface** p : options->provider_factories) { - OrtProvider* provider; - OrtStatus* error_code = (*p)->CreateProvider(p, &provider); - if (error_code) - return error_code; - sess->RegisterExecutionProvider(std::unique_ptr( - reinterpret_cast(provider))); + for (auto& factory : options->provider_factories) { + auto provider = factory->CreateProvider(); + if (provider) + sess->RegisterExecutionProvider(std::move(provider)); } status = sess->Load(model_path); if (!status.IsOK()) @@ -638,5 +633,6 @@ ORT_API_STATUS_IMPL(OrtSessionGetOutputName, _In_ const OrtSession* sess, size_t DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Env, OrtEnv) DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Value, MLValue) +DEFINE_RELEASE_ORT_OBJECT_FUNCTION(RunOptions, OrtRunOptions) DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Session, ::onnxruntime::InferenceSession) DEFINE_RELEASE_ORT_OBJECT_FUNCTION_FOR_ARRAY(Status, char) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 1db3d36f2d..91f5716c0b 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -42,6 +42,7 @@ #define BACKEND_DEVICE BACKEND_PROC BACKEND_MKLDNN BACKEND_MKLML BACKEND_OPENBLAS #include "core/session/onnxruntime_cxx_api.h" +#include "core/providers/providers.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "core/providers/cpu/cpu_provider_factory.h" @@ -54,6 +55,15 @@ #ifdef USE_NUPHAR #include "core/providers/nuphar/nuphar_provider_factory.h" #endif + +namespace onnxruntime { +std::shared_ptr CreateExecutionProviderFactory_CPU(int use_arena); +std::shared_ptr CreateExecutionProviderFactory_CUDA(int device_id); +std::shared_ptr CreateExecutionProviderFactory_Mkldnn(int use_arena); +std::shared_ptr CreateExecutionProviderFactory_Nuphar(int device_id, const char*); +std::shared_ptr CreateExecutionProviderFactory_BrainSlice(int id, bool f, const char*, const char*, const char*); +} // namespace onnxruntime + #if defined(_MSC_VER) #pragma warning(disable : 4267 4996 4503 4003) #endif // _MSC_VER @@ -172,45 +182,32 @@ class SessionObjectInitializer { } }; -inline void RegisterExecutionProvider(InferenceSession* sess, OrtProviderFactoryInterface** f) { - OrtProvider* p; - (*f)->CreateProvider(f, &p); - std::unique_ptr q((onnxruntime::IExecutionProvider*)p); - auto status = sess->RegisterExecutionProvider(std::move(q)); +inline void RegisterExecutionProvider(InferenceSession* sess, onnxruntime::IExecutionProviderFactory& f) { + auto p = f.CreateProvider(); + auto status = sess->RegisterExecutionProvider(std::move(p)); if (!status.IsOK()) { throw std::runtime_error(status.ErrorMessage().c_str()); } } -#define FACTORY_PTR_HOLDER \ - std::unique_ptr ptr_holder_(f, OrtReleaseObject); - void InitializeSession(InferenceSession* sess) { onnxruntime::common::Status status; + #ifdef USE_CUDA { - OrtProviderFactoryInterface** f; - ORT_THROW_ON_ERROR(OrtCreateCUDAExecutionProviderFactory(0, &f)); - RegisterExecutionProvider(sess, f); - FACTORY_PTR_HOLDER; + RegisterExecutionProvider(sess, *onnxruntime::CreateExecutionProviderFactory_CUDA(0)); } #endif #ifdef USE_MKLDNN { const bool enable_cpu_mem_arena = true; - OrtProviderFactoryInterface** f; - ORT_THROW_ON_ERROR(OrtCreateMkldnnExecutionProviderFactory(enable_cpu_mem_arena ? 1 : 0, &f)); - RegisterExecutionProvider(sess, f); - FACTORY_PTR_HOLDER; + RegisterExecutionProvider(sess, *onnxruntime::CreateExecutionProviderFactory_Mkldnn(enable_cpu_mem_arena ? 1 : 0)); } #endif #if 0 //USE_NUPHAR { - OrtProviderFactoryInterface** f; - ORT_THROW_ON_ERROR(OrtCreateNupharExecutionProviderFactory(0, "", &f)); - RegisterExecutionProvider(sess, f); - FACTORY_PTR_HOLDER; + RegisterExecutionProvider(sess, *onnxruntime::CreateExecutionProviderFactory_Nuphar(0, "")); } #endif diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 03ef85115a..bf1c2ec71b 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -203,10 +203,7 @@ int real_main(int argc, char* argv[]) { sf.DisableSequentialExecution(); if (enable_cuda) { #ifdef USE_CUDA - OrtProviderFactoryInterface** f; - ORT_THROW_ON_ERROR(OrtCreateCUDAExecutionProviderFactory(0, &f)); - sf.AppendExecutionProvider(f); - OrtReleaseObject(f); + ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_CUDA(sf, 0)); #else fprintf(stderr, "CUDA is not supported in this build"); return -1; @@ -214,10 +211,7 @@ int real_main(int argc, char* argv[]) { } if (enable_nuphar) { #ifdef USE_NUPHAR - OrtProviderFactoryInterface** f; - ORT_THROW_ON_ERROR(OrtCreateNupharExecutionProviderFactory(0, "", &f)); - sf.AppendExecutionProvider(f); - OrtReleaseObject(f); + ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Nuphar(sf, 0, "")); #else fprintf(stderr, "Nuphar is not supported in this build"); return -1; @@ -225,10 +219,7 @@ int real_main(int argc, char* argv[]) { } if (enable_mkl) { #ifdef USE_MKLDNN - OrtProviderFactoryInterface** f; - ORT_THROW_ON_ERROR(OrtCreateMkldnnExecutionProviderFactory(enable_cpu_mem_arena ? 1 : 0, &f)); - sf.AppendExecutionProvider(f); - OrtReleaseObject(f); + ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Mkldnn(sf, enable_cpu_mem_arena ? 1 : 0)); #else fprintf(stderr, "MKL-DNN is not supported in this build"); return -1; diff --git a/onnxruntime/test/perftest/testenv.cc b/onnxruntime/test/perftest/testenv.cc index fbfbb04e12..68a843c385 100644 --- a/onnxruntime/test/perftest/testenv.cc +++ b/onnxruntime/test/perftest/testenv.cc @@ -12,21 +12,17 @@ #include #endif #include "providers.h" +#include "default_providers.h" using namespace std::experimental::filesystem::v1; using onnxruntime::Status; -inline void RegisterExecutionProvider(onnxruntime::InferenceSession* sess, OrtProviderFactoryInterface** f) { - OrtProvider* p; - (*f)->CreateProvider(f, &p); - std::unique_ptr q((onnxruntime::IExecutionProvider*)p); - auto status = sess->RegisterExecutionProvider(std::move(q)); +inline void RegisterExecutionProvider(onnxruntime::InferenceSession* sess, std::unique_ptr&& f) { + auto status = sess->RegisterExecutionProvider(std::move(f)); if (!status.IsOK()) { throw std::runtime_error(status.ErrorMessage().c_str()); } } -#define FACTORY_PTR_HOLDER \ - std::unique_ptr ptr_holder_(f, OrtReleaseObject); Status SessionFactory::create(std::shared_ptr<::onnxruntime::InferenceSession>& sess, const path& model_url, const std::string& logid) const { ::onnxruntime::SessionOptions so; @@ -41,37 +37,25 @@ Status SessionFactory::create(std::shared_ptr<::onnxruntime::InferenceSession>& for (const std::string& provider : providers_) { if (provider == onnxruntime::kCudaExecutionProvider) { #ifdef USE_CUDA - OrtProviderFactoryInterface** f; - ORT_THROW_ON_ERROR(OrtCreateCUDAExecutionProviderFactory(0, &f)); - FACTORY_PTR_HOLDER; - RegisterExecutionProvider(sess.get(), f); + RegisterExecutionProvider(sess.get(), onnxruntime::test::DefaultCudaExecutionProvider()); #else ORT_THROW("CUDA is not supported in this build"); #endif } else if (provider == onnxruntime::kMklDnnExecutionProvider) { #ifdef USE_MKLDNN - OrtProviderFactoryInterface** f; - ORT_THROW_ON_ERROR(OrtCreateMkldnnExecutionProviderFactory(enable_cpu_mem_arena_ ? 1 : 0, &f)); - FACTORY_PTR_HOLDER; - RegisterExecutionProvider(sess.get(), f); + RegisterExecutionProvider(sess.get(), onnxruntime::test::DefaultMkldnnExecutionProvider(enable_cpu_mem_arena_ ? 1 : 0)); #else ORT_THROW("CUDA is not supported in this build"); #endif } else if (provider == onnxruntime::kNupharExecutionProvider) { #ifdef USE_NUPHAR - OrtProviderFactoryInterface** f; - ORT_THROW_ON_ERROR(OrtCreateNupharExecutionProviderFactory(0, "", &f)); - RegisterExecutionProvider(sess.get(), f); - FACTORY_PTR_HOLDER; + RegisterExecutionProvider(sess.get(), onnxruntime::test::DefaultNupharExecutionProvider()); #else ORT_THROW("CUDA is not supported in this build"); #endif } else if (provider == onnxruntime::kBrainSliceExecutionProvider) { #if USE_BRAINSLICE - OrtProviderFactoryInterface** f; - ORT_THROW_ON_ERROR(OrtCreateBrainSliceExecutionProviderFactory(0, true, "testdata/firmwares/onnx_rnns/instructions.bin", "testdata/firmwares/onnx_rnns/data.bin", "testdata/firmwares/onnx_rnns/schema.bin", &f)); - RegisterExecutionProvider(sess.get(), f); - FACTORY_PTR_HOLDER; + RegisterExecutionProvider(sess.get(), onnxruntime::test::DefaultBrainsliceExecutionProvider()); #else ORT_THROW("This executable was not built with BrainSlice"); #endif @@ -85,7 +69,7 @@ Status SessionFactory::create(std::shared_ptr<::onnxruntime::InferenceSession>& ORT_THROW("TensorRT is not supported in this build"); #endif } - //TODO: add more + // TODO: add more } status = sess->Load(model_url.string()); diff --git a/onnxruntime/test/shared_lib/fns_candy_style_transfer.c b/onnxruntime/test/shared_lib/fns_candy_style_transfer.c index 86a37bf310..f266464bae 100644 --- a/onnxruntime/test/shared_lib/fns_candy_style_transfer.c +++ b/onnxruntime/test/shared_lib/fns_candy_style_transfer.c @@ -182,10 +182,7 @@ void verify_input_output_count(OrtSession* session) { #ifdef USE_CUDA void enable_cuda(OrtSessionOptions* session_option) { - OrtProviderFactoryInterface** factory; - ORT_ABORT_ON_ERROR(OrtCreateCUDAExecutionProviderFactory(0, &factory)); - OrtSessionOptionsAppendExecutionProvider(session_option, factory); - OrtReleaseObject(factory); + ORT_ABORT_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_CUDA(session_option, 0)); } #endif @@ -207,9 +204,9 @@ int main(int argc, char* argv[]) { ORT_ABORT_ON_ERROR(OrtCreateSession(env, model_path, session_option, &session)); verify_input_output_count(session); int ret = run_inference(session, input_file, output_file); - OrtReleaseObject(session_option); + OrtReleaseSessionOptions(session_option); OrtReleaseSession(session); - OrtReleaseObject(env); + OrtReleaseEnv(env); if (ret != 0) { fprintf(stderr, "fail\n"); } diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index e623d11de9..e74f980420 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -64,30 +64,21 @@ void TestInference(OrtEnv* env, T model_uri, if (provider_type == 1) { #ifdef USE_CUDA - OrtProviderFactoryInterface** f; - ORT_THROW_ON_ERROR(OrtCreateCUDAExecutionProviderFactory(0, &f)); - sf.AppendExecutionProvider(f); - OrtReleaseObject(f); + ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_CUDA(sf, 0)); std::cout << "Running simple inference with cuda provider" << std::endl; #else return; #endif } else if (provider_type == 2) { #ifdef USE_MKLDNN - OrtProviderFactoryInterface** f; - ORT_THROW_ON_ERROR(OrtCreateMkldnnExecutionProviderFactory(1, &f)); - sf.AppendExecutionProvider(f); - OrtReleaseObject(f); + ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Mkldnn(sf, 1)); std::cout << "Running simple inference with mkldnn provider" << std::endl; #else return; #endif } else if (provider_type == 3) { #ifdef USE_NUPHAR - OrtProviderFactoryInterface** f; - ORT_THROW_ON_ERROR(OrtCreateNupharExecutionProviderFactory(0, "", &f)); - sf.AppendExecutionProvider(f); - OrtReleaseObject(f); + ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Nuphar(sf, 0, "")); std::cout << "Running simple inference with nuphar provider" << std::endl; #else return; @@ -196,7 +187,7 @@ TEST_F(CApiTest, create_tensor_with_data) { const struct OrtTensorTypeAndShapeInfo* tensor_info = OrtCastTypeInfoToTensorInfo(type_info); ASSERT_NE(tensor_info, nullptr); ASSERT_EQ(1, OrtGetNumOfDimensions(tensor_info)); - OrtReleaseObject(type_info); + OrtReleaseTypeInfo(type_info); } int main(int argc, char** argv) { diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 9d08cd8ed4..89d5866d3d 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -4,28 +4,25 @@ #include "default_providers.h" #include "providers.h" #include "core/session/onnxruntime_cxx_api.h" -#define FACTORY_PTR_HOLDER \ - std::unique_ptr ptr_holder_(f); +#include "core/providers/providers.h" namespace onnxruntime { + +std::shared_ptr CreateExecutionProviderFactory_CPU(int use_arena); +std::shared_ptr CreateExecutionProviderFactory_CUDA(int device_id); +std::shared_ptr CreateExecutionProviderFactory_Mkldnn(int use_arena); +std::shared_ptr CreateExecutionProviderFactory_Nuphar(int device_id, const char*); +std::shared_ptr CreateExecutionProviderFactory_BrainSlice(int id, bool f, const char*, const char*, const char*); + namespace test { + std::unique_ptr DefaultCpuExecutionProvider(bool enable_arena) { - OrtProviderFactoryInterface** f; - ORT_THROW_ON_ERROR(OrtCreateCpuExecutionProviderFactory(enable_arena ? 1 : 0, &f)); - FACTORY_PTR_HOLDER; - OrtProvider* out; - ORT_THROW_ON_ERROR((*f)->CreateProvider(f, &out)); - return std::unique_ptr((IExecutionProvider*)out); + return CreateExecutionProviderFactory_CPU(enable_arena)->CreateProvider(); } std::unique_ptr DefaultCudaExecutionProvider() { #ifdef USE_CUDA - OrtProviderFactoryInterface** f; - ORT_THROW_ON_ERROR(OrtCreateCUDAExecutionProviderFactory(0, &f)); - FACTORY_PTR_HOLDER; - OrtProvider* out; - ORT_THROW_ON_ERROR((*f)->CreateProvider(f, &out)); - return std::unique_ptr((IExecutionProvider*)out); + return CreateExecutionProviderFactory_CUDA(0)->CreateProvider(); #else return nullptr; #endif @@ -33,12 +30,7 @@ std::unique_ptr DefaultCudaExecutionProvider() { std::unique_ptr DefaultMkldnnExecutionProvider(bool enable_arena) { #ifdef USE_MKLDNN - OrtProviderFactoryInterface** f; - ORT_THROW_ON_ERROR(OrtCreateMkldnnExecutionProviderFactory(enable_arena ? 1 : 0, &f)); - FACTORY_PTR_HOLDER; - OrtProvider* out; - ORT_THROW_ON_ERROR((*f)->CreateProvider(f, &out)); - return std::unique_ptr((IExecutionProvider*)out); + return CreateExecutionProviderFactory_Mkldnn(enable_arena ? 1 : 0)->CreateProvider(); #else ORT_UNUSED_PARAMETER(enable_arena); return nullptr; @@ -47,12 +39,7 @@ std::unique_ptr DefaultMkldnnExecutionProvider(bool enable_a std::unique_ptr DefaultNupharExecutionProvider() { #ifdef USE_NUPHAR - OrtProviderFactoryInterface** f; - ORT_THROW_ON_ERROR(OrtCreateNupharExecutionProviderFactory(0, "", &f)); - FACTORY_PTR_HOLDER; - OrtProvider* out; - ORT_THROW_ON_ERROR((*f)->CreateProvider(f, &out)); - return std::unique_ptr((IExecutionProvider*)out); + return CreateExecutionProviderFactory_Nuphar(0, "")->CreateProvider(); #else return nullptr; #endif @@ -60,12 +47,7 @@ std::unique_ptr DefaultNupharExecutionProvider() { std::unique_ptr DefaultBrainSliceExecutionProvider() { #ifdef USE_BRAINSLICE - OrtProviderFactoryInterface** f; - ORT_THROW_ON_ERROR(OrtCreateBrainSliceExecutionProviderFactory(0, true, "testdata/firmwares/onnx_rnns/instructions.bin", "testdata/firmwares/onnx_rnns/data.bin", "testdata/firmwares/onnx_rnns/schema.bin", &f)); - FACTORY_PTR_HOLDER; - OrtProvider* out; - ORT_THROW_ON_ERROR((*f)->CreateProvider(f, &out)); - return std::unique_ptr((IExecutionProvider*)out); + return CreateExecutionProviderFactory_BrainSlice(0, true, "testdata/firmwares/onnx_rnns/instructions.bin", "testdata/firmwares/onnx_rnns/data.bin", "testdata/firmwares/onnx_rnns/schema.bin", &f)); #else return nullptr; #endif