// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
namespace Microsoft.ML.OnnxRuntime
{
///
/// This helper class contains methods to create native OrtValue from a managed value object
///
internal static class NativeOnnxValueHelper
{
///
/// Converts C# UTF-16 string to UTF-8 zero terminated
/// byte[] instance
///
/// string to be converted
/// UTF-8 encoded equivalent
internal static byte[] StringToZeroTerminatedUtf8(string s)
{
int arraySize = UTF8Encoding.UTF8.GetByteCount(s);
byte[] utf8Bytes = new byte[arraySize + 1];
var bytesWritten = UTF8Encoding.UTF8.GetBytes(s, 0, s.Length, utf8Bytes, 0);
Debug.Assert(arraySize == bytesWritten);
utf8Bytes[utf8Bytes.Length - 1] = 0;
return utf8Bytes;
}
///
/// This function converts the input string into UTF-8 encoding string (no zero termination)
/// straight into the pre-allocated native buffer. The buffer size
/// must match the required size and can be obtained in advance with
/// System.Text.Encoding.UTF8.GetByteCount(s).
///
///
/// fixed char* ptr
/// string length
/// Native buffer to write
///
///
internal unsafe static void StringToUtf8NativeMemory(char* strPtr, int strLength, IntPtr ptr, int nativeBufferSize)
{
// total bytes to write is size of native memory buffer
var bytesWritten = Encoding.UTF8.GetBytes(strPtr, strLength, (byte*)ptr, nativeBufferSize);
if (bytesWritten != nativeBufferSize)
{
throw new OnnxRuntimeException(ErrorCode.RuntimeException,
$"Failed to convert to UTF8. Expected bytes: {nativeBufferSize}, written: {bytesWritten}");
}
}
///
/// Reads UTF-8 encode string from a C zero terminated string
/// and converts it into a C# UTF-16 encoded string
///
/// pointer to native or pinned memory where Utf-8 resides
/// optional allocator to free nativeUtf8 if it was allocated by OrtAllocator
///
internal static string StringFromNativeUtf8(IntPtr nativeUtf8, OrtAllocator allocator = null)
{
try
{
unsafe
{
int len = 0;
while (*(byte*)(nativeUtf8 + len) != 0) ++len;
if (len == 0)
{
return string.Empty;
}
var nativeBytes = (byte*)nativeUtf8;
return Encoding.UTF8.GetString(nativeBytes, len);
}
}
finally
{
allocator?.FreeMemory(nativeUtf8);
}
}
///
/// Reads UTF-8 string from native C zero terminated string,
/// makes a copy of it on unmanaged heap and converts it to C# UTF-16 string,
/// then returns both C# string and the unmanaged copy of the UTF-8 string.
///
/// On return it deallocates the nativeUtf8 string using the specified allocator
///
/// allocator to use to free nativeUtf8
/// input
/// C# UTF-16 string
/// UTF-8 bytes in a unmanaged allocation, zero terminated
internal static void StringAndUtf8FromNative(OrtAllocator allocator, IntPtr nativeUtf8, out string str, out IntPtr utf8)
{
try
{
unsafe
{
int len = 0;
while (*(byte*)(nativeUtf8 + len) != 0) ++len;
if (len == 0)
{
str = string.Empty;
utf8 = IntPtr.Zero;
return;
}
var src = new Span((nativeUtf8).ToPointer(), len);
utf8 = Marshal.AllocHGlobal(len + 1);
try
{
// Make a copy of the UTF-8 bytes and add a zero terminator
// on unmanaged heap
var dest = new Span((utf8).ToPointer(), len + 1);
src.CopyTo(dest);
dest[len] = 0;
var nativeBytes = (byte*)nativeUtf8;
str = Encoding.UTF8.GetString(nativeBytes, len);
}
catch (Exception)
{
Marshal.FreeHGlobal(utf8);
throw;
}
}
}
finally
{
allocator.FreeMemory(nativeUtf8);
}
}
///
/// Converts C# UTF-16 string to UTF-8 zero terminated
/// byte[] instance
///
/// string to be converted
/// UTF-8 encoded equivalent
internal static byte[] GetPlatformSerializedString(string str)
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
return System.Text.Encoding.Unicode.GetBytes(str + Char.MinValue);
else
return StringToZeroTerminatedUtf8(str);
}
}
// Guards an array of disposable objects on stack and disposes them in reverse order
internal ref struct DisposableArray where T : IDisposable
{
internal Span Span { get; private set; }
internal DisposableArray(Span disposables)
{
Span = disposables;
}
public void Dispose()
{
// Dispose in the reverse order in case there are dependencies
// between objects created later.
for (int i = Span.Length - 1; i >= 0; --i)
{
Span[i]?.Dispose();
}
}
}
internal ref struct DisposableOrtValueHandleArray
{
internal Span Span { get; private set; }
internal DisposableOrtValueHandleArray(Span handles)
{
Span = handles;
}
public void Dispose()
{
// Dispose in the reverse order in case there are dependencies
for (int i = Span.Length - 1; i >= 0; --i)
{
if (Span[i] != IntPtr.Zero)
{
NativeMethods.OrtReleaseValue(Span[i]);
}
}
}
}
///
/// This class converts a string to a UTF8 encoded byte array and then copies it to an unmanaged buffer.
/// This is done, so we can pass it to the native code and avoid pinning.
///
public unsafe struct MarshaledString : IDisposable
{
internal MarshaledString(string input)
{
int length;
IntPtr value;
if (input is null)
{
length = 0;
value = IntPtr.Zero;
}
else
{
var valueBytes = (input.Length != 0) ? Encoding.UTF8.GetBytes(input) :
ArrayUtilities.GetEmpty();
length = valueBytes.Length;
value = Marshal.AllocHGlobal(length + 1);
Span destination = new Span(value.ToPointer(), length + 1);
valueBytes.AsSpan(0, length).CopyTo(destination);
destination[length] = 0;
}
Length = length;
Value = value;
}
///
// Native allocation (UTF8-8 string length with terminating zero)
///
internal int Length { get; private set; }
///
/// Actual native buffer
///
internal IntPtr Value { get; private set; }
///
/// IDisposable implementation
///
public void Dispose()
{
// No managed resources to dispose
if (Value != IntPtr.Zero)
{
Marshal.FreeHGlobal(Value);
Value = IntPtr.Zero;
Length = 0;
}
}
}
///
/// Keeps a list of MarshaledString instances and provides a way to dispose them all at once.
/// It is a ref struct, so it can not be IDisposable.
///
public unsafe ref struct MarshaledStringArray
{
private MarshaledString[] _values;
internal MarshaledStringArray(Tensor inputs)
{
if (inputs.Length == 0)
{
_values = null;
}
else
{
_values = new MarshaledString[inputs.Length];
for (int i = 0; i < inputs.Length; i++)
{
_values[i] = new MarshaledString(inputs.GetValue(i));
}
}
}
internal MarshaledStringArray(IEnumerable inputs)
{
if (inputs is null)
{
_values = null;
}
else
{
_values = new MarshaledString[inputs.Count()];
int i = 0;
foreach (var input in inputs)
{
_values[i++] = new MarshaledString(input);
}
}
}
internal ReadOnlySpan Values => _values;
internal void Fill(IntPtr[] pDestination)
{
if (_values != null)
{
for (var i = 0; i < _values.Length; i++)
{
pDestination[i] = Values[i].Value;
}
}
}
public void Dispose()
{
if (_values != null)
{
for (var i = 0; i < _values.Length; i++)
{
_values[i].Dispose();
}
_values = null;
}
}
}
///
/// Utility class used in SessioniOptions and ProviderOptions
///
internal class ProviderOptionsUpdater
{
///
/// A utility method to update the provider options, provides common functionality.
///
///
/// The actual key/value option pairs
/// to the object
/// encapsulates a native method that returns
/// Arg1=handle, Arg2=array of keys, Arg3=array of values, Arg4 - count, Arg5 - return ORT status
internal static void Update(Dictionary providerOptions,
IntPtr handle,
Func updateFunc)
{
var keyStrings = providerOptions.Keys.ToArray();
var valStrings = providerOptions.Values.ToArray();
MarshaledStringArray keys = default;
MarshaledStringArray values = default;
try
{
keys = new MarshaledStringArray(keyStrings);
values = new MarshaledStringArray(valStrings);
var nativeKeys = new IntPtr[keyStrings.Length];
keys.Fill(nativeKeys);
var nativeVals = new IntPtr[valStrings.Length];
values.Fill(nativeVals);
NativeApiStatus.VerifySuccess(updateFunc(handle, nativeKeys, nativeVals, (UIntPtr)providerOptions.Count));
}
finally
{
keys.Dispose();
values.Dispose();
}
}
}
}