diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs index 7964302956..77b5b4c624 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs @@ -46,6 +46,25 @@ namespace Microsoft.ML.OnnxRuntime Init(modelPath, options); } + /// + /// Constructs an InferenceSession from a model data in byte array + /// + /// + public InferenceSession(byte[] model) + { + _builtInSessionOptions = new SessionOptions(); // need to be disposed + Init(model, _builtInSessionOptions); + } + + /// + /// Constructs an InferenceSession from a model data in byte array, with some additional session options + /// + /// + /// + public InferenceSession(byte[] model, SessionOptions options) + { + Init(model, options); + } /// /// Meta data regarding the input nodes, keyed by input names @@ -182,17 +201,39 @@ namespace Microsoft.ML.OnnxRuntime #region private methods - protected void Init(string modelPath, SessionOptions options) + private void Init(string modelPath, SessionOptions options) { var envHandle = OnnxRuntime.Handle; + var session = IntPtr.Zero; - _nativeHandle = IntPtr.Zero; + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.Unicode.GetBytes(modelPath), options.Handle, out session)); + else + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.UTF8.GetBytes(modelPath), options.Handle, out session)); + + InitWithSessionHandle(session, options); + } + + private void Init(byte[] modelData, SessionOptions options) + { + var envHandle = OnnxRuntime.Handle; + var session = IntPtr.Zero; + + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSessionFromArray(envHandle, modelData, (UIntPtr)modelData.Length, options.Handle, out session)); + + InitWithSessionHandle(session, options); + } + + /// + /// Initializes the session object with a native session handle + /// + /// Handle of a native session object + /// Session options + private void InitWithSessionHandle(IntPtr session, SessionOptions options) + { + _nativeHandle = session; try { - if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.Unicode.GetBytes(modelPath), options.Handle, out _nativeHandle)); - else - NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.UTF8.GetBytes(modelPath), options.Handle, out _nativeHandle)); // Initialize input/output metadata _inputMetadata = new Dictionary(); diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs index 8813011450..e3ccbb8844 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs @@ -137,6 +137,7 @@ namespace Microsoft.ML.OnnxRuntime OrtReleaseStatus = (DOrtReleaseStatus)Marshal.GetDelegateForFunctionPointer(api_.ReleaseStatus, typeof(DOrtReleaseStatus)); OrtCreateSession = (DOrtCreateSession)Marshal.GetDelegateForFunctionPointer(api_.CreateSession, typeof(DOrtCreateSession)); + OrtCreateSessionFromArray = (DOrtCreateSessionFromArray)Marshal.GetDelegateForFunctionPointer(api_.CreateSessionFromArray, typeof(DOrtCreateSessionFromArray)); OrtRun = (DOrtRun)Marshal.GetDelegateForFunctionPointer(api_.Run, typeof(DOrtRun)); OrtSessionGetInputCount = (DOrtSessionGetInputCount)Marshal.GetDelegateForFunctionPointer(api_.SessionGetInputCount, typeof(DOrtSessionGetInputCount)); OrtSessionGetOutputCount = (DOrtSessionGetOutputCount)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOutputCount, typeof(DOrtSessionGetOutputCount)); @@ -242,6 +243,14 @@ namespace Microsoft.ML.OnnxRuntime out IntPtr /**/ session); public static DOrtCreateSession OrtCreateSession; + public delegate IntPtr /* OrtStatus* */DOrtCreateSessionFromArray( + IntPtr /* (OrtEnv*) */ environment, + byte[] modelData, + UIntPtr modelSize, + IntPtr /* (OrtSessionOptions*) */sessionOptions, + out IntPtr /**/ session); + public static DOrtCreateSessionFromArray OrtCreateSessionFromArray; + public delegate IntPtr /*(ONNStatus*)*/ DOrtRun( IntPtr /*(OrtSession*)*/ session, IntPtr /*(OrtSessionRunOptions*)*/ runOptions, // can be null to use the default options diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs index bc3a34ad71..4257d5e2b1 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs @@ -792,6 +792,26 @@ namespace Microsoft.ML.OnnxRuntime.Tests } } + [Fact] + private void TestInferenceSessionWithByteArray() + { + // model takes 1x5 input of fixed type, echoes back + string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_types_FLOAT.pb"); + byte[] modelData = File.ReadAllBytes(modelPath); + + using (var session = new InferenceSession(modelData)) + { + var container = new List(); + var tensorIn = new DenseTensor(new float[] { 1.0f, 2.0f, -3.0f, float.MinValue, float.MaxValue }, new int[] { 1, 5 }); + var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn); + container.Add(nov); + using (var res = session.Run(container)) + { + var tensorOut = res.First().AsTensor(); + Assert.True(tensorOut.SequenceEqual(tensorIn)); + } + } + } [DllImport("kernel32", SetLastError = true)] static extern IntPtr LoadLibrary(string lpFileName);