InferenceSession ctor with byte array in C# (#1883)

* add ctor overloads that accept model byte array

* doxygen. mark Init method as private.

* doxygen

* rename test method for clarity

* PR feedback - add two overloads that accept either model path or model byte array

* update native signature to align with latest codebase

* fix native call
This commit is contained in:
yeohan 2019-09-25 03:59:04 +09:00 committed by Pranav Sharma
parent 294db0f978
commit 034aa80167
3 changed files with 76 additions and 6 deletions

View file

@ -46,6 +46,25 @@ namespace Microsoft.ML.OnnxRuntime
Init(modelPath, options);
}
/// <summary>
/// Constructs an InferenceSession from a model data in byte array
/// </summary>
/// <param name="model"></param>
public InferenceSession(byte[] model)
{
_builtInSessionOptions = new SessionOptions(); // need to be disposed
Init(model, _builtInSessionOptions);
}
/// <summary>
/// Constructs an InferenceSession from a model data in byte array, with some additional session options
/// </summary>
/// <param name="model"></param>
/// <param name="options"></param>
public InferenceSession(byte[] model, SessionOptions options)
{
Init(model, options);
}
/// <summary>
/// Meta data regarding the input nodes, keyed by input names
@ -182,17 +201,39 @@ namespace Microsoft.ML.OnnxRuntime
#region private methods
protected void Init(string modelPath, SessionOptions options)
private void Init(string modelPath, SessionOptions options)
{
var envHandle = OnnxRuntime.Handle;
var session = IntPtr.Zero;
_nativeHandle = IntPtr.Zero;
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.Unicode.GetBytes(modelPath), options.Handle, out session));
else
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.UTF8.GetBytes(modelPath), options.Handle, out session));
InitWithSessionHandle(session, options);
}
private void Init(byte[] modelData, SessionOptions options)
{
var envHandle = OnnxRuntime.Handle;
var session = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSessionFromArray(envHandle, modelData, (UIntPtr)modelData.Length, options.Handle, out session));
InitWithSessionHandle(session, options);
}
/// <summary>
/// Initializes the session object with a native session handle
/// </summary>
/// <param name="session">Handle of a native session object</param>
/// <param name="options">Session options</param>
private void InitWithSessionHandle(IntPtr session, SessionOptions options)
{
_nativeHandle = session;
try
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.Unicode.GetBytes(modelPath), options.Handle, out _nativeHandle));
else
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.UTF8.GetBytes(modelPath), options.Handle, out _nativeHandle));
// Initialize input/output metadata
_inputMetadata = new Dictionary<string, NodeMetadata>();

View file

@ -137,6 +137,7 @@ namespace Microsoft.ML.OnnxRuntime
OrtReleaseStatus = (DOrtReleaseStatus)Marshal.GetDelegateForFunctionPointer(api_.ReleaseStatus, typeof(DOrtReleaseStatus));
OrtCreateSession = (DOrtCreateSession)Marshal.GetDelegateForFunctionPointer(api_.CreateSession, typeof(DOrtCreateSession));
OrtCreateSessionFromArray = (DOrtCreateSessionFromArray)Marshal.GetDelegateForFunctionPointer(api_.CreateSessionFromArray, typeof(DOrtCreateSessionFromArray));
OrtRun = (DOrtRun)Marshal.GetDelegateForFunctionPointer(api_.Run, typeof(DOrtRun));
OrtSessionGetInputCount = (DOrtSessionGetInputCount)Marshal.GetDelegateForFunctionPointer(api_.SessionGetInputCount, typeof(DOrtSessionGetInputCount));
OrtSessionGetOutputCount = (DOrtSessionGetOutputCount)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOutputCount, typeof(DOrtSessionGetOutputCount));
@ -242,6 +243,14 @@ namespace Microsoft.ML.OnnxRuntime
out IntPtr /**/ session);
public static DOrtCreateSession OrtCreateSession;
public delegate IntPtr /* OrtStatus* */DOrtCreateSessionFromArray(
IntPtr /* (OrtEnv*) */ environment,
byte[] modelData,
UIntPtr modelSize,
IntPtr /* (OrtSessionOptions*) */sessionOptions,
out IntPtr /**/ session);
public static DOrtCreateSessionFromArray OrtCreateSessionFromArray;
public delegate IntPtr /*(ONNStatus*)*/ DOrtRun(
IntPtr /*(OrtSession*)*/ session,
IntPtr /*(OrtSessionRunOptions*)*/ runOptions, // can be null to use the default options

View file

@ -792,6 +792,26 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
}
[Fact]
private void TestInferenceSessionWithByteArray()
{
// model takes 1x5 input of fixed type, echoes back
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_types_FLOAT.pb");
byte[] modelData = File.ReadAllBytes(modelPath);
using (var session = new InferenceSession(modelData))
{
var container = new List<NamedOnnxValue>();
var tensorIn = new DenseTensor<float>(new float[] { 1.0f, 2.0f, -3.0f, float.MinValue, float.MaxValue }, new int[] { 1, 5 });
var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn);
container.Add(nov);
using (var res = session.Run(container))
{
var tensorOut = res.First().AsTensor<float>();
Assert.True(tensorOut.SequenceEqual(tensorIn));
}
}
}
[DllImport("kernel32", SetLastError = true)]
static extern IntPtr LoadLibrary(string lpFileName);