From b322e072b982d2df5f0690da4e3b5319ba9d6d87 Mon Sep 17 00:00:00 2001 From: shahasad <43590019+shahasad@users.noreply.github.com> Date: Fri, 4 Oct 2019 16:38:00 -0700 Subject: [PATCH] added the overridableinitializers api (#1977) --- .../InferenceSession.cs | 69 ++++++++++++++++++- .../Microsoft.ML.OnnxRuntime/NativeMethods.cs | 25 +++++++ .../InferenceTest.cs | 49 +++++++++++++ .../Microsoft.ML.OnnxRuntime.Tests.csproj | 6 +- 4 files changed, 145 insertions(+), 4 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs index 77b5b4c624..688ea4a1f8 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs @@ -18,7 +18,7 @@ namespace Microsoft.ML.OnnxRuntime public class InferenceSession : IDisposable { protected IntPtr _nativeHandle; - protected Dictionary _inputMetadata, _outputMetadata; + protected Dictionary _inputMetadata, _outputMetadata, _overridableInitializerMetadata; private SessionOptions _builtInSessionOptions = null; private RunOptions _builtInRunOptions = null; @@ -88,6 +88,16 @@ namespace Microsoft.ML.OnnxRuntime } } + /// + /// Metadata regarding the overridable initializers, keyed by node names + /// + public IReadOnlyDictionary OverridableInitializerMetadata + { + get + { + return _overridableInitializerMetadata; + } + } /// /// Runs the loaded model for the given inputs, and fetches all the outputs. @@ -238,12 +248,13 @@ namespace Microsoft.ML.OnnxRuntime // Initialize input/output metadata _inputMetadata = new Dictionary(); _outputMetadata = new Dictionary(); + _overridableInitializerMetadata = new Dictionary(); // get input count UIntPtr inputCount = UIntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputCount(_nativeHandle, out inputCount)); - // get all the output names + // get all the input names and metadata for (ulong i = 0; i < (ulong)inputCount; i++) { var iname = GetInputName(i); @@ -253,12 +264,22 @@ namespace Microsoft.ML.OnnxRuntime UIntPtr outputCount = UIntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputCount(_nativeHandle, out outputCount)); - // get all the output names + // get all the output names and metadata for (ulong i = 0; i < (ulong)outputCount; i++) { _outputMetadata[GetOutputName(i)] = GetOutputMetadata(i); } + // get overridable initializer count + UIntPtr initilaizerCount = UIntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOverridableInitializerCount(_nativeHandle, out initilaizerCount)); + + // get all the overridable initializer names and metadata + for (ulong i = 0; i < (ulong)initilaizerCount; i++) + { + _overridableInitializerMetadata[GetOverridableInitializerName(i)] = GetOverridableInitializerMetadata(i); + } + } catch (OnnxRuntimeException e) { @@ -326,6 +347,31 @@ namespace Microsoft.ML.OnnxRuntime return str; } + private string GetOverridableInitializerName(ulong index) + { + IntPtr nameHandle = IntPtr.Zero; + string str = null; + + IntPtr status = NativeMethods.OrtSessionGetOverridableInitializerName( + _nativeHandle, + (UIntPtr)index, + NativeMemoryAllocator.DefaultInstance.Handle, + out nameHandle); + try + { + + NativeApiStatus.VerifySuccess(status); + str = Marshal.PtrToStringAnsi(nameHandle); //assumes charset = ANSI + } + finally + { + if (nameHandle != IntPtr.Zero) + { + NativeMemoryAllocator.DefaultInstance.FreeMemory(nameHandle); + } + } + return str; + } private NodeMetadata GetInputMetadata(ulong index) { @@ -361,6 +407,23 @@ namespace Microsoft.ML.OnnxRuntime } } + private NodeMetadata GetOverridableInitializerMetadata(ulong index) + { + IntPtr typeInfo = IntPtr.Zero; + try + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOverridableInitializerTypeInfo(_nativeHandle, (UIntPtr)index, out typeInfo)); + return GetMetadataFromTypeInfo(typeInfo); + } + finally + { + if (typeInfo != IntPtr.Zero) + { + NativeMethods.OrtReleaseTypeInfo(typeInfo); + } + } + } + internal static NodeMetadata GetMetadataFromTypeInfo(IntPtr typeInfo) { OnnxValueType valueType; diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs index 3d31f4fee4..4277a852d0 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs @@ -141,10 +141,15 @@ namespace Microsoft.ML.OnnxRuntime OrtRun = (DOrtRun)Marshal.GetDelegateForFunctionPointer(api_.Run, typeof(DOrtRun)); OrtSessionGetInputCount = (DOrtSessionGetInputCount)Marshal.GetDelegateForFunctionPointer(api_.SessionGetInputCount, typeof(DOrtSessionGetInputCount)); OrtSessionGetOutputCount = (DOrtSessionGetOutputCount)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOutputCount, typeof(DOrtSessionGetOutputCount)); + OrtSessionGetOverridableInitializerCount = (DOrtSessionGetOverridableInitializerCount)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOverridableInitializerCount, typeof(DOrtSessionGetOverridableInitializerCount)); + OrtSessionGetInputName = (DOrtSessionGetInputName)Marshal.GetDelegateForFunctionPointer(api_.SessionGetInputName, typeof(DOrtSessionGetInputName)); OrtSessionGetOutputName = (DOrtSessionGetOutputName)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOutputName, typeof(DOrtSessionGetOutputName)); + OrtSessionGetOverridableInitializerName = (DOrtSessionGetOverridableInitializerName)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOverridableInitializerName, typeof(DOrtSessionGetOverridableInitializerName)); OrtSessionGetInputTypeInfo = (DOrtSessionGetInputTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.SessionGetInputTypeInfo, typeof(DOrtSessionGetInputTypeInfo)); OrtSessionGetOutputTypeInfo = (DOrtSessionGetOutputTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOutputTypeInfo, typeof(DOrtSessionGetOutputTypeInfo)); + OrtSessionGetOverridableInitializerTypeInfo = (DOrtSessionGetOverridableInitializerTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOverridableInitializerTypeInfo, typeof(DOrtSessionGetOverridableInitializerTypeInfo)); + OrtReleaseTypeInfo = (DOrtReleaseTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.ReleaseTypeInfo, typeof(DOrtReleaseTypeInfo)); OrtReleaseSession = (DOrtReleaseSession)Marshal.GetDelegateForFunctionPointer(api_.ReleaseSession, typeof(DOrtReleaseSession)); @@ -273,6 +278,11 @@ namespace Microsoft.ML.OnnxRuntime out UIntPtr count); public static DOrtSessionGetOutputCount OrtSessionGetOutputCount; + public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetOverridableInitializerCount( + IntPtr /*(OrtSession*)*/ session, + out UIntPtr count); + public static DOrtSessionGetOverridableInitializerCount OrtSessionGetOverridableInitializerCount; + public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetInputName( IntPtr /*(OrtSession*)*/ session, UIntPtr index, @@ -287,6 +297,13 @@ namespace Microsoft.ML.OnnxRuntime out IntPtr /*(char**)*/name); public static DOrtSessionGetOutputName OrtSessionGetOutputName; + public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetOverridableInitializerName( + IntPtr /*(OrtSession*)*/ session, + UIntPtr index, + IntPtr /*(OrtAllocator*)*/ allocator, + out IntPtr /*(char**)*/name); + public static DOrtSessionGetOverridableInitializerName OrtSessionGetOverridableInitializerName; + // release the typeinfo using OrtReleaseTypeInfo public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetInputTypeInfo( IntPtr /*(const OrtSession*)*/ session, @@ -301,6 +318,14 @@ namespace Microsoft.ML.OnnxRuntime out IntPtr /* (struct OrtTypeInfo**)*/ typeInfo); public static DOrtSessionGetOutputTypeInfo OrtSessionGetOutputTypeInfo; + // release the typeinfo using OrtReleaseTypeInfo + public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetOverridableInitializerTypeInfo( + IntPtr /*(const OrtSession*)*/ session, + UIntPtr index, + out IntPtr /* (struct OrtTypeInfo**)*/ typeInfo); + public static DOrtSessionGetOverridableInitializerTypeInfo OrtSessionGetOverridableInitializerTypeInfo; + + public delegate void DOrtReleaseTypeInfo(IntPtr /*(OrtTypeInfo*)*/session); public static DOrtReleaseTypeInfo OrtReleaseTypeInfo; diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs index 4257d5e2b1..6e2b5417ca 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs @@ -397,6 +397,55 @@ namespace Microsoft.ML.OnnxRuntime.Tests } //opset } + [Fact] + private void TestOverridableInitializerMetadata() + { + string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "overridable_initializer.onnx"); + using (var session = new InferenceSession(modelPath)) + { + Assert.Equal(2, session.InputMetadata.Count); + Assert.True(session.InputMetadata.ContainsKey("Label")); + Assert.True(session.InputMetadata.ContainsKey("F2")); + + Assert.Equal(1, session.OverridableInitializerMetadata.Count); + Assert.True(session.OverridableInitializerMetadata.ContainsKey("F1")); + Assert.True(session.OverridableInitializerMetadata["F1"].IsTensor); + Assert.Equal(typeof(float), session.OverridableInitializerMetadata["F1"].ElementType); + Assert.Equal(2, session.OverridableInitializerMetadata["F1"].Dimensions.Length); + Assert.Equal(1, session.OverridableInitializerMetadata["F1"].Dimensions[0]); + Assert.Equal(1, session.OverridableInitializerMetadata["F1"].Dimensions[1]); + + var container = new List(); + var Label_input = new DenseTensor(new bool[] { true }, new int[] { 1, 1 }); + container.Add(NamedOnnxValue.CreateFromTensor("Label", Label_input)); + + var F2_input = new DenseTensor(new string[] { "f2_string" }, new int[] { 1, 1 }); + container.Add(NamedOnnxValue.CreateFromTensor("F2", F2_input)); + + var F1_initializer = new DenseTensor(new float[] { 2.0f }, new int[] { 1, 1 }); + container.Add(NamedOnnxValue.CreateFromTensor("F1", F1_initializer)); + + using (var result = session.Run(container)) + { + var resultMap = new Dictionary(); + + foreach (var output in result) + { + resultMap[output.Name] = output; + } + + Assert.True(resultMap.ContainsKey("Label0")); + Assert.True(resultMap.ContainsKey("F20")); + Assert.True(resultMap.ContainsKey("F11")); + + var overriddenInitializer = resultMap["F11"].AsTensor(); + Assert.NotNull(overriddenInitializer); + Assert.True(overriddenInitializer.SequenceEqual(F1_initializer)); + } + } + } + + [Fact] private void TestModelInputFloat() { diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Microsoft.ML.OnnxRuntime.Tests.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Microsoft.ML.OnnxRuntime.Tests.csproj index 075a6962ce..9ecd184d61 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Microsoft.ML.OnnxRuntime.Tests.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Microsoft.ML.OnnxRuntime.Tests.csproj @@ -71,7 +71,11 @@ Always false - + + Always + false + +