From 034aa8016767db946887d2ae53b05ddee76592d3 Mon Sep 17 00:00:00 2001
From: yeohan <35736802+yeohan@users.noreply.github.com>
Date: Wed, 25 Sep 2019 03:59:04 +0900
Subject: [PATCH] InferenceSession ctor with byte array in C# (#1883)
* add ctor overloads that accept model byte array
* doxygen. mark Init method as private.
* doxygen
* rename test method for clarity
* PR feedback - add two overloads that accept either model path or model byte array
* update native signature to align with latest codebase
* fix native call
---
.../InferenceSession.cs | 53 ++++++++++++++++---
.../Microsoft.ML.OnnxRuntime/NativeMethods.cs | 9 ++++
.../InferenceTest.cs | 20 +++++++
3 files changed, 76 insertions(+), 6 deletions(-)
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs
index 7964302956..77b5b4c624 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs
@@ -46,6 +46,25 @@ namespace Microsoft.ML.OnnxRuntime
Init(modelPath, options);
}
+ ///
+ /// Constructs an InferenceSession from a model data in byte array
+ ///
+ ///
+ public InferenceSession(byte[] model)
+ {
+ _builtInSessionOptions = new SessionOptions(); // need to be disposed
+ Init(model, _builtInSessionOptions);
+ }
+
+ ///
+ /// Constructs an InferenceSession from a model data in byte array, with some additional session options
+ ///
+ ///
+ ///
+ public InferenceSession(byte[] model, SessionOptions options)
+ {
+ Init(model, options);
+ }
///
/// Meta data regarding the input nodes, keyed by input names
@@ -182,17 +201,39 @@ namespace Microsoft.ML.OnnxRuntime
#region private methods
- protected void Init(string modelPath, SessionOptions options)
+ private void Init(string modelPath, SessionOptions options)
{
var envHandle = OnnxRuntime.Handle;
+ var session = IntPtr.Zero;
- _nativeHandle = IntPtr.Zero;
+ if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.Unicode.GetBytes(modelPath), options.Handle, out session));
+ else
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.UTF8.GetBytes(modelPath), options.Handle, out session));
+
+ InitWithSessionHandle(session, options);
+ }
+
+ private void Init(byte[] modelData, SessionOptions options)
+ {
+ var envHandle = OnnxRuntime.Handle;
+ var session = IntPtr.Zero;
+
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSessionFromArray(envHandle, modelData, (UIntPtr)modelData.Length, options.Handle, out session));
+
+ InitWithSessionHandle(session, options);
+ }
+
+ ///
+ /// Initializes the session object with a native session handle
+ ///
+ /// Handle of a native session object
+ /// Session options
+ private void InitWithSessionHandle(IntPtr session, SessionOptions options)
+ {
+ _nativeHandle = session;
try
{
- if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
- NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.Unicode.GetBytes(modelPath), options.Handle, out _nativeHandle));
- else
- NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.UTF8.GetBytes(modelPath), options.Handle, out _nativeHandle));
// Initialize input/output metadata
_inputMetadata = new Dictionary();
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
index 8813011450..e3ccbb8844 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
@@ -137,6 +137,7 @@ namespace Microsoft.ML.OnnxRuntime
OrtReleaseStatus = (DOrtReleaseStatus)Marshal.GetDelegateForFunctionPointer(api_.ReleaseStatus, typeof(DOrtReleaseStatus));
OrtCreateSession = (DOrtCreateSession)Marshal.GetDelegateForFunctionPointer(api_.CreateSession, typeof(DOrtCreateSession));
+ OrtCreateSessionFromArray = (DOrtCreateSessionFromArray)Marshal.GetDelegateForFunctionPointer(api_.CreateSessionFromArray, typeof(DOrtCreateSessionFromArray));
OrtRun = (DOrtRun)Marshal.GetDelegateForFunctionPointer(api_.Run, typeof(DOrtRun));
OrtSessionGetInputCount = (DOrtSessionGetInputCount)Marshal.GetDelegateForFunctionPointer(api_.SessionGetInputCount, typeof(DOrtSessionGetInputCount));
OrtSessionGetOutputCount = (DOrtSessionGetOutputCount)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOutputCount, typeof(DOrtSessionGetOutputCount));
@@ -242,6 +243,14 @@ namespace Microsoft.ML.OnnxRuntime
out IntPtr /**/ session);
public static DOrtCreateSession OrtCreateSession;
+ public delegate IntPtr /* OrtStatus* */DOrtCreateSessionFromArray(
+ IntPtr /* (OrtEnv*) */ environment,
+ byte[] modelData,
+ UIntPtr modelSize,
+ IntPtr /* (OrtSessionOptions*) */sessionOptions,
+ out IntPtr /**/ session);
+ public static DOrtCreateSessionFromArray OrtCreateSessionFromArray;
+
public delegate IntPtr /*(ONNStatus*)*/ DOrtRun(
IntPtr /*(OrtSession*)*/ session,
IntPtr /*(OrtSessionRunOptions*)*/ runOptions, // can be null to use the default options
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
index bc3a34ad71..4257d5e2b1 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
@@ -792,6 +792,26 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
}
+ [Fact]
+ private void TestInferenceSessionWithByteArray()
+ {
+ // model takes 1x5 input of fixed type, echoes back
+ string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_types_FLOAT.pb");
+ byte[] modelData = File.ReadAllBytes(modelPath);
+
+ using (var session = new InferenceSession(modelData))
+ {
+ var container = new List();
+ var tensorIn = new DenseTensor(new float[] { 1.0f, 2.0f, -3.0f, float.MinValue, float.MaxValue }, new int[] { 1, 5 });
+ var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn);
+ container.Add(nov);
+ using (var res = session.Run(container))
+ {
+ var tensorOut = res.First().AsTensor();
+ Assert.True(tensorOut.SequenceEqual(tensorIn));
+ }
+ }
+ }
[DllImport("kernel32", SetLastError = true)]
static extern IntPtr LoadLibrary(string lpFileName);