// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. using Microsoft.ML.OnnxRuntime.Tensors; using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Net.NetworkInformation; using System.Runtime.InteropServices; using System.Threading.Tasks; namespace Microsoft.ML.OnnxRuntime { /// /// Represents an Inference Session on an ONNX Model. /// This is a IDisposable class and it must be disposed of /// using either a explicit call to Dispose() method or /// a pattern of using() block. If this is a member of another /// class that class must also become IDisposable and it must /// dispose of InferenceSession in its Dispose() method. /// public class InferenceSession : IDisposable { /// /// A pointer to a underlying native instance of OrtSession /// private IntPtr _nativeHandle; /// /// Dictionary that represents input metadata /// private Dictionary _inputMetadata; /// /// Ordered list of input names /// private List _inputNames; /// /// Dictionary that represent output metadata /// private Dictionary _outputMetadata; /// /// Ordered list of output names /// private List _outputNames; /// /// Dictionary that represents overridableInitializers metadata /// private Dictionary _overridableInitializerMetadata; /// /// This list holds Utf-8 converted input/output names allocated from a native heap /// and as such do not require pinning. It must be disposed of (freed). /// /// Introduced to reduce the GC burden as the names are used in every Run() call. /// private List _namesMemoryPtrs; private SessionOptions _builtInSessionOptions = null; private RunOptions _builtInRunOptions = null; private ModelMetadata _modelMetadata = null; private bool _disposed = false; private ulong _profilingStartTimeNs = 0; #region Public API /// /// Constructs an InferenceSession from a model file /// /// public InferenceSession(string modelPath) { _builtInSessionOptions = new SessionOptions(); // need to be disposed Init(modelPath, _builtInSessionOptions); } /// /// Constructs an InferenceSession from a model file and it will use /// the provided pre-packed weights container to store and share pre-packed buffers /// of shared initializers across sessions if any. /// /// Model path /// Instance of PrepackedWeightsContainer. /// Lifetime of 'prepackedWeightsContainer' must be /// managed by the user and it must outlive any sessions reliant on it public InferenceSession(string modelPath, PrePackedWeightsContainer prepackedWeightsContainer) { _builtInSessionOptions = new SessionOptions(); // need to be disposed Init(modelPath, _builtInSessionOptions, prepackedWeightsContainer); } /// /// Constructs an InferenceSession from a model file, with some additional session options /// /// /// public InferenceSession(string modelPath, SessionOptions options) { Init(modelPath, options); } /// /// Constructs an InferenceSession from a model file, with some additional session options /// and it will use the provided pre-packed weights container to store and share pre-packed buffers /// of shared initializers across sessions if any. /// /// Model path /// Session options /// Instance of PrepackedWeightsContainer. /// Lifetime of 'prepackedWeightsContainer' must be /// managed by the user and it must outlive any sessions reliant on it public InferenceSession(string modelPath, SessionOptions options, PrePackedWeightsContainer prepackedWeightsContainer) { Init(modelPath, options, prepackedWeightsContainer); } /// /// Constructs an InferenceSession from a model data in byte array /// /// public InferenceSession(byte[] model) { _builtInSessionOptions = new SessionOptions(); // need to be disposed Init(model, _builtInSessionOptions); } /// /// Constructs an InferenceSession from a model data (in byte array) and it will use /// the provided pre-packed weights container to store and share pre-packed buffers /// of shared initializers across sessions if any. /// /// Model as byte array /// Instance of PrepackedWeightsContainer. /// Lifetime of 'prepackedWeightsContainer' must be /// managed by the user and it must outlive any sessions reliant on it public InferenceSession(byte[] model, PrePackedWeightsContainer prepackedWeightsContainer) { _builtInSessionOptions = new SessionOptions(); // need to be disposed Init(model, _builtInSessionOptions, prepackedWeightsContainer); } /// /// Constructs an InferenceSession from a model data in byte array, with some additional session options /// /// /// public InferenceSession(byte[] model, SessionOptions options) { Init(model, options); } /// /// Constructs an InferenceSession from a model data (in byte array) with some additional /// session options and it will use the provided pre-packed weights container to store /// and share pre-packed buffers of shared initializers across sessions if any. /// /// Model as byte array /// Session Options /// Instance of PrepackedWeightsContainer. /// Lifetime of 'prepackedWeightsContainer' must be /// managed by the user and it must outlive any sessions reliant on it public InferenceSession(byte[] model, SessionOptions options, PrePackedWeightsContainer prepackedWeightsContainer) { Init(model, options, prepackedWeightsContainer); } /// /// Meta data regarding the input nodes, keyed by input names /// public IReadOnlyDictionary InputMetadata { get { return _inputMetadata; } } /// /// Ordered list of input names that can be accessed by index; /// public IReadOnlyList InputNames { get { return _inputNames; } } /// /// Metadata regarding the output nodes, keyed by output names /// public IReadOnlyDictionary OutputMetadata { get { return _outputMetadata; } } /// /// Ordered list of output names that can be accessed by index. /// public IReadOnlyList OutputNames { get { return _outputNames; } } /// /// Metadata regarding the overridable initializers, keyed by node names /// public IReadOnlyDictionary OverridableInitializerMetadata { get { return _overridableInitializerMetadata; } } /// /// Runs the loaded model for the given inputs, and fetches all the outputs. /// /// specify a collection of that indicates the input values. /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. public IDisposableReadOnlyCollection Run(IReadOnlyCollection inputs) { return Run(inputs, _outputNames); } /// /// Runs the loaded model for the given inputs, and fetches the outputs specified in . /// /// Specify a collection of that indicates the input values. /// Specify a collection of string that indicates the output names to fetch. /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. public IDisposableReadOnlyCollection Run(IReadOnlyCollection inputs, IReadOnlyCollection outputNames) { return Run(inputs, outputNames, _builtInRunOptions); } /// /// Runs the loaded model for the given inputs, and fetches the specified outputs in . Uses the given RunOptions for this run. /// /// Specify a collection of that indicates the input values. /// Specify a collection of string that indicates the output names to fetch. /// /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. public IDisposableReadOnlyCollection Run(IReadOnlyCollection inputs, IReadOnlyCollection outputNames, RunOptions options) { var inputNamesArray = LookupUtf8Names(inputs, v => v.Name, LookupInputMetadata); var outputNamesArray = LookupUtf8Names(outputNames, n => n, LookupOutputMetadata); var inputValuesArray = GetOrtValuesHandles(inputs, LookupInputMetadata, ExtractOrtValueHandleForInput, out DisposableArray inputsDisposer); try { var outputsDisposer = RunImpl(options, inputNamesArray, inputValuesArray, outputNamesArray); try { return CreateDisposableResult(outputsDisposer.Span, outputNames); } finally { outputsDisposer.Dispose(); } } finally { inputsDisposer.Dispose(); } } /// /// Runs the loaded model for the given inputs, and fetches all the outputs. /// /// Specify a collection of string that indicates the input names. Should match . /// Specify a collection of that indicates the input values. /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. public IDisposableReadOnlyCollection Run( IReadOnlyCollection inputNames, IReadOnlyCollection inputValues) { return Run(inputNames, inputValues, _outputNames, _builtInRunOptions); } /// /// Runs the loaded model for the given inputs, and fetches the outputs specified in . /// /// Specify a collection of string that indicates the input names. Should match . /// Specify a collection of that indicates the input values. /// Specify a collection of string that indicates the output names to fetch. /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. public IDisposableReadOnlyCollection Run( IReadOnlyCollection inputNames, IReadOnlyCollection inputValues, IReadOnlyCollection outputNames) { return Run(inputNames, inputValues, outputNames, _builtInRunOptions); } /// /// Runs the loaded model for the given inputs, and fetches the specified outputs in . Uses the given RunOptions for this run. /// /// Specify a collection of string that indicates the input names. Should match . /// Specify a collection of that indicates the input values. /// Specify a collection of string that indicates the output names to fetch. /// /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. public IDisposableReadOnlyCollection Run( IReadOnlyCollection inputNames, IReadOnlyCollection inputValues, IReadOnlyCollection outputNames, RunOptions options) { if (inputNames.Count != inputValues.Count) { throw new ArgumentException($"Length of {nameof(inputNames)} ({inputNames.Count}) must match that of {nameof(inputValues)} ({inputValues.Count})."); } var inputNamesArray = LookupUtf8Names(inputNames, n => n, LookupInputMetadata); IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true); var outputNamesArray = LookupUtf8Names(outputNames, n => n, LookupOutputMetadata); var disposableHandles = RunImpl(options, inputNamesArray, inputValuesArray, outputNamesArray); try { return CreateDisposableResult(disposableHandles.Span, outputNames); } finally { disposableHandles.Dispose(); } } /// /// Runs the loaded model for the given inputs and outputs. /// /// Outputs need to be created with correct type and dimension to accept the fetched data. /// /// Specify a collection of string that indicates the input names. Should match . /// Specify a collection of that indicates the input values. /// Specify a collection of string that indicates the output names. Should match . /// Specify a collection of that indicates the output values. public void Run( IReadOnlyCollection inputNames, IReadOnlyCollection inputValues, IReadOnlyCollection outputNames, IReadOnlyCollection outputValues) { Run(inputNames, inputValues, outputNames, outputValues, _builtInRunOptions); } /// /// Runs the loaded model for the given inputs and outputs. Uses the given RunOptions for this run. /// /// Outputs need to be created with correct type and dimension to accept the fetched data. /// /// Specify a collection of string that indicates the input names. Should match . /// Specify a collection of that indicates the input values. /// Specify a collection of string that indicates the output names. Should match . /// Specify a collection of that indicates the output values. /// public void Run( IReadOnlyCollection inputNames, IReadOnlyCollection inputValues, IReadOnlyCollection outputNames, IReadOnlyCollection outputValues, RunOptions options) { if (inputNames.Count != inputValues.Count) { throw new ArgumentException($"Length of {nameof(inputNames)} ({inputNames.Count}) must match that of {nameof(inputValues)} ({inputValues.Count})."); } if (outputNames.Count != outputValues.Count) { throw new ArgumentException($"Length of {nameof(outputNames)} ({outputNames.Count}) must match that of {nameof(outputValues)} ({outputValues.Count})."); } // prepare inputs var inputNamesArray = LookupUtf8Names(inputNames, n => n, LookupInputMetadata); IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true); // prepare outputs var outputNamesArray = LookupUtf8Names(outputNames, n => n, LookupOutputMetadata); IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); NativeApiStatus.VerifySuccess(NativeMethods.OrtRun( _nativeHandle, options.Handle, inputNamesArray, inputValuesArray, (UIntPtr)inputNames.Count, outputNamesArray, (UIntPtr)outputNames.Count, outputValuesArray /* pointers to Pre-allocated OrtValue instances */ )); } /// /// Runs the loaded model for the given inputs and outputs. /// /// Outputs need to be created with correct type and dimension to receive the fetched data. /// /// Specify a collection of that indicates the input values. /// Specify a collection of that indicates the output values. public void Run( IReadOnlyCollection inputs, IReadOnlyCollection outputs) { Run(inputs, outputs, _builtInRunOptions); } /// /// /// Runs the loaded model for the given inputs and outputs. Uses the given RunOptions for this run. /// /// Outputs need to be created with correct type and dimension to receive the fetched data. /// /// Specify a collection of that indicates the input values. /// Specify a collection of that indicates the output values. /// public void Run( IReadOnlyCollection inputs, IReadOnlyCollection outputs, RunOptions options) { var inputNamesArray = LookupUtf8Names(inputs, i => i.Name, LookupInputMetadata); var outputNamesArray = LookupUtf8Names(outputs, o => o.Name, LookupOutputMetadata); var inputValuesArray = GetOrtValuesHandles(inputs, LookupInputMetadata, ExtractOrtValueHandleForInput, out DisposableArray inputDisposer); try { var outputValuesArray = GetOrtValuesHandles(outputs, LookupOutputMetadata, ExtractOrtValueHandleForOutput, out DisposableArray outputDisposer); try { NativeApiStatus.VerifySuccess(NativeMethods.OrtRun( _nativeHandle, options.Handle, inputNamesArray, inputValuesArray, (UIntPtr)inputs.Count, outputNamesArray, (UIntPtr)outputs.Count, outputValuesArray /* pointers to Pre-allocated OrtValue instances */ )); } finally { outputDisposer.Dispose(); } } finally { inputDisposer.Dispose(); } } /// /// Runs the loaded model for the given inputs and outputs. /// /// Outputs need to be created with correct type and dimension to receive the fetched data. /// /// Specify a collection of that indicates the input values. /// Specify a collection of string that indicates the output names. Should match . /// Specify a collection of that indicates the output values. public void Run( IReadOnlyCollection inputs, IReadOnlyCollection outputNames, IReadOnlyCollection outputValues) { Run(inputs, outputNames, outputValues, _builtInRunOptions); } /// /// Runs the loaded model for the given inputs and outputs. Uses the given RunOptions for this run. /// /// Outputs need to be created with correct type and dimension to receive the fetched data. /// /// Specify a collection of that indicates the input values. /// Specify a collection of string that indicates the output names. Should match . /// Specify a collection of that indicates the output values. /// public void Run( IReadOnlyCollection inputs, IReadOnlyCollection outputNames, IReadOnlyCollection outputValues, RunOptions options) { if (outputNames.Count != outputValues.Count) { throw new ArgumentException($"Length of {nameof(outputNames)} ({outputNames.Count}) must match that of {nameof(outputValues)} ({outputValues.Count})."); } var inputNamesArray = LookupUtf8Names(inputs, i => i.Name, LookupInputMetadata); var outputNamesArray = LookupUtf8Names(outputNames, n => n, LookupOutputMetadata); var outputValuesArray = GetOrtValuesHandles(outputValues, false); var inputValuesArray = GetOrtValuesHandles(inputs, LookupInputMetadata, ExtractOrtValueHandleForInput, out DisposableArray inputsDisposer); try { NativeApiStatus.VerifySuccess(NativeMethods.OrtRun( _nativeHandle, options.Handle, inputNamesArray, inputValuesArray, (UIntPtr)inputs.Count, outputNamesArray, (UIntPtr)outputNames.Count, outputValuesArray /* pointers to Pre-allocated OrtValue instances */ )); } finally { inputsDisposer.Dispose(); } } /// /// /// Runs the loaded model for the given inputs and outputs. /// /// Outputs need to be created with correct type and dimension to receive the fetched data. /// /// Specify a collection of string that indicates the input names. Should match . /// Specify a collection of that indicates the input values. /// Specify a collection of that indicates the output values. public void Run( IReadOnlyCollection inputNames, IReadOnlyCollection inputValues, IReadOnlyCollection outputs) { Run(inputNames, inputValues, outputs, _builtInRunOptions); } /// /// /// Runs the loaded model for the given inputs and outputs. Uses the given RunOptions for this run. /// /// Outputs need to be created with correct type and dimension to receive the fetched data. /// /// Specify a collection of string that indicates the input names. Should match . /// Specify a collection of that indicates the input values. /// Specify a collection of that indicates the output values. /// public void Run( IReadOnlyCollection inputNames, IReadOnlyCollection inputValues, IReadOnlyCollection outputs, RunOptions options) { if (inputNames.Count != inputValues.Count) { throw new ArgumentException($"Length of {nameof(inputNames)} ({inputNames.Count}) must match that of {nameof(inputValues)} ({inputValues.Count})."); } // prepare inputs var inputNamesArray = LookupUtf8Names(inputNames, n => n, LookupInputMetadata); var inputValuesArray = GetOrtValuesHandles(inputValues, true); // prepare outputs var outputNamesArray = LookupUtf8Names(outputs, o => o.Name, LookupOutputMetadata); var outputValuesArray = GetOrtValuesHandles(outputs, LookupOutputMetadata, ExtractOrtValueHandleForOutput, out DisposableArray outputsDisposer); try { NativeApiStatus.VerifySuccess(NativeMethods.OrtRun( _nativeHandle, options.Handle, inputNamesArray, inputValuesArray, (UIntPtr)inputNames.Count, outputNamesArray, (UIntPtr)outputs.Count, outputValuesArray /* pointers to Pre-allocated OrtValue instances */ )); } finally { outputsDisposer.Dispose(); } } /// /// The API runs the inference taking a collection of OrtValues as input and /// returning a collection of output OrtValues. /// /// runOptions /// A collection of input names. /// To supply all names, use InputNames property /// Input OrtValues. The size of the collection must match the size and the order of the inputNames /// Output names requested. To supply all names, use OutputNames property. /// A disposable collection of disposable OrtValues /// public IDisposableReadOnlyCollection Run(RunOptions runOptions, IReadOnlyCollection inputNames, IReadOnlyCollection inputValues, IReadOnlyCollection outputNames) { if (inputNames.Count != inputValues.Count) { throw new ArgumentException($"Length of {nameof(inputNames)} ({inputNames.Count}) must match that of {nameof(inputValues)} ({inputValues.Count})."); } var inputNamesArray = LookupUtf8Names(inputNames, n => n, LookupInputMetadata); var inputHandlesArray = inputValues.Select(v => v.Handle).ToArray(); var outputNamesArray = LookupUtf8Names(outputNames, n => n, LookupOutputMetadata); var disposableHandles = RunImpl(runOptions, inputNamesArray, inputHandlesArray, outputNamesArray); try { return CreateDisposableResult(disposableHandles); } finally { disposableHandles.Dispose(); } } /// /// This API takes inputs as a dictionary of input names paired with input OrtValues /// /// It returns a disposable collection of OrtValues for outputs that were designated by outputNames /// /// /// Dictionary of name/value pairs /// requested outputs. To request all outputs, use OutputNames property of this sessions /// A disposable collection of outputs public IDisposableReadOnlyCollection Run(RunOptions runOptions, IReadOnlyDictionary inputs, IReadOnlyCollection outputNames) { IntPtr[] inputNamesArray = new IntPtr[inputs.Count]; IntPtr[] inputHandlesArray = new IntPtr[inputs.Count]; int count = 0; foreach (var input in inputs) { inputNamesArray[count] = LookupInputMetadata(input.Key).ZeroTerminatedName; inputHandlesArray[count] = input.Value.Handle; ++count; } var outputNamesArray = LookupUtf8Names(outputNames, n => n, LookupOutputMetadata); var disposableHandles = RunImpl(runOptions, inputNamesArray, inputHandlesArray, outputNamesArray); try { return CreateDisposableResult(disposableHandles); } finally { disposableHandles.Dispose(); } } private static IDisposableReadOnlyCollection CreateDisposableResult(DisposableOrtValueHandleArray disposableHandles) { var outputValues = new DisposableList(disposableHandles.Span.Length); try { for (int i = 0; i < disposableHandles.Span.Length; i++) { outputValues.Add(new OrtValue(disposableHandles.Span[i])); disposableHandles.Span[i] = IntPtr.Zero; } return outputValues; } catch (Exception) { outputValues.Dispose(); throw; } } /// /// The API takes collections of inputNames/inputValues and collections of outputNames/outputValues. /// The sizes of the corresponding collections must match. /// /// The output OrtValues are pre-allocated and the API will fill the data into the OrtValues. /// These MUST be tensors. The API does not support non-tensor types for output values. /// /// The API is useful when the output values are tensors and their shapes are known, and you /// prefer the output to go to the pre-allocated memory. In such a case, you create /// output OrtValues over those pre-allocated buffers and pass them to the API. /// /// runOptions, if null the defaults are used /// collection of input names. /// collection of input OrtValues. Must match the order and the number of input names. /// Requested output names. /// Pre-allocated output values. /// The order and the number must match the specified output names. Shapes must match actual output values. /// public void Run(RunOptions runOptions, IReadOnlyCollection inputNames, IReadOnlyCollection inputValues, IReadOnlyCollection outputNames, IReadOnlyCollection outputValues) { if (inputNames.Count != inputValues.Count) { throw new ArgumentException( $"Length of {nameof(inputNames)} ({inputNames.Count}) must match that of {nameof(inputValues)} ({inputValues.Count})."); } if (outputNames.Count != outputValues.Count) { throw new ArgumentException( $"Length of {nameof(outputNames)} ({outputNames.Count}) must match that of {nameof(outputValues)} ({outputValues.Count})."); } if (runOptions is null) { runOptions = _builtInRunOptions; } var inputNamesArray = LookupUtf8Names(inputNames, n => n, LookupInputMetadata); var inputHandlesArray = inputValues.Select(v => v.Handle).ToArray(); var outputNamesArray = LookupUtf8Names(outputNames, n => n, LookupOutputMetadata); var outputHandlesArray = outputValues.Select(v => v.Handle).ToArray(); NativeApiStatus.VerifySuccess(NativeMethods.OrtRun( _nativeHandle, runOptions.Handle, inputNamesArray, inputHandlesArray, (UIntPtr)inputNames.Count, outputNamesArray, (UIntPtr)outputNames.Count, outputHandlesArray /* pointers to Pre-allocated OrtValue instances */ )); } /// /// Create OrtIoBinding instance to bind pre-allocated buffers /// to input/output /// /// A new instance of OrtIoBinding public OrtIoBinding CreateIoBinding() { return new OrtIoBinding(this); } /// /// This method runs inference on the OrtIoBinding instance /// The method does not return anything. This is a lightweight version of /// RunWithBindingAndNames(). When you bind pre-allocated buffers to the output values /// you may not want to fetch the outputs since you already have access to them so you can spare /// the expense of fetching them and pairing with names. /// You can still fetch the outputs by calling OrtIOBinding.GetOutputValues() /// /// runOptions /// ioBinding instance to use public void RunWithBinding(RunOptions runOptions, OrtIoBinding ioBinding) { NativeApiStatus.VerifySuccess(NativeMethods.OrtRunWithBinding(Handle, runOptions.Handle, ioBinding.Handle)); } /// /// This method runs inference on the OrtIoBinding instance. It returns a collection of OrtValues. /// This method is useful when it is impossible to bind outputs to pre-allocated buffers, because /// the output shape is not known in advance. In this case, the OrtValues returned by this method /// are allocated and owned by ORT. The caller is responsible for disposing the collection. /// /// RunOptions /// Binding instance /// A disposable collection of OrtValues public IDisposableReadOnlyCollection RunWithBoundResults(RunOptions runOptions, OrtIoBinding ioBinding) { NativeApiStatus.VerifySuccess(NativeMethods.OrtRunWithBinding(Handle, runOptions.Handle, ioBinding.Handle)); return ioBinding.GetOutputValues(); } /// /// This method return a collection of DisposableNamedOnnxValue as in other interfaces /// Query names from OrtIoBinding object and pair then with the array of OrtValues returned /// from OrtIoBinding.GetOutputValues(). /// /// This API will be deprecated in favor of the API that returns a collection of OrtValues. /// /// /// RunOptions /// OrtIoBinding instance with bindings /// optional parameter. If you already know the names of the outputs you can save a native /// call to retrieve output names. They will be paired with the returned OrtValues and combined into DisposbleNamedOnnxValues. /// Otherwise, the method will retrieve output names from the OrtIoBinding instance. /// It is an error if you supply a different number of names than the returned outputs /// A disposable collection of DisposableNamedOnnxValue that encapsulate output OrtValues public IDisposableReadOnlyCollection RunWithBindingAndNames(RunOptions runOptions, OrtIoBinding ioBinding, string[] names = null) { string[] outputNames = names; if (outputNames == null || names.Length == 0) { outputNames = ioBinding.GetOutputNames(); } NativeApiStatus.VerifySuccess(NativeMethods.OrtRunWithBinding(Handle, runOptions.Handle, ioBinding.Handle)); var ortValues = ioBinding.GetOutputOrtValues(); var dispValues = new DisposableArray(ortValues); try { var result = new DisposableList(ortValues.Length); try { for (int i = 0; i < outputNames.Length; ++i) { result.Add(DisposableNamedOnnxValue.CreateFromOrtValue(outputNames[i], ref ortValues[i])); } return result; } catch (Exception) { result.Dispose(); throw; } } finally { // On success ortValues would contain nulls that will be // ignored. On failure, ortValues would contain at least // some valid OrtValue instances that need to be disposed. dispValues.Dispose(); } } /// /// Ends profiling for the session. /// /// Returns the profile file name. public string EndProfiling() { var allocator = OrtAllocator.DefaultInstance; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionEndProfiling(_nativeHandle, allocator.Pointer, out IntPtr nameHandle)); return NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle, allocator); } // Delegate for string extraction from an arbitrary input/output object private delegate string NameExtractor(TInput input); // delegate to fetch input/output OrtValue private delegate IntPtr OrtValueHandleExtractor(NamedOnnxValue value, NodeMetadata metadata, out IDisposable memOwner); // Delegate to lookup metadata for input/initializers/output private delegate NodeMetadata MetadataLookup(string nodeName); /// /// Checks if the name is a known input or overridable initializer name /// and if so, returns metadata for it. /// metadata /// /// /// NodeMetadata for the nodeName /// private NodeMetadata LookupInputMetadata(string nodeName) { if (!_inputMetadata.TryGetValue(nodeName, out NodeMetadata meta) && !_overridableInitializerMetadata.TryGetValue(nodeName, out meta)) { throw new OnnxRuntimeException(ErrorCode.InvalidArgument, $"Input name: '{nodeName}' is not in the metadata"); } return meta; } /// /// Checks if the nodeName is a known output name and if so returns metadata for it. /// /// /// /// private NodeMetadata LookupOutputMetadata(string nodeName) { if (!_outputMetadata.TryGetValue(nodeName, out NodeMetadata meta)) { throw new OnnxRuntimeException(ErrorCode.InvalidArgument, $"Output name: '{nodeName}' is not in the metadata"); } return meta; } /// /// Fetches/creates OrtValue for the content of the input /// /// /// /// /// private static IntPtr ExtractOrtValueHandleForInput(NamedOnnxValue input, NodeMetadata metadata, out IDisposable memOwner) { return input.InputToOrtValueHandle(metadata, out memOwner); } /// /// Fetches/Creates OrtValue for output /// /// /// /// /// May return null if the onnx value type does not support pre-creation of output OrtValues private static IntPtr ExtractOrtValueHandleForOutput(NamedOnnxValue output, NodeMetadata metadata, out IDisposable memOwner) { return output.OutputToOrtValueHandle(metadata, out memOwner); } /// /// Run helper /// /// names to convert to zero terminated utf8 and pin /// extractor functor that helps extracting names from inputs /// inputs/outputs metadata /// private static IntPtr[] LookupUtf8Names(IReadOnlyCollection values, NameExtractor nameExtractor, MetadataLookup metaLookup) { var result = new IntPtr[values.Count]; for (int i = 0; i < values.Count; ++i) { var name = nameExtractor(values.ElementAt(i)); NodeMetadata meta = metaLookup(name); result[i] = meta.ZeroTerminatedName; } return result; } /// /// This function obtains ortValues for NamedOnnxValue. /// The problem with NamedOnnxValue is that it is not disposable and can not contain any disposable items. /// so calling InputToOrtValue creates a new instance of OrtValue that needs to be disposed. /// The deriving object DisposableNamedValue actually contains and owns OrtValue and it returns /// it. /// /// a collection of NamedOnnxValues /// Metadata lookup function (input/initializers/output) /// private static IntPtr[] GetOrtValuesHandles(IReadOnlyCollection values, MetadataLookup metaLookup, OrtValueHandleExtractor ortValueExtractor, out DisposableArray disposer) { IDisposable[] memHolders = new IDisposable[values.Count]; var disp = new DisposableArray(memHolders); try { IntPtr[] result = new IntPtr[values.Count]; for (int valueIndex = 0; valueIndex < values.Count; ++valueIndex) { var value = values.ElementAt(valueIndex); var meta = metaLookup(value.Name); result[valueIndex] = ortValueExtractor(value, meta, out IDisposable memHolder); if (memHolder != null) { memHolders[valueIndex] = memHolder; } } disposer = disp; return result; } catch (Exception) { disp.Dispose(); throw; } } private static IntPtr[] GetOrtValuesHandles(IReadOnlyCollection values, bool input) { var valuesArray = new IntPtr[values.Count]; for (int index = 0; index < values.Count; ++index) { var v = values.ElementAt(index); if (!input && v.ElementType == Tensors.TensorElementType.String) { throw new NotSupportedException("Using string type FixedBufferOnnxValue in outputs is not supported."); } valuesArray[index] = v.Value.Handle; } return valuesArray; } private DisposableOrtValueHandleArray RunImpl(RunOptions options, IntPtr[] inputNames, IntPtr[] inputValues, IntPtr[] outputNames) { IntPtr[] outputValuesArray = new IntPtr[outputNames.Length]; NativeApiStatus.VerifySuccess(NativeMethods.OrtRun( _nativeHandle, options.Handle, inputNames, inputValues, (UIntPtr)inputNames.Length, outputNames, (UIntPtr)outputNames.Length, outputValuesArray /* Empty array is passed in to receive output OrtValue pointers */ )); return new DisposableOrtValueHandleArray(outputValuesArray); } private static IDisposableReadOnlyCollection CreateDisposableResult(Span valueHandles, IReadOnlyCollection outputNames) { Debug.Assert(valueHandles.Length == outputNames.Count, "Handles and names sizes must match"); var result = new DisposableList(valueHandles.Length); try { for (int i = 0; i < valueHandles.Length; i++) { var ortValue = new OrtValue(valueHandles[i]); result.Add(DisposableNamedOnnxValue.CreateFromOrtValue(outputNames.ElementAt(i), ref ortValue)); valueHandles[i] = IntPtr.Zero; // Prevent double disposal } } catch (OnnxRuntimeException) { result.Dispose(); throw; } return result; } /// /// This property queries model metadata, constructs /// an instance of ModelMetadata and caches it /// /// Instance of ModelMetdata public ModelMetadata ModelMetadata { get { if (_modelMetadata != null) { return _modelMetadata; } _modelMetadata = new ModelMetadata(this); return _modelMetadata; } } /// /// Return the nanoseconds of profiling's start time /// On some platforms, this timer may not be as precise as nanoseconds /// For instance, on Windows and MacOS, the precision will be ~100ns /// public ulong ProfilingStartTimeNs { get { return _profilingStartTimeNs; } } private static void OrtCallback(IntPtr userData, IntPtr[] ouputs, uint numOutputs, IntPtr status) { var hostHdl = GCHandle.FromIntPtr(userData); CallbackHost host = (CallbackHost)hostHdl.Target; try { host.callback(host.outputValues, status); } finally { hostHdl.Free(); } } [UnmanagedFunctionPointer(CallingConvention.Cdecl)] private delegate void OrtCallbackDelegate(IntPtr userData, IntPtr[] outputs, uint numOutputs, IntPtr status); private static OrtCallbackDelegate ortCallback = new OrtCallbackDelegate(OrtCallback); private delegate void UserCallbackDelegate(IReadOnlyCollection outputs, IntPtr status); private class CallbackHost { public IReadOnlyCollection inputNames { get; } public IReadOnlyCollection inputValues { get; } public IReadOnlyCollection outputNames { get; } public IReadOnlyCollection outputValues { get; } public UserCallbackDelegate callback { get; } public IntPtr[] rawInputNames { get; } public IntPtr[] rawInputValues { get; } public IntPtr[] rawOutputNames { get; } public IntPtr[] rawOutputValues { get; } public CallbackHost(InferenceSession session, IReadOnlyCollection cbInputNames, IReadOnlyCollection cbinputValues, IReadOnlyCollection cbOutputNames, IReadOnlyCollection cbOutputValues, UserCallbackDelegate userCallback) { inputNames = cbInputNames; inputValues = cbinputValues; outputNames = cbOutputNames; outputValues = cbOutputValues; callback = userCallback; rawInputNames = LookupUtf8Names(inputNames, n => n, session.LookupInputMetadata); rawInputValues = inputValues.Select(v => v.Handle).ToArray(); rawOutputNames = LookupUtf8Names(outputNames, n => n, session.LookupOutputMetadata); rawOutputValues = outputValues.Select(v => v.Handle).ToArray(); } } private void RunAsyncInternal(RunOptions options, IReadOnlyCollection inputNames, IReadOnlyCollection inputValues, IReadOnlyCollection outputNames, IReadOnlyCollection outputValues, UserCallbackDelegate callback) { CallbackHost host = new CallbackHost(this, inputNames, inputValues, outputNames, outputValues, callback); var host_hdl = GCHandle.Alloc(host, GCHandleType.Normal); try { NativeApiStatus.VerifySuccess(NativeMethods.OrtRunAsync( _nativeHandle, options == null ? (IntPtr)null : options.Handle, host.rawInputNames, host.rawInputValues, (UIntPtr)host.rawInputNames.Length, host.rawOutputNames, (UIntPtr)host.rawOutputNames.Length, host.rawOutputValues, Marshal.GetFunctionPointerForDelegate(ortCallback), GCHandle.ToIntPtr(host_hdl) )); } catch (OnnxRuntimeException) { host_hdl.Free(); throw; } } /// /// Run inference asynchronous in a thread of intra-op thread pool /// /// run option, can be null /// name of inputs /// input ort values /// name of outputs /// output of ort values /// task to be awaited /// public async Task> RunAsync(RunOptions options, IReadOnlyCollection inputNames, IReadOnlyCollection inputValues, IReadOnlyCollection outputNames, IReadOnlyCollection outputValues) { var promise = new TaskCompletionSource>(); RunAsyncInternal(options, inputNames, inputValues, outputNames, outputValues, (IReadOnlyCollection outputs, IntPtr status) => { try { NativeApiStatus.VerifySuccess(status); promise.SetResult(outputs); } catch (Exception ex) { promise.SetException(ex); } }); return await promise.Task; } #endregion #region private methods private void Init(string modelPath, SessionOptions options, PrePackedWeightsContainer prepackedWeightsContainer = null) { var envHandle = OrtEnv.Instance().Handle; IntPtr session; if (prepackedWeightsContainer == null) { NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, NativeOnnxValueHelper.GetPlatformSerializedString(modelPath), options.Handle, out session)); } else { NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSessionWithPrepackedWeightsContainer( envHandle, NativeOnnxValueHelper.GetPlatformSerializedString(modelPath), options.Handle, prepackedWeightsContainer.Pointer, out session)); } InitWithSessionHandle(session); } private void Init(byte[] modelData, SessionOptions options, PrePackedWeightsContainer prepackedWeightsContainer = null) { var envHandle = OrtEnv.Instance().Handle; IntPtr session; if (prepackedWeightsContainer == null) { NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSessionFromArray(envHandle, modelData, (UIntPtr)modelData.Length, options.Handle, out session)); } else { NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSessionFromArrayWithPrepackedWeightsContainer( envHandle, modelData, (UIntPtr)modelData.Length, options.Handle, prepackedWeightsContainer.Pointer, out session)); } InitWithSessionHandle(session); } /// /// Initializes the session object with a native session handle /// /// Value of a native session object /// Session options private void InitWithSessionHandle(IntPtr session) { _nativeHandle = session; try { // Initialize input/output metadata // get input count NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputCount(_nativeHandle, out UIntPtr inputCount)); // get output count NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputCount(_nativeHandle, out UIntPtr outputCount)); // get overridable initializer count NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOverridableInitializerCount(_nativeHandle, out UIntPtr initializerCount)); int totalNameCount = (int)inputCount + (int)outputCount + (int)initializerCount; _namesMemoryPtrs = new List(totalNameCount); // get all the input names and metadata _inputMetadata = new Dictionary((int)inputCount); _inputNames = new List((int)inputCount); for (ulong i = 0; i < (ulong)inputCount; i++) { var inputMeta = GetInputMetadata(i); var iname = GetInputName(i, out IntPtr utf8); _namesMemoryPtrs.Add(utf8); inputMeta.ZeroTerminatedName = utf8; _inputNames.Add(iname); _inputMetadata[iname] = inputMeta; } // get all the output names and metadata _outputMetadata = new Dictionary((int)outputCount); _outputNames = new List((int)outputCount); for (ulong i = 0; i < (ulong)outputCount; i++) { var outputMeta = GetOutputMetadata(i); var oname = GetOutputName(i, out IntPtr utf8); _namesMemoryPtrs.Add(utf8); outputMeta.ZeroTerminatedName = utf8; _outputNames.Add(oname); _outputMetadata[oname] = outputMeta; } _overridableInitializerMetadata = new Dictionary((int)initializerCount); // get all the overridable initializer names and metadata for (ulong i = 0; i < (ulong)initializerCount; i++) { var meta = GetOverridableInitializerMetadata(i); var iname = GetOverridableInitializerName(i, out IntPtr utf8); _namesMemoryPtrs.Add(utf8); meta.ZeroTerminatedName = utf8; _overridableInitializerMetadata[iname] = meta; } // set profiling's start time NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetProfilingStartTimeNs(_nativeHandle, out UIntPtr startTime)); _profilingStartTimeNs = (ulong)startTime; } catch (Exception) { DisposeImpl(true); throw; } _builtInRunOptions = new RunOptions(); // create a default built-in run option, and avoid creating a new one every run() call } private string GetOutputName(ulong index, out IntPtr utf8) { var allocator = OrtAllocator.DefaultInstance; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputName( _nativeHandle, (UIntPtr)index, allocator.Pointer, out IntPtr nameHandle)); NativeOnnxValueHelper.StringAndUtf8FromNative(allocator, nameHandle, out string str, out utf8); return str; } private string GetInputName(ulong index, out IntPtr utf8) { var allocator = OrtAllocator.DefaultInstance; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputName( _nativeHandle, (UIntPtr)index, allocator.Pointer, out IntPtr nameHandle)); NativeOnnxValueHelper.StringAndUtf8FromNative(allocator, nameHandle, out string str, out utf8); return str; } private string GetOverridableInitializerName(ulong index, out IntPtr utf8) { var allocator = OrtAllocator.DefaultInstance; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOverridableInitializerName( _nativeHandle, (UIntPtr)index, allocator.Pointer, out IntPtr nameHandle)); NativeOnnxValueHelper.StringAndUtf8FromNative(allocator, nameHandle, out string str, out utf8); return str; } private NodeMetadata GetInputMetadata(ulong index) { NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputTypeInfo(_nativeHandle, (UIntPtr)index, out IntPtr typeInfo)); try { return GetMetadataFromTypeInfo(typeInfo); } finally { NativeMethods.OrtReleaseTypeInfo(typeInfo); } } private NodeMetadata GetOutputMetadata(ulong index) { NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputTypeInfo(_nativeHandle, (UIntPtr)index, out IntPtr typeInfo)); try { return GetMetadataFromTypeInfo(typeInfo); } finally { NativeMethods.OrtReleaseTypeInfo(typeInfo); } } private NodeMetadata GetOverridableInitializerMetadata(ulong index) { NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOverridableInitializerTypeInfo(_nativeHandle, (UIntPtr)index, out IntPtr typeInfo)); try { return GetMetadataFromTypeInfo(typeInfo); } finally { NativeMethods.OrtReleaseTypeInfo(typeInfo); } } internal static NodeMetadata GetMetadataFromTypeInfo(IntPtr typeInfo) { OnnxValueType valueType; { NativeApiStatus.VerifySuccess(NativeMethods.OrtGetOnnxTypeFromTypeInfo(typeInfo, out IntPtr valType)); valueType = (OnnxValueType)valType; } switch (valueType) { case OnnxValueType.ONNX_TYPE_TENSOR: case OnnxValueType.ONNX_TYPE_SPARSETENSOR: return GetTensorNodeMetadata(valueType, typeInfo); case OnnxValueType.ONNX_TYPE_SEQUENCE: return GetSequenceMetadataFromTypeInfo(typeInfo); case OnnxValueType.ONNX_TYPE_MAP: return GetMapMetadataFromTypeInfo(typeInfo); case OnnxValueType.ONNX_TYPE_OPTIONAL: return GetOptionalMetadataFromTypeInfo(typeInfo); } throw new OnnxRuntimeException(ErrorCode.NotImplemented, $"Value type: '{valueType}' not supported in this code"); } internal static NodeMetadata GetSequenceMetadataFromTypeInfo(IntPtr typeInfo) { NativeApiStatus.VerifySuccess(NativeMethods.OrtCastTypeInfoToSequenceTypeInfo(typeInfo, out IntPtr sequenceTypeInfo)); // Casts API are broken. Always return success, but may return null for the result. if (sequenceTypeInfo == IntPtr.Zero) { throw new OnnxRuntimeException(ErrorCode.Fail, "TypeInfo cast to SequenceTypeInfo failed. The object does not represent a sequence"); } NativeApiStatus.VerifySuccess(NativeMethods.OrtGetSequenceElementType(sequenceTypeInfo, out IntPtr elementType)); try { var elementMeta = GetMetadataFromTypeInfo(elementType); var seqMeta = new SequenceMetadata(elementMeta); return new NodeMetadata(seqMeta); } finally { NativeMethods.OrtReleaseTypeInfo(elementType); } } internal static NodeMetadata GetMapMetadataFromTypeInfo(IntPtr typeInfo) { NativeApiStatus.VerifySuccess(NativeMethods.OrtCastTypeInfoToMapTypeInfo(typeInfo, out IntPtr mapTypeInfo)); // Casts API are broken. Always return success, but may return null for the result. if (mapTypeInfo == IntPtr.Zero) { throw new OnnxRuntimeException(ErrorCode.Fail, "TypeInfo cast to MapTypeInfo failed. The object does not represent a map"); } NativeApiStatus.VerifySuccess(NativeMethods.OrtGetMapKeyType(mapTypeInfo, out IntPtr keyType)); NativeApiStatus.VerifySuccess(NativeMethods.OrtGetMapValueType(mapTypeInfo, out IntPtr valueTypeInfo)); try { var valueMetadata = GetMetadataFromTypeInfo(valueTypeInfo); var mapMeta = new MapMetadata((TensorElementType)keyType, valueMetadata); return new NodeMetadata(mapMeta); } finally { NativeMethods.OrtReleaseTypeInfo(valueTypeInfo); } } internal static NodeMetadata GetOptionalMetadataFromTypeInfo(IntPtr typeInfo) { // This should not be destroyed NativeApiStatus.VerifySuccess(NativeMethods.OrtCastTypeInfoToOptionalTypeInfo(typeInfo, out IntPtr optTypeInfo)); // Casts API are broken. Always return success, but may return null for the result. if (optTypeInfo == IntPtr.Zero) { throw new OnnxRuntimeException(ErrorCode.Fail, "TypeInfo cast to OptionalTypeInfo failed. The object does not represent a optional value"); } NativeApiStatus.VerifySuccess(NativeMethods.OrtGetOptionalContainedTypeInfo(optTypeInfo, out IntPtr elementTypeInfo)); try { var elementMetadata = GetMetadataFromTypeInfo(elementTypeInfo); var optMetadata = new OptionalMetadata(elementMetadata); return new NodeMetadata(optMetadata); } finally { NativeMethods.OrtReleaseTypeInfo(elementTypeInfo); } } internal static NodeMetadata GetTensorNodeMetadata(OnnxValueType valueType, IntPtr typeInfo) { // Fetch tensor type and shape from the TypeInfo NativeApiStatus.VerifySuccess(NativeMethods.OrtCastTypeInfoToTensorInfo(typeInfo, out IntPtr tensorInfo)); //(IntPtr)(int)(uint) // Casts API are broken. Always return success, but may return null for the result. if (tensorInfo == IntPtr.Zero) { throw new OnnxRuntimeException(ErrorCode.Fail, "TypeInfo cast to TensorTypeInfo failed. The object does not represent a tensor"); } TensorElementType type; { NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorElementType(tensorInfo, out IntPtr el_type)); type = (TensorElementType)el_type; } NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(tensorInfo, out UIntPtr numDimensions)); long[] dimensions = new long[(int)numDimensions]; NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensions(tensorInfo, dimensions, numDimensions)); int[] intDimensions = new int[(int)numDimensions]; for (var i = 0; i < (long)numDimensions; i++) { intDimensions[i] = (int)dimensions[i]; } IntPtr[] dimensionNamePtrs = new IntPtr[(int)numDimensions]; NativeApiStatus.VerifySuccess( NativeMethods.OrtGetSymbolicDimensions(tensorInfo, dimensionNamePtrs, numDimensions)); string[] symbolicDimensions = new string[(int)numDimensions]; for (var i = 0; i < (int)numDimensions; i++) { symbolicDimensions[i] = NativeOnnxValueHelper.StringFromNativeUtf8(dimensionNamePtrs[i]); } var tensorTypeAndShape = new TensorTypeAndShape(type, intDimensions, symbolicDimensions); return new NodeMetadata(valueType, tensorTypeAndShape); } /// /// Other classes access /// internal IntPtr Handle { get { return _nativeHandle; } } #endregion #region IDisposable /// /// Finalizer. to cleanup session in case it runs /// and the user forgets to Dispose() of the session /// ~InferenceSession() { Dispose(false); } /// /// IDisposable implementation /// public void Dispose() { Dispose(true); GC.SuppressFinalize(this); } /// /// IDisposable implementation /// /// true if invoked from Dispose() method protected virtual void Dispose(bool disposing) { if (_disposed) { return; } DisposeImpl(disposing); } /// /// This function is also used on failure in the constructor /// /// void DisposeImpl(bool disposing) { if (disposing) { if (_namesMemoryPtrs != null) { foreach (var ptr in _namesMemoryPtrs) Marshal.FreeHGlobal(ptr); _namesMemoryPtrs = null; } // cleanup managed resources if (_builtInSessionOptions != null) { _builtInSessionOptions.Dispose(); _builtInSessionOptions = null; } if (_builtInRunOptions != null) { _builtInRunOptions.Dispose(); _builtInRunOptions = null; } } // cleanup unmanaged resources if (_nativeHandle != IntPtr.Zero) { NativeMethods.OrtReleaseSession(_nativeHandle); _nativeHandle = IntPtr.Zero; } _disposed = true; } #endregion } /// /// Represents tensor element type and its shapes /// public class TensorTypeAndShape { internal TensorTypeAndShape(TensorElementType elementType, int[] dimensions, string[] symbolicDimensions) { ElementTypeInfo = TensorBase.GetElementTypeInfo(elementType); if (ElementTypeInfo == null) { throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "Unregistered TensorElementType value of: " + elementType.ToString()); } ElementDataType = elementType; Dimensions = dimensions; SymbolicDimensions = symbolicDimensions; } /// /// Tensor Element type /// /// TensorElementType enum public TensorElementType ElementDataType { get; } /// /// Shape /// /// Array of dimensions public int[] Dimensions { get; } /// /// Symbolic dimensions /// /// Array of symbolic dimensions if present. public string[] SymbolicDimensions { get; } /// /// Tensor element metadata /// public TensorElementTypeInfo ElementTypeInfo { get; } } /// /// Represents sequnce metdata /// public class SequenceMetadata { /// /// __ctor /// /// internal SequenceMetadata(NodeMetadata elementData) { ElementMeta = elementData; } /// /// Element Metatada, recursive definition with a Tensor being a base case /// may contain maps, tensors and other sequences /// public NodeMetadata ElementMeta { get; } } /// /// The class contains metadata for an optional input/output /// public class OptionalMetadata { /// /// __ctor /// /// internal OptionalMetadata(NodeMetadata elementData) { ElementMeta = elementData; } /// /// Element Metatada, recursive definition with a Tensor being a base case /// may contain maps, tensors and sequences /// public NodeMetadata ElementMeta { get; } } /// /// Represents Map MetaData. /// Key is always a tensor denoted by an element type /// with value type being a recursive structure that may /// contain other maps, sequences or tensors. /// public class MapMetadata { internal MapMetadata(TensorElementType keyDataType, NodeMetadata valueMetadata) { KeyDataType = keyDataType; ValueMetadata = valueMetadata; } /// /// Key tensor data type /// /// A value of TensorElementType enum public TensorElementType KeyDataType { get; } /// /// Value metadata /// /// /// Instance of Nodemetadata for the value of the map public NodeMetadata ValueMetadata { get; } } /// /// Resembles type and shape information of session-graph nodes, used for communicating the shape/type of input/output nodes /// public class NodeMetadata { private readonly Object _metadata; /// /// Constructs NodeMetadata for tensor /// /// either ONNX_TYPE_TENSOR or ONNX_TYPE_SPARSETENSOR /// Tensor type and shape information internal NodeMetadata(OnnxValueType onnxValueType, TensorTypeAndShape typeAndShape) { OnnxValueType = onnxValueType; CheckTensor(); _metadata = typeAndShape; } /// /// __ctor for map metadata /// /// internal NodeMetadata(MapMetadata mapMetadata) { OnnxValueType = OnnxValueType.ONNX_TYPE_MAP; _metadata = mapMetadata; } /// /// __ctor for sequence metadata /// /// internal NodeMetadata(SequenceMetadata sequenceMetadata) { OnnxValueType = OnnxValueType.ONNX_TYPE_SEQUENCE; _metadata = sequenceMetadata; } /// /// __ctor /// /// internal NodeMetadata(OptionalMetadata optMetadata) { OnnxValueType = OnnxValueType.ONNX_TYPE_OPTIONAL; _metadata = optMetadata; } private void CheckTensor() { if (!IsTensor) { throw new OnnxRuntimeException(ErrorCode.Fail, "OnnxValueType must either be a tensor or sparse tensor"); } } /// /// Retrieves MapMetadata, valid only if this node represents a Map. /// /// /// when the instance does not contain map metadata public MapMetadata AsMapMetadata() { if (OnnxValueType != OnnxValueType.ONNX_TYPE_MAP) { throw new OnnxRuntimeException(ErrorCode.Fail, "Instance does not contain Map metadata"); } return _metadata as MapMetadata; } /// /// Retrieves SequenceMetadata, valid only if this node represents a Sequence /// /// /// when the instance does not contain sequence metadata public SequenceMetadata AsSequenceMetadata() { if (OnnxValueType != OnnxValueType.ONNX_TYPE_SEQUENCE) { throw new OnnxRuntimeException(ErrorCode.Fail, "Instance does not contain Sequence metadata"); } return _metadata as SequenceMetadata; } /// /// Retrieves Optional type metadata, valid if this node is optional /// Optional metadata is nothing more than just a container for all the usual /// element types. /// /// /// public OptionalMetadata AsOptionalMetadata() { if (OnnxValueType != OnnxValueType.ONNX_TYPE_OPTIONAL) { throw new OnnxRuntimeException(ErrorCode.Fail, "Instance does not contain Optional metadata"); } return _metadata as OptionalMetadata; } /// /// Type value of the node /// /// A value of OnnxValueType enum public OnnxValueType OnnxValueType { get; } /// /// Node name in the natively allocated memory. /// /// Present only on the top-level instance /// metadata dictionary entries. /// /// Avoids repeated conversion and pinning /// /// This memory chunk is owned and freed by the InferenceSession /// object. /// internal IntPtr ZeroTerminatedName { get; set; } /// /// Tensor shape valid only if this is a Tensor. /// Preserved for API compatibility /// /// Array of dimensions public int[] Dimensions { get { CheckTensor(); return (_metadata as TensorTypeAndShape).Dimensions; } } /// /// Symbolic dimensions valid only if this is a Tensor. /// Preserved for API compatibility /// /// Array of symbolic dimensions if present. public string[] SymbolicDimensions { get { CheckTensor(); return (_metadata as TensorTypeAndShape).SymbolicDimensions; } } /// /// .NET type that corresponds to the primitive Tensor data type. /// Valid only if this is a Tensor. /// /// System.Type public System.Type ElementType { get { CheckTensor(); return (_metadata as TensorTypeAndShape).ElementTypeInfo.TensorType; } } /// /// Tensor Element Type. Valid if tensor /// public TensorElementType ElementDataType { get { CheckTensor(); return (_metadata as TensorTypeAndShape).ElementDataType; } } /// /// Convinience method to check for string /// public bool IsString { get { CheckTensor(); return (_metadata as TensorTypeAndShape).ElementTypeInfo.IsString; } } /// /// Whether it is a Tensor /// /// currently always returns true public bool IsTensor { get { return (OnnxValueType == OnnxValueType.ONNX_TYPE_TENSOR) || (OnnxValueType == OnnxValueType.ONNX_TYPE_SPARSETENSOR); } } } /// /// A class that queries and caches model metadata and exposes /// it as properties /// public class ModelMetadata { private string _producerName; private string _graphName; private string _domain; private string _description; private string _graphDescription; private long _version; private Dictionary _customMetadataMap = new Dictionary(); internal ModelMetadata(InferenceSession session) { var allocator = OrtAllocator.DefaultInstance; // Get the native ModelMetadata instance associated with the InferenceSession NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetModelMetadata(session.Handle, out IntPtr modelMetadataHandle)); try { // Process producer name NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetProducerName(modelMetadataHandle, allocator.Pointer, out IntPtr producerNameHandle)); _producerName = NativeOnnxValueHelper.StringFromNativeUtf8(producerNameHandle, allocator); // Process graph name NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetGraphName(modelMetadataHandle, allocator.Pointer, out IntPtr graphNameHandle)); _graphName = NativeOnnxValueHelper.StringFromNativeUtf8(graphNameHandle, allocator); // Process domain NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetDomain(modelMetadataHandle, allocator.Pointer, out IntPtr domainHandle)); _domain = NativeOnnxValueHelper.StringFromNativeUtf8(domainHandle, allocator); // Process description NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetDescription(modelMetadataHandle, allocator.Pointer, out IntPtr descriptionHandle)); _description = NativeOnnxValueHelper.StringFromNativeUtf8(descriptionHandle, allocator); // Process graph description NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetGraphDescription(modelMetadataHandle, allocator.Pointer, out IntPtr graphDescriptionHandle)); _graphDescription = NativeOnnxValueHelper.StringFromNativeUtf8(graphDescriptionHandle, allocator); // Process version NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetVersion(modelMetadataHandle, out _version)); // Process CustomMetadata Map NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetCustomMetadataMapKeys(modelMetadataHandle, allocator.Pointer, out IntPtr customMetadataMapKeysHandle, out long numKeys)); // We have received an array of null terminated C strings which are the keys that we can use to lookup the custom metadata map // The OrtAllocator will finally free the customMetadataMapKeysHandle try { Span keysHandles; unsafe { keysHandles = new Span(customMetadataMapKeysHandle.ToPointer(), (int)numKeys); } using (var ortAllocationKeys = new DisposableList((int)numKeys)) { // Put all the handles to each key in the DisposableList to be disposed off in an exception-safe manner foreach (var keyHandle in keysHandles) { ortAllocationKeys.Add(new OrtMemoryAllocation(allocator, keyHandle, 0)); } // Process each key via the stored key handles foreach (var keyHandle in keysHandles) { NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataLookupCustomMetadataMap(modelMetadataHandle, allocator.Pointer, keyHandle, out IntPtr valueHandle)); var value = NativeOnnxValueHelper.StringFromNativeUtf8(valueHandle, allocator); var key = NativeOnnxValueHelper.StringFromNativeUtf8(keyHandle); // Put the key/value pair into the dictionary _customMetadataMap[key] = value; } } } finally { allocator.FreeMemory(customMetadataMapKeysHandle); } } finally { // Free ModelMetadata handle NativeMethods.OrtReleaseModelMetadata(modelMetadataHandle); } } /// /// Producer name string /// /// producer name string public string ProducerName { get { return _producerName; } } /// /// Graph name for this model /// /// graph name string public string GraphName { get { return _graphName; } } /// /// Domain for this model /// /// domain name string public string Domain { get { return _domain; } } /// /// Unstructured model description /// /// description string public string Description { get { return _description; } } /// /// Unstructured graph description /// /// description string public string GraphDescription { get { return _graphDescription; } } /// /// Version number /// /// long version integer public long Version { get { return _version; } } /// /// Custom metadata key/value pairs /// /// An instance of a Dictionary public Dictionary CustomMetadataMap { get { return _customMetadataMap; } } } }