2023-04-11 16:41:59 +00:00
|
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
|
// Licensed under the MIT License.
|
|
|
|
|
|
|
|
|
|
using Microsoft.ML.OnnxRuntime.Tensors;
|
|
|
|
|
using System;
|
|
|
|
|
using System.Collections.Generic;
|
|
|
|
|
using System.Diagnostics;
|
|
|
|
|
|
|
|
|
|
namespace Microsoft.ML.OnnxRuntime
|
|
|
|
|
{
|
|
|
|
|
/// <summary>
|
|
|
|
|
/// The class helps to feed the NamedOnnxValue as inference input.
|
|
|
|
|
/// It projects managed classes to OrtValues so they can be consumed
|
|
|
|
|
/// by the native onnxruntime library. if possible, it will avoid copying data.
|
|
|
|
|
/// The NamedOnnxValue can be a tensor, sequence or map.
|
|
|
|
|
/// For recursive structures, create nested NamedOnnxValue instances.
|
|
|
|
|
/// For example, a sequence instance would contain a list of NamedOnnxValue instances
|
|
|
|
|
/// that in turn may represent tensors or other ONNX values.
|
|
|
|
|
/// </summary>
|
2023-06-29 15:59:23 +00:00
|
|
|
internal class ManagedTypeProjection
|
2023-04-11 16:41:59 +00:00
|
|
|
{
|
|
|
|
|
/// <summary>
|
|
|
|
|
/// Dispatches the creation of the projection
|
|
|
|
|
/// </summary>
|
|
|
|
|
/// <param name="namedOnnxValue"></param>
|
|
|
|
|
/// <param name="metadata"></param>
|
|
|
|
|
/// <param name="disposables"></param>
|
|
|
|
|
/// <returns></returns>
|
2023-06-29 15:59:23 +00:00
|
|
|
internal static OrtValue CreateProjection(NamedOnnxValue namedOnnxValue, NodeMetadata metadata)
|
2023-04-11 16:41:59 +00:00
|
|
|
{
|
|
|
|
|
OrtValue result;
|
|
|
|
|
|
|
|
|
|
NodeMetadata meta = metadata;
|
|
|
|
|
// Use element meta to create types
|
|
|
|
|
if (metadata.OnnxValueType == OnnxValueType.ONNX_TYPE_OPTIONAL)
|
|
|
|
|
{
|
|
|
|
|
meta = metadata.AsOptionalMetadata().ElementMeta;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (namedOnnxValue.ValueType != meta.OnnxValueType)
|
|
|
|
|
{
|
|
|
|
|
throw new OnnxRuntimeException(ErrorCode.RuntimeException,
|
|
|
|
|
$"NamedOnnxValue: {namedOnnxValue.Name} has value type: {namedOnnxValue.ValueType}" +
|
|
|
|
|
$" expected: {meta.OnnxValueType} after optional type adjustment");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
switch (namedOnnxValue.ValueType)
|
|
|
|
|
{
|
|
|
|
|
case OnnxValueType.ONNX_TYPE_TENSOR:
|
2023-06-29 15:59:23 +00:00
|
|
|
result = CreateTensorProjection(namedOnnxValue, meta);
|
2023-04-11 16:41:59 +00:00
|
|
|
break;
|
|
|
|
|
case OnnxValueType.ONNX_TYPE_SEQUENCE:
|
2023-06-29 15:59:23 +00:00
|
|
|
result = CreateSequenceProjection(namedOnnxValue, meta);
|
2023-04-11 16:41:59 +00:00
|
|
|
break;
|
|
|
|
|
case OnnxValueType.ONNX_TYPE_MAP:
|
2023-06-29 15:59:23 +00:00
|
|
|
result = CreateMapProjection(namedOnnxValue, meta);
|
2023-04-11 16:41:59 +00:00
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "ManagedTypeProjection can only project tensors, sequences, maps and optional types");
|
|
|
|
|
}
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
|
/// The function creates OrtValue objects for each element of the sequence
|
|
|
|
|
/// and then creates an OrtValue for the whole sequence.
|
|
|
|
|
/// </summary>
|
2023-06-29 15:59:23 +00:00
|
|
|
/// <param name="namedOnnxValue">NamedOnnxValue containing a IEnumerable<NameOnnValue></param>
|
2023-04-11 16:41:59 +00:00
|
|
|
/// <param name="metadata">sequence metadata</param>
|
|
|
|
|
/// <param name="disposables">cleanup list</param>
|
|
|
|
|
/// <returns></returns>
|
|
|
|
|
/// <exception cref="OnnxRuntimeException"></exception>
|
2023-06-29 15:59:23 +00:00
|
|
|
private static OrtValue CreateSequenceProjection(NamedOnnxValue namedOnnxValue, NodeMetadata metadata)
|
2023-04-11 16:41:59 +00:00
|
|
|
{
|
|
|
|
|
var elementMeta = metadata.AsSequenceMetadata().ElementMeta;
|
|
|
|
|
var elementOnnxValue = elementMeta.OnnxValueType;
|
2023-06-29 15:59:23 +00:00
|
|
|
var seqContainer = namedOnnxValue.AsEnumerable<NamedOnnxValue>() ??
|
2023-04-11 16:41:59 +00:00
|
|
|
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
|
|
|
|
|
$"NamedOnnxValue: {namedOnnxValue.Name} sequence does not contain NamedOnnxValue elements");
|
|
|
|
|
int capacity = 0;
|
|
|
|
|
|
2023-06-29 15:59:23 +00:00
|
|
|
if (seqContainer is ICollection<NamedOnnxValue> collection)
|
2023-04-11 16:41:59 +00:00
|
|
|
{
|
2023-06-29 15:59:23 +00:00
|
|
|
capacity = collection.Count;
|
2023-04-11 16:41:59 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Record all the ortValues belonging to the sequence locally
|
2023-06-29 15:59:23 +00:00
|
|
|
using (var sequenceOrtValues = new DisposableList<OrtValue>(capacity))
|
2023-04-11 16:41:59 +00:00
|
|
|
{
|
2023-06-29 15:59:23 +00:00
|
|
|
foreach (var element in seqContainer)
|
2023-04-11 16:41:59 +00:00
|
|
|
{
|
2023-06-29 15:59:23 +00:00
|
|
|
if (elementOnnxValue != element.ValueType)
|
|
|
|
|
{
|
|
|
|
|
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
|
|
|
|
|
$"NamedOnnxValue: {namedOnnxValue.Name} sequence element expected to be {elementOnnxValue}, received {element.ValueType}");
|
|
|
|
|
}
|
2023-04-11 16:41:59 +00:00
|
|
|
|
2023-06-29 15:59:23 +00:00
|
|
|
sequenceOrtValues.Add(CreateProjection(element, elementMeta));
|
|
|
|
|
}
|
|
|
|
|
return OrtValue.CreateSequence(sequenceOrtValues);
|
2023-04-11 16:41:59 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
|
/// Creates map projection. Since we support only primitive types in maps
|
|
|
|
|
/// we map two tensors (keys and values)
|
|
|
|
|
/// </summary>
|
|
|
|
|
/// <param name="node"></param>
|
|
|
|
|
/// <param name="elementMeta"></param>
|
|
|
|
|
/// <param name="disposables"></param>
|
|
|
|
|
/// <returns>OrtValue</returns>
|
|
|
|
|
/// <exception cref="OnnxRuntimeException"></exception>
|
2023-06-29 15:59:23 +00:00
|
|
|
private static OrtValue CreateMapProjection(NamedOnnxValue node, NodeMetadata elementMeta)
|
2023-04-11 16:41:59 +00:00
|
|
|
{
|
2023-06-29 15:59:23 +00:00
|
|
|
MapMetadata mapMeta = elementMeta.AsMapMetadata();
|
2023-04-11 16:41:59 +00:00
|
|
|
Debug.Assert(mapMeta != null);
|
|
|
|
|
// Maps currently support only primitive types expressed as two parallel tensors and not nested Sequences or Maps
|
|
|
|
|
|
|
|
|
|
var mapValuesMeta = mapMeta.ValueMetadata;
|
|
|
|
|
if (mapValuesMeta.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR)
|
|
|
|
|
{
|
|
|
|
|
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
|
|
|
|
|
$"Node: {node.Name} onnxruntime only supports maps with primitive types values");
|
|
|
|
|
}
|
|
|
|
|
|
2023-06-29 15:59:23 +00:00
|
|
|
TensorBase keys = node.GetDictionaryKeys();
|
|
|
|
|
using (OrtValue ortValueKeys = OrtValue.CreateFromTensorObject(keys, out TensorElementType elementTypeKeys))
|
2023-04-11 16:41:59 +00:00
|
|
|
{
|
2023-06-29 15:59:23 +00:00
|
|
|
if (elementTypeKeys != mapMeta.KeyDataType)
|
|
|
|
|
{
|
|
|
|
|
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
|
|
|
|
|
$"Map key data type supplied: {elementTypeKeys} metadata expected: {mapMeta.KeyDataType}");
|
|
|
|
|
}
|
2023-04-11 16:41:59 +00:00
|
|
|
|
2023-06-29 15:59:23 +00:00
|
|
|
TensorBase values = node.GetDictionaryValues();
|
|
|
|
|
using (OrtValue ortValueValues = OrtValue.CreateFromTensorObject(values, out TensorElementType elementTypeValues))
|
|
|
|
|
{
|
|
|
|
|
if (elementTypeValues != mapValuesMeta.ElementDataType)
|
|
|
|
|
{
|
|
|
|
|
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
|
|
|
|
|
$"Map value data type supplied: {elementTypeValues} metadata expected: {mapValuesMeta.ElementDataType}");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Create Map OrtValue
|
|
|
|
|
return OrtValue.CreateMap(ortValueKeys, ortValueValues);
|
|
|
|
|
}
|
2023-04-11 16:41:59 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
|
/// This pins memory that is contained within DenseTensor.
|
|
|
|
|
/// </summary>
|
|
|
|
|
/// <param name="node">NodeOnnxValue containing DenseTensor</param>
|
|
|
|
|
/// <param name="elementMeta"></param>
|
|
|
|
|
/// <param name="disposables">cleanup list</param>
|
|
|
|
|
/// <returns></returns>
|
|
|
|
|
/// <exception cref="OnnxRuntimeException"></exception>
|
2023-06-29 15:59:23 +00:00
|
|
|
private static OrtValue CreateTensorProjection(NamedOnnxValue node, NodeMetadata elementMeta)
|
2023-04-11 16:41:59 +00:00
|
|
|
{
|
2023-06-29 15:59:23 +00:00
|
|
|
if (!(node.Value is TensorBase))
|
2023-04-11 16:41:59 +00:00
|
|
|
{
|
|
|
|
|
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
|
2023-06-29 15:59:23 +00:00
|
|
|
$"NamedOnnxValue contains: {node.Value.GetType()}, expecting a Tensor<T>");
|
2023-04-11 16:41:59 +00:00
|
|
|
}
|
|
|
|
|
|
2023-06-29 15:59:23 +00:00
|
|
|
OrtValue ortValue = OrtValue.CreateFromTensorObject(node.Value as TensorBase, out TensorElementType elementType);
|
|
|
|
|
try
|
2023-04-11 16:41:59 +00:00
|
|
|
{
|
2023-06-29 15:59:23 +00:00
|
|
|
if (elementType != elementMeta.ElementDataType)
|
|
|
|
|
{
|
|
|
|
|
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
|
|
|
|
|
$"Tensor element data type discovered: {elementType} metadata expected: {elementMeta.ElementDataType}");
|
|
|
|
|
}
|
2023-04-11 16:41:59 +00:00
|
|
|
}
|
2023-06-29 15:59:23 +00:00
|
|
|
catch(Exception)
|
2023-04-11 16:41:59 +00:00
|
|
|
{
|
2023-06-29 15:59:23 +00:00
|
|
|
ortValue.Dispose();
|
|
|
|
|
throw;
|
2023-04-11 16:41:59 +00:00
|
|
|
}
|
|
|
|
|
|
2023-06-29 15:59:23 +00:00
|
|
|
return ortValue;
|
2023-04-11 16:41:59 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|