Support more C# APIs (#5608)

This commit is contained in:
Hariharan Seshadri 2020-10-30 19:19:50 -07:00 committed by GitHub
parent 17bce6f07e
commit 7a80a4b526
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 301 additions and 18 deletions

View file

@ -275,6 +275,7 @@ namespace Microsoft.ML.OnnxRuntime
OrtAllocatorFree = (DOrtAllocatorFree)Marshal.GetDelegateForFunctionPointer(api_.AllocatorFree, typeof(DOrtAllocatorFree));
OrtAllocatorGetInfo = (DOrtAllocatorGetInfo)Marshal.GetDelegateForFunctionPointer(api_.AllocatorGetInfo, typeof(DOrtAllocatorGetInfo));
OrtAddFreeDimensionOverride = (DOrtAddFreeDimensionOverride)Marshal.GetDelegateForFunctionPointer(api_.AddFreeDimensionOverride, typeof(DOrtAddFreeDimensionOverride));
OrtAddFreeDimensionOverrideByName = (DOrtAddFreeDimensionOverrideByName)Marshal.GetDelegateForFunctionPointer(api_.AddFreeDimensionOverrideByName, typeof(DOrtAddFreeDimensionOverrideByName));
OrtCreateIoBinding = (DOrtCreateIoBinding)Marshal.GetDelegateForFunctionPointer(api_.CreateIoBinding, typeof(DOrtCreateIoBinding));
OrtReleaseIoBinding = (DOrtReleaseIoBinding)Marshal.GetDelegateForFunctionPointer(api_.ReleaseIoBinding, typeof(DOrtReleaseIoBinding));
@ -321,6 +322,8 @@ namespace Microsoft.ML.OnnxRuntime
OrtModelMetadataLookupCustomMetadataMap = (DOrtModelMetadataLookupCustomMetadataMap)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataLookupCustomMetadataMap, typeof(DOrtModelMetadataLookupCustomMetadataMap));
OrtReleaseModelMetadata = (DOrtReleaseModelMetadata)Marshal.GetDelegateForFunctionPointer(api_.ReleaseModelMetadata, typeof(DOrtReleaseModelMetadata));
OrtGetAvailableProviders = (DOrtGetAvailableProviders)Marshal.GetDelegateForFunctionPointer(api_.GetAvailableProviders, typeof(DOrtGetAvailableProviders));
OrtReleaseAvailableProviders = (DOrtReleaseAvailableProviders)Marshal.GetDelegateForFunctionPointer(api_.ReleaseAvailableProviders, typeof(DOrtReleaseAvailableProviders));
}
[DllImport(nativeLib, CharSet = charSet)]
@ -522,7 +525,15 @@ namespace Microsoft.ML.OnnxRuntime
public delegate IntPtr /*(OrtStatus*)*/ DOrtSetSessionGraphOptimizationLevel(IntPtr /* OrtSessionOptions* */ options, GraphOptimizationLevel graphOptimizationLevel);
public static DOrtSetSessionGraphOptimizationLevel OrtSetSessionGraphOptimizationLevel;
public delegate IntPtr /*(OrtStatus*)*/ DOrtAddSessionConfigEntry(IntPtr /* OrtSessionOptions* */ options, string configKey, string configValue);
/// <summary>
/// Add session config entry
/// </summary>
/// <param name="options">Native SessionOptions instance</param>
/// <param name="configKey">Config key</param>
/// <param name="configValue">Config value</param>
public delegate IntPtr /*(OrtStatus*)*/ DOrtAddSessionConfigEntry(IntPtr /* OrtSessionOptions* */ options,
IntPtr /* const char* */configKey,
IntPtr /* const char* */ configValue);
public static DOrtAddSessionConfigEntry OrtAddSessionConfigEntry;
///**
@ -564,13 +575,49 @@ namespace Microsoft.ML.OnnxRuntime
//[DllImport(nativeLib, CharSet = charSet)]
//public static extern void OrtAddCustomOp(IntPtr /*(OrtSessionOptions*)*/ options, string custom_op_path);
public delegate IntPtr /*(OrtStatus*)*/DOrtAddFreeDimensionOverride(IntPtr /*(OrtSessionOptions*) */ options, string /*(const char*)*/ symbolic_dim, int dim_override);
/// <summary>
/// Free Dimension override (by denotation)
/// </summary>
/// <param name="options">Native SessionOptions instance</param>
/// <param name="dimDenotation">Dimension denotation</param>
/// <param name="dimValue">Dimension value</param>
public delegate IntPtr /*(OrtStatus*)*/DOrtAddFreeDimensionOverride(IntPtr /*(OrtSessionOptions*)*/ options,
IntPtr /*(const char*)*/ dimDenotation,
long dimValue);
public static DOrtAddFreeDimensionOverride OrtAddFreeDimensionOverride;
public delegate IntPtr /*(OrtStatus*)*/DOrtRegisterCustomOpsLibrary(IntPtr /*(OrtSessionOptions*) */ options, string /*(const char*)*/ library_path, out IntPtr /* (void**) */ library_handle);
/// <summary>
/// Free Dimension override (by name)
/// </summary>
/// <param name="options">Native SessionOptions instance</param>
/// <param name="dimName">Dimension name</param>
/// <param name="dimValue">Dimension value</param>
public delegate IntPtr /*(OrtStatus*)*/DOrtAddFreeDimensionOverrideByName(IntPtr /*(OrtSessionOptions*)*/ options,
IntPtr /*(const char*)*/ dimName,
long dimValue);
public static DOrtAddFreeDimensionOverrideByName OrtAddFreeDimensionOverrideByName;
/// <summary>
/// Register custom op library
/// </summary>
/// <param name="options">Native SessionOptions instance</param>
/// <param name="libraryPath">Library path</param>
/// <param name="libraryHandle">(out) Native library handle</param>
public delegate IntPtr /*(OrtStatus*)*/DOrtRegisterCustomOpsLibrary(IntPtr /*(OrtSessionOptions*) */ options,
IntPtr /*(const char*)*/ libraryPath,
out IntPtr /*(void**)*/ libraryHandle);
public static DOrtRegisterCustomOpsLibrary OrtRegisterCustomOpsLibrary;
public delegate IntPtr /*(OrtStatus*)*/DOrtAddInitializer(IntPtr /*(OrtSessionOptions*) */ options, string /*(const char*)*/ name, IntPtr /* OrtValue* */ ort_value);
/// <summary>
/// Add initializer that is shared across Sessions using this SessionOptions (by denotation)
/// </summary>
/// <param name="options">Native SessionOptions instance</param>
/// <param name="name">Name of the initializer</param>
/// <param name="ortValue">Native OrtValue instnce</param>
public delegate IntPtr /*(OrtStatus*)*/DOrtAddInitializer(IntPtr /*(OrtSessionOptions*)*/ options,
IntPtr /*(const char*)*/ name,
IntPtr /*(OrtValue*)*/ ortValue);
public static DOrtAddInitializer OrtAddInitializer;
#endregion
@ -1048,6 +1095,27 @@ namespace Microsoft.ML.OnnxRuntime
#endregion
#region Misc API
/// <summary>
/// Queries all the execution providers supported in the native onnxruntime shared library
/// </summary>
/// <param name="providers">(output) all execution providers (strings) supported in the native onnxruntime shared library</param>
/// <param name="numProviders">(output) number of execution providers (strings)</param>
public delegate IntPtr /* (OrtStatus*) */ DOrtGetAvailableProviders(out IntPtr /* (char***) */ providers, out int /* (int*) */ numProviders);
public static DOrtGetAvailableProviders OrtGetAvailableProviders;
/// <summary>
/// Releases all execution provider strings allocated and returned by OrtGetAvailableProviders
/// </summary>
/// <param name="providers">all execution providers (strings) returned by OrtGetAvailableProviders</param>
/// <param name="numProviders">number of execution providers (strings)</param>
public delegate IntPtr /* (OrtStatus*) */ DOrtReleaseAvailableProviders(IntPtr /* (char**) */ providers, int /* (int) */ numProviders);
public static DOrtReleaseAvailableProviders OrtReleaseAvailableProviders;
#endregion
public static byte[] GetPlatformSerializedString(string str)
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))

