// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
namespace Microsoft.ML.OnnxRuntime
{
///
/// The class helps to feed the NamedOnnxValue as inference input.
/// It projects managed classes to OrtValues so they can be consumed
/// by the native onnxruntime library. if possible, it will avoid copying data.
/// The NamedOnnxValue can be a tensor, sequence or map.
/// For recursive structures, create nested NamedOnnxValue instances.
/// For example, a sequence instance would contain a list of NamedOnnxValue instances
/// that in turn may represent tensors or other ONNX values.
///
internal class ManagedTypeProjection
{
///
/// Dispatches the creation of the projection
///
///
///
/// OrtValye created accoding to the metadata
internal static OrtValue CreateProjection(NamedOnnxValue namedOnnxValue, NodeMetadata metadata)
{
OrtValue result;
NodeMetadata meta = metadata;
// Use element meta to create types
if (metadata.OnnxValueType == OnnxValueType.ONNX_TYPE_OPTIONAL)
{
meta = metadata.AsOptionalMetadata().ElementMeta;
}
if (namedOnnxValue.ValueType != meta.OnnxValueType)
{
throw new OnnxRuntimeException(ErrorCode.RuntimeException,
$"NamedOnnxValue: {namedOnnxValue.Name} has value type: {namedOnnxValue.ValueType}" +
$" expected: {meta.OnnxValueType} after optional type adjustment");
}
switch (namedOnnxValue.ValueType)
{
case OnnxValueType.ONNX_TYPE_TENSOR:
result = CreateTensorProjection(namedOnnxValue, meta);
break;
case OnnxValueType.ONNX_TYPE_SEQUENCE:
result = CreateSequenceProjection(namedOnnxValue, meta);
break;
case OnnxValueType.ONNX_TYPE_MAP:
result = CreateMapProjection(namedOnnxValue, meta);
break;
default:
throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "ManagedTypeProjection can only project tensors, sequences, maps and optional types");
}
return result;
}
///
/// The function creates OrtValue objects for each element of the sequence
/// and then creates an OrtValue for the whole sequence.
///
/// NamedOnnxValue containing a IEnumerable{NamedOnnxValue}
/// sequence metadata
/// OrtValue that represents a sequence
///
private static OrtValue CreateSequenceProjection(NamedOnnxValue namedOnnxValue, NodeMetadata metadata)
{
var elementMeta = metadata.AsSequenceMetadata().ElementMeta;
var elementOnnxValue = elementMeta.OnnxValueType;
var seqContainer = namedOnnxValue.AsEnumerable() ??
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
$"NamedOnnxValue: {namedOnnxValue.Name} sequence does not contain NamedOnnxValue elements");
int capacity = 0;
if (seqContainer is ICollection collection)
{
capacity = collection.Count;
}
DisposableList sequenceOrtValues = new(capacity);
try
{
foreach (var element in seqContainer)
{
if (elementOnnxValue != element.ValueType)
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
$"NamedOnnxValue: {namedOnnxValue.Name} sequence element expected to be {elementOnnxValue}, received {element.ValueType}");
}
sequenceOrtValues.Add(CreateProjection(element, elementMeta));
}
return OrtValue.CreateSequence(ref sequenceOrtValues);
}
catch(Exception)
{
sequenceOrtValues?.Dispose();
throw;
}
}
///
/// Creates map projection. Since we support only primitive types in maps
/// we map two tensors (keys and values)
///
///
///
/// OrtValue
///
private static OrtValue CreateMapProjection(NamedOnnxValue node, NodeMetadata elementMeta)
{
MapMetadata mapMeta = elementMeta.AsMapMetadata();
Debug.Assert(mapMeta != null);
// Maps currently support only primitive types expressed as two parallel tensors and not nested Sequences or Maps
var mapValuesMeta = mapMeta.ValueMetadata;
if (mapValuesMeta.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR)
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
$"Node: {node.Name} onnxruntime only supports maps with primitive types values");
}
Span ortValues = new OrtValue[2];
var disposableGuard = new DisposableArray(ortValues);
try
{
TensorBase keys = node.GetDictionaryKeys();
ortValues[0] = OrtValue.CreateFromTensorObject(keys, out TensorElementType elementTypeKeys);
if (elementTypeKeys != mapMeta.KeyDataType)
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
$"Map key data type supplied: {elementTypeKeys} metadata expected: {mapMeta.KeyDataType}");
}
TensorBase values = node.GetDictionaryValues();
ortValues[1] = OrtValue.CreateFromTensorObject(values, out TensorElementType elementTypeValues);
if (elementTypeValues != mapValuesMeta.ElementDataType)
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
$"Map value data type supplied: {elementTypeValues} metadata expected: {mapValuesMeta.ElementDataType}");
}
// Create Map OrtValue
return OrtValue.CreateMap(ref ortValues[0], ref ortValues[1]);
}
catch (Exception)
{
disposableGuard.Dispose();
throw;
}
}
///
/// This pins memory that is contained within DenseTensor.
///
/// NodeOnnxValue containing DenseTensor
///
///
///
private static OrtValue CreateTensorProjection(NamedOnnxValue node, NodeMetadata elementMeta)
{
if (node.Value is not TensorBase)
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
$"NamedOnnxValue contains: {node.Value.GetType()}, expecting a Tensor");
}
OrtValue ortValue = OrtValue.CreateFromTensorObject(node.Value as TensorBase, out TensorElementType elementType);
try
{
if (elementType != elementMeta.ElementDataType)
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
$"Tensor element data type discovered: {elementType} metadata expected: {elementMeta.ElementDataType}");
}
}
catch (Exception)
{
ortValue.Dispose();
throw;
}
return ortValue;
}
}
}