mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
added the overridableinitializers api (#1977)
This commit is contained in:
parent
19873c70dc
commit
b322e072b9
4 changed files with 145 additions and 4 deletions
|
|
@ -18,7 +18,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
public class InferenceSession : IDisposable
|
||||
{
|
||||
protected IntPtr _nativeHandle;
|
||||
protected Dictionary<string, NodeMetadata> _inputMetadata, _outputMetadata;
|
||||
protected Dictionary<string, NodeMetadata> _inputMetadata, _outputMetadata, _overridableInitializerMetadata;
|
||||
private SessionOptions _builtInSessionOptions = null;
|
||||
private RunOptions _builtInRunOptions = null;
|
||||
|
||||
|
|
@ -88,6 +88,16 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Metadata regarding the overridable initializers, keyed by node names
|
||||
/// </summary>
|
||||
public IReadOnlyDictionary<string, NodeMetadata> OverridableInitializerMetadata
|
||||
{
|
||||
get
|
||||
{
|
||||
return _overridableInitializerMetadata;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Runs the loaded model for the given inputs, and fetches all the outputs.
|
||||
|
|
@ -238,12 +248,13 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
// Initialize input/output metadata
|
||||
_inputMetadata = new Dictionary<string, NodeMetadata>();
|
||||
_outputMetadata = new Dictionary<string, NodeMetadata>();
|
||||
_overridableInitializerMetadata = new Dictionary<string, NodeMetadata>();
|
||||
|
||||
// get input count
|
||||
UIntPtr inputCount = UIntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputCount(_nativeHandle, out inputCount));
|
||||
|
||||
// get all the output names
|
||||
// get all the input names and metadata
|
||||
for (ulong i = 0; i < (ulong)inputCount; i++)
|
||||
{
|
||||
var iname = GetInputName(i);
|
||||
|
|
@ -253,12 +264,22 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
UIntPtr outputCount = UIntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputCount(_nativeHandle, out outputCount));
|
||||
|
||||
// get all the output names
|
||||
// get all the output names and metadata
|
||||
for (ulong i = 0; i < (ulong)outputCount; i++)
|
||||
{
|
||||
_outputMetadata[GetOutputName(i)] = GetOutputMetadata(i);
|
||||
}
|
||||
|
||||
// get overridable initializer count
|
||||
UIntPtr initilaizerCount = UIntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOverridableInitializerCount(_nativeHandle, out initilaizerCount));
|
||||
|
||||
// get all the overridable initializer names and metadata
|
||||
for (ulong i = 0; i < (ulong)initilaizerCount; i++)
|
||||
{
|
||||
_overridableInitializerMetadata[GetOverridableInitializerName(i)] = GetOverridableInitializerMetadata(i);
|
||||
}
|
||||
|
||||
}
|
||||
catch (OnnxRuntimeException e)
|
||||
{
|
||||
|
|
@ -326,6 +347,31 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
return str;
|
||||
}
|
||||
|
||||
private string GetOverridableInitializerName(ulong index)
|
||||
{
|
||||
IntPtr nameHandle = IntPtr.Zero;
|
||||
string str = null;
|
||||
|
||||
IntPtr status = NativeMethods.OrtSessionGetOverridableInitializerName(
|
||||
_nativeHandle,
|
||||
(UIntPtr)index,
|
||||
NativeMemoryAllocator.DefaultInstance.Handle,
|
||||
out nameHandle);
|
||||
try
|
||||
{
|
||||
|
||||
NativeApiStatus.VerifySuccess(status);
|
||||
str = Marshal.PtrToStringAnsi(nameHandle); //assumes charset = ANSI
|
||||
}
|
||||
finally
|
||||
{
|
||||
if (nameHandle != IntPtr.Zero)
|
||||
{
|
||||
NativeMemoryAllocator.DefaultInstance.FreeMemory(nameHandle);
|
||||
}
|
||||
}
|
||||
return str;
|
||||
}
|
||||
|
||||
private NodeMetadata GetInputMetadata(ulong index)
|
||||
{
|
||||
|
|
@ -361,6 +407,23 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
}
|
||||
}
|
||||
|
||||
private NodeMetadata GetOverridableInitializerMetadata(ulong index)
|
||||
{
|
||||
IntPtr typeInfo = IntPtr.Zero;
|
||||
try
|
||||
{
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOverridableInitializerTypeInfo(_nativeHandle, (UIntPtr)index, out typeInfo));
|
||||
return GetMetadataFromTypeInfo(typeInfo);
|
||||
}
|
||||
finally
|
||||
{
|
||||
if (typeInfo != IntPtr.Zero)
|
||||
{
|
||||
NativeMethods.OrtReleaseTypeInfo(typeInfo);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal static NodeMetadata GetMetadataFromTypeInfo(IntPtr typeInfo)
|
||||
{
|
||||
OnnxValueType valueType;
|
||||
|
|
|
|||
|
|
@ -141,10 +141,15 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
OrtRun = (DOrtRun)Marshal.GetDelegateForFunctionPointer(api_.Run, typeof(DOrtRun));
|
||||
OrtSessionGetInputCount = (DOrtSessionGetInputCount)Marshal.GetDelegateForFunctionPointer(api_.SessionGetInputCount, typeof(DOrtSessionGetInputCount));
|
||||
OrtSessionGetOutputCount = (DOrtSessionGetOutputCount)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOutputCount, typeof(DOrtSessionGetOutputCount));
|
||||
OrtSessionGetOverridableInitializerCount = (DOrtSessionGetOverridableInitializerCount)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOverridableInitializerCount, typeof(DOrtSessionGetOverridableInitializerCount));
|
||||
|
||||
OrtSessionGetInputName = (DOrtSessionGetInputName)Marshal.GetDelegateForFunctionPointer(api_.SessionGetInputName, typeof(DOrtSessionGetInputName));
|
||||
OrtSessionGetOutputName = (DOrtSessionGetOutputName)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOutputName, typeof(DOrtSessionGetOutputName));
|
||||
OrtSessionGetOverridableInitializerName = (DOrtSessionGetOverridableInitializerName)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOverridableInitializerName, typeof(DOrtSessionGetOverridableInitializerName));
|
||||
OrtSessionGetInputTypeInfo = (DOrtSessionGetInputTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.SessionGetInputTypeInfo, typeof(DOrtSessionGetInputTypeInfo));
|
||||
OrtSessionGetOutputTypeInfo = (DOrtSessionGetOutputTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOutputTypeInfo, typeof(DOrtSessionGetOutputTypeInfo));
|
||||
OrtSessionGetOverridableInitializerTypeInfo = (DOrtSessionGetOverridableInitializerTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOverridableInitializerTypeInfo, typeof(DOrtSessionGetOverridableInitializerTypeInfo));
|
||||
|
||||
OrtReleaseTypeInfo = (DOrtReleaseTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.ReleaseTypeInfo, typeof(DOrtReleaseTypeInfo));
|
||||
OrtReleaseSession = (DOrtReleaseSession)Marshal.GetDelegateForFunctionPointer(api_.ReleaseSession, typeof(DOrtReleaseSession));
|
||||
|
||||
|
|
@ -273,6 +278,11 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
out UIntPtr count);
|
||||
public static DOrtSessionGetOutputCount OrtSessionGetOutputCount;
|
||||
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetOverridableInitializerCount(
|
||||
IntPtr /*(OrtSession*)*/ session,
|
||||
out UIntPtr count);
|
||||
public static DOrtSessionGetOverridableInitializerCount OrtSessionGetOverridableInitializerCount;
|
||||
|
||||
public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetInputName(
|
||||
IntPtr /*(OrtSession*)*/ session,
|
||||
UIntPtr index,
|
||||
|
|
@ -287,6 +297,13 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
out IntPtr /*(char**)*/name);
|
||||
public static DOrtSessionGetOutputName OrtSessionGetOutputName;
|
||||
|
||||
public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetOverridableInitializerName(
|
||||
IntPtr /*(OrtSession*)*/ session,
|
||||
UIntPtr index,
|
||||
IntPtr /*(OrtAllocator*)*/ allocator,
|
||||
out IntPtr /*(char**)*/name);
|
||||
public static DOrtSessionGetOverridableInitializerName OrtSessionGetOverridableInitializerName;
|
||||
|
||||
// release the typeinfo using OrtReleaseTypeInfo
|
||||
public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetInputTypeInfo(
|
||||
IntPtr /*(const OrtSession*)*/ session,
|
||||
|
|
@ -301,6 +318,14 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
out IntPtr /* (struct OrtTypeInfo**)*/ typeInfo);
|
||||
public static DOrtSessionGetOutputTypeInfo OrtSessionGetOutputTypeInfo;
|
||||
|
||||
// release the typeinfo using OrtReleaseTypeInfo
|
||||
public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetOverridableInitializerTypeInfo(
|
||||
IntPtr /*(const OrtSession*)*/ session,
|
||||
UIntPtr index,
|
||||
out IntPtr /* (struct OrtTypeInfo**)*/ typeInfo);
|
||||
public static DOrtSessionGetOverridableInitializerTypeInfo OrtSessionGetOverridableInitializerTypeInfo;
|
||||
|
||||
|
||||
public delegate void DOrtReleaseTypeInfo(IntPtr /*(OrtTypeInfo*)*/session);
|
||||
public static DOrtReleaseTypeInfo OrtReleaseTypeInfo;
|
||||
|
||||
|
|
|
|||
|
|
@ -397,6 +397,55 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
} //opset
|
||||
}
|
||||
|
||||
[Fact]
|
||||
private void TestOverridableInitializerMetadata()
|
||||
{
|
||||
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "overridable_initializer.onnx");
|
||||
using (var session = new InferenceSession(modelPath))
|
||||
{
|
||||
Assert.Equal(2, session.InputMetadata.Count);
|
||||
Assert.True(session.InputMetadata.ContainsKey("Label"));
|
||||
Assert.True(session.InputMetadata.ContainsKey("F2"));
|
||||
|
||||
Assert.Equal(1, session.OverridableInitializerMetadata.Count);
|
||||
Assert.True(session.OverridableInitializerMetadata.ContainsKey("F1"));
|
||||
Assert.True(session.OverridableInitializerMetadata["F1"].IsTensor);
|
||||
Assert.Equal(typeof(float), session.OverridableInitializerMetadata["F1"].ElementType);
|
||||
Assert.Equal(2, session.OverridableInitializerMetadata["F1"].Dimensions.Length);
|
||||
Assert.Equal(1, session.OverridableInitializerMetadata["F1"].Dimensions[0]);
|
||||
Assert.Equal(1, session.OverridableInitializerMetadata["F1"].Dimensions[1]);
|
||||
|
||||
var container = new List<NamedOnnxValue>();
|
||||
var Label_input = new DenseTensor<bool>(new bool[] { true }, new int[] { 1, 1 });
|
||||
container.Add(NamedOnnxValue.CreateFromTensor("Label", Label_input));
|
||||
|
||||
var F2_input = new DenseTensor<string>(new string[] { "f2_string" }, new int[] { 1, 1 });
|
||||
container.Add(NamedOnnxValue.CreateFromTensor("F2", F2_input));
|
||||
|
||||
var F1_initializer = new DenseTensor<float>(new float[] { 2.0f }, new int[] { 1, 1 });
|
||||
container.Add(NamedOnnxValue.CreateFromTensor("F1", F1_initializer));
|
||||
|
||||
using (var result = session.Run(container))
|
||||
{
|
||||
var resultMap = new Dictionary<string, NamedOnnxValue>();
|
||||
|
||||
foreach (var output in result)
|
||||
{
|
||||
resultMap[output.Name] = output;
|
||||
}
|
||||
|
||||
Assert.True(resultMap.ContainsKey("Label0"));
|
||||
Assert.True(resultMap.ContainsKey("F20"));
|
||||
Assert.True(resultMap.ContainsKey("F11"));
|
||||
|
||||
var overriddenInitializer = resultMap["F11"].AsTensor<float>();
|
||||
Assert.NotNull(overriddenInitializer);
|
||||
Assert.True(overriddenInitializer.SequenceEqual(F1_initializer));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
[Fact]
|
||||
private void TestModelInputFloat()
|
||||
{
|
||||
|
|
|
|||
|
|
@ -71,7 +71,11 @@
|
|||
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
|
||||
<Visible>false</Visible>
|
||||
</None>
|
||||
|
||||
<None Include="$(OnnxRuntimeCSharpRoot)\..\onnxruntime\test\testdata\overridable_initializer.onnx">
|
||||
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
|
||||
<Visible>false</Visible>
|
||||
</None>
|
||||
|
||||
<BuildEnvVars Include="OnnxRuntimeBuildDirectory=$(OnnxRuntimeBuildDirectory)" />
|
||||
</ItemGroup>
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue