added the overridableinitializers api (#1977)

This commit is contained in:
shahasad 2019-10-04 16:38:00 -07:00 committed by GitHub
parent 19873c70dc
commit b322e072b9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 145 additions and 4 deletions

View file

@ -18,7 +18,7 @@ namespace Microsoft.ML.OnnxRuntime
public class InferenceSession : IDisposable
{
protected IntPtr _nativeHandle;
protected Dictionary<string, NodeMetadata> _inputMetadata, _outputMetadata;
protected Dictionary<string, NodeMetadata> _inputMetadata, _outputMetadata, _overridableInitializerMetadata;
private SessionOptions _builtInSessionOptions = null;
private RunOptions _builtInRunOptions = null;
@ -88,6 +88,16 @@ namespace Microsoft.ML.OnnxRuntime
}
}
/// <summary>
/// Metadata regarding the overridable initializers, keyed by node names
/// </summary>
public IReadOnlyDictionary<string, NodeMetadata> OverridableInitializerMetadata
{
get
{
return _overridableInitializerMetadata;
}
}
/// <summary>
/// Runs the loaded model for the given inputs, and fetches all the outputs.
@ -238,12 +248,13 @@ namespace Microsoft.ML.OnnxRuntime
// Initialize input/output metadata
_inputMetadata = new Dictionary<string, NodeMetadata>();
_outputMetadata = new Dictionary<string, NodeMetadata>();
_overridableInitializerMetadata = new Dictionary<string, NodeMetadata>();
// get input count
UIntPtr inputCount = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputCount(_nativeHandle, out inputCount));
// get all the output names
// get all the input names and metadata
for (ulong i = 0; i < (ulong)inputCount; i++)
{
var iname = GetInputName(i);
@ -253,12 +264,22 @@ namespace Microsoft.ML.OnnxRuntime
UIntPtr outputCount = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputCount(_nativeHandle, out outputCount));
// get all the output names
// get all the output names and metadata
for (ulong i = 0; i < (ulong)outputCount; i++)
{
_outputMetadata[GetOutputName(i)] = GetOutputMetadata(i);
}
// get overridable initializer count
UIntPtr initilaizerCount = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOverridableInitializerCount(_nativeHandle, out initilaizerCount));
// get all the overridable initializer names and metadata
for (ulong i = 0; i < (ulong)initilaizerCount; i++)
{
_overridableInitializerMetadata[GetOverridableInitializerName(i)] = GetOverridableInitializerMetadata(i);
}
}
catch (OnnxRuntimeException e)
{
@ -326,6 +347,31 @@ namespace Microsoft.ML.OnnxRuntime
return str;
}
private string GetOverridableInitializerName(ulong index)
{
IntPtr nameHandle = IntPtr.Zero;
string str = null;
IntPtr status = NativeMethods.OrtSessionGetOverridableInitializerName(
_nativeHandle,
(UIntPtr)index,
NativeMemoryAllocator.DefaultInstance.Handle,
out nameHandle);
try
{
NativeApiStatus.VerifySuccess(status);
str = Marshal.PtrToStringAnsi(nameHandle); //assumes charset = ANSI
}
finally
{
if (nameHandle != IntPtr.Zero)
{
NativeMemoryAllocator.DefaultInstance.FreeMemory(nameHandle);
}
}
return str;
}
private NodeMetadata GetInputMetadata(ulong index)
{
@ -361,6 +407,23 @@ namespace Microsoft.ML.OnnxRuntime
}
}
private NodeMetadata GetOverridableInitializerMetadata(ulong index)
{
IntPtr typeInfo = IntPtr.Zero;
try
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOverridableInitializerTypeInfo(_nativeHandle, (UIntPtr)index, out typeInfo));
return GetMetadataFromTypeInfo(typeInfo);
}
finally
{
if (typeInfo != IntPtr.Zero)
{
NativeMethods.OrtReleaseTypeInfo(typeInfo);
}
}
}
internal static NodeMetadata GetMetadataFromTypeInfo(IntPtr typeInfo)
{
OnnxValueType valueType;

View file

@ -141,10 +141,15 @@ namespace Microsoft.ML.OnnxRuntime
OrtRun = (DOrtRun)Marshal.GetDelegateForFunctionPointer(api_.Run, typeof(DOrtRun));
OrtSessionGetInputCount = (DOrtSessionGetInputCount)Marshal.GetDelegateForFunctionPointer(api_.SessionGetInputCount, typeof(DOrtSessionGetInputCount));
OrtSessionGetOutputCount = (DOrtSessionGetOutputCount)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOutputCount, typeof(DOrtSessionGetOutputCount));
OrtSessionGetOverridableInitializerCount = (DOrtSessionGetOverridableInitializerCount)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOverridableInitializerCount, typeof(DOrtSessionGetOverridableInitializerCount));
OrtSessionGetInputName = (DOrtSessionGetInputName)Marshal.GetDelegateForFunctionPointer(api_.SessionGetInputName, typeof(DOrtSessionGetInputName));
OrtSessionGetOutputName = (DOrtSessionGetOutputName)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOutputName, typeof(DOrtSessionGetOutputName));
OrtSessionGetOverridableInitializerName = (DOrtSessionGetOverridableInitializerName)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOverridableInitializerName, typeof(DOrtSessionGetOverridableInitializerName));
OrtSessionGetInputTypeInfo = (DOrtSessionGetInputTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.SessionGetInputTypeInfo, typeof(DOrtSessionGetInputTypeInfo));
OrtSessionGetOutputTypeInfo = (DOrtSessionGetOutputTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOutputTypeInfo, typeof(DOrtSessionGetOutputTypeInfo));
OrtSessionGetOverridableInitializerTypeInfo = (DOrtSessionGetOverridableInitializerTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOverridableInitializerTypeInfo, typeof(DOrtSessionGetOverridableInitializerTypeInfo));
OrtReleaseTypeInfo = (DOrtReleaseTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.ReleaseTypeInfo, typeof(DOrtReleaseTypeInfo));
OrtReleaseSession = (DOrtReleaseSession)Marshal.GetDelegateForFunctionPointer(api_.ReleaseSession, typeof(DOrtReleaseSession));
@ -273,6 +278,11 @@ namespace Microsoft.ML.OnnxRuntime
out UIntPtr count);
public static DOrtSessionGetOutputCount OrtSessionGetOutputCount;
public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetOverridableInitializerCount(
IntPtr /*(OrtSession*)*/ session,
out UIntPtr count);
public static DOrtSessionGetOverridableInitializerCount OrtSessionGetOverridableInitializerCount;
public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetInputName(
IntPtr /*(OrtSession*)*/ session,
UIntPtr index,
@ -287,6 +297,13 @@ namespace Microsoft.ML.OnnxRuntime
out IntPtr /*(char**)*/name);
public static DOrtSessionGetOutputName OrtSessionGetOutputName;
public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetOverridableInitializerName(
IntPtr /*(OrtSession*)*/ session,
UIntPtr index,
IntPtr /*(OrtAllocator*)*/ allocator,
out IntPtr /*(char**)*/name);
public static DOrtSessionGetOverridableInitializerName OrtSessionGetOverridableInitializerName;
// release the typeinfo using OrtReleaseTypeInfo
public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetInputTypeInfo(
IntPtr /*(const OrtSession*)*/ session,
@ -301,6 +318,14 @@ namespace Microsoft.ML.OnnxRuntime
out IntPtr /* (struct OrtTypeInfo**)*/ typeInfo);
public static DOrtSessionGetOutputTypeInfo OrtSessionGetOutputTypeInfo;
// release the typeinfo using OrtReleaseTypeInfo
public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetOverridableInitializerTypeInfo(
IntPtr /*(const OrtSession*)*/ session,
UIntPtr index,
out IntPtr /* (struct OrtTypeInfo**)*/ typeInfo);
public static DOrtSessionGetOverridableInitializerTypeInfo OrtSessionGetOverridableInitializerTypeInfo;
public delegate void DOrtReleaseTypeInfo(IntPtr /*(OrtTypeInfo*)*/session);
public static DOrtReleaseTypeInfo OrtReleaseTypeInfo;

View file

@ -397,6 +397,55 @@ namespace Microsoft.ML.OnnxRuntime.Tests
} //opset
}
[Fact]
private void TestOverridableInitializerMetadata()
{
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "overridable_initializer.onnx");
using (var session = new InferenceSession(modelPath))
{
Assert.Equal(2, session.InputMetadata.Count);
Assert.True(session.InputMetadata.ContainsKey("Label"));
Assert.True(session.InputMetadata.ContainsKey("F2"));
Assert.Equal(1, session.OverridableInitializerMetadata.Count);
Assert.True(session.OverridableInitializerMetadata.ContainsKey("F1"));
Assert.True(session.OverridableInitializerMetadata["F1"].IsTensor);
Assert.Equal(typeof(float), session.OverridableInitializerMetadata["F1"].ElementType);
Assert.Equal(2, session.OverridableInitializerMetadata["F1"].Dimensions.Length);
Assert.Equal(1, session.OverridableInitializerMetadata["F1"].Dimensions[0]);
Assert.Equal(1, session.OverridableInitializerMetadata["F1"].Dimensions[1]);
var container = new List<NamedOnnxValue>();
var Label_input = new DenseTensor<bool>(new bool[] { true }, new int[] { 1, 1 });
container.Add(NamedOnnxValue.CreateFromTensor("Label", Label_input));
var F2_input = new DenseTensor<string>(new string[] { "f2_string" }, new int[] { 1, 1 });
container.Add(NamedOnnxValue.CreateFromTensor("F2", F2_input));
var F1_initializer = new DenseTensor<float>(new float[] { 2.0f }, new int[] { 1, 1 });
container.Add(NamedOnnxValue.CreateFromTensor("F1", F1_initializer));
using (var result = session.Run(container))
{
var resultMap = new Dictionary<string, NamedOnnxValue>();
foreach (var output in result)
{
resultMap[output.Name] = output;
}
Assert.True(resultMap.ContainsKey("Label0"));
Assert.True(resultMap.ContainsKey("F20"));
Assert.True(resultMap.ContainsKey("F11"));
var overriddenInitializer = resultMap["F11"].AsTensor<float>();
Assert.NotNull(overriddenInitializer);
Assert.True(overriddenInitializer.SequenceEqual(F1_initializer));
}
}
}
[Fact]
private void TestModelInputFloat()
{

View file

@ -71,7 +71,11 @@
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
<None Include="$(OnnxRuntimeCSharpRoot)\..\onnxruntime\test\testdata\overridable_initializer.onnx">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
<BuildEnvVars Include="OnnxRuntimeBuildDirectory=$(OnnxRuntimeBuildDirectory)" />
</ItemGroup>