[C#] Add ML Sequences and Maps Create and Process APIs (#16648)

### Description
1) Added Sequence And Maps convenience APIs to create input Sequences
and Maps
and also visit the outputs.

2) Address OrtValue design issue when the values are created on top of
the
managed memory and the ortValues are used for sequence and maps
creation.
We should retain the original managed instances that keep the memory
pinned.
We opt to keep track of those and dispose of them within an instance of
OrtValue
that represents a Map or a Sequence.

3) Set `LangVersion` to default per [MS Versioning
Docs.](https://learn.microsoft.com/en-us/dotnet/csharp/language-reference/configure-language-version)

### Motivation and Context
1) When writing code examples, use of Map and Sequences API proved to be
cumbersome.
2) It is a BUG, that we should address, as the managed memory can move
by the GC and lead to
intermittent crashes.
3) Make use of the most feature of the C#.
This commit is contained in:
Dmitri Smirnov 2023-07-20 21:58:29 -07:00 committed by GitHub
parent 4d569f6586
commit 1e18efade5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 583 additions and 231 deletions

View file

@ -21,6 +21,7 @@ namespace Microsoft.ML.OnnxRuntime
internal class DisposableList<T> : List<T>, IDisposableReadOnlyCollection<T>
where T : IDisposable
{
private bool _disposed;
public DisposableList() { }
public DisposableList(int count) : base(count) { }
@ -30,6 +31,11 @@ namespace Microsoft.ML.OnnxRuntime
protected virtual void Dispose(bool disposing)
{
if (_disposed)
{
return;
}
if (disposing)
{
// Dispose in the reverse order.
@ -43,6 +49,7 @@ namespace Microsoft.ML.OnnxRuntime
this[i]?.Dispose();
}
this.Clear();
_disposed = true;
}
}

View file

@ -89,7 +89,7 @@ namespace Microsoft.ML.OnnxRuntime
/// \endcode
/// </example>
public static FixedBufferOnnxValue CreateFromMemory<T>(OrtMemoryInfo memoryInfo, Memory<T> memory,
TensorElementType elementType, long[] shape, long bytesSize)
TensorElementType elementType, long[] shape, long bytesSize) where T : unmanaged
{
if(elementType == TensorElementType.String)
{

View file

@ -5,6 +5,7 @@ using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
namespace Microsoft.ML.OnnxRuntime
{
@ -24,8 +25,7 @@ namespace Microsoft.ML.OnnxRuntime
/// </summary>
/// <param name="namedOnnxValue"></param>
/// <param name="metadata"></param>
/// <param name="disposables"></param>
/// <returns></returns>
/// <returns>OrtValye created accoding to the metadata</returns>
internal static OrtValue CreateProjection(NamedOnnxValue namedOnnxValue, NodeMetadata metadata)
{
OrtValue result;
@ -67,8 +67,7 @@ namespace Microsoft.ML.OnnxRuntime
/// </summary>
/// <param name="namedOnnxValue">NamedOnnxValue containing a IEnumerable<NameOnnValue></param>
/// <param name="metadata">sequence metadata</param>
/// <param name="disposables">cleanup list</param>
/// <returns></returns>
/// <returns>OrtValue that represents a sequence</returns>
/// <exception cref="OnnxRuntimeException"></exception>
private static OrtValue CreateSequenceProjection(NamedOnnxValue namedOnnxValue, NodeMetadata metadata)
{
@ -84,8 +83,8 @@ namespace Microsoft.ML.OnnxRuntime
capacity = collection.Count;
}
// Record all the ortValues belonging to the sequence locally
using (var sequenceOrtValues = new DisposableList<OrtValue>(capacity))
DisposableList<OrtValue> sequenceOrtValues = new(capacity);
try
{
foreach (var element in seqContainer)
{
@ -97,7 +96,12 @@ namespace Microsoft.ML.OnnxRuntime
sequenceOrtValues.Add(CreateProjection(element, elementMeta));
}
return OrtValue.CreateSequence(sequenceOrtValues);
return OrtValue.CreateSequence(ref sequenceOrtValues);
}
catch(Exception)
{
sequenceOrtValues?.Dispose();
throw;
}
}
@ -107,7 +111,6 @@ namespace Microsoft.ML.OnnxRuntime
/// </summary>
/// <param name="node"></param>
/// <param name="elementMeta"></param>
/// <param name="disposables"></param>
/// <returns>OrtValue</returns>
/// <exception cref="OnnxRuntimeException"></exception>
private static OrtValue CreateMapProjection(NamedOnnxValue node, NodeMetadata elementMeta)
@ -123,9 +126,13 @@ namespace Microsoft.ML.OnnxRuntime
$"Node: {node.Name} onnxruntime only supports maps with primitive types values");
}
TensorBase keys = node.GetDictionaryKeys();
using (OrtValue ortValueKeys = OrtValue.CreateFromTensorObject(keys, out TensorElementType elementTypeKeys))
Span<OrtValue> ortValues = new OrtValue[2];
var disposableGuard = new DisposableArray<OrtValue>(ortValues);
try
{
TensorBase keys = node.GetDictionaryKeys();
ortValues[0] = OrtValue.CreateFromTensorObject(keys, out TensorElementType elementTypeKeys);
if (elementTypeKeys != mapMeta.KeyDataType)
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
@ -133,39 +140,40 @@ namespace Microsoft.ML.OnnxRuntime
}
TensorBase values = node.GetDictionaryValues();
using (OrtValue ortValueValues = OrtValue.CreateFromTensorObject(values, out TensorElementType elementTypeValues))
ortValues[1] = OrtValue.CreateFromTensorObject(values, out TensorElementType elementTypeValues);
if (elementTypeValues != mapValuesMeta.ElementDataType)
{
if (elementTypeValues != mapValuesMeta.ElementDataType)
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
$"Map value data type supplied: {elementTypeValues} metadata expected: {mapValuesMeta.ElementDataType}");
}
// Create Map OrtValue
return OrtValue.CreateMap(ortValueKeys, ortValueValues);
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
$"Map value data type supplied: {elementTypeValues} metadata expected: {mapValuesMeta.ElementDataType}");
}
// Create Map OrtValue
return OrtValue.CreateMap(ref ortValues[0], ref ortValues[1]);
}
catch (Exception)
{
disposableGuard.Dispose();
throw;
}
}
/// <summary>
/// This pins memory that is contained within DenseTensor.
/// </summary>
/// <param name="node">NodeOnnxValue containing DenseTensor</param>
/// <param name="elementMeta"></param>
/// <param name="disposables">cleanup list</param>
/// <returns></returns>
/// <exception cref="OnnxRuntimeException"></exception>
private static OrtValue CreateTensorProjection(NamedOnnxValue node, NodeMetadata elementMeta)
{
if (!(node.Value is TensorBase))
if (node.Value is not TensorBase)
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
$"NamedOnnxValue contains: {node.Value.GetType()}, expecting a Tensor<T>");
}
OrtValue ortValue = OrtValue.CreateFromTensorObject(node.Value as TensorBase, out TensorElementType elementType);
try
try
{
if (elementType != elementMeta.ElementDataType)
{
@ -173,7 +181,7 @@ namespace Microsoft.ML.OnnxRuntime
$"Tensor element data type discovered: {elementType} metadata expected: {elementMeta.ElementDataType}");
}
}
catch(Exception)
catch (Exception)
{
ortValue.Dispose();
throw;

View file

@ -66,7 +66,7 @@
<PropertyGroup>
<Platforms>AnyCPU;x86</Platforms>
<LangVersion>7.3</LangVersion>
<LangVersion>default</LangVersion>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<SignAssembly>true</SignAssembly>
<AssemblyOriginatorKeyFile>..\..\OnnxRuntime.snk</AssemblyOriginatorKeyFile>

View file

@ -5,6 +5,7 @@ using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.Linq;
using System.Runtime.InteropServices;
@ -39,25 +40,60 @@ namespace Microsoft.ML.OnnxRuntime
/// </summary>
public class OrtValue : IOrtValueOwner, IDisposable
{
// OrtValues that are members of Sequences or Maps that map. They potentially map managed memory and we need to keep them around.
// this exists only when we deal with compose ML types.
private DisposableList<OrtValue> _compositeMembers;
private IntPtr _handle;
private MemoryHandle? _memHandle; // Present when the OrtValue is created on top of managed memory
private bool _disposed;
/// <summary>
/// Constructor. The newly constructed OrtValue takes ownership of the native OrtValue instance
/// and disposes of it when the OrtValue instance is disposed.
/// </summary>
/// <param name="handle">Pointer to a native instance of OrtValue</param>
/// <param name="onnxValueType">OnnxValue type if known, otherwise the constructor would interrogate
/// the handle</param>
internal OrtValue(IntPtr handle, OnnxValueType onnxValueType = OnnxValueType.ONNX_TYPE_UNKNOWN)
internal OrtValue(IntPtr handle)
{
_handle = handle;
OnnxType = onnxValueType;
if (OnnxType == OnnxValueType.ONNX_TYPE_UNKNOWN)
InitOnnxType();
}
/// <summary>
/// Constructor. The newly constructed OrtValue takes ownership of the native OrtValue instance
/// </summary>
/// <param name="handle"></param>
/// <param name="onnxValueType"></param>
/// <exception cref="ArgumentException">thrown when onnxValue type is not known</exception>
internal OrtValue(IntPtr handle, OnnxValueType onnxValueType)
{
if (onnxValueType == OnnxValueType.ONNX_TYPE_UNKNOWN)
{
InitOnnxType();
throw new ArgumentException("onnxValueType argument is passed as unknown");
}
_handle = handle;
OnnxType = onnxValueType;
}
/// <summary>
/// Constructor. The newly constructed OrtValue takes ownership of the native OrtValue instance
/// and disposes of it when the OrtValue instance is disposed. The instance will take ownership and will
/// dispose of compositeMembers instances.
///
/// This constructor can only throw if OnnxType is not specified.
/// </summary>
/// <param name="handle">native ortValue handle</param>
/// <param name="onnxValueType">must one of the valid types</param>
/// <param name="compositeMembers">For composite types this contains dependent ortValues such as members of a sequence
/// or keys/values for the map, that may have been created on top of the managed memory and must be disposed
/// with the new ortValue. This container will be taken the ownership of and the argument will be set to null.</param>
/// <exception cref="ArgumentException">throws when onnxValueType is not specified</exception>
internal OrtValue(IntPtr handle, OnnxValueType onnxValueType, ref DisposableList<OrtValue> compositeMembers)
{
if (onnxValueType == OnnxValueType.ONNX_TYPE_UNKNOWN)
{
throw new ArgumentException("onnxValueType argument is passed as unknown");
}
_handle = handle;
OnnxType = onnxValueType;
_compositeMembers = compositeMembers;
compositeMembers = null;
}
/// <summary>
@ -165,7 +201,7 @@ namespace Microsoft.ML.OnnxRuntime
/// <typeparam name="T"></typeparam>
/// <returns>ReadOnlySpan<typeparamref name="T"/></returns>
/// <exception cref="OnnxRuntimeException"></exception>
public ReadOnlySpan<T> GetTensorDataAsSpan<T>() where T : struct
public ReadOnlySpan<T> GetTensorDataAsSpan<T>() where T : unmanaged
{
var byteSpan = GetTensorBufferRawData(typeof(T));
return MemoryMarshal.Cast<byte, T>(byteSpan);
@ -185,7 +221,7 @@ namespace Microsoft.ML.OnnxRuntime
/// </summary>
/// <typeparam name="T"></typeparam>
/// <returns>Typed Span over the native buffer</returns>
public Span<T> GetTensorMutableDataAsSpan<T>() where T : struct
public Span<T> GetTensorMutableDataAsSpan<T>() where T : unmanaged
{
var byteSpan = GetTensorBufferRawData(typeof(T));
return MemoryMarshal.Cast<byte, T>(byteSpan);
@ -505,6 +541,7 @@ namespace Microsoft.ML.OnnxRuntime
/// <returns>A disposable OrtValue instance</returns>
/// <exception cref="OnnxRuntimeException"></exception>
public static OrtValue CreateTensorValueFromMemory<T>(OrtMemoryInfo memoryInfo, Memory<T> memory, long[] shape)
where T : unmanaged
{
var typeInfo = TensorBase.GetTypeInfo(typeof(T)) ??
throw new OnnxRuntimeException(ErrorCode.InvalidArgument, $"Tensor of type: {typeof(T)} is not supported");
@ -561,7 +598,7 @@ namespace Microsoft.ML.OnnxRuntime
/// <param name="data">managed data buffer</param>
/// <param name="shape">shape that describes the buffer</param>
/// <returns>A disposable OrtValue instance</returns>
public static OrtValue CreateTensorValueFromMemory<T>(T[] data, long[] shape)
public static OrtValue CreateTensorValueFromMemory<T>(T[] data, long[] shape) where T : unmanaged
{
return OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, new Memory<T>(data), shape);
}
@ -847,56 +884,340 @@ namespace Microsoft.ML.OnnxRuntime
/// All OrtValues in the collection must be of the same Onnx type.
/// I.e. (Tensor, SparseTensor, Map, Sequence, etc.)
///
/// All OrtValues are internally ref-counted and stored within the sequence OrtValue
/// so the input OrtValues can be disposed of after this call.
/// The ortValues that are passed as argument are taken possession of by the newly
/// created OrtValue. The caller should not dispose them, unless this call fails.
///
/// The ortValues would be empty on successful return.
/// </summary>
/// <param name="ortValues">a collection of OrtValues</param>
/// <param name="ortValues">a collection of OrtValues. On success the ortValues contained in the list
/// are taken ownership of and the list is cleared.</param>
/// <returns>A disposable instance of OrtValues</returns>
/// <exception cref="ArgumentNullException"></exception>
public static OrtValue CreateSequence(IReadOnlyCollection<OrtValue> ortValues)
public static OrtValue CreateSequence(ICollection<OrtValue> ortValues)
{
if (ortValues is null)
{
throw new ArgumentNullException(nameof(ortValues));
}
var handles = new IntPtr[ortValues.Count];
for (int i = 0; i < ortValues.Count; i++)
if (ortValues.IsReadOnly)
{
handles[i] = ortValues.ElementAt(i).Handle;
throw new ArgumentException("ortValues argument can not be a readonly collection");
}
var compositeMembers = new DisposableList<OrtValue>(ortValues);
try
{
var result = CreateSequence(ref compositeMembers);
Debug.Assert(compositeMembers is null, "Must be null on success");
ortValues.Clear();
return result;
}
catch (Exception)
{
// The caller is responsible for disposing the ortValues
compositeMembers?.Clear();
throw;
}
}
/// <summary>
/// Creates a sequence from the values in compositeMembers
/// The argument is taken possession of and is nullified on successful return.
/// </summary>
/// <param name="compositeMembers">sequence ortValues</param>
/// <returns>OrtValue instance representing a Sequence</returns>
internal static OrtValue CreateSequence(ref DisposableList<OrtValue> compositeMembers)
{
var handles = new IntPtr[compositeMembers.Count];
for (int i = 0; i < compositeMembers.Count; i++)
{
handles[i] = compositeMembers[i].Handle;
}
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateValue(handles,
(UIntPtr)ortValues.Count, (IntPtr)OnnxValueType.ONNX_TYPE_SEQUENCE,
out IntPtr sequenceHandle));
return new OrtValue(sequenceHandle, OnnxValueType.ONNX_TYPE_SEQUENCE);
(UIntPtr)handles.Length, (IntPtr)OnnxValueType.ONNX_TYPE_SEQUENCE,
out IntPtr sequenceHandle));
return new OrtValue(sequenceHandle, OnnxValueType.ONNX_TYPE_SEQUENCE, ref compositeMembers);
}
/// <summary>
/// A delegate type that is expected to process each OrtValue in a sequence.
/// </summary>
/// <param name="ortValue">OrtValue that holds sequence element</param>
/// <param name="index">ordinal of the value</param>
public delegate void SequenceElementVisitor(OrtValue ortValue, int index);
/// <summary>
/// Feeds each OrtValue in a sequence to the visitor delegate.
/// This helps users to avoid dealing each value life-span
/// </summary>
/// <param name="visitor">visitor delegate</param>
/// <param name="allocator">allocator to use for intermediate ort values</param>
/// <exception cref="OnnxRuntimeException"></exception>
public void ProcessSequence(SequenceElementVisitor visitor, OrtAllocator allocator)
{
if (OnnxType != OnnxValueType.ONNX_TYPE_SEQUENCE)
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
$"OrtValue.OnnxType of {OnnxType} is not a sequence");
}
int count = GetValueCount();
for (int i = 0; i < count; i++)
{
using var ortValue = GetValue(i, allocator);
visitor(ortValue, i);
}
}
/// <summary>
/// Creates a map OrtValue with keys and values.
/// ORT supports only a subset of types for keys and values.
/// We are not restricting them here.
/// On a high level the Onnxruntime representation of the map always consists of two
/// OrtValues, keys and values.
///
/// All OrtValues are internally ref-counted and stored within the map OrtValue
/// so the input OrtValues can be disposed of after this call.
/// According to ONNX standard map keys can be unmanaged types only (or strings).
/// Those keys are contained in a single tensor within OrtValue keys.
///
/// Map values, on the other hand, can be composite types. The values parameter
/// can either contain a single tensor with unmanaged map values with the same number of
/// elements as the keys, or it can be a sequence of OrtValues,
/// each of those can be a composite type (tensor, sequence, map). If it is a sequence,
/// then the number of elements must match the number of elements in keys.
///
/// Keys and values must be in the same order.
///
/// ORT supports only a subset of types for keys and values, however, this API does not
/// restrict it.
///
/// The ortValues that are passed as argument are taken possession of by the newly
/// created OrtValue. The caller should not dispose them, unless this call fails.
///
/// Keys and values arguments will be set to null on success.
/// </summary>
/// <param name="keys">Contains keys</param>
/// <param name="values">Contains values</param>
/// <returns>A disposable OrtValue</returns>
/// <exception cref="ArgumentNullException"></exception>
public static OrtValue CreateMap(OrtValue keys, OrtValue values)
public static OrtValue CreateMap(ref OrtValue keys, ref OrtValue values)
{
if (keys is null || values is null)
{
throw new ArgumentNullException($"keys or/and values are null");
throw new ArgumentNullException("keys or/and values are null");
}
IntPtr[] handles = { keys.Handle, values.Handle };
NativeApiStatus.VerifySuccess(
NativeMethods.OrtCreateValue(handles, (UIntPtr)handles.Length, (IntPtr)OnnxValueType.ONNX_TYPE_MAP,
out IntPtr mapHandle));
return new OrtValue(mapHandle, OnnxValueType.ONNX_TYPE_MAP);
var compositeMembers = new DisposableList<OrtValue>
{
keys,
values
};
keys = null;
values = null;
// This constructor will not throw.
return new OrtValue(mapHandle, OnnxValueType.ONNX_TYPE_MAP, ref compositeMembers);
}
/// <summary>
/// This API helps to quickly creates a map OrtValue with unmanaged (primitive) keys and values specified as arrays.
/// This helps the user not to create OrtValues for keys and values separately and deal only with the final result.
/// The map would consist of two tensors, one for keys and one for values.
///
/// The OrtValues would be created on top of the managed memory arrays and use it directly.
/// The number of elements in keys and values must be the same and they must be in order.
///
/// The types must be unmanaged.
/// </summary>
/// <typeparam name="K">keys type</typeparam>
/// <typeparam name="V">values type</typeparam>
/// <param name="keys">array of keys of K type</param>
/// <param name="values">array of values of V type</param>
/// <returns>OrtValue instance</returns>
/// <exception cref="ArgumentNullException"></exception>
/// <exception cref="ArgumentException"></exception>
public static OrtValue CreateMap<K, V>(K[] keys, V[] values) where K : unmanaged where V : unmanaged
{
if (keys is null || values is null)
{
throw new ArgumentNullException("Keys or/and values are null");
}
if (keys.Length != values.Length)
{
throw new ArgumentException("Expecting keys and values same len. " +
$"Received keys: {keys.Length}, Values: {values.Length}");
}
long[] shape = { keys.Length };
Span<OrtValue> ortValues = new OrtValue[2];
var disposableGuard = new DisposableArray<OrtValue>(ortValues);
try
{
ortValues[0] = CreateTensorValueFromMemory(keys, shape);
ortValues[1] = CreateTensorValueFromMemory(values, shape);
return CreateMap(ref ortValues[0], ref ortValues[1]);
}
catch (Exception)
{
disposableGuard.Dispose();
throw;
}
}
/// <summary>
/// Creates a map OrtValue with string keys and non-string values.
/// This helps the user not to create OrtValues for keys and values separately.
/// The number of elements in keys and values must be the same and they must be in order.
/// The map would consist of two tensors, one for keys and one for values.
///
/// string keys would be converted to UTF-8 encoding and copied to an allocated native memory.
/// The OrtValue for values would be created on top of the managed memory using it directly.
///
/// The values type must be unmanaged.
/// </summary>
/// <typeparam name="V"></typeparam>
/// <param name="keys">Collection of strings</param>
/// <param name="values"></param>
/// <returns>OrtValue instance</returns>
/// <exception cref="ArgumentNullException"></exception>
/// <exception cref="ArgumentException"></exception>
public static OrtValue CreateMapWithStringKeys<V>(IReadOnlyCollection<string> keys, V[] values) where V : unmanaged
{
if (keys is null || values is null)
{
throw new ArgumentNullException("Keys or/and values are null");
}
if (keys.Count != values.Length)
{
throw new ArgumentException("Expecting keys and values same len. " +
$"Received keys: {keys.Count}, Values: {values.Length}");
}
long[] shape = { keys.Count };
Span<OrtValue> ortValues = new OrtValue[2];
var disposableGuard = new DisposableArray<OrtValue>(ortValues);
try
{
ortValues[0] = CreateTensorWithEmptyStrings(OrtAllocator.DefaultInstance, shape);
int count = 0;
foreach (var key in keys)
{
ortValues[0].FillStringTensorElement(key.AsSpan(), count++);
}
ortValues[1] = CreateTensorValueFromMemory(values, shape);
return CreateMap(ref ortValues[0], ref ortValues[1]);
}
catch (Exception)
{
disposableGuard.Dispose();
throw;
}
}
/// <summary>
/// Creates a map OrtValue with non-string keys and string values.
///
/// This helps the user not to create OrtValues for keys and values separately.
/// The number of elements in keys and values must be the same and they must be in order.
///
/// The OrtValue for keys would be created on top of the managed memory using it directly.
/// string values would be converted to UTF-8 encoding and copied to an allocated native memory.
///
/// </summary>
/// <typeparam name="K">unmanaged type of keys</typeparam>
/// <param name="keys"></param>
/// <param name="values">collection of string values</param>
/// <returns>Instance of OrtValue</returns>
/// <exception cref="ArgumentNullException"></exception>
/// <exception cref="ArgumentException"></exception>
public static OrtValue CreateMapWithStringValues<K>(K[] keys, IReadOnlyCollection<string> values) where K : unmanaged
{
if (keys is null || values is null)
{
throw new ArgumentNullException("Keys or/and values are null");
}
if (keys.Length != values.Count)
{
throw new ArgumentException("Expecting keys and values same len. " +
$"Received keys: {keys.Length}, Values: {values.Count}");
}
long[] shape = { keys.Length };
Span<OrtValue> ortValues = new OrtValue[2];
var disposableGuard = new DisposableArray<OrtValue>(ortValues);
try
{
ortValues[0] = CreateTensorValueFromMemory(keys, shape);
ortValues[1] = CreateTensorWithEmptyStrings(OrtAllocator.DefaultInstance, shape);
int count = 0;
foreach (var value in values)
{
ortValues[1].FillStringTensorElement(value.AsSpan(), count++);
}
return CreateMap(ref ortValues[0], ref ortValues[1]);
}
catch (Exception)
{
disposableGuard.Dispose();
throw;
}
}
/// <summary>
/// A public delegate that will be invoked once with map keys and values.
/// The delegate helps not to deal with the lifespan of intermediate OrtValues.
/// Typically, when one uses GetValue() API, it creates a copy of OrtValue
/// that points to the same buffer as keys or values. This API helps to deal with those
/// temporary instances and avoid leaks.
///
/// According to ONNX standard map keys can be unmanaged types only (or strings).
/// Those keys are contained in a single tensor within OrtValue keys. So you can query those
/// directly from keys argument.
///
/// Map values, on the other hand, can be composite types. The values parameter
/// can either contain a single tensor with unmanaged map values with the same number of
/// elements as the keys, or it can be a sequence of OrtValues,
/// each of those can be a composite type (tensor, sequence, map). If it is a sequence,
/// then the number of elements must match the number of elements in keys.
///
/// Depending on the structure of the values, one will either directly query a single tensor
/// from values, or will have to iterate over the sequence of OrtValues and visit each of those
/// resulting in a recursive visitation.
/// </summary>
/// <param name="keys">This would always represent a tensor</param>
/// <param name="values">Can be any of the Onnx types, but they would all reduce to tensors eventually</param>
public delegate void MapVisitor(OrtValue keys, OrtValue values);
/// <summary>
/// This API helps the user to process a map OrtValue without
/// having to deal with the lifespan of intermediate OrtValues.
///
/// each API value is fed to the vistor functor.
/// </summary>
/// <param name="visitor">visitor function</param>
/// <param name="allocator">Allocator to use for intermediate values</param>
/// <exception cref="OnnxRuntimeException"></exception>
public void ProcessMap(MapVisitor visitor, OrtAllocator allocator)
{
if (OnnxType != OnnxValueType.ONNX_TYPE_MAP)
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "This OrtValue does not represent a map");
}
using var keys = GetValue(0, allocator);
using var values = GetValue(1, allocator);
visitor(keys, values);
}
private unsafe void FillStringTensorElement(char* strPtr, int strLength, int index)
@ -973,6 +1294,8 @@ namespace Microsoft.ML.OnnxRuntime
{
_memHandle?.Dispose();
_memHandle = null;
_compositeMembers?.Dispose();
_compositeMembers = null;
}
Debug.Assert(_handle != IntPtr.Zero);

View file

@ -6,11 +6,12 @@
<Platform>AnyCPU</Platform>
<OutputPath>bin\$(Configuration)\</OutputPath>
<!-- arbitrary version for testing locally. when used in a CI CurrentOnnxRuntimeVersion should always be specified and match the package being tested -->
<CurrentOnnxRuntimeVersion Condition="'$(CurrentOnnxRuntimeVersion)' == ''">1.9.0</CurrentOnnxRuntimeVersion>
<CurrentOnnxRuntimeVersion Condition="'$(CurrentOnnxRuntimeVersion)' == ''">1.15.0</CurrentOnnxRuntimeVersion>
<PackageName Condition="'$(PACKAGENAME)' == ''">Microsoft.ML.OnnxRuntime</PackageName>
<IsLinuxBuild Condition="'$([System.Runtime.InteropServices.RuntimeInformation]::IsOSPlatform($([System.Runtime.InteropServices.OSPlatform]::Linux)))' == 'true'">true</IsLinuxBuild>
<IsWindowsBuild Condition="'$([System.Runtime.InteropServices.RuntimeInformation]::IsOSPlatform($([System.Runtime.InteropServices.OSPlatform]::Windows)))' == 'true'">true</IsWindowsBuild>
<IsMacOSBuild Condition="'$([System.Runtime.InteropServices.RuntimeInformation]::IsOSPlatform($([System.Runtime.InteropServices.OSPlatform]::OSX)))' == 'true'">true</IsMacOSBuild>
<LangVersion>default</LangVersion>
<AllowUnsafeBlocks>True</AllowUnsafeBlocks>
<SignAssembly>true</SignAssembly>
<AssemblyOriginatorKeyFile>..\..\OnnxRuntime.snk</AssemblyOriginatorKeyFile>
@ -47,9 +48,10 @@
<ItemGroup>
<BuildEnvVars Include="OnnxRuntimeBuildDirectory=$(OnnxRuntimeBuildDirectory)" />
<Compile Include="..\Microsoft.ML.OnnxRuntime.Tests.Common\InferenceTest.cs"/>
<Compile Include="..\Microsoft.ML.OnnxRuntime.Tests.Common\OnnxMl.cs"/>
<Compile Include="..\Microsoft.ML.OnnxRuntime.Tests.Common\OnnxData.cs"/>
<Compile Include="..\Microsoft.ML.OnnxRuntime.Tests.Common\InferenceTest.cs" />
<Compile Include="..\Microsoft.ML.OnnxRuntime.Tests.Common\EqualityComparers.cs" />
<Compile Include="..\Microsoft.ML.OnnxRuntime.Tests.Common\OnnxMl.cs" />
<Compile Include="..\Microsoft.ML.OnnxRuntime.Tests.Common\OnnxData.cs" />
<Compile Include="..\Microsoft.ML.OnnxRuntime.Tests.NetCoreApp\InferenceTest.netcore.cs" />
<Compile Include="..\Microsoft.ML.OnnxRuntime.Tests.Common\TestDataLoader.cs" />
@ -102,4 +104,5 @@
<ItemGroup>
<Service Include="{508349b6-6b84-4df5-91f0-309beebad82d}" />
</ItemGroup>
</Project>

View file

@ -14,7 +14,7 @@
<ProtoSrc>$(OnnxRuntimeCsharpRoot)\..\cmake\external\onnx</ProtoSrc>
<!-- following attributes were necessary for the migrated Tensor tests -->
<LangVersion>7.2</LangVersion>
<LangVersion>default</LangVersion>
<AllowUnsafeBlocks>True</AllowUnsafeBlocks>
<SignAssembly>true</SignAssembly> <!-- need signing for friend access to the internals of the Tensors assembly -->
<AssemblyOriginatorKeyFile>..\..\OnnxRuntime.snk</AssemblyOriginatorKeyFile>

View file

@ -120,7 +120,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
}
static void VerifyTensorCreateWithData<T>(OrtValue tensor, TensorElementType dataType, long[] shape,
ReadOnlySpan<T> originalData) where T : struct
ReadOnlySpan<T> originalData) where T : unmanaged
{
// Verify invocation
var dataTypeInfo = TensorBase.GetTypeInfo(typeof(T));
@ -172,7 +172,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
// The tensor will be created on top of the managed memory. No copy is made.
// The memory should stay pinned until the OrtValue instance is disposed. This means
// stayed pinned until the end of Run() method when you are actually running inference.
using(var tensor = OrtValue.CreateTensorValueFromMemory(data, shape))
using (var tensor = OrtValue.CreateTensorValueFromMemory(data, shape))
{
VerifyTensorCreateWithData<int>(tensor, TensorElementType.Int32, shape, data);
}
@ -215,7 +215,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
}
private static void PopulateAndCheck<T>(T[] data) where T : struct
private static void PopulateAndCheck<T>(T[] data) where T : unmanaged
{
var typeInfo = TensorBase.GetTypeInfo(typeof(T));
Assert.NotNull(typeInfo);
@ -255,80 +255,92 @@ namespace Microsoft.ML.OnnxRuntime.Tests
private static readonly long[] ml_data_2 = { 3, 4 };
// Use this utility method to create two tensors for Map and Sequence tests
private static Tuple<OrtValue, OrtValue> CreateTwoTensors(IList<IDisposable> cleanup)
private static void CreateTwoTensors(out OrtValue val1, out OrtValue val2)
{
const int ml_data_dim = 2;
// For map tensors they must be single dimensional
long[] shape = { ml_data_dim };
unsafe
{
var ortValue_1 = OrtValue.CreateTensorValueFromMemory(ml_data_1, shape);
cleanup.Add(ortValue_1);
var ortValue_2 = OrtValue.CreateTensorValueFromMemory(ml_data_2, shape);
cleanup.Add(ortValue_2);
return Tuple.Create(ortValue_1, ortValue_2);
}
val1 = OrtValue.CreateTensorValueFromMemory(ml_data_1, shape);
val2 = OrtValue.CreateTensorValueFromMemory(ml_data_2, shape);
}
[Fact(DisplayName = "CreateMap")]
public void CreateMap()
[Fact(DisplayName = "CreateMapFromValues")]
public void CreateMapFromValues()
{
using (var cleanUp = new DisposableListTest<IDisposable>())
{
var valTuple = CreateTwoTensors(cleanUp);
using (var map = OrtValue.CreateMap(valTuple.Item1, valTuple.Item2))
{
Assert.Equal(OnnxValueType.ONNX_TYPE_MAP, map.OnnxType);
var typeInfo = map.GetTypeInfo();
var mapInfo = typeInfo.MapTypeInfo;
Assert.Equal(TensorElementType.Int64, mapInfo.KeyType);
Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, mapInfo.ValueType.OnnxType);
CreateTwoTensors(out OrtValue keys, out OrtValue values);
using var map = OrtValue.CreateMap(ref keys, ref values);
Assert.Equal(OnnxValueType.ONNX_TYPE_MAP, map.OnnxType);
var typeInfo = map.GetTypeInfo();
var mapInfo = typeInfo.MapTypeInfo;
Assert.Equal(TensorElementType.Int64, mapInfo.KeyType);
Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, mapInfo.ValueType.OnnxType);
// Must return always 2 for map since we have two ort values
Assert.Equal(2, map.GetValueCount());
// Must return always 2 for map since we have two ort values
Assert.Equal(2, map.GetValueCount());
var keys = map.GetValue(0, OrtAllocator.DefaultInstance);
cleanUp.Add(keys);
Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, keys.OnnxType);
Assert.Equal(ml_data_1, keys.GetTensorDataAsSpan<long>().ToArray());
map.ProcessMap((keys, values) => {
Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, keys.OnnxType);
Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, values.OnnxType);
Assert.Equal(ml_data_1, keys.GetTensorDataAsSpan<long>().ToArray());
Assert.Equal(ml_data_2, values.GetTensorDataAsSpan<long>().ToArray());
var vals = map.GetValue(1, OrtAllocator.DefaultInstance);
cleanUp.Add(vals);
Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, vals.OnnxType);
Assert.Equal(ml_data_2, vals.GetTensorDataAsSpan<long>().ToArray());
}
}
}, OrtAllocator.DefaultInstance);
}
[Fact(DisplayName = "CreateMapFromArraysUnmanaged")]
public void CreateMapFromArraysUnmanaged()
{
long[] keys = { 1, 2, 3 };
float[] vals = { 1, 2, 3 };
using var map = OrtValue.CreateMap(keys, vals);
}
[Fact(DisplayName = "CreateMapWithStringKeys")]
public void CreateMapWithStringKeys()
{
string[] keys = { "one", "two", "three" };
float[] vals = { 1, 2, 3 };
using var map = OrtValue.CreateMapWithStringKeys(keys, vals);
}
[Fact(DisplayName = "CreateMapWithStringValues")]
public void CreateMapWithStringValues()
{
long[] keys = { 1, 2, 3 };
string[] values = { "one", "two", "three" };
using var map = OrtValue.CreateMapWithStringValues(keys, values);
}
[Fact(DisplayName = "CreateSequence")]
public void CreateSequence()
{
using (var cleanUp = new DisposableListTest<IDisposable>())
CreateTwoTensors(out OrtValue val1, out OrtValue val2);
using var seqVals = new DisposableListTest<OrtValue> { val1, val2 };
using var seq = OrtValue.CreateSequence(seqVals);
Assert.Equal(OnnxValueType.ONNX_TYPE_SEQUENCE, seq.OnnxType);
var typeInfo = seq.GetTypeInfo();
var seqInfo = typeInfo.SequenceTypeInfo;
Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, seqInfo.ElementType.OnnxType);
// Will return 2 because we put 2 values in the sequence
Assert.Equal(2, seq.GetValueCount());
// Visit each element in the sequence
seq.ProcessSequence((ortValue, index) =>
{
var valTuple = CreateTwoTensors(cleanUp);
OrtValue[] seqVals = { valTuple.Item1, valTuple.Item2 };
using (var seq = OrtValue.CreateSequence(seqVals))
// We know both elements are tensors of long
Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, ortValue.OnnxType);
if (index == 0)
{
Assert.Equal(OnnxValueType.ONNX_TYPE_SEQUENCE, seq.OnnxType);
var typeInfo = seq.GetTypeInfo();
var seqInfo = typeInfo.SequenceTypeInfo;
Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, seqInfo.ElementType.OnnxType);
// Will return 2 because we put 2 values in the sequence
Assert.Equal(2, seq.GetValueCount());
var item_0 = seq.GetValue(0, OrtAllocator.DefaultInstance);
cleanUp.Add(item_0);
Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, item_0.OnnxType);
Assert.Equal(ml_data_1, item_0.GetTensorDataAsSpan<long>().ToArray());
var item_1 = seq.GetValue(1, OrtAllocator.DefaultInstance);
cleanUp.Add(item_1);
Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, item_1.OnnxType);
Assert.Equal(ml_data_2, item_1.GetTensorDataAsSpan<long>().ToArray());
Assert.Equal(ml_data_1, ortValue.GetTensorDataAsSpan<long>().ToArray());
}
}
else
{
Assert.Equal(ml_data_2, ortValue.GetTensorDataAsSpan<long>().ToArray());
}
}, OrtAllocator.DefaultInstance);
}
}
}

View file

@ -16,12 +16,16 @@ namespace Microsoft.ML.OnnxRuntime.Tests
where T : IDisposable
{
public DisposableListTest()
{}
{ }
public DisposableListTest(IEnumerable<T> enumerable) : base(enumerable)
{ }
public DisposableListTest(int count)
: base(count)
{}
{ }
#region IDisposable Support
#region IDisposable Support
private bool disposedValue = false; // To detect redundant calls
protected virtual void Dispose(bool disposing)
@ -54,7 +58,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
Dispose(true);
GC.SuppressFinalize(this);
}
#endregion
#endregion
}
internal struct DisposableTestPair<TValue> : IDisposable
@ -218,26 +222,26 @@ namespace Microsoft.ML.OnnxRuntime.Tests
switch (nodeMeta.OnnxValueType)
{
case OnnxValueType.ONNX_TYPE_TENSOR:
{
var tensor = Onnx.TensorProto.Parser.ParseFrom(file);
return LoadTensorPb(tensor, nodeName, nodeMeta);
}
{
var tensor = Onnx.TensorProto.Parser.ParseFrom(file);
return LoadTensorPb(tensor, nodeName, nodeMeta);
}
case OnnxValueType.ONNX_TYPE_SEQUENCE:
{
var sequence = Onnx.SequenceProto.Parser.ParseFrom(file);
return CreateNamedOnnxValueFromSequence(sequence, nodeName, nodeMeta);
}
{
var sequence = Onnx.SequenceProto.Parser.ParseFrom(file);
return CreateNamedOnnxValueFromSequence(sequence, nodeName, nodeMeta);
}
case OnnxValueType.ONNX_TYPE_MAP:
{
throw new NotImplementedException(
"Map test data format requires clarification: https://github.com/onnx/onnx/issues/5072");
}
{
throw new NotImplementedException(
"Map test data format requires clarification: https://github.com/onnx/onnx/issues/5072");
}
case OnnxValueType.ONNX_TYPE_OPTIONAL:
{
var opt = Onnx.OptionalProto.Parser.ParseFrom(file);
return CreateNamedOnnxValueFromOptional(opt, nodeName, nodeMeta);
}
{
var opt = Onnx.OptionalProto.Parser.ParseFrom(file);
return CreateNamedOnnxValueFromOptional(opt, nodeName, nodeMeta);
}
default:
throw new NotImplementedException($"Unable to load value type: {nodeMeta.OnnxValueType} not implemented");
}
@ -254,26 +258,26 @@ namespace Microsoft.ML.OnnxRuntime.Tests
switch (nodeMeta.OnnxValueType)
{
case OnnxValueType.ONNX_TYPE_TENSOR:
{
var tensor = Onnx.TensorProto.Parser.ParseFrom(file);
return new DisposableTestPair<OrtValue>(nodeName, LoadOrValueTensorPb(tensor, nodeName, nodeMeta));
}
{
var tensor = Onnx.TensorProto.Parser.ParseFrom(file);
return new DisposableTestPair<OrtValue>(nodeName, LoadOrValueTensorPb(tensor, nodeName, nodeMeta));
}
case OnnxValueType.ONNX_TYPE_SEQUENCE:
{
var sequence = Onnx.SequenceProto.Parser.ParseFrom(file);
return new DisposableTestPair<OrtValue>(nodeName, CreateOrtValueFromSequence(sequence, nodeName, nodeMeta));
}
{
var sequence = Onnx.SequenceProto.Parser.ParseFrom(file);
return new DisposableTestPair<OrtValue>(nodeName, CreateOrtValueFromSequence(sequence, nodeName, nodeMeta));
}
case OnnxValueType.ONNX_TYPE_MAP:
{
throw new NotImplementedException(
"Map test data format requires clarification: https://github.com/onnx/onnx/issues/5072");
}
{
throw new NotImplementedException(
"Map test data format requires clarification: https://github.com/onnx/onnx/issues/5072");
}
case OnnxValueType.ONNX_TYPE_OPTIONAL:
{
var opt = Onnx.OptionalProto.Parser.ParseFrom(file);
return new DisposableTestPair<OrtValue>(nodeName, CreateOrtValueFromOptional(opt, nodeName, nodeMeta));
}
{
var opt = Onnx.OptionalProto.Parser.ParseFrom(file);
return new DisposableTestPair<OrtValue>(nodeName, CreateOrtValueFromOptional(opt, nodeName, nodeMeta));
}
default:
throw new NotImplementedException($"Unable to load value type: {nodeMeta.OnnxValueType} not implemented");
}
@ -309,50 +313,50 @@ namespace Microsoft.ML.OnnxRuntime.Tests
switch (seqElemType)
{
case Onnx.SequenceProto.Types.DataType.Tensor:
{
SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_TENSOR);
var sequenceOfTensors = new List<NamedOnnxValue>(sequence.TensorValues.Count);
foreach (var tensor in sequence.TensorValues)
{
var elemName = MakeSequenceElementName(nodeName, sequence.Name, seqNum++);
var namedOnnxValue = LoadTensorPb(tensor, elemName, elemMeta);
sequenceOfTensors.Add(namedOnnxValue);
SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_TENSOR);
var sequenceOfTensors = new List<NamedOnnxValue>(sequence.TensorValues.Count);
foreach (var tensor in sequence.TensorValues)
{
var elemName = MakeSequenceElementName(nodeName, sequence.Name, seqNum++);
var namedOnnxValue = LoadTensorPb(tensor, elemName, elemMeta);
sequenceOfTensors.Add(namedOnnxValue);
}
return NamedOnnxValue.CreateFromSequence(nodeName, sequenceOfTensors);
}
return NamedOnnxValue.CreateFromSequence(nodeName, sequenceOfTensors);
}
case Onnx.SequenceProto.Types.DataType.Sequence:
{
SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_SEQUENCE);
var seqOfSequences = new List<NamedOnnxValue>(sequence.SequenceValues.Count);
foreach (var s in sequence.SequenceValues)
{
var elemName = MakeSequenceElementName(nodeName, sequence.Name, seqNum++);
seqOfSequences.Add(CreateNamedOnnxValueFromSequence(s, elemName, elemMeta));
SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_SEQUENCE);
var seqOfSequences = new List<NamedOnnxValue>(sequence.SequenceValues.Count);
foreach (var s in sequence.SequenceValues)
{
var elemName = MakeSequenceElementName(nodeName, sequence.Name, seqNum++);
seqOfSequences.Add(CreateNamedOnnxValueFromSequence(s, elemName, elemMeta));
}
return NamedOnnxValue.CreateFromSequence(nodeName, seqOfSequences);
}
return NamedOnnxValue.CreateFromSequence(nodeName, seqOfSequences);
}
case Onnx.SequenceProto.Types.DataType.Map:
{
SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_MAP);
var seqOfMaps = new List<NamedOnnxValue>(sequence.MapValues.Count);
foreach (var m in sequence.MapValues)
{
var elemName = MakeSequenceElementName(nodeName, sequence.Name, seqNum++);
seqOfMaps.Add(CreateNamedOnnxValueFromMap(m, elemName, elemMeta));
SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_MAP);
var seqOfMaps = new List<NamedOnnxValue>(sequence.MapValues.Count);
foreach (var m in sequence.MapValues)
{
var elemName = MakeSequenceElementName(nodeName, sequence.Name, seqNum++);
seqOfMaps.Add(CreateNamedOnnxValueFromMap(m, elemName, elemMeta));
}
return NamedOnnxValue.CreateFromSequence(nodeName, seqOfMaps);
}
return NamedOnnxValue.CreateFromSequence(nodeName, seqOfMaps);
}
case Onnx.SequenceProto.Types.DataType.Optional:
{
SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_OPTIONAL);
var seqOfOpts = new List<NamedOnnxValue>(sequence.OptionalValues.Count);
foreach (var opt in sequence.OptionalValues)
{
var elemName = MakeSequenceElementName(nodeName, sequence.Name, seqNum++);
seqOfOpts.Add(CreateNamedOnnxValueFromOptional(opt, elemName, elemMeta));
SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_OPTIONAL);
var seqOfOpts = new List<NamedOnnxValue>(sequence.OptionalValues.Count);
foreach (var opt in sequence.OptionalValues)
{
var elemName = MakeSequenceElementName(nodeName, sequence.Name, seqNum++);
seqOfOpts.Add(CreateNamedOnnxValueFromOptional(opt, elemName, elemMeta));
}
return NamedOnnxValue.CreateFromSequence(nodeName, seqOfOpts);
}
return NamedOnnxValue.CreateFromSequence(nodeName, seqOfOpts);
}
default:
throw new NotImplementedException($"Sequence test data loading does not support element type: " +
$"'{seqElemType}'");
@ -370,20 +374,20 @@ namespace Microsoft.ML.OnnxRuntime.Tests
switch ((Onnx.OptionalProto.Types.DataType)optional.ElemType)
{
case Onnx.OptionalProto.Types.DataType.Tensor:
{
var tensor = optional.TensorValue;
return LoadTensorPb(tensor, nodeName, meta);
}
{
var tensor = optional.TensorValue;
return LoadTensorPb(tensor, nodeName, meta);
}
case Onnx.OptionalProto.Types.DataType.Sequence:
{
var sequence = optional.SequenceValue;
return CreateNamedOnnxValueFromSequence(sequence, nodeName, meta);
}
{
var sequence = optional.SequenceValue;
return CreateNamedOnnxValueFromSequence(sequence, nodeName, meta);
}
case Onnx.OptionalProto.Types.DataType.Map:
{
var map = optional.MapValue;
return CreateNamedOnnxValueFromMap(map, nodeName, meta);
}
{
var map = optional.MapValue;
return CreateNamedOnnxValueFromMap(map, nodeName, meta);
}
case Onnx.OptionalProto.Types.DataType.Optional:
throw new NotImplementedException($"Unable to load '{nodeName}' optional contained within optional");
default:
@ -454,23 +458,21 @@ namespace Microsoft.ML.OnnxRuntime.Tests
switch (seqElemType)
{
case Onnx.SequenceProto.Types.DataType.Tensor:
{
SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_TENSOR);
using (var sequenceOfTensors = new DisposableListTest<OrtValue>(sequence.TensorValues.Count))
{
SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_TENSOR);
using DisposableListTest<OrtValue> sequenceOfTensors = new(sequence.TensorValues.Count);
foreach (var tensor in sequence.TensorValues)
{
var element = LoadOrValueTensorPb(tensor, sequence.Name, elemMeta);
sequenceOfTensors.Add(element);
}
// Will take possession of ortValues in the sequence and will clear this container
return OrtValue.CreateSequence(sequenceOfTensors);
}
}
case Onnx.SequenceProto.Types.DataType.Sequence: // Sequence of sequences
{
SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_SEQUENCE);
using (var seqOfSequences = new DisposableListTest<OrtValue>(sequence.TensorValues.Count))
{
SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_SEQUENCE);
using DisposableListTest<OrtValue> seqOfSequences = new(sequence.TensorValues.Count);
foreach (var s in sequence.SequenceValues)
{
var elemName = MakeSequenceElementName(nodeName, sequence.Name, seqNum++);
@ -479,17 +481,15 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
return OrtValue.CreateSequence(seqOfSequences);
}
}
case Onnx.SequenceProto.Types.DataType.Map:
{
throw new NotImplementedException(
"Test data format for maps is under investigation");
}
case Onnx.SequenceProto.Types.DataType.Optional:
{
SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_OPTIONAL);
using (var seqOfSequences = new DisposableListTest<OrtValue>(sequence.TensorValues.Count))
{
throw new NotImplementedException(
"Test data format for maps is under investigation");
}
case Onnx.SequenceProto.Types.DataType.Optional:
{
SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_OPTIONAL);
using DisposableListTest<OrtValue> seqOfSequences = new(sequence.TensorValues.Count);
foreach (var opt in sequence.OptionalValues)
{
var elemName = MakeSequenceElementName(nodeName, sequence.Name, seqNum++);
@ -498,7 +498,6 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
return OrtValue.CreateSequence(seqOfSequences);
}
}
default:
throw new NotImplementedException($"Sequence test data loading does not support element type: " +
$"'{seqElemType}'");
@ -511,20 +510,20 @@ namespace Microsoft.ML.OnnxRuntime.Tests
switch ((Onnx.OptionalProto.Types.DataType)optional.ElemType)
{
case Onnx.OptionalProto.Types.DataType.Tensor:
{
var tensor = optional.TensorValue;
return LoadOrValueTensorPb(tensor, nodeName, meta);
}
{
var tensor = optional.TensorValue;
return LoadOrValueTensorPb(tensor, nodeName, meta);
}
case Onnx.OptionalProto.Types.DataType.Sequence:
{
var sequence = optional.SequenceValue;
return CreateOrtValueFromSequence(sequence, nodeName, meta);
}
{
var sequence = optional.SequenceValue;
return CreateOrtValueFromSequence(sequence, nodeName, meta);
}
case Onnx.OptionalProto.Types.DataType.Map:
{
throw new NotImplementedException(
"Test data format for maps is under investigation");
}
{
throw new NotImplementedException(
"Test data format for maps is under investigation");
}
case Onnx.OptionalProto.Types.DataType.Optional:
throw new NotImplementedException($"Unable to load '{nodeName}' optional contained within optional");
default:

View file

@ -11,7 +11,7 @@
<IsMacOSBuild Condition="'$([System.Runtime.InteropServices.RuntimeInformation]::IsOSPlatform($([System.Runtime.InteropServices.OSPlatform]::OSX)))' == 'true'">true</IsMacOSBuild>
<ProtoSrc>$(OnnxSourceDirectory)\onnx</ProtoSrc>
<!-- following attributes were necessary for the migrated Tensor tests -->
<LangVersion>7.2</LangVersion>
<LangVersion>default</LangVersion>
<AllowUnsafeBlocks>True</AllowUnsafeBlocks>
<SignAssembly>true</SignAssembly> <!-- need signing for friend access to the internals of the Tensors assembly -->
<AssemblyOriginatorKeyFile>..\..\OnnxRuntime.snk</AssemblyOriginatorKeyFile>