Add Float16 and BFloat16 support to C# API (#5775)

Add Float16 and BFloat16 support.
This commit is contained in:
Dmitri Smirnov 2020-11-12 17:57:08 -08:00 committed by GitHub
parent 4d517c68a3
commit 2f35e65135
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 452 additions and 178 deletions

View file

@ -5,7 +5,6 @@ using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
namespace Microsoft.ML.OnnxRuntime
{
@ -55,8 +54,8 @@ namespace Microsoft.ML.OnnxRuntime
/// tensors, sequences of tensors, sequences and maps
/// It extends NamedOnnxValue, exposes the OnnxValueType and Tensor type
/// The class must be disposed of.
/// It disposes of _ortValueHolder that owns the underlying Ort output value or
/// anything that the class that implements that interfaces needs to dispose.
/// It disposes of _ortValueHolder that owns the underlying Ort output value and
/// anything else that would need to be disposed by the instance of the class.
/// Use factory method CreateFromOrtValue to obtain an instance of the class.
/// </summary>
public class DisposableNamedOnnxValue : NamedOnnxValue, IDisposable
@ -64,6 +63,19 @@ namespace Microsoft.ML.OnnxRuntime
private IOrtValueOwner _ortValueHolder;
private bool _disposed = false;
/// <summary>
/// Ctor
/// </summary>
/// <param name="name">Name of the output value</param>
/// <param name="value">Managed object created to represent output value, such as DenseTensor<T>
/// List or Dictionary
/// </param>
/// <param name="onnxValueType">Use this to decide what you want to call to fetch data, AsTensor(), AsDictionary()
/// or AsEnumerable()</param>
/// <param name="elementType">Tensor element type if value type is a Tensor</param>
/// <param name="ortValueHolder">Object that holds native resources.
/// Typically, this is an output OrtValue that holds native memory where Tensor is mapped but may also be
/// other things that would need to be disposed by this instance depending on how IOrtValueOwner is implemented.</param>
private DisposableNamedOnnxValue(string name, Object value, OnnxValueType onnxValueType, TensorElementType elementType, IOrtValueOwner ortValueHolder)
: base(name, value)
{
@ -169,6 +181,12 @@ namespace Microsoft.ML.OnnxRuntime
case TensorElementType.Bool:
result = DisposableNamedOnnxValueFromNativeTensor<bool>(name, ortValue);
break;
case TensorElementType.Float16:
result = DisposableNamedOnnxValueFromNativeTensor<Float16>(name, ortValue);
break;
case TensorElementType.BFloat16:
result = DisposableNamedOnnxValueFromNativeTensor<BFloat16>(name, ortValue);
break;
default:
throw new NotSupportedException("Tensor of element type: " + elemType + " is not supported");

View file

@ -83,60 +83,16 @@ namespace Microsoft.ML.OnnxRuntime
{
public static void GetTypeAndWidth(TensorElementType elemType, out Type type, out int width)
{
switch (elemType)
TensorElementTypeInfo result = TensorBase.GetElementTypeInfo(elemType);
if(result != null)
{
case TensorElementType.Float:
type = typeof(float);
width = sizeof(float);
break;
case TensorElementType.Double:
type = typeof(double);
width = sizeof(double);
break;
case TensorElementType.Int16:
type = typeof(short);
width = sizeof(short);
break;
case TensorElementType.UInt16:
type = typeof(ushort);
width = sizeof(ushort);
break;
case TensorElementType.Int32:
type = typeof(int);
width = sizeof(int);
break;
case TensorElementType.UInt32:
type = typeof(uint);
width = sizeof(uint);
break;
case TensorElementType.Int64:
type = typeof(long);
width = sizeof(long);
break;
case TensorElementType.UInt64:
type = typeof(ulong);
width = sizeof(ulong);
break;
case TensorElementType.UInt8:
type = typeof(byte);
width = sizeof(byte);
break;
case TensorElementType.Int8:
type = typeof(sbyte);
width = sizeof(sbyte);
break;
case TensorElementType.String:
type = typeof(string);
width = sizeof(byte);
break;
case TensorElementType.Bool:
type = typeof(bool);
width = sizeof(bool);
break;
default:
type = null;
width = 0;
break;
type = result.TensorType;
width = result.TypeSize;
}
else
{
type = null;
width = 0;
}
}
}

View file

@ -208,6 +208,16 @@ namespace Microsoft.ML.OnnxRuntime
out memHandle, out dataBufferLength,
out shape, out rank);
break;
case TensorElementType.Float16:
PinAsTensor(value as Tensor<Float16>, typeSize,
out memHandle, out dataBufferLength,
out shape, out rank);
break;
case TensorElementType.BFloat16:
PinAsTensor(value as Tensor<BFloat16>, typeSize,
out memHandle, out dataBufferLength,
out shape, out rank);
break;
default:
throw new NotSupportedException("Element type: " + elType + " is not of a supported type");
}

View file

@ -11,14 +11,13 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Reflection;
using System.Text;
// Making this assembly's internals visible to the internal Test assembly
[assembly: InternalsVisibleTo("Microsoft.ML.OnnxRuntime.Tests," +
@ -55,28 +54,168 @@ namespace Microsoft.ML.OnnxRuntime.Tensors
DataTypeMax = 17
}
[StructLayout(LayoutKind.Sequential)]
internal struct Float16
/// <summary>
/// This value type represents A Float16 value
/// it is blittable as defined in https://docs.microsoft.com/en-us/dotnet/framework/interop/blittable-and-non-blittable-types
/// and as such, represented the same way in managed and native memories. This means that arrays of this type
/// do not have to be copied to be passed to native memory but simply pinnned and read by native code. Thus,
/// one can create a Tensor on top of an array of these structures and feed it directly to Onnxruntime library.
/// Binary wise, it is the same as ushort[] (uint16_t in C++). However, we would like a separate type for type dispatching.
/// </summary>
public struct Float16
{
public ushort Value { get; private set; }
public Float16(ushort val)
public ushort value;
/// <summary>
/// Ctor
/// </summary>
/// <param name="v"></param>
public Float16(ushort v)
{
Value = val;
value = v;
}
}
[StructLayout(LayoutKind.Sequential)]
internal struct BFloat16
{
public ushort Value { get; private set; }
public BFloat16(ushort val)
/// <summary>
/// Converts to ushort
/// </summary>
/// <param name="f">instance of Float16</param>
public static implicit operator ushort (Float16 f) { return f.value; }
/// <summary>
/// Converts a 16-bit unsigned integer to a Float16.
/// </summary>
/// <param name="value">A 16-bit unsigned integer.</param>
/// <returns>A Float16 that represents the converted 16-bit unsigned integer.</returns>
public static implicit operator Float16(ushort value) { return new Float16(value); }
/// <summary>
/// Compares values of two Float16 for binary equality
/// </summary>
/// <param name="lhs"></param>
/// <param name="rhs"></param>
/// <returns>result of value comparisons</returns>
public static bool operator ==(Float16 lhs, Float16 rhs) { return lhs.value == rhs.value; }
/// <summary>
/// Compares values of two Float16 for binary inequality
/// </summary>
/// <param name="lhs"></param>
/// <param name="rhs"></param>
/// <returns>result of value comparisons</returns>
public static bool operator !=(Float16 lhs, Float16 rhs) { return lhs.value != rhs.value; }
/// <summary>
/// Returns a value indicating whether this instance and other Float16 represent the same value.
/// </summary>
/// <param name="other">A Float16 object to compare to this instance.</param>
/// <returns>true if other.value is equal to this instance; otherwise, false.</returns>
public bool Equals(Float16 other)
{
Value = val;
return (other == this);
}
/// <summary>
/// Returns a value indicating whether this instance and a specified System.Object
/// represent the same type and value.
/// </summary>
/// <param name="obj">An System.Object.</param>
/// <returns>true if obj is Float16 and its value is equal to this instance; otherwise, false.</returns>
public override bool Equals(object obj)
{
bool result = false;
if (obj is Float16)
{
Float16 fl16 = (Float16)obj;
result = (fl16 == this);
}
return result;
}
/// <summary>
/// Returns the hash code for this instance.
/// </summary>
/// <returns>A 32-bit signed integer hash code.</returns>
public override int GetHashCode()
{
return value.GetHashCode();
}
}
/// <summary>
/// Helps typecasting. Holds primitive type information
/// This value type represents A BFloat16 value
/// it is blittable as defined in https://docs.microsoft.com/en-us/dotnet/framework/interop/blittable-and-non-blittable-types
/// and as such, represented the same way in managed and native memories. This means that arrays of this type
/// do not have to be copied to be passed to native memory but simply pinnned and read by native code. Thus,
/// one can create a Tensor on top of an array of these structures and feed it directly to Onnxruntime library.
/// Binary wise, it is the same as ushort[] (uint16_t in C++). However, we would like a separate type for type dispatching.
/// </summary>
public struct BFloat16
{
public ushort value;
/// <summary>
/// Ctor
/// </summary>
/// <param name="v"></param>
public BFloat16(ushort v)
{
value = v;
}
/// <summary>
/// Converts to ushort
/// </summary>
/// <param name="bf">instance of BFloat16</param>
public static implicit operator ushort(BFloat16 bf) { return bf.value; }
/// <summary>
/// Converts a 16-bit unsigned integer to a BFloat16.
/// </summary>
/// <param name="value">A 16-bit unsigned integer.</param>
/// <returns>A BFloat16 that represents the converted 16-bit unsigned integer.</returns>
public static implicit operator BFloat16(ushort value) { return new BFloat16(value); }
/// <summary>
/// Compares values of two BFloat16 for binary equality
/// </summary>
/// <param name="lhs"></param>
/// <param name="rhs"></param>
/// <returns>result of value comparisons</returns>
public static bool operator ==(BFloat16 lhs, BFloat16 rhs) { return lhs.value == rhs.value; }
/// <summary>
/// Compares values of two BFloat16 for binary inequality
/// </summary>
/// <param name="lhs"></param>
/// <param name="rhs"></param>
/// <returns>result of value comparisons</returns>
public static bool operator !=(BFloat16 lhs, BFloat16 rhs) { return lhs.value != rhs.value; }
/// <summary>
/// Returns a value indicating whether this instance and other BFloat16 represent the same value.
/// </summary>
/// <param name="other">A BFloat16 object to compare to this instance.</param>
/// <returns>true if other.value is equal to this instance; otherwise, false.</returns>
public bool Equals(BFloat16 other)
{
return (other == this);
}
/// <summary>
/// Returns a value indicating whether this instance and a specified System.Object
/// represent the same type and value.
/// </summary>
/// <param name="obj">An System.Object.</param>
/// <returns>true if obj is BFloat16 its value is equal to this instance; otherwise, false.</returns>
public override bool Equals(object obj)
{
bool result = false;
if (obj is BFloat16)
{
BFloat16 bfl16 = (BFloat16)obj;
result = (bfl16 == this);
}
return result;
}
/// <summary>
/// Returns the hash code for this instance.
/// </summary>
/// <returns>A 32-bit signed integer hash code.</returns>
public override int GetHashCode()
{
return value.GetHashCode();
}
}
/// <summary>
/// Helps typecasting. Holds Tensor element type traits.
/// </summary>
public class TensorTypeInfo
{
@ -90,10 +229,33 @@ namespace Microsoft.ML.OnnxRuntime.Tensors
}
}
/// <summary>
/// Holds TensorElement traits
/// </summary>
public class TensorElementTypeInfo
{
public Type TensorType { get; private set; }
public int TypeSize { get; private set; }
public bool IsString { get; private set; }
public TensorElementTypeInfo(Type type, int typeSize)
{
TensorType = type;
TypeSize = typeSize;
IsString = type == typeof(string);
}
}
/// <summary>
/// This class is a base for all Tensors. It hosts maps with type traits.
/// </summary>
public class TensorBase
{
private static readonly Dictionary<Type, TensorTypeInfo> typeInfoMap =
new Dictionary<Type, TensorTypeInfo>()
private static readonly Dictionary<Type, TensorTypeInfo> typeInfoMap;
private static readonly Dictionary<TensorElementType, TensorElementTypeInfo> tensorElementTypeInfoMap;
static TensorBase () {
typeInfoMap = new Dictionary<Type, TensorTypeInfo>()
{
{ typeof(float), new TensorTypeInfo( TensorElementType.Float, sizeof(float)) },
{ typeof(byte), new TensorTypeInfo( TensorElementType.UInt8, sizeof(byte)) },
@ -111,20 +273,57 @@ namespace Microsoft.ML.OnnxRuntime.Tensors
{ typeof(BFloat16), new TensorTypeInfo( TensorElementType.BFloat16, sizeof(ushort)) }
};
tensorElementTypeInfoMap = new Dictionary<TensorElementType, TensorElementTypeInfo>();
foreach(var info in typeInfoMap)
{
tensorElementTypeInfoMap.Add(info.Value.ElementType, new TensorElementTypeInfo(info.Key, info.Value.TypeSize));
}
}
private readonly Type _primitiveType;
protected TensorBase(Type primitiveType)
{
// Should hold as we rely on this to pass arrays of these
// types to native code
unsafe
{
Debug.Assert(sizeof(ushort) == sizeof(Float16));
Debug.Assert(sizeof(ushort) == sizeof(BFloat16));
}
_primitiveType = primitiveType;
}
/// <summary>
/// Queries the map returns result or null
/// Query TensorTypeInfo by one of the supported types
/// </summary>
/// <param name="type"></param>
/// <returns>TensorTypeInfo or null if not supported</returns>
public static TensorTypeInfo GetTypeInfo(Type type)
{
TensorTypeInfo result = null;
typeInfoMap.TryGetValue(type, out result);
return result;
}
/// <summary>
/// Query TensorElementTypeInfo by enum
/// </summary>
/// <param name="elementType">type enum</param>
/// <returns>instance of TensorElementTypeInfo or null if not found</returns>
public static TensorElementTypeInfo GetElementTypeInfo(TensorElementType elementType)
{
TensorElementTypeInfo result = null;
tensorElementTypeInfoMap.TryGetValue(elementType, out result);
return result;
}
/// <summary>
/// Query TensorTypeInfo using this Tensor type
/// </summary>
/// <returns></returns>
public TensorTypeInfo GetTypeInfo()
{
TensorTypeInfo result = null;
typeInfoMap.TryGetValue(_primitiveType, out result);
return result;
return GetTypeInfo(_primitiveType);
}
}
@ -312,6 +511,14 @@ namespace Microsoft.ML.OnnxRuntime.Tensors
{
return (T)(object)(ushort)(0);
}
else if (typeof(T) == typeof(Float16))
{
return (T)(object)(ushort)(0);
}
else if (typeof(T) == typeof(BFloat16))
{
return (T)(object)(ushort)(0);
}
throw new NotSupportedException();
}
@ -372,6 +579,14 @@ namespace Microsoft.ML.OnnxRuntime.Tensors
else if (typeof(T) == typeof(ushort))
{
return (T)(object)(ushort)(1);
}
else if(typeof(T) == typeof(Float16))
{
return (T)(object)(ushort)(15360);
}
else if (typeof(T) == typeof(BFloat16))
{
return (T)(object)(ushort)(16256);
}
throw new NotSupportedException();