View file

@ -101,6 +101,36 @@ namespace Microsoft.ML.OnnxRuntime
NativeApiStatus.VerifySuccess(NativeMethods.OrtDisableTelemetryEvents(Handle));
}
/// <summary>
/// Queries all the execution providers supported in the native onnxruntime shared library
/// </summary>
public string[] GetAvailableProviders()
{
IntPtr availableProvidersHandle = IntPtr.Zero;
int numProviders;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetAvailableProviders(out availableProvidersHandle, out numProviders));
var availableProviders = new string[numProviders];
try
{
for(int i=0; i<numProviders; ++i)
{
availableProviders[i] = NativeOnnxValueHelper.StringFromNativeUtf8(Marshal.ReadIntPtr(availableProvidersHandle, IntPtr.Size * i));
}
}
finally
{
// Looks a bit weird that we might throw in finally(...)
// But the native method OrtReleaseAvailableProviders actually doesn't return a failure status
// If it does, it is BUG and we would like to propagate that to the user in the form of an exception
NativeApiStatus.VerifySuccess(NativeMethods.OrtReleaseAvailableProviders(availableProvidersHandle, numProviders));
}
return availableProviders;
}
#endregion
#region SafeHandle

View file

@ -120,7 +120,7 @@ namespace Microsoft.ML.OnnxRuntime
/// However, it re-uses managed memory if possible.
/// </summary>
/// <param name="value">Tensor object</param>
/// <param name="memoryHandle">For all tensor types but string tensors we endevour to use managed memory
/// <param name="memoryHandle">For all tensor types but string tensors we endeavor to use managed memory
/// to avoid additional allocation and copy. This out parameter represents a chunk of pinned memory
/// </param>
/// <param name="elementType">discovered tensor element type</param>

