diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs index 4b9562c30f..d3481d3287 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs @@ -275,6 +275,7 @@ namespace Microsoft.ML.OnnxRuntime OrtAllocatorFree = (DOrtAllocatorFree)Marshal.GetDelegateForFunctionPointer(api_.AllocatorFree, typeof(DOrtAllocatorFree)); OrtAllocatorGetInfo = (DOrtAllocatorGetInfo)Marshal.GetDelegateForFunctionPointer(api_.AllocatorGetInfo, typeof(DOrtAllocatorGetInfo)); OrtAddFreeDimensionOverride = (DOrtAddFreeDimensionOverride)Marshal.GetDelegateForFunctionPointer(api_.AddFreeDimensionOverride, typeof(DOrtAddFreeDimensionOverride)); + OrtAddFreeDimensionOverrideByName = (DOrtAddFreeDimensionOverrideByName)Marshal.GetDelegateForFunctionPointer(api_.AddFreeDimensionOverrideByName, typeof(DOrtAddFreeDimensionOverrideByName)); OrtCreateIoBinding = (DOrtCreateIoBinding)Marshal.GetDelegateForFunctionPointer(api_.CreateIoBinding, typeof(DOrtCreateIoBinding)); OrtReleaseIoBinding = (DOrtReleaseIoBinding)Marshal.GetDelegateForFunctionPointer(api_.ReleaseIoBinding, typeof(DOrtReleaseIoBinding)); @@ -321,6 +322,8 @@ namespace Microsoft.ML.OnnxRuntime OrtModelMetadataLookupCustomMetadataMap = (DOrtModelMetadataLookupCustomMetadataMap)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataLookupCustomMetadataMap, typeof(DOrtModelMetadataLookupCustomMetadataMap)); OrtReleaseModelMetadata = (DOrtReleaseModelMetadata)Marshal.GetDelegateForFunctionPointer(api_.ReleaseModelMetadata, typeof(DOrtReleaseModelMetadata)); + OrtGetAvailableProviders = (DOrtGetAvailableProviders)Marshal.GetDelegateForFunctionPointer(api_.GetAvailableProviders, typeof(DOrtGetAvailableProviders)); + OrtReleaseAvailableProviders = (DOrtReleaseAvailableProviders)Marshal.GetDelegateForFunctionPointer(api_.ReleaseAvailableProviders, typeof(DOrtReleaseAvailableProviders)); } [DllImport(nativeLib, CharSet = charSet)] @@ -522,7 +525,15 @@ namespace Microsoft.ML.OnnxRuntime public delegate IntPtr /*(OrtStatus*)*/ DOrtSetSessionGraphOptimizationLevel(IntPtr /* OrtSessionOptions* */ options, GraphOptimizationLevel graphOptimizationLevel); public static DOrtSetSessionGraphOptimizationLevel OrtSetSessionGraphOptimizationLevel; - public delegate IntPtr /*(OrtStatus*)*/ DOrtAddSessionConfigEntry(IntPtr /* OrtSessionOptions* */ options, string configKey, string configValue); + /// + /// Add session config entry + /// + /// Native SessionOptions instance + /// Config key + /// Config value + public delegate IntPtr /*(OrtStatus*)*/ DOrtAddSessionConfigEntry(IntPtr /* OrtSessionOptions* */ options, + IntPtr /* const char* */configKey, + IntPtr /* const char* */ configValue); public static DOrtAddSessionConfigEntry OrtAddSessionConfigEntry; ///** @@ -564,13 +575,49 @@ namespace Microsoft.ML.OnnxRuntime //[DllImport(nativeLib, CharSet = charSet)] //public static extern void OrtAddCustomOp(IntPtr /*(OrtSessionOptions*)*/ options, string custom_op_path); - public delegate IntPtr /*(OrtStatus*)*/DOrtAddFreeDimensionOverride(IntPtr /*(OrtSessionOptions*) */ options, string /*(const char*)*/ symbolic_dim, int dim_override); + /// + /// Free Dimension override (by denotation) + /// + /// Native SessionOptions instance + /// Dimension denotation + /// Dimension value + public delegate IntPtr /*(OrtStatus*)*/DOrtAddFreeDimensionOverride(IntPtr /*(OrtSessionOptions*)*/ options, + IntPtr /*(const char*)*/ dimDenotation, + long dimValue); public static DOrtAddFreeDimensionOverride OrtAddFreeDimensionOverride; - public delegate IntPtr /*(OrtStatus*)*/DOrtRegisterCustomOpsLibrary(IntPtr /*(OrtSessionOptions*) */ options, string /*(const char*)*/ library_path, out IntPtr /* (void**) */ library_handle); + /// + /// Free Dimension override (by name) + /// + /// Native SessionOptions instance + /// Dimension name + /// Dimension value + public delegate IntPtr /*(OrtStatus*)*/DOrtAddFreeDimensionOverrideByName(IntPtr /*(OrtSessionOptions*)*/ options, + IntPtr /*(const char*)*/ dimName, + long dimValue); + public static DOrtAddFreeDimensionOverrideByName OrtAddFreeDimensionOverrideByName; + + + /// + /// Register custom op library + /// + /// Native SessionOptions instance + /// Library path + /// (out) Native library handle + public delegate IntPtr /*(OrtStatus*)*/DOrtRegisterCustomOpsLibrary(IntPtr /*(OrtSessionOptions*) */ options, + IntPtr /*(const char*)*/ libraryPath, + out IntPtr /*(void**)*/ libraryHandle); public static DOrtRegisterCustomOpsLibrary OrtRegisterCustomOpsLibrary; - public delegate IntPtr /*(OrtStatus*)*/DOrtAddInitializer(IntPtr /*(OrtSessionOptions*) */ options, string /*(const char*)*/ name, IntPtr /* OrtValue* */ ort_value); + /// + /// Add initializer that is shared across Sessions using this SessionOptions (by denotation) + /// + /// Native SessionOptions instance + /// Name of the initializer + /// Native OrtValue instnce + public delegate IntPtr /*(OrtStatus*)*/DOrtAddInitializer(IntPtr /*(OrtSessionOptions*)*/ options, + IntPtr /*(const char*)*/ name, + IntPtr /*(OrtValue*)*/ ortValue); public static DOrtAddInitializer OrtAddInitializer; #endregion @@ -1048,6 +1095,27 @@ namespace Microsoft.ML.OnnxRuntime #endregion + #region Misc API + + /// + /// Queries all the execution providers supported in the native onnxruntime shared library + /// + /// (output) all execution providers (strings) supported in the native onnxruntime shared library + /// (output) number of execution providers (strings) + + public delegate IntPtr /* (OrtStatus*) */ DOrtGetAvailableProviders(out IntPtr /* (char***) */ providers, out int /* (int*) */ numProviders); + public static DOrtGetAvailableProviders OrtGetAvailableProviders; + + /// + /// Releases all execution provider strings allocated and returned by OrtGetAvailableProviders + /// + /// all execution providers (strings) returned by OrtGetAvailableProviders + /// number of execution providers (strings) + + public delegate IntPtr /* (OrtStatus*) */ DOrtReleaseAvailableProviders(IntPtr /* (char**) */ providers, int /* (int) */ numProviders); + public static DOrtReleaseAvailableProviders OrtReleaseAvailableProviders; + #endregion + public static byte[] GetPlatformSerializedString(string str) { if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OnnxRuntime.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OnnxRuntime.cs index 209484d8ab..f2c58a7776 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OnnxRuntime.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OnnxRuntime.cs @@ -101,6 +101,36 @@ namespace Microsoft.ML.OnnxRuntime NativeApiStatus.VerifySuccess(NativeMethods.OrtDisableTelemetryEvents(Handle)); } + /// + /// Queries all the execution providers supported in the native onnxruntime shared library + /// + public string[] GetAvailableProviders() + { + IntPtr availableProvidersHandle = IntPtr.Zero; + int numProviders; + + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetAvailableProviders(out availableProvidersHandle, out numProviders)); + + var availableProviders = new string[numProviders]; + + try + { + for(int i=0; i /// Tensor object - /// For all tensor types but string tensors we endevour to use managed memory + /// For all tensor types but string tensors we endeavor to use managed memory /// to avoid additional allocation and copy. This out parameter represents a chunk of pinned memory /// /// discovered tensor element type diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs index f2aead52cc..3b9cec2580 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs @@ -170,10 +170,41 @@ namespace Microsoft.ML.OnnxRuntime #endregion //ExecutionProviderAppends #region Public Methods + + /// + /// (Deprecated) Loads a DLL named 'libraryPath' and looks for this entry point: + /// OrtStatus* RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api); + /// It then passes in the provided session options to this function along with the api base. + /// Deprecated in favor of RegisterCustomOpLibraryV2() because it provides users with the library handle + /// to release when all sessions relying on it are destroyed + /// + [ObsoleteAttribute("RegisterCustomOpLibrary(...) is obsolete. Use RegisterCustomOpLibraryV2(...) instead.", false)] public void RegisterCustomOpLibrary(string libraryPath) { IntPtr libraryHandle = IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtRegisterCustomOpsLibrary(handle, libraryPath, out libraryHandle)); + var libraryPathPinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(libraryPath), GCHandleType.Pinned); + using (var pinnedlibraryPath = new PinnedGCHandle(libraryPathPinned)) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtRegisterCustomOpsLibrary(handle, pinnedlibraryPath.Pointer, out libraryHandle)); + } + } + + /// + /// Loads a DLL named 'libraryPath' and looks for this entry point: + /// OrtStatus* RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api); + /// It then passes in the provided session options to this function along with the api base. + /// The handle to the loaded library is returned in 'libraryHandle'. + /// It can be unloaded by the caller after all sessions using the passed in + /// session options are destroyed, or if an error occurs and it is non null. + /// Hint: .NET Core 3.1 has a 'NativeLibrary' class that can be used to free the library handle + /// + public void RegisterCustomOpLibraryV2(string libraryPath, out IntPtr libraryHandle) + { + var libraryPathPinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(libraryPath), GCHandleType.Pinned); + using (var pinnedlibraryPath = new PinnedGCHandle(libraryPathPinned)) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtRegisterCustomOpsLibrary(handle, pinnedlibraryPath.Pointer, out libraryHandle)); + } } /// @@ -186,18 +217,60 @@ namespace Microsoft.ML.OnnxRuntime /// managed by the user (created using the CreateTensorWithDataAsOrtValue API) and it must outlive the session object /// to which it is added. /// - public void AddInitializer(string name, OrtValue ort_value) + public void AddInitializer(string name, OrtValue ortValue) { - NativeApiStatus.VerifySuccess(NativeMethods.OrtAddInitializer(handle, name, ort_value.Handle)); + var utf8NamePinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(name), GCHandleType.Pinned); + using (var pinnedName = new PinnedGCHandle(utf8NamePinned)) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtAddInitializer(handle, pinnedName.Pointer, ortValue.Handle)); + } } + /// + /// Set a single session configuration entry as a pair of strings + /// If a configuration with same key exists, this will overwrite the configuration with the given configValue + /// public void AddSessionConfigEntry(string configKey, string configValue) { - NativeApiStatus.VerifySuccess(NativeMethods.OrtAddSessionConfigEntry(handle, configKey, configValue)); - } - #endregion + var utf8NameConfigKeyPinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(configKey), GCHandleType.Pinned); + var utf8NameConfigValuePinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(configValue), GCHandleType.Pinned); - internal IntPtr Handle + using (var pinnedConfigKeyName = new PinnedGCHandle(utf8NameConfigKeyPinned)) + using (var pinnedConfigValueName = new PinnedGCHandle(utf8NameConfigValuePinned)) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtAddSessionConfigEntry(handle, + pinnedConfigKeyName.Pointer, pinnedConfigValueName.Pointer)); + } + } + + /// + /// Override symbolic dimensions (by specific denotation strings) with actual values if known at session initialization time to enable + /// optimizations that can take advantage of fixed values (such as memory planning, etc) + /// + public void AddFreeDimensionOverride(string dimDenotation, long dimValue) + { + var utf8DimDenotationPinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(dimDenotation), GCHandleType.Pinned); + using (var pinnedDimDenotation = new PinnedGCHandle(utf8DimDenotationPinned)) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtAddFreeDimensionOverride(handle, pinnedDimDenotation.Pointer, dimValue)); + } + } + + /// + /// Override symbolic dimensions (by specific name strings) with actual values if known at session initialization time to enable + /// optimizations that can take advantage of fixed values (such as memory planning, etc) + /// + public void AddFreeDimensionOverrideByName(string dimName, long dimValue) + { + var utf8DimNamePinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(dimName), GCHandleType.Pinned); + using (var pinnedDimName = new PinnedGCHandle(utf8DimNamePinned)) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtAddFreeDimensionOverrideByName(handle, pinnedDimName.Pointer, dimValue)); + } + } + #endregion + + internal IntPtr Handle { get { diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs index c41e650f37..1dcf4a092e 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs @@ -176,6 +176,20 @@ namespace Microsoft.ML.OnnxRuntime.Tests ortEnvInstance.EnableTelemetryEvents(); } + [Fact] + public void GetAvailableProviders() + { + var ortEnvInstance = OrtEnv.Instance(); + string[] providers = ortEnvInstance.GetAvailableProviders(); + + Assert.True(providers.Length > 0); + Assert.Equal("CPUExecutionProvider", providers[0]); + +# if USE_CUDA + Assert.True(Array.Exists(providers, provider => provider == "CUDAExecutionProvider");); +#endif + } + [Fact] public void CanCreateAndDisposeSessionWithModelPath() { @@ -409,6 +423,48 @@ namespace Microsoft.ML.OnnxRuntime.Tests Assert.True(startTime1 <= startTime2 && startTime2 <= startTime3); } + [Fact] + public void SessionOptionsFreeDimensionOverrides() + { + + string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "abs_free_dimensions.onnx"); + + // By Name + using (SessionOptions options = new SessionOptions()) + { + options.AddFreeDimensionOverrideByName("Dim1", 4); + options.AddFreeDimensionOverrideByName("Dim2", 6); + + using (var session = new InferenceSession(modelPath, options)) + { + var inputMetadata = session.InputMetadata; + var dims = inputMetadata["x"].Dimensions; + Assert.Equal(3, dims.Length); + Assert.Equal(4, dims[0]); + Assert.Equal(6, dims[1]); + Assert.Equal(5, dims[2]); + } + } + + // By Denotation + using (SessionOptions options = new SessionOptions()) + { + options.AddFreeDimensionOverride("DATA_BATCH", 3); + options.AddFreeDimensionOverride("DATA_CHANNEL", 5); + + using (var session = new InferenceSession(modelPath, options)) + { + var inputMetadata = session.InputMetadata; + var dims = inputMetadata["x"].Dimensions; + Assert.Equal(3, dims.Length); + Assert.Equal(3, dims[0]); + Assert.Equal(5, dims[1]); + Assert.Equal(5, dims[2]); + } + } + + } + private void validateRunResults(IReadOnlyCollection results) { // validate the results @@ -903,6 +959,26 @@ namespace Microsoft.ML.OnnxRuntime.Tests } } + // Hint: .NET Core 3.1 has a 'NativeLibrary' class that can be used to free the library handle + private void UnloadLibrary(IntPtr libraryHandle) + { + if (libraryHandle != IntPtr.Zero) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + if(!FreeLibrary(libraryHandle)) + { + throw new Exception("Could not unload the provided shared library using its handle"); + } + } + + else + { + // TODO: Deal with non-Windows platforms for the .NET Core use-case + } + } + } + [SkipNonPackageTests] private void TestRegisterCustomOpLibrary() { @@ -926,9 +1002,11 @@ namespace Microsoft.ML.OnnxRuntime.Tests string libFullPath = Path.Combine(Directory.GetCurrentDirectory(), libName); Assert.True(File.Exists(libFullPath), $"Expected lib {libFullPath} does not exist."); + IntPtr libraryHandle = IntPtr.Zero; try { - option.RegisterCustomOpLibrary(libFullPath); + + option.RegisterCustomOpLibraryV2(libFullPath, out libraryHandle); } catch (Exception ex) { @@ -979,6 +1057,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests Assert.True(tensorOut.SequenceEqual(expectedOut)); } } + + // Safe to unload the custom op shared library now + UnloadLibrary(libraryHandle); } } @@ -1962,6 +2043,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests [DllImport("kernel32", CharSet = CharSet.Ansi)] static extern UIntPtr GetProcAddress(IntPtr hModule, string procName); + [DllImport("kernel32.dll", CharSet = CharSet.Ansi)] + private static extern bool FreeLibrary(IntPtr hModule); + [Fact] private void VerifyNativeMethodsExist() { @@ -1996,12 +2080,20 @@ namespace Microsoft.ML.OnnxRuntime.Tests ,"OrtSessionOptionsAppendExecutionProvider_Nnapi" #endif }; - - var hModule = LoadLibrary(module); - foreach (var ep in entryPointNames) + IntPtr libraryHandle = IntPtr.Zero; + try { - var x = GetProcAddress(hModule, ep); - Assert.False(x == UIntPtr.Zero, $"Entrypoint {ep} not found in module {module}"); + libraryHandle = LoadLibrary(module); + foreach (var ep in entryPointNames) + { + var x = GetProcAddress(libraryHandle, ep); + Assert.False(x == UIntPtr.Zero, $"Entrypoint {ep} not found in module {module}"); + } + } + + finally + { + UnloadLibrary(libraryHandle); } } diff --git a/csharp/testdata/abs_free_dimensions.onnx b/csharp/testdata/abs_free_dimensions.onnx new file mode 100644 index 0000000000..4c3d5ab6c4 --- /dev/null +++ b/csharp/testdata/abs_free_dimensions.onnx @@ -0,0 +1,14 @@ + backend-test:s + +xy"Abstest_absZ9 +x4 +2. +Dim1 +DATA_BATCH +Dim2 DATA_CHANNEL +b +y + +Dim1 +Dim2 +B \ No newline at end of file diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 9aec6c0fe0..1cd852078d 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -13,6 +13,7 @@ #include #include #include +#include #include #include "test_allocator.h" #include "test_fixture.h" @@ -1006,6 +1007,11 @@ TEST(CApiTest, get_available_providers_cpp) { std::vector providers = Ort::GetAvailableProviders(); ASSERT_TRUE(providers.size() > 0); ASSERT_TRUE(providers[0] == std::string("CPUExecutionProvider")); + +#ifdef USE_CUDA + // CUDA EP will exist in the list but its position may vary based on other EPs included in the build + ASSERT_TRUE(std::find(providers.begin(), providers.end(), std::string("CUDAExecutionProvider")) != providers.end()); +#endif } // This test uses the CreateAndRegisterAllocator API to register an allocator with the env,