mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
InferenceSession ctor with byte array in C# (#1883)
* add ctor overloads that accept model byte array * doxygen. mark Init method as private. * doxygen * rename test method for clarity * PR feedback - add two overloads that accept either model path or model byte array * update native signature to align with latest codebase * fix native call
This commit is contained in:
parent
294db0f978
commit
034aa80167
3 changed files with 76 additions and 6 deletions
|
|
@ -46,6 +46,25 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
Init(modelPath, options);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Constructs an InferenceSession from a model data in byte array
|
||||
/// </summary>
|
||||
/// <param name="model"></param>
|
||||
public InferenceSession(byte[] model)
|
||||
{
|
||||
_builtInSessionOptions = new SessionOptions(); // need to be disposed
|
||||
Init(model, _builtInSessionOptions);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Constructs an InferenceSession from a model data in byte array, with some additional session options
|
||||
/// </summary>
|
||||
/// <param name="model"></param>
|
||||
/// <param name="options"></param>
|
||||
public InferenceSession(byte[] model, SessionOptions options)
|
||||
{
|
||||
Init(model, options);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Meta data regarding the input nodes, keyed by input names
|
||||
|
|
@ -182,17 +201,39 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
|
||||
#region private methods
|
||||
|
||||
protected void Init(string modelPath, SessionOptions options)
|
||||
private void Init(string modelPath, SessionOptions options)
|
||||
{
|
||||
var envHandle = OnnxRuntime.Handle;
|
||||
var session = IntPtr.Zero;
|
||||
|
||||
_nativeHandle = IntPtr.Zero;
|
||||
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.Unicode.GetBytes(modelPath), options.Handle, out session));
|
||||
else
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.UTF8.GetBytes(modelPath), options.Handle, out session));
|
||||
|
||||
InitWithSessionHandle(session, options);
|
||||
}
|
||||
|
||||
private void Init(byte[] modelData, SessionOptions options)
|
||||
{
|
||||
var envHandle = OnnxRuntime.Handle;
|
||||
var session = IntPtr.Zero;
|
||||
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSessionFromArray(envHandle, modelData, (UIntPtr)modelData.Length, options.Handle, out session));
|
||||
|
||||
InitWithSessionHandle(session, options);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes the session object with a native session handle
|
||||
/// </summary>
|
||||
/// <param name="session">Handle of a native session object</param>
|
||||
/// <param name="options">Session options</param>
|
||||
private void InitWithSessionHandle(IntPtr session, SessionOptions options)
|
||||
{
|
||||
_nativeHandle = session;
|
||||
try
|
||||
{
|
||||
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.Unicode.GetBytes(modelPath), options.Handle, out _nativeHandle));
|
||||
else
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.UTF8.GetBytes(modelPath), options.Handle, out _nativeHandle));
|
||||
|
||||
// Initialize input/output metadata
|
||||
_inputMetadata = new Dictionary<string, NodeMetadata>();
|
||||
|
|
|
|||
|
|
@ -137,6 +137,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
OrtReleaseStatus = (DOrtReleaseStatus)Marshal.GetDelegateForFunctionPointer(api_.ReleaseStatus, typeof(DOrtReleaseStatus));
|
||||
|
||||
OrtCreateSession = (DOrtCreateSession)Marshal.GetDelegateForFunctionPointer(api_.CreateSession, typeof(DOrtCreateSession));
|
||||
OrtCreateSessionFromArray = (DOrtCreateSessionFromArray)Marshal.GetDelegateForFunctionPointer(api_.CreateSessionFromArray, typeof(DOrtCreateSessionFromArray));
|
||||
OrtRun = (DOrtRun)Marshal.GetDelegateForFunctionPointer(api_.Run, typeof(DOrtRun));
|
||||
OrtSessionGetInputCount = (DOrtSessionGetInputCount)Marshal.GetDelegateForFunctionPointer(api_.SessionGetInputCount, typeof(DOrtSessionGetInputCount));
|
||||
OrtSessionGetOutputCount = (DOrtSessionGetOutputCount)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOutputCount, typeof(DOrtSessionGetOutputCount));
|
||||
|
|
@ -242,6 +243,14 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
out IntPtr /**/ session);
|
||||
public static DOrtCreateSession OrtCreateSession;
|
||||
|
||||
public delegate IntPtr /* OrtStatus* */DOrtCreateSessionFromArray(
|
||||
IntPtr /* (OrtEnv*) */ environment,
|
||||
byte[] modelData,
|
||||
UIntPtr modelSize,
|
||||
IntPtr /* (OrtSessionOptions*) */sessionOptions,
|
||||
out IntPtr /**/ session);
|
||||
public static DOrtCreateSessionFromArray OrtCreateSessionFromArray;
|
||||
|
||||
public delegate IntPtr /*(ONNStatus*)*/ DOrtRun(
|
||||
IntPtr /*(OrtSession*)*/ session,
|
||||
IntPtr /*(OrtSessionRunOptions*)*/ runOptions, // can be null to use the default options
|
||||
|
|
|
|||
|
|
@ -792,6 +792,26 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
private void TestInferenceSessionWithByteArray()
|
||||
{
|
||||
// model takes 1x5 input of fixed type, echoes back
|
||||
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_types_FLOAT.pb");
|
||||
byte[] modelData = File.ReadAllBytes(modelPath);
|
||||
|
||||
using (var session = new InferenceSession(modelData))
|
||||
{
|
||||
var container = new List<NamedOnnxValue>();
|
||||
var tensorIn = new DenseTensor<float>(new float[] { 1.0f, 2.0f, -3.0f, float.MinValue, float.MaxValue }, new int[] { 1, 5 });
|
||||
var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn);
|
||||
container.Add(nov);
|
||||
using (var res = session.Run(container))
|
||||
{
|
||||
var tensorOut = res.First().AsTensor<float>();
|
||||
Assert.True(tensorOut.SequenceEqual(tensorIn));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[DllImport("kernel32", SetLastError = true)]
|
||||
static extern IntPtr LoadLibrary(string lpFileName);
|
||||
|
|
|
|||
Loading…
Reference in a new issue