View file

@ -170,10 +170,41 @@ namespace Microsoft.ML.OnnxRuntime
#endregion //ExecutionProviderAppends
#region Public Methods
/// <summary>
/// (Deprecated) Loads a DLL named 'libraryPath' and looks for this entry point:
/// OrtStatus* RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api);
/// It then passes in the provided session options to this function along with the api base.
/// Deprecated in favor of RegisterCustomOpLibraryV2() because it provides users with the library handle
/// to release when all sessions relying on it are destroyed
/// </summary>
[ObsoleteAttribute("RegisterCustomOpLibrary(...) is obsolete. Use RegisterCustomOpLibraryV2(...) instead.", false)]
public void RegisterCustomOpLibrary(string libraryPath)
{
IntPtr libraryHandle = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtRegisterCustomOpsLibrary(handle, libraryPath, out libraryHandle));
var libraryPathPinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(libraryPath), GCHandleType.Pinned);
using (var pinnedlibraryPath = new PinnedGCHandle(libraryPathPinned))
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtRegisterCustomOpsLibrary(handle, pinnedlibraryPath.Pointer, out libraryHandle));
}
}
/// <summary>
/// Loads a DLL named 'libraryPath' and looks for this entry point:
/// OrtStatus* RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api);
/// It then passes in the provided session options to this function along with the api base.
/// The handle to the loaded library is returned in 'libraryHandle'.
/// It can be unloaded by the caller after all sessions using the passed in
/// session options are destroyed, or if an error occurs and it is non null.
/// Hint: .NET Core 3.1 has a 'NativeLibrary' class that can be used to free the library handle
/// </summary>
public void RegisterCustomOpLibraryV2(string libraryPath, out IntPtr libraryHandle)
{
var libraryPathPinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(libraryPath), GCHandleType.Pinned);
using (var pinnedlibraryPath = new PinnedGCHandle(libraryPathPinned))
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtRegisterCustomOpsLibrary(handle, pinnedlibraryPath.Pointer, out libraryHandle));
}
}
/// <summary>
@ -186,18 +217,60 @@ namespace Microsoft.ML.OnnxRuntime
/// managed by the user (created using the CreateTensorWithDataAsOrtValue API) and it must outlive the session object
/// to which it is added.
/// </summary>
public void AddInitializer(string name, OrtValue ort_value)
public void AddInitializer(string name, OrtValue ortValue)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtAddInitializer(handle, name, ort_value.Handle));
var utf8NamePinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(name), GCHandleType.Pinned);
using (var pinnedName = new PinnedGCHandle(utf8NamePinned))
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtAddInitializer(handle, pinnedName.Pointer, ortValue.Handle));
}
}
/// <summary>
/// Set a single session configuration entry as a pair of strings
/// If a configuration with same key exists, this will overwrite the configuration with the given configValue
/// </summary>
public void AddSessionConfigEntry(string configKey, string configValue)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtAddSessionConfigEntry(handle, configKey, configValue));
}
#endregion
var utf8NameConfigKeyPinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(configKey), GCHandleType.Pinned);
var utf8NameConfigValuePinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(configValue), GCHandleType.Pinned);
internal IntPtr Handle
using (var pinnedConfigKeyName = new PinnedGCHandle(utf8NameConfigKeyPinned))
using (var pinnedConfigValueName = new PinnedGCHandle(utf8NameConfigValuePinned))
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtAddSessionConfigEntry(handle,
pinnedConfigKeyName.Pointer, pinnedConfigValueName.Pointer));
}
}
/// <summary>
/// Override symbolic dimensions (by specific denotation strings) with actual values if known at session initialization time to enable
/// optimizations that can take advantage of fixed values (such as memory planning, etc)
/// </summary>
public void AddFreeDimensionOverride(string dimDenotation, long dimValue)
{
var utf8DimDenotationPinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(dimDenotation), GCHandleType.Pinned);
using (var pinnedDimDenotation = new PinnedGCHandle(utf8DimDenotationPinned))
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtAddFreeDimensionOverride(handle, pinnedDimDenotation.Pointer, dimValue));
}
}
/// <summary>
/// Override symbolic dimensions (by specific name strings) with actual values if known at session initialization time to enable
/// optimizations that can take advantage of fixed values (such as memory planning, etc)
/// </summary>
public void AddFreeDimensionOverrideByName(string dimName, long dimValue)
{
var utf8DimNamePinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(dimName), GCHandleType.Pinned);
using (var pinnedDimName = new PinnedGCHandle(utf8DimNamePinned))
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtAddFreeDimensionOverrideByName(handle, pinnedDimName.Pointer, dimValue));
}
}
#endregion
internal IntPtr Handle
{
get
{

View file

@ -176,6 +176,20 @@ namespace Microsoft.ML.OnnxRuntime.Tests
ortEnvInstance.EnableTelemetryEvents();
}
[Fact]
public void GetAvailableProviders()
{
var ortEnvInstance = OrtEnv.Instance();
string[] providers = ortEnvInstance.GetAvailableProviders();
Assert.True(providers.Length > 0);
Assert.Equal("CPUExecutionProvider", providers[0]);
# if USE_CUDA
Assert.True(Array.Exists(providers, provider => provider == "CUDAExecutionProvider"););
#endif
}
[Fact]
public void CanCreateAndDisposeSessionWithModelPath()
{
@ -409,6 +423,48 @@ namespace Microsoft.ML.OnnxRuntime.Tests
Assert.True(startTime1 <= startTime2 && startTime2 <= startTime3);
}
[Fact]
public void SessionOptionsFreeDimensionOverrides()
{
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "abs_free_dimensions.onnx");
// By Name
using (SessionOptions options = new SessionOptions())
{
options.AddFreeDimensionOverrideByName("Dim1", 4);
options.AddFreeDimensionOverrideByName("Dim2", 6);
using (var session = new InferenceSession(modelPath, options))
{
var inputMetadata = session.InputMetadata;
var dims = inputMetadata["x"].Dimensions;
Assert.Equal(3, dims.Length);
Assert.Equal(4, dims[0]);
Assert.Equal(6, dims[1]);
Assert.Equal(5, dims[2]);
}
}
// By Denotation
using (SessionOptions options = new SessionOptions())
{
options.AddFreeDimensionOverride("DATA_BATCH", 3);
options.AddFreeDimensionOverride("DATA_CHANNEL", 5);
using (var session = new InferenceSession(modelPath, options))
{
var inputMetadata = session.InputMetadata;
var dims = inputMetadata["x"].Dimensions;
Assert.Equal(3, dims.Length);
Assert.Equal(3, dims[0]);
Assert.Equal(5, dims[1]);
Assert.Equal(5, dims[2]);
}
}
}
private void validateRunResults(IReadOnlyCollection<NamedOnnxValue> results)
{
// validate the results
@ -903,6 +959,26 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
}
// Hint: .NET Core 3.1 has a 'NativeLibrary' class that can be used to free the library handle
private void UnloadLibrary(IntPtr libraryHandle)
{
if (libraryHandle != IntPtr.Zero)
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
if(!FreeLibrary(libraryHandle))
{
throw new Exception("Could not unload the provided shared library using its handle");
}
}
else
{
// TODO: Deal with non-Windows platforms for the .NET Core use-case
}
}
}
[SkipNonPackageTests]
private void TestRegisterCustomOpLibrary()
{
@ -926,9 +1002,11 @@ namespace Microsoft.ML.OnnxRuntime.Tests
string libFullPath = Path.Combine(Directory.GetCurrentDirectory(), libName);
Assert.True(File.Exists(libFullPath), $"Expected lib {libFullPath} does not exist.");
IntPtr libraryHandle = IntPtr.Zero;
try
{
option.RegisterCustomOpLibrary(libFullPath);
option.RegisterCustomOpLibraryV2(libFullPath, out libraryHandle);
}
catch (Exception ex)
{
@ -979,6 +1057,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests
Assert.True(tensorOut.SequenceEqual(expectedOut));
}
}
// Safe to unload the custom op shared library now
UnloadLibrary(libraryHandle);
}
}
@ -1962,6 +2043,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests
[DllImport("kernel32", CharSet = CharSet.Ansi)]
static extern UIntPtr GetProcAddress(IntPtr hModule, string procName);
[DllImport("kernel32.dll", CharSet = CharSet.Ansi)]
private static extern bool FreeLibrary(IntPtr hModule);
[Fact]
private void VerifyNativeMethodsExist()
{
@ -1996,12 +2080,20 @@ namespace Microsoft.ML.OnnxRuntime.Tests
,"OrtSessionOptionsAppendExecutionProvider_Nnapi"
#endif
};
var hModule = LoadLibrary(module);
foreach (var ep in entryPointNames)
IntPtr libraryHandle = IntPtr.Zero;
try
{
var x = GetProcAddress(hModule, ep);
Assert.False(x == UIntPtr.Zero, $"Entrypoint {ep} not found in module {module}");
libraryHandle = LoadLibrary(module);
foreach (var ep in entryPointNames)
{
var x = GetProcAddress(libraryHandle, ep);
Assert.False(x == UIntPtr.Zero, $"Entrypoint {ep} not found in module {module}");
}
}
finally
{
UnloadLibrary(libraryHandle);
}
}

View file

@ -0,0 +1,14 @@
 backend-test:s
xy"Abstest_absZ9
x4
2.
Dim1
DATA_BATCH
Dim2 DATA_CHANNEL
b
y

Dim1
Dim2
B

View file

@ -13,6 +13,7 @@
#include <sstream>
#include <atomic>
#include <mutex>
#include <algorithm>
#include <gtest/gtest.h>
#include "test_allocator.h"
#include "test_fixture.h"
@ -1006,6 +1007,11 @@ TEST(CApiTest, get_available_providers_cpp) {
std::vector<std::string> providers = Ort::GetAvailableProviders();
ASSERT_TRUE(providers.size() > 0);
ASSERT_TRUE(providers[0] == std::string("CPUExecutionProvider"));
#ifdef USE_CUDA
// CUDA EP will exist in the list but its position may vary based on other EPs included in the build
ASSERT_TRUE(std::find(providers.begin(), providers.end(), std::string("CUDAExecutionProvider")) != providers.end());
#endif
}
// This test uses the CreateAndRegisterAllocator API to register an allocator with the env,