mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Add Float16 and BFloat16 support to C# API (#5775)
Add Float16 and BFloat16 support.
This commit is contained in:
parent
4d517c68a3
commit
2f35e65135
10 changed files with 452 additions and 178 deletions
|
|
@ -5,7 +5,6 @@ using Microsoft.ML.OnnxRuntime.Tensors;
|
|||
using System;
|
||||
using System.Buffers;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
|
||||
namespace Microsoft.ML.OnnxRuntime
|
||||
{
|
||||
|
|
@ -55,8 +54,8 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
/// tensors, sequences of tensors, sequences and maps
|
||||
/// It extends NamedOnnxValue, exposes the OnnxValueType and Tensor type
|
||||
/// The class must be disposed of.
|
||||
/// It disposes of _ortValueHolder that owns the underlying Ort output value or
|
||||
/// anything that the class that implements that interfaces needs to dispose.
|
||||
/// It disposes of _ortValueHolder that owns the underlying Ort output value and
|
||||
/// anything else that would need to be disposed by the instance of the class.
|
||||
/// Use factory method CreateFromOrtValue to obtain an instance of the class.
|
||||
/// </summary>
|
||||
public class DisposableNamedOnnxValue : NamedOnnxValue, IDisposable
|
||||
|
|
@ -64,6 +63,19 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
private IOrtValueOwner _ortValueHolder;
|
||||
private bool _disposed = false;
|
||||
|
||||
/// <summary>
|
||||
/// Ctor
|
||||
/// </summary>
|
||||
/// <param name="name">Name of the output value</param>
|
||||
/// <param name="value">Managed object created to represent output value, such as DenseTensor<T>
|
||||
/// List or Dictionary
|
||||
/// </param>
|
||||
/// <param name="onnxValueType">Use this to decide what you want to call to fetch data, AsTensor(), AsDictionary()
|
||||
/// or AsEnumerable()</param>
|
||||
/// <param name="elementType">Tensor element type if value type is a Tensor</param>
|
||||
/// <param name="ortValueHolder">Object that holds native resources.
|
||||
/// Typically, this is an output OrtValue that holds native memory where Tensor is mapped but may also be
|
||||
/// other things that would need to be disposed by this instance depending on how IOrtValueOwner is implemented.</param>
|
||||
private DisposableNamedOnnxValue(string name, Object value, OnnxValueType onnxValueType, TensorElementType elementType, IOrtValueOwner ortValueHolder)
|
||||
: base(name, value)
|
||||
{
|
||||
|
|
@ -169,6 +181,12 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
case TensorElementType.Bool:
|
||||
result = DisposableNamedOnnxValueFromNativeTensor<bool>(name, ortValue);
|
||||
break;
|
||||
case TensorElementType.Float16:
|
||||
result = DisposableNamedOnnxValueFromNativeTensor<Float16>(name, ortValue);
|
||||
break;
|
||||
case TensorElementType.BFloat16:
|
||||
result = DisposableNamedOnnxValueFromNativeTensor<BFloat16>(name, ortValue);
|
||||
break;
|
||||
default:
|
||||
throw new NotSupportedException("Tensor of element type: " + elemType + " is not supported");
|
||||
|
||||
|
|
|
|||
|
|
@ -83,60 +83,16 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
{
|
||||
public static void GetTypeAndWidth(TensorElementType elemType, out Type type, out int width)
|
||||
{
|
||||
switch (elemType)
|
||||
TensorElementTypeInfo result = TensorBase.GetElementTypeInfo(elemType);
|
||||
if(result != null)
|
||||
{
|
||||
case TensorElementType.Float:
|
||||
type = typeof(float);
|
||||
width = sizeof(float);
|
||||
break;
|
||||
case TensorElementType.Double:
|
||||
type = typeof(double);
|
||||
width = sizeof(double);
|
||||
break;
|
||||
case TensorElementType.Int16:
|
||||
type = typeof(short);
|
||||
width = sizeof(short);
|
||||
break;
|
||||
case TensorElementType.UInt16:
|
||||
type = typeof(ushort);
|
||||
width = sizeof(ushort);
|
||||
break;
|
||||
case TensorElementType.Int32:
|
||||
type = typeof(int);
|
||||
width = sizeof(int);
|
||||
break;
|
||||
case TensorElementType.UInt32:
|
||||
type = typeof(uint);
|
||||
width = sizeof(uint);
|
||||
break;
|
||||
case TensorElementType.Int64:
|
||||
type = typeof(long);
|
||||
width = sizeof(long);
|
||||
break;
|
||||
case TensorElementType.UInt64:
|
||||
type = typeof(ulong);
|
||||
width = sizeof(ulong);
|
||||
break;
|
||||
case TensorElementType.UInt8:
|
||||
type = typeof(byte);
|
||||
width = sizeof(byte);
|
||||
break;
|
||||
case TensorElementType.Int8:
|
||||
type = typeof(sbyte);
|
||||
width = sizeof(sbyte);
|
||||
break;
|
||||
case TensorElementType.String:
|
||||
type = typeof(string);
|
||||
width = sizeof(byte);
|
||||
break;
|
||||
case TensorElementType.Bool:
|
||||
type = typeof(bool);
|
||||
width = sizeof(bool);
|
||||
break;
|
||||
default:
|
||||
type = null;
|
||||
width = 0;
|
||||
break;
|
||||
type = result.TensorType;
|
||||
width = result.TypeSize;
|
||||
}
|
||||
else
|
||||
{
|
||||
type = null;
|
||||
width = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -208,6 +208,16 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
out memHandle, out dataBufferLength,
|
||||
out shape, out rank);
|
||||
break;
|
||||
case TensorElementType.Float16:
|
||||
PinAsTensor(value as Tensor<Float16>, typeSize,
|
||||
out memHandle, out dataBufferLength,
|
||||
out shape, out rank);
|
||||
break;
|
||||
case TensorElementType.BFloat16:
|
||||
PinAsTensor(value as Tensor<BFloat16>, typeSize,
|
||||
out memHandle, out dataBufferLength,
|
||||
out shape, out rank);
|
||||
break;
|
||||
default:
|
||||
throw new NotSupportedException("Element type: " + elType + " is not of a supported type");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,14 +11,13 @@
|
|||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Collections;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.Text;
|
||||
using System;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Reflection;
|
||||
using System.Text;
|
||||
|
||||
// Making this assembly's internals visible to the internal Test assembly
|
||||
[assembly: InternalsVisibleTo("Microsoft.ML.OnnxRuntime.Tests," +
|
||||
|
|
@ -55,28 +54,168 @@ namespace Microsoft.ML.OnnxRuntime.Tensors
|
|||
DataTypeMax = 17
|
||||
}
|
||||
|
||||
[StructLayout(LayoutKind.Sequential)]
|
||||
internal struct Float16
|
||||
/// <summary>
|
||||
/// This value type represents A Float16 value
|
||||
/// it is blittable as defined in https://docs.microsoft.com/en-us/dotnet/framework/interop/blittable-and-non-blittable-types
|
||||
/// and as such, represented the same way in managed and native memories. This means that arrays of this type
|
||||
/// do not have to be copied to be passed to native memory but simply pinnned and read by native code. Thus,
|
||||
/// one can create a Tensor on top of an array of these structures and feed it directly to Onnxruntime library.
|
||||
/// Binary wise, it is the same as ushort[] (uint16_t in C++). However, we would like a separate type for type dispatching.
|
||||
/// </summary>
|
||||
public struct Float16
|
||||
{
|
||||
public ushort Value { get; private set; }
|
||||
public Float16(ushort val)
|
||||
public ushort value;
|
||||
/// <summary>
|
||||
/// Ctor
|
||||
/// </summary>
|
||||
/// <param name="v"></param>
|
||||
public Float16(ushort v)
|
||||
{
|
||||
Value = val;
|
||||
value = v;
|
||||
}
|
||||
}
|
||||
|
||||
[StructLayout(LayoutKind.Sequential)]
|
||||
internal struct BFloat16
|
||||
{
|
||||
public ushort Value { get; private set; }
|
||||
public BFloat16(ushort val)
|
||||
/// <summary>
|
||||
/// Converts to ushort
|
||||
/// </summary>
|
||||
/// <param name="f">instance of Float16</param>
|
||||
public static implicit operator ushort (Float16 f) { return f.value; }
|
||||
/// <summary>
|
||||
/// Converts a 16-bit unsigned integer to a Float16.
|
||||
/// </summary>
|
||||
/// <param name="value">A 16-bit unsigned integer.</param>
|
||||
/// <returns>A Float16 that represents the converted 16-bit unsigned integer.</returns>
|
||||
public static implicit operator Float16(ushort value) { return new Float16(value); }
|
||||
/// <summary>
|
||||
/// Compares values of two Float16 for binary equality
|
||||
/// </summary>
|
||||
/// <param name="lhs"></param>
|
||||
/// <param name="rhs"></param>
|
||||
/// <returns>result of value comparisons</returns>
|
||||
public static bool operator ==(Float16 lhs, Float16 rhs) { return lhs.value == rhs.value; }
|
||||
/// <summary>
|
||||
/// Compares values of two Float16 for binary inequality
|
||||
/// </summary>
|
||||
/// <param name="lhs"></param>
|
||||
/// <param name="rhs"></param>
|
||||
/// <returns>result of value comparisons</returns>
|
||||
public static bool operator !=(Float16 lhs, Float16 rhs) { return lhs.value != rhs.value; }
|
||||
/// <summary>
|
||||
/// Returns a value indicating whether this instance and other Float16 represent the same value.
|
||||
/// </summary>
|
||||
/// <param name="other">A Float16 object to compare to this instance.</param>
|
||||
/// <returns>true if other.value is equal to this instance; otherwise, false.</returns>
|
||||
public bool Equals(Float16 other)
|
||||
{
|
||||
Value = val;
|
||||
return (other == this);
|
||||
}
|
||||
/// <summary>
|
||||
/// Returns a value indicating whether this instance and a specified System.Object
|
||||
/// represent the same type and value.
|
||||
/// </summary>
|
||||
/// <param name="obj">An System.Object.</param>
|
||||
/// <returns>true if obj is Float16 and its value is equal to this instance; otherwise, false.</returns>
|
||||
public override bool Equals(object obj)
|
||||
{
|
||||
bool result = false;
|
||||
if (obj is Float16)
|
||||
{
|
||||
Float16 fl16 = (Float16)obj;
|
||||
result = (fl16 == this);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
/// <summary>
|
||||
/// Returns the hash code for this instance.
|
||||
/// </summary>
|
||||
/// <returns>A 32-bit signed integer hash code.</returns>
|
||||
public override int GetHashCode()
|
||||
{
|
||||
return value.GetHashCode();
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Helps typecasting. Holds primitive type information
|
||||
/// This value type represents A BFloat16 value
|
||||
/// it is blittable as defined in https://docs.microsoft.com/en-us/dotnet/framework/interop/blittable-and-non-blittable-types
|
||||
/// and as such, represented the same way in managed and native memories. This means that arrays of this type
|
||||
/// do not have to be copied to be passed to native memory but simply pinnned and read by native code. Thus,
|
||||
/// one can create a Tensor on top of an array of these structures and feed it directly to Onnxruntime library.
|
||||
/// Binary wise, it is the same as ushort[] (uint16_t in C++). However, we would like a separate type for type dispatching.
|
||||
/// </summary>
|
||||
public struct BFloat16
|
||||
{
|
||||
public ushort value;
|
||||
/// <summary>
|
||||
/// Ctor
|
||||
/// </summary>
|
||||
/// <param name="v"></param>
|
||||
public BFloat16(ushort v)
|
||||
{
|
||||
value = v;
|
||||
}
|
||||
/// <summary>
|
||||
/// Converts to ushort
|
||||
/// </summary>
|
||||
/// <param name="bf">instance of BFloat16</param>
|
||||
public static implicit operator ushort(BFloat16 bf) { return bf.value; }
|
||||
/// <summary>
|
||||
/// Converts a 16-bit unsigned integer to a BFloat16.
|
||||
/// </summary>
|
||||
/// <param name="value">A 16-bit unsigned integer.</param>
|
||||
/// <returns>A BFloat16 that represents the converted 16-bit unsigned integer.</returns>
|
||||
public static implicit operator BFloat16(ushort value) { return new BFloat16(value); }
|
||||
/// <summary>
|
||||
/// Compares values of two BFloat16 for binary equality
|
||||
/// </summary>
|
||||
/// <param name="lhs"></param>
|
||||
/// <param name="rhs"></param>
|
||||
/// <returns>result of value comparisons</returns>
|
||||
public static bool operator ==(BFloat16 lhs, BFloat16 rhs) { return lhs.value == rhs.value; }
|
||||
/// <summary>
|
||||
/// Compares values of two BFloat16 for binary inequality
|
||||
/// </summary>
|
||||
/// <param name="lhs"></param>
|
||||
/// <param name="rhs"></param>
|
||||
/// <returns>result of value comparisons</returns>
|
||||
public static bool operator !=(BFloat16 lhs, BFloat16 rhs) { return lhs.value != rhs.value; }
|
||||
|
||||
/// <summary>
|
||||
/// Returns a value indicating whether this instance and other BFloat16 represent the same value.
|
||||
/// </summary>
|
||||
/// <param name="other">A BFloat16 object to compare to this instance.</param>
|
||||
/// <returns>true if other.value is equal to this instance; otherwise, false.</returns>
|
||||
public bool Equals(BFloat16 other)
|
||||
{
|
||||
return (other == this);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns a value indicating whether this instance and a specified System.Object
|
||||
/// represent the same type and value.
|
||||
/// </summary>
|
||||
/// <param name="obj">An System.Object.</param>
|
||||
/// <returns>true if obj is BFloat16 its value is equal to this instance; otherwise, false.</returns>
|
||||
public override bool Equals(object obj)
|
||||
{
|
||||
bool result = false;
|
||||
if (obj is BFloat16)
|
||||
{
|
||||
BFloat16 bfl16 = (BFloat16)obj;
|
||||
result = (bfl16 == this);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
/// <summary>
|
||||
/// Returns the hash code for this instance.
|
||||
/// </summary>
|
||||
/// <returns>A 32-bit signed integer hash code.</returns>
|
||||
public override int GetHashCode()
|
||||
{
|
||||
return value.GetHashCode();
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Helps typecasting. Holds Tensor element type traits.
|
||||
/// </summary>
|
||||
public class TensorTypeInfo
|
||||
{
|
||||
|
|
@ -90,10 +229,33 @@ namespace Microsoft.ML.OnnxRuntime.Tensors
|
|||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Holds TensorElement traits
|
||||
/// </summary>
|
||||
public class TensorElementTypeInfo
|
||||
{
|
||||
public Type TensorType { get; private set; }
|
||||
public int TypeSize { get; private set; }
|
||||
public bool IsString { get; private set; }
|
||||
public TensorElementTypeInfo(Type type, int typeSize)
|
||||
{
|
||||
TensorType = type;
|
||||
TypeSize = typeSize;
|
||||
IsString = type == typeof(string);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// This class is a base for all Tensors. It hosts maps with type traits.
|
||||
/// </summary>
|
||||
public class TensorBase
|
||||
{
|
||||
private static readonly Dictionary<Type, TensorTypeInfo> typeInfoMap =
|
||||
new Dictionary<Type, TensorTypeInfo>()
|
||||
private static readonly Dictionary<Type, TensorTypeInfo> typeInfoMap;
|
||||
|
||||
private static readonly Dictionary<TensorElementType, TensorElementTypeInfo> tensorElementTypeInfoMap;
|
||||
|
||||
static TensorBase () {
|
||||
typeInfoMap = new Dictionary<Type, TensorTypeInfo>()
|
||||
{
|
||||
{ typeof(float), new TensorTypeInfo( TensorElementType.Float, sizeof(float)) },
|
||||
{ typeof(byte), new TensorTypeInfo( TensorElementType.UInt8, sizeof(byte)) },
|
||||
|
|
@ -111,20 +273,57 @@ namespace Microsoft.ML.OnnxRuntime.Tensors
|
|||
{ typeof(BFloat16), new TensorTypeInfo( TensorElementType.BFloat16, sizeof(ushort)) }
|
||||
};
|
||||
|
||||
tensorElementTypeInfoMap = new Dictionary<TensorElementType, TensorElementTypeInfo>();
|
||||
foreach(var info in typeInfoMap)
|
||||
{
|
||||
tensorElementTypeInfoMap.Add(info.Value.ElementType, new TensorElementTypeInfo(info.Key, info.Value.TypeSize));
|
||||
}
|
||||
}
|
||||
|
||||
private readonly Type _primitiveType;
|
||||
protected TensorBase(Type primitiveType)
|
||||
{
|
||||
// Should hold as we rely on this to pass arrays of these
|
||||
// types to native code
|
||||
unsafe
|
||||
{
|
||||
Debug.Assert(sizeof(ushort) == sizeof(Float16));
|
||||
Debug.Assert(sizeof(ushort) == sizeof(BFloat16));
|
||||
}
|
||||
_primitiveType = primitiveType;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Queries the map returns result or null
|
||||
/// Query TensorTypeInfo by one of the supported types
|
||||
/// </summary>
|
||||
/// <param name="type"></param>
|
||||
/// <returns>TensorTypeInfo or null if not supported</returns>
|
||||
public static TensorTypeInfo GetTypeInfo(Type type)
|
||||
{
|
||||
TensorTypeInfo result = null;
|
||||
typeInfoMap.TryGetValue(type, out result);
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Query TensorElementTypeInfo by enum
|
||||
/// </summary>
|
||||
/// <param name="elementType">type enum</param>
|
||||
/// <returns>instance of TensorElementTypeInfo or null if not found</returns>
|
||||
public static TensorElementTypeInfo GetElementTypeInfo(TensorElementType elementType)
|
||||
{
|
||||
TensorElementTypeInfo result = null;
|
||||
tensorElementTypeInfoMap.TryGetValue(elementType, out result);
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Query TensorTypeInfo using this Tensor type
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
public TensorTypeInfo GetTypeInfo()
|
||||
{
|
||||
TensorTypeInfo result = null;
|
||||
typeInfoMap.TryGetValue(_primitiveType, out result);
|
||||
return result;
|
||||
return GetTypeInfo(_primitiveType);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -312,6 +511,14 @@ namespace Microsoft.ML.OnnxRuntime.Tensors
|
|||
{
|
||||
return (T)(object)(ushort)(0);
|
||||
}
|
||||
else if (typeof(T) == typeof(Float16))
|
||||
{
|
||||
return (T)(object)(ushort)(0);
|
||||
}
|
||||
else if (typeof(T) == typeof(BFloat16))
|
||||
{
|
||||
return (T)(object)(ushort)(0);
|
||||
}
|
||||
|
||||
throw new NotSupportedException();
|
||||
}
|
||||
|
|
@ -372,6 +579,14 @@ namespace Microsoft.ML.OnnxRuntime.Tensors
|
|||
else if (typeof(T) == typeof(ushort))
|
||||
{
|
||||
return (T)(object)(ushort)(1);
|
||||
}
|
||||
else if(typeof(T) == typeof(Float16))
|
||||
{
|
||||
return (T)(object)(ushort)(15360);
|
||||
}
|
||||
else if (typeof(T) == typeof(BFloat16))
|
||||
{
|
||||
return (T)(object)(ushort)(16256);
|
||||
}
|
||||
|
||||
throw new NotSupportedException();
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
using Microsoft.ML.OnnxRuntime.Tensors;
|
||||
using Microsoft.VisualBasic.CompilerServices;
|
||||
using Onnx;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
|
|
@ -676,15 +678,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
var skipModels = new Dictionary<string, string>() {
|
||||
{ "mxnet_arcface", "Model is an invalid ONNX model"},
|
||||
{ "tf_inception_v2", "TODO: Debug failing model, skipping for now" },
|
||||
{ "fp16_inception_v1", "16-bit float not supported type in C#." },
|
||||
{ "fp16_shufflenet", "16-bit float not supported type in C#." },
|
||||
{ "fp16_tiny_yolov2", "16-bit float not supported type in C#." },
|
||||
{ "fp16_coreml_FNS-Candy", "16-bit float not supported type in C#." },
|
||||
{ "test_mnist", "16-bit float not supported type in C#." },
|
||||
{ "fp16_test_shufflenet", "16-bit float not supported type in C#." },
|
||||
{ "fp16_coreml_LinearRegression_NYCTaxi", "16-bit float not supported type in C#." },
|
||||
{ "test_bidaf", "16-bit float not supported type in C#." },
|
||||
{ "fp16_test_tiny_yolov2", "16-bit float not supported type in C#." },
|
||||
{ "fp16_tiny_yolov2", "Tolerance level for float16 is not known. We now support fp16." },
|
||||
{ "test_bidaf", "Does not run in opset9, runs in other opsets. Tensors of type ElementType not currently supported in the LoadTensorFromFile." },
|
||||
{ "test_mnist", "Does not run in opset9, runs in other opsets. Tensors of type ElementType not currently supported in the LoadTensorFromFile" },
|
||||
{ "BERT_Squad", "Could not find an implementation for the node bert / embeddings / one_hot:OneHot(9)" },
|
||||
{ "mlperf_ssd_mobilenet_300", "Could not find file output_0.pb" },
|
||||
{ "tf_resnet_v1_50", "result mismatch when Conv BN Fusion is applied" },
|
||||
|
|
@ -826,7 +822,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
string testDataDirNamePattern = "test_data*";
|
||||
if (opset == "opset9" && modelName == "LSTM_Seq_lens_unpacked")
|
||||
{
|
||||
testDataDirNamePattern = "seq_lens*"; // discrepency in data directory
|
||||
testDataDirNamePattern = "seq_lens*"; // discrepancy in data directory
|
||||
}
|
||||
foreach (var testDataDir in modelDir.EnumerateDirectories(testDataDirNamePattern))
|
||||
{
|
||||
|
|
@ -898,6 +894,14 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
{
|
||||
Assert.Equal(result.AsTensor<bool>(), outputValue.AsTensor<bool>(), new ExactComparer<bool>());
|
||||
}
|
||||
else if (outputMeta.ElementType == typeof(Float16))
|
||||
{
|
||||
Assert.Equal(result.AsTensor<Float16>(), outputValue.AsTensor<Float16>(), new Float16Comparer { tolerance = 2 });
|
||||
}
|
||||
else if (outputMeta.ElementType == typeof(BFloat16))
|
||||
{
|
||||
Assert.Equal(result.AsTensor<BFloat16>(), outputValue.AsTensor<BFloat16>(), new BFloat16Comparer { tolerance = 2 });
|
||||
}
|
||||
else
|
||||
{
|
||||
Assert.True(false, "The TestPretrainedModels does not yet support output of type " + nameof(outputMeta.ElementType));
|
||||
|
|
@ -1520,20 +1524,47 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact(Skip = "FLOAT16 not available in C#")]
|
||||
[Fact]
|
||||
private void TestModelInputFLOAT16()
|
||||
{
|
||||
// model takes 1x5 input of fixed type, echoes back
|
||||
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_types_FLOAT16.pb");
|
||||
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_types_FLOAT16.onnx");
|
||||
using (var session = new InferenceSession(modelPath))
|
||||
{
|
||||
var container = new List<NamedOnnxValue>();
|
||||
var tensorIn = new DenseTensor<float>(new float[] { 1.0f, 2.0f, -3.0f, float.MinValue, float.MaxValue }, new int[] { 1, 5 });
|
||||
var tensorIn = new DenseTensor<Float16>(
|
||||
new Float16[] { 15360, 16384, 16896, 17408, 17664 }, new int[] { 1, 5 });
|
||||
var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn);
|
||||
container.Add(nov);
|
||||
using (var res = session.Run(container))
|
||||
{
|
||||
var tensorOut = res.First().AsTensor<float>();
|
||||
var valueOut = res.First();
|
||||
Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, valueOut.ValueType);
|
||||
Assert.Equal(Tensors.TensorElementType.Float16, valueOut.ElementType);
|
||||
var tensorOut = res.First().AsTensor<Float16>();
|
||||
Assert.True(tensorOut.SequenceEqual(tensorIn));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
private void TestModelInputBFLOAT16()
|
||||
{
|
||||
// model takes 1x5 input of fixed type, echoes back
|
||||
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_types_BFLOAT16.onnx");
|
||||
using (var session = new InferenceSession(modelPath))
|
||||
{
|
||||
var container = new List<NamedOnnxValue>();
|
||||
var tensorIn = new DenseTensor<BFloat16>(
|
||||
new BFloat16[] { 16256, 16384, 16448, 16512, 16544 }, new int[] { 1, 5 });
|
||||
var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn);
|
||||
container.Add(nov);
|
||||
using (var res = session.Run(container))
|
||||
{
|
||||
var valueOut = res.First();
|
||||
Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, valueOut.ValueType);
|
||||
Assert.Equal(Tensors.TensorElementType.BFloat16, valueOut.ElementType);
|
||||
var tensorOut = res.First().AsTensor<BFloat16>();
|
||||
Assert.True(tensorOut.SequenceEqual(tensorIn));
|
||||
}
|
||||
}
|
||||
|
|
@ -1580,7 +1611,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
var outNode0 = outputs.ElementAtOrDefault(0);
|
||||
Assert.Equal("label", outNode0.Name);
|
||||
Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, outNode0.ValueType);
|
||||
Assert.Equal(TensorElementType.Int64, (TensorElementType)outNode0.ElementType);
|
||||
Assert.Equal(Tensors.TensorElementType.Int64, outNode0.ElementType);
|
||||
|
||||
// try-cast as a tensor
|
||||
var outLabelTensor = outNode0.AsTensor<long>();
|
||||
|
|
@ -2151,86 +2182,21 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
return tensorData.ToArray();
|
||||
}
|
||||
|
||||
|
||||
private enum TensorElementType
|
||||
private static void GetTypeAndWidth(Tensors.TensorElementType elemType, out Type type, out int width)
|
||||
{
|
||||
Float = 1,
|
||||
UInt8 = 2,
|
||||
Int8 = 3,
|
||||
UInt16 = 4,
|
||||
Int16 = 5,
|
||||
Int32 = 6,
|
||||
Int64 = 7,
|
||||
String = 8,
|
||||
Bool = 9,
|
||||
Float16 = 10,
|
||||
Double = 11,
|
||||
UInt32 = 12,
|
||||
UInt64 = 13,
|
||||
Complex64 = 14,
|
||||
Complex128 = 15,
|
||||
BFloat16 = 16,
|
||||
DataTypeMax = 17
|
||||
}
|
||||
|
||||
private static void GetTypeAndWidth(TensorElementType elemType, out Type type, out int width)
|
||||
{
|
||||
switch (elemType)
|
||||
TensorElementTypeInfo result = TensorBase.GetElementTypeInfo(elemType);
|
||||
if (result != null)
|
||||
{
|
||||
case TensorElementType.Float:
|
||||
type = typeof(float);
|
||||
width = sizeof(float);
|
||||
break;
|
||||
case TensorElementType.Double:
|
||||
type = typeof(double);
|
||||
width = sizeof(double);
|
||||
break;
|
||||
case TensorElementType.Int16:
|
||||
type = typeof(short);
|
||||
width = sizeof(short);
|
||||
break;
|
||||
case TensorElementType.UInt16:
|
||||
type = typeof(ushort);
|
||||
width = sizeof(ushort);
|
||||
break;
|
||||
case TensorElementType.Int32:
|
||||
type = typeof(int);
|
||||
width = sizeof(int);
|
||||
break;
|
||||
case TensorElementType.UInt32:
|
||||
type = typeof(uint);
|
||||
width = sizeof(uint);
|
||||
break;
|
||||
case TensorElementType.Int64:
|
||||
type = typeof(long);
|
||||
width = sizeof(long);
|
||||
break;
|
||||
case TensorElementType.UInt64:
|
||||
type = typeof(ulong);
|
||||
width = sizeof(ulong);
|
||||
break;
|
||||
case TensorElementType.UInt8:
|
||||
type = typeof(byte);
|
||||
width = sizeof(byte);
|
||||
break;
|
||||
case TensorElementType.Int8:
|
||||
type = typeof(sbyte);
|
||||
width = sizeof(sbyte);
|
||||
break;
|
||||
case TensorElementType.String:
|
||||
type = typeof(byte);
|
||||
width = sizeof(byte);
|
||||
break;
|
||||
case TensorElementType.Bool:
|
||||
type = typeof(bool);
|
||||
width = sizeof(bool);
|
||||
break;
|
||||
default:
|
||||
type = null;
|
||||
width = 0;
|
||||
break;
|
||||
type = result.TensorType;
|
||||
width = result.TypeSize;
|
||||
}
|
||||
else
|
||||
{
|
||||
type = null;
|
||||
width = 0;
|
||||
}
|
||||
}
|
||||
|
||||
static NamedOnnxValue LoadTensorFromFilePb(string filename, IReadOnlyDictionary<string, NodeMetadata> nodeMetaDict)
|
||||
{
|
||||
//Set buffer size to 4MB
|
||||
|
|
@ -2243,7 +2209,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
|
||||
Type tensorElemType = null;
|
||||
int width = 0;
|
||||
GetTypeAndWidth((TensorElementType)tensor.DataType, out tensorElemType, out width);
|
||||
GetTypeAndWidth((Tensors.TensorElementType)tensor.DataType, out tensorElemType, out width);
|
||||
var intDims = new int[tensor.Dims.Count];
|
||||
for (int i = 0; i < tensor.Dims.Count; i++)
|
||||
{
|
||||
|
|
@ -2251,7 +2217,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
}
|
||||
|
||||
NodeMetadata nodeMeta = null;
|
||||
string nodeName = "";
|
||||
string nodeName = string.Empty;
|
||||
if (nodeMetaDict.Count == 1)
|
||||
{
|
||||
nodeMeta = nodeMetaDict.Values.First();
|
||||
|
|
@ -2259,7 +2225,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
}
|
||||
else if (nodeMetaDict.Count > 1)
|
||||
{
|
||||
if (tensor.Name != "")
|
||||
if (tensor.Name.Length > 0)
|
||||
{
|
||||
nodeMeta = nodeMetaDict[tensor.Name];
|
||||
nodeName = tensor.Name;
|
||||
|
|
@ -2353,6 +2319,14 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
{
|
||||
return CreateNamedOnnxValueFromRawData<bool>(nodeName, tensor.RawData.ToArray(), sizeof(bool), intDims);
|
||||
}
|
||||
else if (nodeMeta.ElementType == typeof(Float16))
|
||||
{
|
||||
return CreateNamedOnnxValueFromRawData<Float16>(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims);
|
||||
}
|
||||
else if (nodeMeta.ElementType == typeof(BFloat16))
|
||||
{
|
||||
return CreateNamedOnnxValueFromRawData<BFloat16>(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims);
|
||||
}
|
||||
else
|
||||
{
|
||||
//TODO: Add support for remaining types
|
||||
|
|
@ -2363,9 +2337,24 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
|
||||
static NamedOnnxValue CreateNamedOnnxValueFromRawData<T>(string name, byte[] rawData, int elemWidth, int[] dimensions)
|
||||
{
|
||||
T[] floatArr = new T[rawData.Length / elemWidth];
|
||||
Buffer.BlockCopy(rawData, 0, floatArr, 0, rawData.Length);
|
||||
var dt = new DenseTensor<T>(floatArr, dimensions);
|
||||
T[] typedArr = new T[rawData.Length / elemWidth];
|
||||
var typeOf = typeof(T);
|
||||
if(typeOf == typeof(Float16) || typeOf == typeof(BFloat16))
|
||||
{
|
||||
using (var memSrcHandle = new Memory<byte>(rawData).Pin())
|
||||
using (var memDstHandle = new Memory<T>(typedArr).Pin())
|
||||
{
|
||||
unsafe
|
||||
{
|
||||
Buffer.MemoryCopy(memSrcHandle.Pointer, memDstHandle.Pointer, typedArr.Length * elemWidth, rawData.Length);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
Buffer.BlockCopy(rawData, 0, typedArr, 0, rawData.Length);
|
||||
}
|
||||
var dt = new DenseTensor<T>(typedArr, dimensions);
|
||||
return NamedOnnxValue.CreateFromTensor<T>(name, dt);
|
||||
}
|
||||
|
||||
|
|
@ -2430,7 +2419,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
}
|
||||
public int GetHashCode(float x)
|
||||
{
|
||||
return 0;
|
||||
return x.GetHashCode();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2442,7 +2431,36 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
}
|
||||
public int GetHashCode(T x)
|
||||
{
|
||||
return 0;
|
||||
return x.GetHashCode();
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Use it to compare Float16 and BFloat16
|
||||
/// </summary>
|
||||
internal class Float16Comparer : IEqualityComparer<Float16>
|
||||
{
|
||||
public ushort tolerance;
|
||||
public bool Equals(Float16 x, Float16 y)
|
||||
{
|
||||
return Math.Abs(x - y) <= (tolerance + y);
|
||||
}
|
||||
public int GetHashCode(Float16 x)
|
||||
{
|
||||
return x.GetHashCode();
|
||||
}
|
||||
}
|
||||
|
||||
internal class BFloat16Comparer : IEqualityComparer<BFloat16>
|
||||
{
|
||||
public ushort tolerance;
|
||||
public bool Equals(BFloat16 x, BFloat16 y)
|
||||
{
|
||||
return Math.Abs(x - y) <= (tolerance + y);
|
||||
}
|
||||
public int GetHashCode(BFloat16 x)
|
||||
{
|
||||
return x.GetHashCode();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
27
csharp/testdata/test_input_BFLOAT16.py
vendored
Normal file
27
csharp/testdata/test_input_BFLOAT16.py
vendored
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import onnx
|
||||
from onnx import helper
|
||||
from onnx.helper import make_opsetid
|
||||
from onnx import AttributeProto, TensorProto, GraphProto
|
||||
|
||||
input_info = helper.make_tensor_value_info('input', TensorProto.BFLOAT16, [1, 5])
|
||||
output_info = helper.make_tensor_value_info('output', TensorProto.BFLOAT16, [1, 5])
|
||||
|
||||
# Create a node (NodeProto) - This is based on Pad-11
|
||||
node_def = helper.make_node(
|
||||
'Identity', # node name
|
||||
['input'], # inputs
|
||||
['output'] # outputs
|
||||
)
|
||||
|
||||
graph_def = helper.make_graph(nodes=[node_def], name='test_types_BLOAT16',
|
||||
inputs=[input_info], outputs=[output_info])
|
||||
|
||||
model_def = helper.make_model(graph_def, producer_name='AIInfra',
|
||||
opset_imports=[make_opsetid('', 13)])
|
||||
|
||||
final_model = onnx.utils.polish_model(model_def)
|
||||
onnx.save(final_model, 'test_types_BFLOAT16.onnx')
|
||||
|
||||
30
csharp/testdata/test_input_FLOAT16.py
vendored
Normal file
30
csharp/testdata/test_input_FLOAT16.py
vendored
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import onnx
|
||||
from onnx import helper
|
||||
from onnx.helper import make_opsetid
|
||||
from onnx import AttributeProto, TensorProto, GraphProto
|
||||
|
||||
input_info = helper.make_tensor_value_info('input', TensorProto.FLOAT16, [1, 5])
|
||||
output_info = helper.make_tensor_value_info('output', TensorProto.FLOAT16, [1, 5])
|
||||
|
||||
# Create a node (NodeProto) - This is based on Pad-11
|
||||
node_def = helper.make_node(
|
||||
'Slice', # node name
|
||||
['input'], # inputs
|
||||
['output'], # outputs
|
||||
axes=[0,1], # attributes
|
||||
ends=[1,5],
|
||||
starts=[0,0]
|
||||
)
|
||||
|
||||
graph_def = helper.make_graph(nodes=[node_def], name='test_input_FLOAT16',
|
||||
inputs=[input_info], outputs=[output_info])
|
||||
|
||||
model_def = helper.make_model(graph_def, producer_name='AIInfra',
|
||||
opset_imports=[make_opsetid('', 7)])
|
||||
|
||||
final_model = onnx.utils.polish_model(model_def)
|
||||
onnx.save(final_model, 'test_types_FLOAT16.onnx')
|
||||
|
||||
BIN
csharp/testdata/test_types_BFLOAT16.onnx
vendored
Normal file
BIN
csharp/testdata/test_types_BFLOAT16.onnx
vendored
Normal file
Binary file not shown.
BIN
csharp/testdata/test_types_FLOAT16.onnx
vendored
Normal file
BIN
csharp/testdata/test_types_FLOAT16.onnx
vendored
Normal file
Binary file not shown.
BIN
csharp/testdata/test_types_FLOAT16.pb
vendored
BIN
csharp/testdata/test_types_FLOAT16.pb
vendored
Binary file not shown.
Loading…
Reference in a new issue