From 2f35e651353779dd5fd395bc472ce4a0aec6eea5 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 12 Nov 2020 17:57:08 -0800 Subject: [PATCH] Add Float16 and BFloat16 support to C# API (#5775) Add Float16 and BFloat16 support. --- .../DisposableNamedOnnxValue.cs | 24 +- .../NativeOnnxValueHelper.cs | 62 +---- .../src/Microsoft.ML.OnnxRuntime/OrtValue.cs | 10 + .../Tensors/Tensor.cs | 261 ++++++++++++++++-- .../InferenceTest.cs | 216 ++++++++------- csharp/testdata/test_input_BFLOAT16.py | 27 ++ csharp/testdata/test_input_FLOAT16.py | 30 ++ csharp/testdata/test_types_BFLOAT16.onnx | Bin 0 -> 117 bytes csharp/testdata/test_types_FLOAT16.onnx | Bin 0 -> 162 bytes csharp/testdata/test_types_FLOAT16.pb | Bin 167 -> 0 bytes 10 files changed, 452 insertions(+), 178 deletions(-) create mode 100644 csharp/testdata/test_input_BFLOAT16.py create mode 100644 csharp/testdata/test_input_FLOAT16.py create mode 100644 csharp/testdata/test_types_BFLOAT16.onnx create mode 100644 csharp/testdata/test_types_FLOAT16.onnx delete mode 100644 csharp/testdata/test_types_FLOAT16.pb diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.cs b/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.cs index a4f8832707..03bab8e920 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.cs @@ -5,7 +5,6 @@ using Microsoft.ML.OnnxRuntime.Tensors; using System; using System.Buffers; using System.Collections.Generic; -using System.Diagnostics; namespace Microsoft.ML.OnnxRuntime { @@ -55,8 +54,8 @@ namespace Microsoft.ML.OnnxRuntime /// tensors, sequences of tensors, sequences and maps /// It extends NamedOnnxValue, exposes the OnnxValueType and Tensor type /// The class must be disposed of. - /// It disposes of _ortValueHolder that owns the underlying Ort output value or - /// anything that the class that implements that interfaces needs to dispose. + /// It disposes of _ortValueHolder that owns the underlying Ort output value and + /// anything else that would need to be disposed by the instance of the class. /// Use factory method CreateFromOrtValue to obtain an instance of the class. /// public class DisposableNamedOnnxValue : NamedOnnxValue, IDisposable @@ -64,6 +63,19 @@ namespace Microsoft.ML.OnnxRuntime private IOrtValueOwner _ortValueHolder; private bool _disposed = false; + /// + /// Ctor + /// + /// Name of the output value + /// Managed object created to represent output value, such as DenseTensor + /// List or Dictionary + /// + /// Use this to decide what you want to call to fetch data, AsTensor(), AsDictionary() + /// or AsEnumerable() + /// Tensor element type if value type is a Tensor + /// Object that holds native resources. + /// Typically, this is an output OrtValue that holds native memory where Tensor is mapped but may also be + /// other things that would need to be disposed by this instance depending on how IOrtValueOwner is implemented. private DisposableNamedOnnxValue(string name, Object value, OnnxValueType onnxValueType, TensorElementType elementType, IOrtValueOwner ortValueHolder) : base(name, value) { @@ -169,6 +181,12 @@ namespace Microsoft.ML.OnnxRuntime case TensorElementType.Bool: result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); break; + case TensorElementType.Float16: + result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + break; + case TensorElementType.BFloat16: + result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + break; default: throw new NotSupportedException("Tensor of element type: " + elemType + " is not supported"); diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.cs index 29d6938e4a..efb619f345 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.cs @@ -83,60 +83,16 @@ namespace Microsoft.ML.OnnxRuntime { public static void GetTypeAndWidth(TensorElementType elemType, out Type type, out int width) { - switch (elemType) + TensorElementTypeInfo result = TensorBase.GetElementTypeInfo(elemType); + if(result != null) { - case TensorElementType.Float: - type = typeof(float); - width = sizeof(float); - break; - case TensorElementType.Double: - type = typeof(double); - width = sizeof(double); - break; - case TensorElementType.Int16: - type = typeof(short); - width = sizeof(short); - break; - case TensorElementType.UInt16: - type = typeof(ushort); - width = sizeof(ushort); - break; - case TensorElementType.Int32: - type = typeof(int); - width = sizeof(int); - break; - case TensorElementType.UInt32: - type = typeof(uint); - width = sizeof(uint); - break; - case TensorElementType.Int64: - type = typeof(long); - width = sizeof(long); - break; - case TensorElementType.UInt64: - type = typeof(ulong); - width = sizeof(ulong); - break; - case TensorElementType.UInt8: - type = typeof(byte); - width = sizeof(byte); - break; - case TensorElementType.Int8: - type = typeof(sbyte); - width = sizeof(sbyte); - break; - case TensorElementType.String: - type = typeof(string); - width = sizeof(byte); - break; - case TensorElementType.Bool: - type = typeof(bool); - width = sizeof(bool); - break; - default: - type = null; - width = 0; - break; + type = result.TensorType; + width = result.TypeSize; + } + else + { + type = null; + width = 0; } } } diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.cs index c483daf0b8..c7ebcf2820 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.cs @@ -208,6 +208,16 @@ namespace Microsoft.ML.OnnxRuntime out memHandle, out dataBufferLength, out shape, out rank); break; + case TensorElementType.Float16: + PinAsTensor(value as Tensor, typeSize, + out memHandle, out dataBufferLength, + out shape, out rank); + break; + case TensorElementType.BFloat16: + PinAsTensor(value as Tensor, typeSize, + out memHandle, out dataBufferLength, + out shape, out rank); + break; default: throw new NotSupportedException("Element type: " + elType + " is not of a supported type"); } diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.cs index 62f4be6832..eb56c10763 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.cs @@ -11,14 +11,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Collections; using System.Collections.Generic; using System.Diagnostics; -using System.Text; -using System; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; -using System.Reflection; +using System.Text; // Making this assembly's internals visible to the internal Test assembly [assembly: InternalsVisibleTo("Microsoft.ML.OnnxRuntime.Tests," + @@ -55,28 +54,168 @@ namespace Microsoft.ML.OnnxRuntime.Tensors DataTypeMax = 17 } - [StructLayout(LayoutKind.Sequential)] - internal struct Float16 + /// + /// This value type represents A Float16 value + /// it is blittable as defined in https://docs.microsoft.com/en-us/dotnet/framework/interop/blittable-and-non-blittable-types + /// and as such, represented the same way in managed and native memories. This means that arrays of this type + /// do not have to be copied to be passed to native memory but simply pinnned and read by native code. Thus, + /// one can create a Tensor on top of an array of these structures and feed it directly to Onnxruntime library. + /// Binary wise, it is the same as ushort[] (uint16_t in C++). However, we would like a separate type for type dispatching. + /// + public struct Float16 { - public ushort Value { get; private set; } - public Float16(ushort val) + public ushort value; + /// + /// Ctor + /// + /// + public Float16(ushort v) { - Value = val; + value = v; } - } - - [StructLayout(LayoutKind.Sequential)] - internal struct BFloat16 - { - public ushort Value { get; private set; } - public BFloat16(ushort val) + /// + /// Converts to ushort + /// + /// instance of Float16 + public static implicit operator ushort (Float16 f) { return f.value; } + /// + /// Converts a 16-bit unsigned integer to a Float16. + /// + /// A 16-bit unsigned integer. + /// A Float16 that represents the converted 16-bit unsigned integer. + public static implicit operator Float16(ushort value) { return new Float16(value); } + /// + /// Compares values of two Float16 for binary equality + /// + /// + /// + /// result of value comparisons + public static bool operator ==(Float16 lhs, Float16 rhs) { return lhs.value == rhs.value; } + /// + /// Compares values of two Float16 for binary inequality + /// + /// + /// + /// result of value comparisons + public static bool operator !=(Float16 lhs, Float16 rhs) { return lhs.value != rhs.value; } + /// + /// Returns a value indicating whether this instance and other Float16 represent the same value. + /// + /// A Float16 object to compare to this instance. + /// true if other.value is equal to this instance; otherwise, false. + public bool Equals(Float16 other) { - Value = val; + return (other == this); + } + /// + /// Returns a value indicating whether this instance and a specified System.Object + /// represent the same type and value. + /// + /// An System.Object. + /// true if obj is Float16 and its value is equal to this instance; otherwise, false. + public override bool Equals(object obj) + { + bool result = false; + if (obj is Float16) + { + Float16 fl16 = (Float16)obj; + result = (fl16 == this); + } + return result; + } + /// + /// Returns the hash code for this instance. + /// + /// A 32-bit signed integer hash code. + public override int GetHashCode() + { + return value.GetHashCode(); } } /// - /// Helps typecasting. Holds primitive type information + /// This value type represents A BFloat16 value + /// it is blittable as defined in https://docs.microsoft.com/en-us/dotnet/framework/interop/blittable-and-non-blittable-types + /// and as such, represented the same way in managed and native memories. This means that arrays of this type + /// do not have to be copied to be passed to native memory but simply pinnned and read by native code. Thus, + /// one can create a Tensor on top of an array of these structures and feed it directly to Onnxruntime library. + /// Binary wise, it is the same as ushort[] (uint16_t in C++). However, we would like a separate type for type dispatching. + /// + public struct BFloat16 + { + public ushort value; + /// + /// Ctor + /// + /// + public BFloat16(ushort v) + { + value = v; + } + /// + /// Converts to ushort + /// + /// instance of BFloat16 + public static implicit operator ushort(BFloat16 bf) { return bf.value; } + /// + /// Converts a 16-bit unsigned integer to a BFloat16. + /// + /// A 16-bit unsigned integer. + /// A BFloat16 that represents the converted 16-bit unsigned integer. + public static implicit operator BFloat16(ushort value) { return new BFloat16(value); } + /// + /// Compares values of two BFloat16 for binary equality + /// + /// + /// + /// result of value comparisons + public static bool operator ==(BFloat16 lhs, BFloat16 rhs) { return lhs.value == rhs.value; } + /// + /// Compares values of two BFloat16 for binary inequality + /// + /// + /// + /// result of value comparisons + public static bool operator !=(BFloat16 lhs, BFloat16 rhs) { return lhs.value != rhs.value; } + + /// + /// Returns a value indicating whether this instance and other BFloat16 represent the same value. + /// + /// A BFloat16 object to compare to this instance. + /// true if other.value is equal to this instance; otherwise, false. + public bool Equals(BFloat16 other) + { + return (other == this); + } + + /// + /// Returns a value indicating whether this instance and a specified System.Object + /// represent the same type and value. + /// + /// An System.Object. + /// true if obj is BFloat16 its value is equal to this instance; otherwise, false. + public override bool Equals(object obj) + { + bool result = false; + if (obj is BFloat16) + { + BFloat16 bfl16 = (BFloat16)obj; + result = (bfl16 == this); + } + return result; + } + /// + /// Returns the hash code for this instance. + /// + /// A 32-bit signed integer hash code. + public override int GetHashCode() + { + return value.GetHashCode(); + } + } + + /// + /// Helps typecasting. Holds Tensor element type traits. /// public class TensorTypeInfo { @@ -90,10 +229,33 @@ namespace Microsoft.ML.OnnxRuntime.Tensors } } + /// + /// Holds TensorElement traits + /// + public class TensorElementTypeInfo + { + public Type TensorType { get; private set; } + public int TypeSize { get; private set; } + public bool IsString { get; private set; } + public TensorElementTypeInfo(Type type, int typeSize) + { + TensorType = type; + TypeSize = typeSize; + IsString = type == typeof(string); + } + } + + /// + /// This class is a base for all Tensors. It hosts maps with type traits. + /// public class TensorBase { - private static readonly Dictionary typeInfoMap = - new Dictionary() + private static readonly Dictionary typeInfoMap; + + private static readonly Dictionary tensorElementTypeInfoMap; + + static TensorBase () { + typeInfoMap = new Dictionary() { { typeof(float), new TensorTypeInfo( TensorElementType.Float, sizeof(float)) }, { typeof(byte), new TensorTypeInfo( TensorElementType.UInt8, sizeof(byte)) }, @@ -111,20 +273,57 @@ namespace Microsoft.ML.OnnxRuntime.Tensors { typeof(BFloat16), new TensorTypeInfo( TensorElementType.BFloat16, sizeof(ushort)) } }; + tensorElementTypeInfoMap = new Dictionary(); + foreach(var info in typeInfoMap) + { + tensorElementTypeInfoMap.Add(info.Value.ElementType, new TensorElementTypeInfo(info.Key, info.Value.TypeSize)); + } + } + private readonly Type _primitiveType; protected TensorBase(Type primitiveType) { + // Should hold as we rely on this to pass arrays of these + // types to native code + unsafe + { + Debug.Assert(sizeof(ushort) == sizeof(Float16)); + Debug.Assert(sizeof(ushort) == sizeof(BFloat16)); + } _primitiveType = primitiveType; } + /// - /// Queries the map returns result or null + /// Query TensorTypeInfo by one of the supported types + /// + /// + /// TensorTypeInfo or null if not supported + public static TensorTypeInfo GetTypeInfo(Type type) + { + TensorTypeInfo result = null; + typeInfoMap.TryGetValue(type, out result); + return result; + } + + /// + /// Query TensorElementTypeInfo by enum + /// + /// type enum + /// instance of TensorElementTypeInfo or null if not found + public static TensorElementTypeInfo GetElementTypeInfo(TensorElementType elementType) + { + TensorElementTypeInfo result = null; + tensorElementTypeInfoMap.TryGetValue(elementType, out result); + return result; + } + + /// + /// Query TensorTypeInfo using this Tensor type /// /// public TensorTypeInfo GetTypeInfo() { - TensorTypeInfo result = null; - typeInfoMap.TryGetValue(_primitiveType, out result); - return result; + return GetTypeInfo(_primitiveType); } } @@ -312,6 +511,14 @@ namespace Microsoft.ML.OnnxRuntime.Tensors { return (T)(object)(ushort)(0); } + else if (typeof(T) == typeof(Float16)) + { + return (T)(object)(ushort)(0); + } + else if (typeof(T) == typeof(BFloat16)) + { + return (T)(object)(ushort)(0); + } throw new NotSupportedException(); } @@ -372,6 +579,14 @@ namespace Microsoft.ML.OnnxRuntime.Tensors else if (typeof(T) == typeof(ushort)) { return (T)(object)(ushort)(1); + } + else if(typeof(T) == typeof(Float16)) + { + return (T)(object)(ushort)(15360); + } + else if (typeof(T) == typeof(BFloat16)) + { + return (T)(object)(ushort)(16256); } throw new NotSupportedException(); diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs index 006f90ce76..c17302bda1 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs @@ -2,6 +2,8 @@ // Licensed under the MIT License. using Microsoft.ML.OnnxRuntime.Tensors; +using Microsoft.VisualBasic.CompilerServices; +using Onnx; using System; using System.Collections.Generic; using System.IO; @@ -676,15 +678,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests var skipModels = new Dictionary() { { "mxnet_arcface", "Model is an invalid ONNX model"}, { "tf_inception_v2", "TODO: Debug failing model, skipping for now" }, - { "fp16_inception_v1", "16-bit float not supported type in C#." }, - { "fp16_shufflenet", "16-bit float not supported type in C#." }, - { "fp16_tiny_yolov2", "16-bit float not supported type in C#." }, - { "fp16_coreml_FNS-Candy", "16-bit float not supported type in C#." }, - { "test_mnist", "16-bit float not supported type in C#." }, - { "fp16_test_shufflenet", "16-bit float not supported type in C#." }, - { "fp16_coreml_LinearRegression_NYCTaxi", "16-bit float not supported type in C#." }, - { "test_bidaf", "16-bit float not supported type in C#." }, - { "fp16_test_tiny_yolov2", "16-bit float not supported type in C#." }, + { "fp16_tiny_yolov2", "Tolerance level for float16 is not known. We now support fp16." }, + { "test_bidaf", "Does not run in opset9, runs in other opsets. Tensors of type ElementType not currently supported in the LoadTensorFromFile." }, + { "test_mnist", "Does not run in opset9, runs in other opsets. Tensors of type ElementType not currently supported in the LoadTensorFromFile" }, { "BERT_Squad", "Could not find an implementation for the node bert / embeddings / one_hot:OneHot(9)" }, { "mlperf_ssd_mobilenet_300", "Could not find file output_0.pb" }, { "tf_resnet_v1_50", "result mismatch when Conv BN Fusion is applied" }, @@ -826,7 +822,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests string testDataDirNamePattern = "test_data*"; if (opset == "opset9" && modelName == "LSTM_Seq_lens_unpacked") { - testDataDirNamePattern = "seq_lens*"; // discrepency in data directory + testDataDirNamePattern = "seq_lens*"; // discrepancy in data directory } foreach (var testDataDir in modelDir.EnumerateDirectories(testDataDirNamePattern)) { @@ -898,6 +894,14 @@ namespace Microsoft.ML.OnnxRuntime.Tests { Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); } + else if (outputMeta.ElementType == typeof(Float16)) + { + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new Float16Comparer { tolerance = 2 }); + } + else if (outputMeta.ElementType == typeof(BFloat16)) + { + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new BFloat16Comparer { tolerance = 2 }); + } else { Assert.True(false, "The TestPretrainedModels does not yet support output of type " + nameof(outputMeta.ElementType)); @@ -1520,20 +1524,47 @@ namespace Microsoft.ML.OnnxRuntime.Tests } } - [Fact(Skip = "FLOAT16 not available in C#")] + [Fact] private void TestModelInputFLOAT16() { // model takes 1x5 input of fixed type, echoes back - string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_types_FLOAT16.pb"); + string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_types_FLOAT16.onnx"); using (var session = new InferenceSession(modelPath)) { var container = new List(); - var tensorIn = new DenseTensor(new float[] { 1.0f, 2.0f, -3.0f, float.MinValue, float.MaxValue }, new int[] { 1, 5 }); + var tensorIn = new DenseTensor( + new Float16[] { 15360, 16384, 16896, 17408, 17664 }, new int[] { 1, 5 }); var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn); container.Add(nov); using (var res = session.Run(container)) { - var tensorOut = res.First().AsTensor(); + var valueOut = res.First(); + Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, valueOut.ValueType); + Assert.Equal(Tensors.TensorElementType.Float16, valueOut.ElementType); + var tensorOut = res.First().AsTensor(); + Assert.True(tensorOut.SequenceEqual(tensorIn)); + } + } + } + + [Fact] + private void TestModelInputBFLOAT16() + { + // model takes 1x5 input of fixed type, echoes back + string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_types_BFLOAT16.onnx"); + using (var session = new InferenceSession(modelPath)) + { + var container = new List(); + var tensorIn = new DenseTensor( + new BFloat16[] { 16256, 16384, 16448, 16512, 16544 }, new int[] { 1, 5 }); + var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn); + container.Add(nov); + using (var res = session.Run(container)) + { + var valueOut = res.First(); + Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, valueOut.ValueType); + Assert.Equal(Tensors.TensorElementType.BFloat16, valueOut.ElementType); + var tensorOut = res.First().AsTensor(); Assert.True(tensorOut.SequenceEqual(tensorIn)); } } @@ -1580,7 +1611,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests var outNode0 = outputs.ElementAtOrDefault(0); Assert.Equal("label", outNode0.Name); Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, outNode0.ValueType); - Assert.Equal(TensorElementType.Int64, (TensorElementType)outNode0.ElementType); + Assert.Equal(Tensors.TensorElementType.Int64, outNode0.ElementType); // try-cast as a tensor var outLabelTensor = outNode0.AsTensor(); @@ -2151,86 +2182,21 @@ namespace Microsoft.ML.OnnxRuntime.Tests return tensorData.ToArray(); } - - private enum TensorElementType + private static void GetTypeAndWidth(Tensors.TensorElementType elemType, out Type type, out int width) { - Float = 1, - UInt8 = 2, - Int8 = 3, - UInt16 = 4, - Int16 = 5, - Int32 = 6, - Int64 = 7, - String = 8, - Bool = 9, - Float16 = 10, - Double = 11, - UInt32 = 12, - UInt64 = 13, - Complex64 = 14, - Complex128 = 15, - BFloat16 = 16, - DataTypeMax = 17 - } - - private static void GetTypeAndWidth(TensorElementType elemType, out Type type, out int width) - { - switch (elemType) + TensorElementTypeInfo result = TensorBase.GetElementTypeInfo(elemType); + if (result != null) { - case TensorElementType.Float: - type = typeof(float); - width = sizeof(float); - break; - case TensorElementType.Double: - type = typeof(double); - width = sizeof(double); - break; - case TensorElementType.Int16: - type = typeof(short); - width = sizeof(short); - break; - case TensorElementType.UInt16: - type = typeof(ushort); - width = sizeof(ushort); - break; - case TensorElementType.Int32: - type = typeof(int); - width = sizeof(int); - break; - case TensorElementType.UInt32: - type = typeof(uint); - width = sizeof(uint); - break; - case TensorElementType.Int64: - type = typeof(long); - width = sizeof(long); - break; - case TensorElementType.UInt64: - type = typeof(ulong); - width = sizeof(ulong); - break; - case TensorElementType.UInt8: - type = typeof(byte); - width = sizeof(byte); - break; - case TensorElementType.Int8: - type = typeof(sbyte); - width = sizeof(sbyte); - break; - case TensorElementType.String: - type = typeof(byte); - width = sizeof(byte); - break; - case TensorElementType.Bool: - type = typeof(bool); - width = sizeof(bool); - break; - default: - type = null; - width = 0; - break; + type = result.TensorType; + width = result.TypeSize; + } + else + { + type = null; + width = 0; } } + static NamedOnnxValue LoadTensorFromFilePb(string filename, IReadOnlyDictionary nodeMetaDict) { //Set buffer size to 4MB @@ -2243,7 +2209,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests Type tensorElemType = null; int width = 0; - GetTypeAndWidth((TensorElementType)tensor.DataType, out tensorElemType, out width); + GetTypeAndWidth((Tensors.TensorElementType)tensor.DataType, out tensorElemType, out width); var intDims = new int[tensor.Dims.Count]; for (int i = 0; i < tensor.Dims.Count; i++) { @@ -2251,7 +2217,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests } NodeMetadata nodeMeta = null; - string nodeName = ""; + string nodeName = string.Empty; if (nodeMetaDict.Count == 1) { nodeMeta = nodeMetaDict.Values.First(); @@ -2259,7 +2225,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests } else if (nodeMetaDict.Count > 1) { - if (tensor.Name != "") + if (tensor.Name.Length > 0) { nodeMeta = nodeMetaDict[tensor.Name]; nodeName = tensor.Name; @@ -2353,6 +2319,14 @@ namespace Microsoft.ML.OnnxRuntime.Tests { return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(bool), intDims); } + else if (nodeMeta.ElementType == typeof(Float16)) + { + return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims); + } + else if (nodeMeta.ElementType == typeof(BFloat16)) + { + return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims); + } else { //TODO: Add support for remaining types @@ -2363,9 +2337,24 @@ namespace Microsoft.ML.OnnxRuntime.Tests static NamedOnnxValue CreateNamedOnnxValueFromRawData(string name, byte[] rawData, int elemWidth, int[] dimensions) { - T[] floatArr = new T[rawData.Length / elemWidth]; - Buffer.BlockCopy(rawData, 0, floatArr, 0, rawData.Length); - var dt = new DenseTensor(floatArr, dimensions); + T[] typedArr = new T[rawData.Length / elemWidth]; + var typeOf = typeof(T); + if(typeOf == typeof(Float16) || typeOf == typeof(BFloat16)) + { + using (var memSrcHandle = new Memory(rawData).Pin()) + using (var memDstHandle = new Memory(typedArr).Pin()) + { + unsafe + { + Buffer.MemoryCopy(memSrcHandle.Pointer, memDstHandle.Pointer, typedArr.Length * elemWidth, rawData.Length); + } + } + } + else + { + Buffer.BlockCopy(rawData, 0, typedArr, 0, rawData.Length); + } + var dt = new DenseTensor(typedArr, dimensions); return NamedOnnxValue.CreateFromTensor(name, dt); } @@ -2430,7 +2419,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests } public int GetHashCode(float x) { - return 0; + return x.GetHashCode(); } } @@ -2442,7 +2431,36 @@ namespace Microsoft.ML.OnnxRuntime.Tests } public int GetHashCode(T x) { - return 0; + return x.GetHashCode(); + } + } + + /// + /// Use it to compare Float16 and BFloat16 + /// + internal class Float16Comparer : IEqualityComparer + { + public ushort tolerance; + public bool Equals(Float16 x, Float16 y) + { + return Math.Abs(x - y) <= (tolerance + y); + } + public int GetHashCode(Float16 x) + { + return x.GetHashCode(); + } + } + + internal class BFloat16Comparer : IEqualityComparer + { + public ushort tolerance; + public bool Equals(BFloat16 x, BFloat16 y) + { + return Math.Abs(x - y) <= (tolerance + y); + } + public int GetHashCode(BFloat16 x) + { + return x.GetHashCode(); } } diff --git a/csharp/testdata/test_input_BFLOAT16.py b/csharp/testdata/test_input_BFLOAT16.py new file mode 100644 index 0000000000..c9929898cb --- /dev/null +++ b/csharp/testdata/test_input_BFLOAT16.py @@ -0,0 +1,27 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import onnx +from onnx import helper +from onnx.helper import make_opsetid +from onnx import AttributeProto, TensorProto, GraphProto + +input_info = helper.make_tensor_value_info('input', TensorProto.BFLOAT16, [1, 5]) +output_info = helper.make_tensor_value_info('output', TensorProto.BFLOAT16, [1, 5]) + +# Create a node (NodeProto) - This is based on Pad-11 +node_def = helper.make_node( + 'Identity', # node name + ['input'], # inputs + ['output'] # outputs +) + +graph_def = helper.make_graph(nodes=[node_def], name='test_types_BLOAT16', + inputs=[input_info], outputs=[output_info]) + +model_def = helper.make_model(graph_def, producer_name='AIInfra', + opset_imports=[make_opsetid('', 13)]) + +final_model = onnx.utils.polish_model(model_def) +onnx.save(final_model, 'test_types_BFLOAT16.onnx') + diff --git a/csharp/testdata/test_input_FLOAT16.py b/csharp/testdata/test_input_FLOAT16.py new file mode 100644 index 0000000000..1de04a1a8a --- /dev/null +++ b/csharp/testdata/test_input_FLOAT16.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import onnx +from onnx import helper +from onnx.helper import make_opsetid +from onnx import AttributeProto, TensorProto, GraphProto + +input_info = helper.make_tensor_value_info('input', TensorProto.FLOAT16, [1, 5]) +output_info = helper.make_tensor_value_info('output', TensorProto.FLOAT16, [1, 5]) + +# Create a node (NodeProto) - This is based on Pad-11 +node_def = helper.make_node( + 'Slice', # node name + ['input'], # inputs + ['output'], # outputs + axes=[0,1], # attributes + ends=[1,5], + starts=[0,0] +) + +graph_def = helper.make_graph(nodes=[node_def], name='test_input_FLOAT16', + inputs=[input_info], outputs=[output_info]) + +model_def = helper.make_model(graph_def, producer_name='AIInfra', + opset_imports=[make_opsetid('', 7)]) + +final_model = onnx.utils.polish_model(model_def) +onnx.save(final_model, 'test_types_FLOAT16.onnx') + diff --git a/csharp/testdata/test_types_BFLOAT16.onnx b/csharp/testdata/test_types_BFLOAT16.onnx new file mode 100644 index 0000000000000000000000000000000000000000..7875cfb503091b5e2238ca25d13d750b58cd95ae GIT binary patch literal 117 zcmd;J7h-qx^vp{uO0-JilH_8|%qu7@5n{_PEdkO>9G)quc_o=8l|n)#sl_GnC6xuK t#qmx){*EDrW>MlW<$PQ`90EcdTudB{K+KvX!38xy2u;R`g^NLe7XU4I8WI2i literal 0 HcmV?d00001 diff --git a/csharp/testdata/test_types_FLOAT16.onnx b/csharp/testdata/test_types_FLOAT16.onnx new file mode 100644 index 0000000000000000000000000000000000000000..c8e73b3fd358509d5abe1aa0691086d435752569 GIT binary patch literal 162 zcmd;J7h-qx^vp{uO0?=@02B}sDoHIai3gh$@8;w07-DD^B@Q#1kBf(c cONfJuiGvY{S(7BVpq2@t$vClaF$k~&0L~;MNdN!< literal 0 HcmV?d00001 diff --git a/csharp/testdata/test_types_FLOAT16.pb b/csharp/testdata/test_types_FLOAT16.pb deleted file mode 100644 index f671bb7ce3253f1caf071fa168560869aaaf32eb..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 167 zcmd;J7Gihw^vp{uO0=5F$mPn#nweKnS|Y@jUs?jBl~{vwGLuuac)3^-D^iOc7#tWE zFtUTVsd*{I4vY@0ATB=_TX9KZQ3*(%0Vp8EnUa~7R#I7zS{x-0Gf|97go8^+fQyNP T5g8{ba6#=7!lA;6Nq`*y^tT}f