diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
index 4b9562c30f..d3481d3287 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
@@ -275,6 +275,7 @@ namespace Microsoft.ML.OnnxRuntime
OrtAllocatorFree = (DOrtAllocatorFree)Marshal.GetDelegateForFunctionPointer(api_.AllocatorFree, typeof(DOrtAllocatorFree));
OrtAllocatorGetInfo = (DOrtAllocatorGetInfo)Marshal.GetDelegateForFunctionPointer(api_.AllocatorGetInfo, typeof(DOrtAllocatorGetInfo));
OrtAddFreeDimensionOverride = (DOrtAddFreeDimensionOverride)Marshal.GetDelegateForFunctionPointer(api_.AddFreeDimensionOverride, typeof(DOrtAddFreeDimensionOverride));
+ OrtAddFreeDimensionOverrideByName = (DOrtAddFreeDimensionOverrideByName)Marshal.GetDelegateForFunctionPointer(api_.AddFreeDimensionOverrideByName, typeof(DOrtAddFreeDimensionOverrideByName));
OrtCreateIoBinding = (DOrtCreateIoBinding)Marshal.GetDelegateForFunctionPointer(api_.CreateIoBinding, typeof(DOrtCreateIoBinding));
OrtReleaseIoBinding = (DOrtReleaseIoBinding)Marshal.GetDelegateForFunctionPointer(api_.ReleaseIoBinding, typeof(DOrtReleaseIoBinding));
@@ -321,6 +322,8 @@ namespace Microsoft.ML.OnnxRuntime
OrtModelMetadataLookupCustomMetadataMap = (DOrtModelMetadataLookupCustomMetadataMap)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataLookupCustomMetadataMap, typeof(DOrtModelMetadataLookupCustomMetadataMap));
OrtReleaseModelMetadata = (DOrtReleaseModelMetadata)Marshal.GetDelegateForFunctionPointer(api_.ReleaseModelMetadata, typeof(DOrtReleaseModelMetadata));
+ OrtGetAvailableProviders = (DOrtGetAvailableProviders)Marshal.GetDelegateForFunctionPointer(api_.GetAvailableProviders, typeof(DOrtGetAvailableProviders));
+ OrtReleaseAvailableProviders = (DOrtReleaseAvailableProviders)Marshal.GetDelegateForFunctionPointer(api_.ReleaseAvailableProviders, typeof(DOrtReleaseAvailableProviders));
}
[DllImport(nativeLib, CharSet = charSet)]
@@ -522,7 +525,15 @@ namespace Microsoft.ML.OnnxRuntime
public delegate IntPtr /*(OrtStatus*)*/ DOrtSetSessionGraphOptimizationLevel(IntPtr /* OrtSessionOptions* */ options, GraphOptimizationLevel graphOptimizationLevel);
public static DOrtSetSessionGraphOptimizationLevel OrtSetSessionGraphOptimizationLevel;
- public delegate IntPtr /*(OrtStatus*)*/ DOrtAddSessionConfigEntry(IntPtr /* OrtSessionOptions* */ options, string configKey, string configValue);
+ ///
+ /// Add session config entry
+ ///
+ /// Native SessionOptions instance
+ /// Config key
+ /// Config value
+ public delegate IntPtr /*(OrtStatus*)*/ DOrtAddSessionConfigEntry(IntPtr /* OrtSessionOptions* */ options,
+ IntPtr /* const char* */configKey,
+ IntPtr /* const char* */ configValue);
public static DOrtAddSessionConfigEntry OrtAddSessionConfigEntry;
///**
@@ -564,13 +575,49 @@ namespace Microsoft.ML.OnnxRuntime
//[DllImport(nativeLib, CharSet = charSet)]
//public static extern void OrtAddCustomOp(IntPtr /*(OrtSessionOptions*)*/ options, string custom_op_path);
- public delegate IntPtr /*(OrtStatus*)*/DOrtAddFreeDimensionOverride(IntPtr /*(OrtSessionOptions*) */ options, string /*(const char*)*/ symbolic_dim, int dim_override);
+ ///
+ /// Free Dimension override (by denotation)
+ ///
+ /// Native SessionOptions instance
+ /// Dimension denotation
+ /// Dimension value
+ public delegate IntPtr /*(OrtStatus*)*/DOrtAddFreeDimensionOverride(IntPtr /*(OrtSessionOptions*)*/ options,
+ IntPtr /*(const char*)*/ dimDenotation,
+ long dimValue);
public static DOrtAddFreeDimensionOverride OrtAddFreeDimensionOverride;
- public delegate IntPtr /*(OrtStatus*)*/DOrtRegisterCustomOpsLibrary(IntPtr /*(OrtSessionOptions*) */ options, string /*(const char*)*/ library_path, out IntPtr /* (void**) */ library_handle);
+ ///
+ /// Free Dimension override (by name)
+ ///
+ /// Native SessionOptions instance
+ /// Dimension name
+ /// Dimension value
+ public delegate IntPtr /*(OrtStatus*)*/DOrtAddFreeDimensionOverrideByName(IntPtr /*(OrtSessionOptions*)*/ options,
+ IntPtr /*(const char*)*/ dimName,
+ long dimValue);
+ public static DOrtAddFreeDimensionOverrideByName OrtAddFreeDimensionOverrideByName;
+
+
+ ///
+ /// Register custom op library
+ ///
+ /// Native SessionOptions instance
+ /// Library path
+ /// (out) Native library handle
+ public delegate IntPtr /*(OrtStatus*)*/DOrtRegisterCustomOpsLibrary(IntPtr /*(OrtSessionOptions*) */ options,
+ IntPtr /*(const char*)*/ libraryPath,
+ out IntPtr /*(void**)*/ libraryHandle);
public static DOrtRegisterCustomOpsLibrary OrtRegisterCustomOpsLibrary;
- public delegate IntPtr /*(OrtStatus*)*/DOrtAddInitializer(IntPtr /*(OrtSessionOptions*) */ options, string /*(const char*)*/ name, IntPtr /* OrtValue* */ ort_value);
+ ///
+ /// Add initializer that is shared across Sessions using this SessionOptions (by denotation)
+ ///
+ /// Native SessionOptions instance
+ /// Name of the initializer
+ /// Native OrtValue instnce
+ public delegate IntPtr /*(OrtStatus*)*/DOrtAddInitializer(IntPtr /*(OrtSessionOptions*)*/ options,
+ IntPtr /*(const char*)*/ name,
+ IntPtr /*(OrtValue*)*/ ortValue);
public static DOrtAddInitializer OrtAddInitializer;
#endregion
@@ -1048,6 +1095,27 @@ namespace Microsoft.ML.OnnxRuntime
#endregion
+ #region Misc API
+
+ ///
+ /// Queries all the execution providers supported in the native onnxruntime shared library
+ ///
+ /// (output) all execution providers (strings) supported in the native onnxruntime shared library
+ /// (output) number of execution providers (strings)
+
+ public delegate IntPtr /* (OrtStatus*) */ DOrtGetAvailableProviders(out IntPtr /* (char***) */ providers, out int /* (int*) */ numProviders);
+ public static DOrtGetAvailableProviders OrtGetAvailableProviders;
+
+ ///
+ /// Releases all execution provider strings allocated and returned by OrtGetAvailableProviders
+ ///
+ /// all execution providers (strings) returned by OrtGetAvailableProviders
+ /// number of execution providers (strings)
+
+ public delegate IntPtr /* (OrtStatus*) */ DOrtReleaseAvailableProviders(IntPtr /* (char**) */ providers, int /* (int) */ numProviders);
+ public static DOrtReleaseAvailableProviders OrtReleaseAvailableProviders;
+ #endregion
+
public static byte[] GetPlatformSerializedString(string str)
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OnnxRuntime.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OnnxRuntime.cs
index 209484d8ab..f2c58a7776 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/OnnxRuntime.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/OnnxRuntime.cs
@@ -101,6 +101,36 @@ namespace Microsoft.ML.OnnxRuntime
NativeApiStatus.VerifySuccess(NativeMethods.OrtDisableTelemetryEvents(Handle));
}
+ ///
+ /// Queries all the execution providers supported in the native onnxruntime shared library
+ ///
+ public string[] GetAvailableProviders()
+ {
+ IntPtr availableProvidersHandle = IntPtr.Zero;
+ int numProviders;
+
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtGetAvailableProviders(out availableProvidersHandle, out numProviders));
+
+ var availableProviders = new string[numProviders];
+
+ try
+ {
+ for(int i=0; i
/// Tensor object
- /// For all tensor types but string tensors we endevour to use managed memory
+ /// For all tensor types but string tensors we endeavor to use managed memory
/// to avoid additional allocation and copy. This out parameter represents a chunk of pinned memory
///
/// discovered tensor element type
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs
index f2aead52cc..3b9cec2580 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs
@@ -170,10 +170,41 @@ namespace Microsoft.ML.OnnxRuntime
#endregion //ExecutionProviderAppends
#region Public Methods
+
+ ///
+ /// (Deprecated) Loads a DLL named 'libraryPath' and looks for this entry point:
+ /// OrtStatus* RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api);
+ /// It then passes in the provided session options to this function along with the api base.
+ /// Deprecated in favor of RegisterCustomOpLibraryV2() because it provides users with the library handle
+ /// to release when all sessions relying on it are destroyed
+ ///
+ [ObsoleteAttribute("RegisterCustomOpLibrary(...) is obsolete. Use RegisterCustomOpLibraryV2(...) instead.", false)]
public void RegisterCustomOpLibrary(string libraryPath)
{
IntPtr libraryHandle = IntPtr.Zero;
- NativeApiStatus.VerifySuccess(NativeMethods.OrtRegisterCustomOpsLibrary(handle, libraryPath, out libraryHandle));
+ var libraryPathPinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(libraryPath), GCHandleType.Pinned);
+ using (var pinnedlibraryPath = new PinnedGCHandle(libraryPathPinned))
+ {
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtRegisterCustomOpsLibrary(handle, pinnedlibraryPath.Pointer, out libraryHandle));
+ }
+ }
+
+ ///
+ /// Loads a DLL named 'libraryPath' and looks for this entry point:
+ /// OrtStatus* RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api);
+ /// It then passes in the provided session options to this function along with the api base.
+ /// The handle to the loaded library is returned in 'libraryHandle'.
+ /// It can be unloaded by the caller after all sessions using the passed in
+ /// session options are destroyed, or if an error occurs and it is non null.
+ /// Hint: .NET Core 3.1 has a 'NativeLibrary' class that can be used to free the library handle
+ ///
+ public void RegisterCustomOpLibraryV2(string libraryPath, out IntPtr libraryHandle)
+ {
+ var libraryPathPinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(libraryPath), GCHandleType.Pinned);
+ using (var pinnedlibraryPath = new PinnedGCHandle(libraryPathPinned))
+ {
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtRegisterCustomOpsLibrary(handle, pinnedlibraryPath.Pointer, out libraryHandle));
+ }
}
///
@@ -186,18 +217,60 @@ namespace Microsoft.ML.OnnxRuntime
/// managed by the user (created using the CreateTensorWithDataAsOrtValue API) and it must outlive the session object
/// to which it is added.
///
- public void AddInitializer(string name, OrtValue ort_value)
+ public void AddInitializer(string name, OrtValue ortValue)
{
- NativeApiStatus.VerifySuccess(NativeMethods.OrtAddInitializer(handle, name, ort_value.Handle));
+ var utf8NamePinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(name), GCHandleType.Pinned);
+ using (var pinnedName = new PinnedGCHandle(utf8NamePinned))
+ {
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtAddInitializer(handle, pinnedName.Pointer, ortValue.Handle));
+ }
}
+ ///
+ /// Set a single session configuration entry as a pair of strings
+ /// If a configuration with same key exists, this will overwrite the configuration with the given configValue
+ ///
public void AddSessionConfigEntry(string configKey, string configValue)
{
- NativeApiStatus.VerifySuccess(NativeMethods.OrtAddSessionConfigEntry(handle, configKey, configValue));
- }
- #endregion
+ var utf8NameConfigKeyPinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(configKey), GCHandleType.Pinned);
+ var utf8NameConfigValuePinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(configValue), GCHandleType.Pinned);
- internal IntPtr Handle
+ using (var pinnedConfigKeyName = new PinnedGCHandle(utf8NameConfigKeyPinned))
+ using (var pinnedConfigValueName = new PinnedGCHandle(utf8NameConfigValuePinned))
+ {
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtAddSessionConfigEntry(handle,
+ pinnedConfigKeyName.Pointer, pinnedConfigValueName.Pointer));
+ }
+ }
+
+ ///
+ /// Override symbolic dimensions (by specific denotation strings) with actual values if known at session initialization time to enable
+ /// optimizations that can take advantage of fixed values (such as memory planning, etc)
+ ///
+ public void AddFreeDimensionOverride(string dimDenotation, long dimValue)
+ {
+ var utf8DimDenotationPinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(dimDenotation), GCHandleType.Pinned);
+ using (var pinnedDimDenotation = new PinnedGCHandle(utf8DimDenotationPinned))
+ {
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtAddFreeDimensionOverride(handle, pinnedDimDenotation.Pointer, dimValue));
+ }
+ }
+
+ ///
+ /// Override symbolic dimensions (by specific name strings) with actual values if known at session initialization time to enable
+ /// optimizations that can take advantage of fixed values (such as memory planning, etc)
+ ///
+ public void AddFreeDimensionOverrideByName(string dimName, long dimValue)
+ {
+ var utf8DimNamePinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(dimName), GCHandleType.Pinned);
+ using (var pinnedDimName = new PinnedGCHandle(utf8DimNamePinned))
+ {
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtAddFreeDimensionOverrideByName(handle, pinnedDimName.Pointer, dimValue));
+ }
+ }
+ #endregion
+
+ internal IntPtr Handle
{
get
{
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
index c41e650f37..1dcf4a092e 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
@@ -176,6 +176,20 @@ namespace Microsoft.ML.OnnxRuntime.Tests
ortEnvInstance.EnableTelemetryEvents();
}
+ [Fact]
+ public void GetAvailableProviders()
+ {
+ var ortEnvInstance = OrtEnv.Instance();
+ string[] providers = ortEnvInstance.GetAvailableProviders();
+
+ Assert.True(providers.Length > 0);
+ Assert.Equal("CPUExecutionProvider", providers[0]);
+
+# if USE_CUDA
+ Assert.True(Array.Exists(providers, provider => provider == "CUDAExecutionProvider"););
+#endif
+ }
+
[Fact]
public void CanCreateAndDisposeSessionWithModelPath()
{
@@ -409,6 +423,48 @@ namespace Microsoft.ML.OnnxRuntime.Tests
Assert.True(startTime1 <= startTime2 && startTime2 <= startTime3);
}
+ [Fact]
+ public void SessionOptionsFreeDimensionOverrides()
+ {
+
+ string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "abs_free_dimensions.onnx");
+
+ // By Name
+ using (SessionOptions options = new SessionOptions())
+ {
+ options.AddFreeDimensionOverrideByName("Dim1", 4);
+ options.AddFreeDimensionOverrideByName("Dim2", 6);
+
+ using (var session = new InferenceSession(modelPath, options))
+ {
+ var inputMetadata = session.InputMetadata;
+ var dims = inputMetadata["x"].Dimensions;
+ Assert.Equal(3, dims.Length);
+ Assert.Equal(4, dims[0]);
+ Assert.Equal(6, dims[1]);
+ Assert.Equal(5, dims[2]);
+ }
+ }
+
+ // By Denotation
+ using (SessionOptions options = new SessionOptions())
+ {
+ options.AddFreeDimensionOverride("DATA_BATCH", 3);
+ options.AddFreeDimensionOverride("DATA_CHANNEL", 5);
+
+ using (var session = new InferenceSession(modelPath, options))
+ {
+ var inputMetadata = session.InputMetadata;
+ var dims = inputMetadata["x"].Dimensions;
+ Assert.Equal(3, dims.Length);
+ Assert.Equal(3, dims[0]);
+ Assert.Equal(5, dims[1]);
+ Assert.Equal(5, dims[2]);
+ }
+ }
+
+ }
+
private void validateRunResults(IReadOnlyCollection results)
{
// validate the results
@@ -903,6 +959,26 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
}
+ // Hint: .NET Core 3.1 has a 'NativeLibrary' class that can be used to free the library handle
+ private void UnloadLibrary(IntPtr libraryHandle)
+ {
+ if (libraryHandle != IntPtr.Zero)
+ {
+ if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
+ {
+ if(!FreeLibrary(libraryHandle))
+ {
+ throw new Exception("Could not unload the provided shared library using its handle");
+ }
+ }
+
+ else
+ {
+ // TODO: Deal with non-Windows platforms for the .NET Core use-case
+ }
+ }
+ }
+
[SkipNonPackageTests]
private void TestRegisterCustomOpLibrary()
{
@@ -926,9 +1002,11 @@ namespace Microsoft.ML.OnnxRuntime.Tests
string libFullPath = Path.Combine(Directory.GetCurrentDirectory(), libName);
Assert.True(File.Exists(libFullPath), $"Expected lib {libFullPath} does not exist.");
+ IntPtr libraryHandle = IntPtr.Zero;
try
{
- option.RegisterCustomOpLibrary(libFullPath);
+
+ option.RegisterCustomOpLibraryV2(libFullPath, out libraryHandle);
}
catch (Exception ex)
{
@@ -979,6 +1057,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests
Assert.True(tensorOut.SequenceEqual(expectedOut));
}
}
+
+ // Safe to unload the custom op shared library now
+ UnloadLibrary(libraryHandle);
}
}
@@ -1962,6 +2043,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests
[DllImport("kernel32", CharSet = CharSet.Ansi)]
static extern UIntPtr GetProcAddress(IntPtr hModule, string procName);
+ [DllImport("kernel32.dll", CharSet = CharSet.Ansi)]
+ private static extern bool FreeLibrary(IntPtr hModule);
+
[Fact]
private void VerifyNativeMethodsExist()
{
@@ -1996,12 +2080,20 @@ namespace Microsoft.ML.OnnxRuntime.Tests
,"OrtSessionOptionsAppendExecutionProvider_Nnapi"
#endif
};
-
- var hModule = LoadLibrary(module);
- foreach (var ep in entryPointNames)
+ IntPtr libraryHandle = IntPtr.Zero;
+ try
{
- var x = GetProcAddress(hModule, ep);
- Assert.False(x == UIntPtr.Zero, $"Entrypoint {ep} not found in module {module}");
+ libraryHandle = LoadLibrary(module);
+ foreach (var ep in entryPointNames)
+ {
+ var x = GetProcAddress(libraryHandle, ep);
+ Assert.False(x == UIntPtr.Zero, $"Entrypoint {ep} not found in module {module}");
+ }
+ }
+
+ finally
+ {
+ UnloadLibrary(libraryHandle);
}
}
diff --git a/csharp/testdata/abs_free_dimensions.onnx b/csharp/testdata/abs_free_dimensions.onnx
new file mode 100644
index 0000000000..4c3d5ab6c4
--- /dev/null
+++ b/csharp/testdata/abs_free_dimensions.onnx
@@ -0,0 +1,14 @@
+backend-test:s
+
+xy"Abstest_absZ9
+x4
+2.
+Dim1
+DATA_BATCH
+Dim2DATA_CHANNEL
+b
+y
+
+Dim1
+Dim2
+B
\ No newline at end of file
diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc
index 9aec6c0fe0..1cd852078d 100644
--- a/onnxruntime/test/shared_lib/test_inference.cc
+++ b/onnxruntime/test/shared_lib/test_inference.cc
@@ -13,6 +13,7 @@
#include
#include
#include
+#include
#include
#include "test_allocator.h"
#include "test_fixture.h"
@@ -1006,6 +1007,11 @@ TEST(CApiTest, get_available_providers_cpp) {
std::vector providers = Ort::GetAvailableProviders();
ASSERT_TRUE(providers.size() > 0);
ASSERT_TRUE(providers[0] == std::string("CPUExecutionProvider"));
+
+#ifdef USE_CUDA
+ // CUDA EP will exist in the list but its position may vary based on other EPs included in the build
+ ASSERT_TRUE(std::find(providers.begin(), providers.end(), std::string("CUDAExecutionProvider")) != providers.end());
+#endif
}
// This test uses the CreateAndRegisterAllocator API to register an allocator with the env,