View file

@ -2,6 +2,8 @@
// Licensed under the MIT License.
using Microsoft.ML.OnnxRuntime.Tensors;
using Microsoft.VisualBasic.CompilerServices;
using Onnx;
using System;
using System.Collections.Generic;
using System.IO;
@ -676,15 +678,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests
var skipModels = new Dictionary<string, string>() {
{ "mxnet_arcface", "Model is an invalid ONNX model"},
{ "tf_inception_v2", "TODO: Debug failing model, skipping for now" },
{ "fp16_inception_v1", "16-bit float not supported type in C#." },
{ "fp16_shufflenet", "16-bit float not supported type in C#." },
{ "fp16_tiny_yolov2", "16-bit float not supported type in C#." },
{ "fp16_coreml_FNS-Candy", "16-bit float not supported type in C#." },
{ "test_mnist", "16-bit float not supported type in C#." },
{ "fp16_test_shufflenet", "16-bit float not supported type in C#." },
{ "fp16_coreml_LinearRegression_NYCTaxi", "16-bit float not supported type in C#." },
{ "test_bidaf", "16-bit float not supported type in C#." },
{ "fp16_test_tiny_yolov2", "16-bit float not supported type in C#." },
{ "fp16_tiny_yolov2", "Tolerance level for float16 is not known. We now support fp16." },
{ "test_bidaf", "Does not run in opset9, runs in other opsets. Tensors of type ElementType not currently supported in the LoadTensorFromFile." },
{ "test_mnist", "Does not run in opset9, runs in other opsets. Tensors of type ElementType not currently supported in the LoadTensorFromFile" },
{ "BERT_Squad", "Could not find an implementation for the node bert / embeddings / one_hot:OneHot(9)" },
{ "mlperf_ssd_mobilenet_300", "Could not find file output_0.pb" },
{ "tf_resnet_v1_50", "result mismatch when Conv BN Fusion is applied" },
@ -826,7 +822,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
string testDataDirNamePattern = "test_data*";
if (opset == "opset9" && modelName == "LSTM_Seq_lens_unpacked")
{
testDataDirNamePattern = "seq_lens*"; // discrepency in data directory
testDataDirNamePattern = "seq_lens*"; // discrepancy in data directory
}
foreach (var testDataDir in modelDir.EnumerateDirectories(testDataDirNamePattern))
{
@ -898,6 +894,14 @@ namespace Microsoft.ML.OnnxRuntime.Tests
{
Assert.Equal(result.AsTensor<bool>(), outputValue.AsTensor<bool>(), new ExactComparer<bool>());
}
else if (outputMeta.ElementType == typeof(Float16))
{
Assert.Equal(result.AsTensor<Float16>(), outputValue.AsTensor<Float16>(), new Float16Comparer { tolerance = 2 });
}
else if (outputMeta.ElementType == typeof(BFloat16))
{
Assert.Equal(result.AsTensor<BFloat16>(), outputValue.AsTensor<BFloat16>(), new BFloat16Comparer { tolerance = 2 });
}
else
{
Assert.True(false, "The TestPretrainedModels does not yet support output of type " + nameof(outputMeta.ElementType));
@ -1520,20 +1524,47 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
}
[Fact(Skip = "FLOAT16 not available in C#")]
[Fact]
private void TestModelInputFLOAT16()
{
// model takes 1x5 input of fixed type, echoes back
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_types_FLOAT16.pb");
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_types_FLOAT16.onnx");
using (var session = new InferenceSession(modelPath))
{
var container = new List<NamedOnnxValue>();
var tensorIn = new DenseTensor<float>(new float[] { 1.0f, 2.0f, -3.0f, float.MinValue, float.MaxValue }, new int[] { 1, 5 });
var tensorIn = new DenseTensor<Float16>(
new Float16[] { 15360, 16384, 16896, 17408, 17664 }, new int[] { 1, 5 });
var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn);
container.Add(nov);
using (var res = session.Run(container))
{
var tensorOut = res.First().AsTensor<float>();
var valueOut = res.First();
Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, valueOut.ValueType);
Assert.Equal(Tensors.TensorElementType.Float16, valueOut.ElementType);
var tensorOut = res.First().AsTensor<Float16>();
Assert.True(tensorOut.SequenceEqual(tensorIn));
}
}
}
[Fact]
private void TestModelInputBFLOAT16()
{
// model takes 1x5 input of fixed type, echoes back
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_types_BFLOAT16.onnx");
using (var session = new InferenceSession(modelPath))
{
var container = new List<NamedOnnxValue>();
var tensorIn = new DenseTensor<BFloat16>(
new BFloat16[] { 16256, 16384, 16448, 16512, 16544 }, new int[] { 1, 5 });
var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn);
container.Add(nov);
using (var res = session.Run(container))
{
var valueOut = res.First();
Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, valueOut.ValueType);
Assert.Equal(Tensors.TensorElementType.BFloat16, valueOut.ElementType);
var tensorOut = res.First().AsTensor<BFloat16>();
Assert.True(tensorOut.SequenceEqual(tensorIn));
}
}
@ -1580,7 +1611,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
var outNode0 = outputs.ElementAtOrDefault(0);
Assert.Equal("label", outNode0.Name);
Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, outNode0.ValueType);
Assert.Equal(TensorElementType.Int64, (TensorElementType)outNode0.ElementType);
Assert.Equal(Tensors.TensorElementType.Int64, outNode0.ElementType);
// try-cast as a tensor
var outLabelTensor = outNode0.AsTensor<long>();
@ -2151,86 +2182,21 @@ namespace Microsoft.ML.OnnxRuntime.Tests
return tensorData.ToArray();
}
private enum TensorElementType
private static void GetTypeAndWidth(Tensors.TensorElementType elemType, out Type type, out int width)
{
Float = 1,
UInt8 = 2,
Int8 = 3,
UInt16 = 4,
Int16 = 5,
Int32 = 6,
Int64 = 7,
String = 8,
Bool = 9,
Float16 = 10,
Double = 11,
UInt32 = 12,
UInt64 = 13,
Complex64 = 14,
Complex128 = 15,
BFloat16 = 16,
DataTypeMax = 17
}
private static void GetTypeAndWidth(TensorElementType elemType, out Type type, out int width)
{
switch (elemType)
TensorElementTypeInfo result = TensorBase.GetElementTypeInfo(elemType);
if (result != null)
{
case TensorElementType.Float:
type = typeof(float);
width = sizeof(float);
break;
case TensorElementType.Double:
type = typeof(double);
width = sizeof(double);
break;
case TensorElementType.Int16:
type = typeof(short);
width = sizeof(short);
break;
case TensorElementType.UInt16:
type = typeof(ushort);
width = sizeof(ushort);
break;
case TensorElementType.Int32:
type = typeof(int);
width = sizeof(int);
break;
case TensorElementType.UInt32:
type = typeof(uint);
width = sizeof(uint);
break;
case TensorElementType.Int64:
type = typeof(long);
width = sizeof(long);
break;
case TensorElementType.UInt64:
type = typeof(ulong);
width = sizeof(ulong);
break;
case TensorElementType.UInt8:
type = typeof(byte);
width = sizeof(byte);
break;
case TensorElementType.Int8:
type = typeof(sbyte);
width = sizeof(sbyte);
break;
case TensorElementType.String:
type = typeof(byte);
width = sizeof(byte);
break;
case TensorElementType.Bool:
type = typeof(bool);
width = sizeof(bool);
break;
default:
type = null;
width = 0;
break;
type = result.TensorType;
width = result.TypeSize;
}
else
{
type = null;
width = 0;
}
}
static NamedOnnxValue LoadTensorFromFilePb(string filename, IReadOnlyDictionary<string, NodeMetadata> nodeMetaDict)
{
//Set buffer size to 4MB
@ -2243,7 +2209,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
Type tensorElemType = null;
int width = 0;
GetTypeAndWidth((TensorElementType)tensor.DataType, out tensorElemType, out width);
GetTypeAndWidth((Tensors.TensorElementType)tensor.DataType, out tensorElemType, out width);
var intDims = new int[tensor.Dims.Count];
for (int i = 0; i < tensor.Dims.Count; i++)
{
@ -2251,7 +2217,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
NodeMetadata nodeMeta = null;
string nodeName = "";
string nodeName = string.Empty;
if (nodeMetaDict.Count == 1)
{
nodeMeta = nodeMetaDict.Values.First();
@ -2259,7 +2225,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
else if (nodeMetaDict.Count > 1)
{
if (tensor.Name != "")
if (tensor.Name.Length > 0)
{
nodeMeta = nodeMetaDict[tensor.Name];
nodeName = tensor.Name;
@ -2353,6 +2319,14 @@ namespace Microsoft.ML.OnnxRuntime.Tests
{
return CreateNamedOnnxValueFromRawData<bool>(nodeName, tensor.RawData.ToArray(), sizeof(bool), intDims);
}
else if (nodeMeta.ElementType == typeof(Float16))
{
return CreateNamedOnnxValueFromRawData<Float16>(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims);
}
else if (nodeMeta.ElementType == typeof(BFloat16))
{
return CreateNamedOnnxValueFromRawData<BFloat16>(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims);
}
else
{
//TODO: Add support for remaining types
@ -2363,9 +2337,24 @@ namespace Microsoft.ML.OnnxRuntime.Tests
static NamedOnnxValue CreateNamedOnnxValueFromRawData<T>(string name, byte[] rawData, int elemWidth, int[] dimensions)
{
T[] floatArr = new T[rawData.Length / elemWidth];
Buffer.BlockCopy(rawData, 0, floatArr, 0, rawData.Length);
var dt = new DenseTensor<T>(floatArr, dimensions);
T[] typedArr = new T[rawData.Length / elemWidth];
var typeOf = typeof(T);
if(typeOf == typeof(Float16) || typeOf == typeof(BFloat16))
{
using (var memSrcHandle = new Memory<byte>(rawData).Pin())
using (var memDstHandle = new Memory<T>(typedArr).Pin())
{
unsafe
{
Buffer.MemoryCopy(memSrcHandle.Pointer, memDstHandle.Pointer, typedArr.Length * elemWidth, rawData.Length);
}
}
}
else
{
Buffer.BlockCopy(rawData, 0, typedArr, 0, rawData.Length);
}
var dt = new DenseTensor<T>(typedArr, dimensions);
return NamedOnnxValue.CreateFromTensor<T>(name, dt);
}
@ -2430,7 +2419,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
public int GetHashCode(float x)
{
return 0;
return x.GetHashCode();
}
}
@ -2442,7 +2431,36 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
public int GetHashCode(T x)
{
return 0;
return x.GetHashCode();
}
}
/// <summary>
/// Use it to compare Float16 and BFloat16
/// </summary>
internal class Float16Comparer : IEqualityComparer<Float16>
{
public ushort tolerance;
public bool Equals(Float16 x, Float16 y)
{
return Math.Abs(x - y) <= (tolerance + y);
}
public int GetHashCode(Float16 x)
{
return x.GetHashCode();
}
}
internal class BFloat16Comparer : IEqualityComparer<BFloat16>
{
public ushort tolerance;
public bool Equals(BFloat16 x, BFloat16 y)
{
return Math.Abs(x - y) <= (tolerance + y);
}
public int GetHashCode(BFloat16 x)
{
return x.GetHashCode();
}
}

27
csharp/testdata/test_input_BFLOAT16.py vendored Normal file
View file

@ -0,0 +1,27 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import onnx
from onnx import helper
from onnx.helper import make_opsetid
from onnx import AttributeProto, TensorProto, GraphProto
input_info = helper.make_tensor_value_info('input', TensorProto.BFLOAT16, [1, 5])
output_info = helper.make_tensor_value_info('output', TensorProto.BFLOAT16, [1, 5])
# Create a node (NodeProto) - This is based on Pad-11
node_def = helper.make_node(
'Identity', # node name
['input'], # inputs
['output'] # outputs
)
graph_def = helper.make_graph(nodes=[node_def], name='test_types_BLOAT16',
inputs=[input_info], outputs=[output_info])
model_def = helper.make_model(graph_def, producer_name='AIInfra',
opset_imports=[make_opsetid('', 13)])
final_model = onnx.utils.polish_model(model_def)
onnx.save(final_model, 'test_types_BFLOAT16.onnx')

30
csharp/testdata/test_input_FLOAT16.py vendored Normal file
View file

@ -0,0 +1,30 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import onnx
from onnx import helper
from onnx.helper import make_opsetid
from onnx import AttributeProto, TensorProto, GraphProto
input_info = helper.make_tensor_value_info('input', TensorProto.FLOAT16, [1, 5])
output_info = helper.make_tensor_value_info('output', TensorProto.FLOAT16, [1, 5])
# Create a node (NodeProto) - This is based on Pad-11
node_def = helper.make_node(
'Slice', # node name
['input'], # inputs
['output'], # outputs
axes=[0,1], # attributes
ends=[1,5],
starts=[0,0]
)
graph_def = helper.make_graph(nodes=[node_def], name='test_input_FLOAT16',
inputs=[input_info], outputs=[output_info])
model_def = helper.make_model(graph_def, producer_name='AIInfra',
opset_imports=[make_opsetid('', 7)])
final_model = onnx.utils.polish_model(model_def)
onnx.save(final_model, 'test_types_FLOAT16.onnx')

BIN
csharp/testdata/test_types_BFLOAT16.onnx vendored Normal file

Binary file not shown.

BIN
csharp/testdata/test_types_FLOAT16.onnx vendored Normal file

Binary file not shown.

Binary file not shown.