diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.cs b/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.cs
index a4f8832707..03bab8e920 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.cs
@@ -5,7 +5,6 @@ using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Buffers;
using System.Collections.Generic;
-using System.Diagnostics;
namespace Microsoft.ML.OnnxRuntime
{
@@ -55,8 +54,8 @@ namespace Microsoft.ML.OnnxRuntime
/// tensors, sequences of tensors, sequences and maps
/// It extends NamedOnnxValue, exposes the OnnxValueType and Tensor type
/// The class must be disposed of.
- /// It disposes of _ortValueHolder that owns the underlying Ort output value or
- /// anything that the class that implements that interfaces needs to dispose.
+ /// It disposes of _ortValueHolder that owns the underlying Ort output value and
+ /// anything else that would need to be disposed by the instance of the class.
/// Use factory method CreateFromOrtValue to obtain an instance of the class.
///
public class DisposableNamedOnnxValue : NamedOnnxValue, IDisposable
@@ -64,6 +63,19 @@ namespace Microsoft.ML.OnnxRuntime
private IOrtValueOwner _ortValueHolder;
private bool _disposed = false;
+ ///
+ /// Ctor
+ ///
+ /// Name of the output value
+ /// Managed object created to represent output value, such as DenseTensor
+ /// List or Dictionary
+ ///
+ /// Use this to decide what you want to call to fetch data, AsTensor(), AsDictionary()
+ /// or AsEnumerable()
+ /// Tensor element type if value type is a Tensor
+ /// Object that holds native resources.
+ /// Typically, this is an output OrtValue that holds native memory where Tensor is mapped but may also be
+ /// other things that would need to be disposed by this instance depending on how IOrtValueOwner is implemented.
private DisposableNamedOnnxValue(string name, Object value, OnnxValueType onnxValueType, TensorElementType elementType, IOrtValueOwner ortValueHolder)
: base(name, value)
{
@@ -169,6 +181,12 @@ namespace Microsoft.ML.OnnxRuntime
case TensorElementType.Bool:
result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue);
break;
+ case TensorElementType.Float16:
+ result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue);
+ break;
+ case TensorElementType.BFloat16:
+ result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue);
+ break;
default:
throw new NotSupportedException("Tensor of element type: " + elemType + " is not supported");
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.cs
index 29d6938e4a..efb619f345 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.cs
@@ -83,60 +83,16 @@ namespace Microsoft.ML.OnnxRuntime
{
public static void GetTypeAndWidth(TensorElementType elemType, out Type type, out int width)
{
- switch (elemType)
+ TensorElementTypeInfo result = TensorBase.GetElementTypeInfo(elemType);
+ if(result != null)
{
- case TensorElementType.Float:
- type = typeof(float);
- width = sizeof(float);
- break;
- case TensorElementType.Double:
- type = typeof(double);
- width = sizeof(double);
- break;
- case TensorElementType.Int16:
- type = typeof(short);
- width = sizeof(short);
- break;
- case TensorElementType.UInt16:
- type = typeof(ushort);
- width = sizeof(ushort);
- break;
- case TensorElementType.Int32:
- type = typeof(int);
- width = sizeof(int);
- break;
- case TensorElementType.UInt32:
- type = typeof(uint);
- width = sizeof(uint);
- break;
- case TensorElementType.Int64:
- type = typeof(long);
- width = sizeof(long);
- break;
- case TensorElementType.UInt64:
- type = typeof(ulong);
- width = sizeof(ulong);
- break;
- case TensorElementType.UInt8:
- type = typeof(byte);
- width = sizeof(byte);
- break;
- case TensorElementType.Int8:
- type = typeof(sbyte);
- width = sizeof(sbyte);
- break;
- case TensorElementType.String:
- type = typeof(string);
- width = sizeof(byte);
- break;
- case TensorElementType.Bool:
- type = typeof(bool);
- width = sizeof(bool);
- break;
- default:
- type = null;
- width = 0;
- break;
+ type = result.TensorType;
+ width = result.TypeSize;
+ }
+ else
+ {
+ type = null;
+ width = 0;
}
}
}
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.cs
index c483daf0b8..c7ebcf2820 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.cs
@@ -208,6 +208,16 @@ namespace Microsoft.ML.OnnxRuntime
out memHandle, out dataBufferLength,
out shape, out rank);
break;
+ case TensorElementType.Float16:
+ PinAsTensor(value as Tensor, typeSize,
+ out memHandle, out dataBufferLength,
+ out shape, out rank);
+ break;
+ case TensorElementType.BFloat16:
+ PinAsTensor(value as Tensor, typeSize,
+ out memHandle, out dataBufferLength,
+ out shape, out rank);
+ break;
default:
throw new NotSupportedException("Element type: " + elType + " is not of a supported type");
}
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.cs
index 62f4be6832..eb56c10763 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.cs
@@ -11,14 +11,13 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
-using System.Text;
-using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
-using System.Reflection;
+using System.Text;
// Making this assembly's internals visible to the internal Test assembly
[assembly: InternalsVisibleTo("Microsoft.ML.OnnxRuntime.Tests," +
@@ -55,28 +54,168 @@ namespace Microsoft.ML.OnnxRuntime.Tensors
DataTypeMax = 17
}
- [StructLayout(LayoutKind.Sequential)]
- internal struct Float16
+ ///
+ /// This value type represents A Float16 value
+ /// it is blittable as defined in https://docs.microsoft.com/en-us/dotnet/framework/interop/blittable-and-non-blittable-types
+ /// and as such, represented the same way in managed and native memories. This means that arrays of this type
+ /// do not have to be copied to be passed to native memory but simply pinnned and read by native code. Thus,
+ /// one can create a Tensor on top of an array of these structures and feed it directly to Onnxruntime library.
+ /// Binary wise, it is the same as ushort[] (uint16_t in C++). However, we would like a separate type for type dispatching.
+ ///
+ public struct Float16
{
- public ushort Value { get; private set; }
- public Float16(ushort val)
+ public ushort value;
+ ///
+ /// Ctor
+ ///
+ ///
+ public Float16(ushort v)
{
- Value = val;
+ value = v;
}
- }
-
- [StructLayout(LayoutKind.Sequential)]
- internal struct BFloat16
- {
- public ushort Value { get; private set; }
- public BFloat16(ushort val)
+ ///
+ /// Converts to ushort
+ ///
+ /// instance of Float16
+ public static implicit operator ushort (Float16 f) { return f.value; }
+ ///
+ /// Converts a 16-bit unsigned integer to a Float16.
+ ///
+ /// A 16-bit unsigned integer.
+ /// A Float16 that represents the converted 16-bit unsigned integer.
+ public static implicit operator Float16(ushort value) { return new Float16(value); }
+ ///
+ /// Compares values of two Float16 for binary equality
+ ///
+ ///
+ ///
+ /// result of value comparisons
+ public static bool operator ==(Float16 lhs, Float16 rhs) { return lhs.value == rhs.value; }
+ ///
+ /// Compares values of two Float16 for binary inequality
+ ///
+ ///
+ ///
+ /// result of value comparisons
+ public static bool operator !=(Float16 lhs, Float16 rhs) { return lhs.value != rhs.value; }
+ ///
+ /// Returns a value indicating whether this instance and other Float16 represent the same value.
+ ///
+ /// A Float16 object to compare to this instance.
+ /// true if other.value is equal to this instance; otherwise, false.
+ public bool Equals(Float16 other)
{
- Value = val;
+ return (other == this);
+ }
+ ///
+ /// Returns a value indicating whether this instance and a specified System.Object
+ /// represent the same type and value.
+ ///
+ /// An System.Object.
+ /// true if obj is Float16 and its value is equal to this instance; otherwise, false.
+ public override bool Equals(object obj)
+ {
+ bool result = false;
+ if (obj is Float16)
+ {
+ Float16 fl16 = (Float16)obj;
+ result = (fl16 == this);
+ }
+ return result;
+ }
+ ///
+ /// Returns the hash code for this instance.
+ ///
+ /// A 32-bit signed integer hash code.
+ public override int GetHashCode()
+ {
+ return value.GetHashCode();
}
}
///
- /// Helps typecasting. Holds primitive type information
+ /// This value type represents A BFloat16 value
+ /// it is blittable as defined in https://docs.microsoft.com/en-us/dotnet/framework/interop/blittable-and-non-blittable-types
+ /// and as such, represented the same way in managed and native memories. This means that arrays of this type
+ /// do not have to be copied to be passed to native memory but simply pinnned and read by native code. Thus,
+ /// one can create a Tensor on top of an array of these structures and feed it directly to Onnxruntime library.
+ /// Binary wise, it is the same as ushort[] (uint16_t in C++). However, we would like a separate type for type dispatching.
+ ///
+ public struct BFloat16
+ {
+ public ushort value;
+ ///
+ /// Ctor
+ ///
+ ///
+ public BFloat16(ushort v)
+ {
+ value = v;
+ }
+ ///
+ /// Converts to ushort
+ ///
+ /// instance of BFloat16
+ public static implicit operator ushort(BFloat16 bf) { return bf.value; }
+ ///
+ /// Converts a 16-bit unsigned integer to a BFloat16.
+ ///
+ /// A 16-bit unsigned integer.
+ /// A BFloat16 that represents the converted 16-bit unsigned integer.
+ public static implicit operator BFloat16(ushort value) { return new BFloat16(value); }
+ ///
+ /// Compares values of two BFloat16 for binary equality
+ ///
+ ///
+ ///
+ /// result of value comparisons
+ public static bool operator ==(BFloat16 lhs, BFloat16 rhs) { return lhs.value == rhs.value; }
+ ///
+ /// Compares values of two BFloat16 for binary inequality
+ ///
+ ///
+ ///
+ /// result of value comparisons
+ public static bool operator !=(BFloat16 lhs, BFloat16 rhs) { return lhs.value != rhs.value; }
+
+ ///
+ /// Returns a value indicating whether this instance and other BFloat16 represent the same value.
+ ///
+ /// A BFloat16 object to compare to this instance.
+ /// true if other.value is equal to this instance; otherwise, false.
+ public bool Equals(BFloat16 other)
+ {
+ return (other == this);
+ }
+
+ ///
+ /// Returns a value indicating whether this instance and a specified System.Object
+ /// represent the same type and value.
+ ///
+ /// An System.Object.
+ /// true if obj is BFloat16 its value is equal to this instance; otherwise, false.
+ public override bool Equals(object obj)
+ {
+ bool result = false;
+ if (obj is BFloat16)
+ {
+ BFloat16 bfl16 = (BFloat16)obj;
+ result = (bfl16 == this);
+ }
+ return result;
+ }
+ ///
+ /// Returns the hash code for this instance.
+ ///
+ /// A 32-bit signed integer hash code.
+ public override int GetHashCode()
+ {
+ return value.GetHashCode();
+ }
+ }
+
+ ///
+ /// Helps typecasting. Holds Tensor element type traits.
///
public class TensorTypeInfo
{
@@ -90,10 +229,33 @@ namespace Microsoft.ML.OnnxRuntime.Tensors
}
}
+ ///
+ /// Holds TensorElement traits
+ ///
+ public class TensorElementTypeInfo
+ {
+ public Type TensorType { get; private set; }
+ public int TypeSize { get; private set; }
+ public bool IsString { get; private set; }
+ public TensorElementTypeInfo(Type type, int typeSize)
+ {
+ TensorType = type;
+ TypeSize = typeSize;
+ IsString = type == typeof(string);
+ }
+ }
+
+ ///
+ /// This class is a base for all Tensors. It hosts maps with type traits.
+ ///
public class TensorBase
{
- private static readonly Dictionary typeInfoMap =
- new Dictionary()
+ private static readonly Dictionary typeInfoMap;
+
+ private static readonly Dictionary tensorElementTypeInfoMap;
+
+ static TensorBase () {
+ typeInfoMap = new Dictionary()
{
{ typeof(float), new TensorTypeInfo( TensorElementType.Float, sizeof(float)) },
{ typeof(byte), new TensorTypeInfo( TensorElementType.UInt8, sizeof(byte)) },
@@ -111,20 +273,57 @@ namespace Microsoft.ML.OnnxRuntime.Tensors
{ typeof(BFloat16), new TensorTypeInfo( TensorElementType.BFloat16, sizeof(ushort)) }
};
+ tensorElementTypeInfoMap = new Dictionary();
+ foreach(var info in typeInfoMap)
+ {
+ tensorElementTypeInfoMap.Add(info.Value.ElementType, new TensorElementTypeInfo(info.Key, info.Value.TypeSize));
+ }
+ }
+
private readonly Type _primitiveType;
protected TensorBase(Type primitiveType)
{
+ // Should hold as we rely on this to pass arrays of these
+ // types to native code
+ unsafe
+ {
+ Debug.Assert(sizeof(ushort) == sizeof(Float16));
+ Debug.Assert(sizeof(ushort) == sizeof(BFloat16));
+ }
_primitiveType = primitiveType;
}
+
///
- /// Queries the map returns result or null
+ /// Query TensorTypeInfo by one of the supported types
+ ///
+ ///
+ /// TensorTypeInfo or null if not supported
+ public static TensorTypeInfo GetTypeInfo(Type type)
+ {
+ TensorTypeInfo result = null;
+ typeInfoMap.TryGetValue(type, out result);
+ return result;
+ }
+
+ ///
+ /// Query TensorElementTypeInfo by enum
+ ///
+ /// type enum
+ /// instance of TensorElementTypeInfo or null if not found
+ public static TensorElementTypeInfo GetElementTypeInfo(TensorElementType elementType)
+ {
+ TensorElementTypeInfo result = null;
+ tensorElementTypeInfoMap.TryGetValue(elementType, out result);
+ return result;
+ }
+
+ ///
+ /// Query TensorTypeInfo using this Tensor type
///
///
public TensorTypeInfo GetTypeInfo()
{
- TensorTypeInfo result = null;
- typeInfoMap.TryGetValue(_primitiveType, out result);
- return result;
+ return GetTypeInfo(_primitiveType);
}
}
@@ -312,6 +511,14 @@ namespace Microsoft.ML.OnnxRuntime.Tensors
{
return (T)(object)(ushort)(0);
}
+ else if (typeof(T) == typeof(Float16))
+ {
+ return (T)(object)(ushort)(0);
+ }
+ else if (typeof(T) == typeof(BFloat16))
+ {
+ return (T)(object)(ushort)(0);
+ }
throw new NotSupportedException();
}
@@ -372,6 +579,14 @@ namespace Microsoft.ML.OnnxRuntime.Tensors
else if (typeof(T) == typeof(ushort))
{
return (T)(object)(ushort)(1);
+ }
+ else if(typeof(T) == typeof(Float16))
+ {
+ return (T)(object)(ushort)(15360);
+ }
+ else if (typeof(T) == typeof(BFloat16))
+ {
+ return (T)(object)(ushort)(16256);
}
throw new NotSupportedException();
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
index 006f90ce76..c17302bda1 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
@@ -2,6 +2,8 @@
// Licensed under the MIT License.
using Microsoft.ML.OnnxRuntime.Tensors;
+using Microsoft.VisualBasic.CompilerServices;
+using Onnx;
using System;
using System.Collections.Generic;
using System.IO;
@@ -676,15 +678,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests
var skipModels = new Dictionary() {
{ "mxnet_arcface", "Model is an invalid ONNX model"},
{ "tf_inception_v2", "TODO: Debug failing model, skipping for now" },
- { "fp16_inception_v1", "16-bit float not supported type in C#." },
- { "fp16_shufflenet", "16-bit float not supported type in C#." },
- { "fp16_tiny_yolov2", "16-bit float not supported type in C#." },
- { "fp16_coreml_FNS-Candy", "16-bit float not supported type in C#." },
- { "test_mnist", "16-bit float not supported type in C#." },
- { "fp16_test_shufflenet", "16-bit float not supported type in C#." },
- { "fp16_coreml_LinearRegression_NYCTaxi", "16-bit float not supported type in C#." },
- { "test_bidaf", "16-bit float not supported type in C#." },
- { "fp16_test_tiny_yolov2", "16-bit float not supported type in C#." },
+ { "fp16_tiny_yolov2", "Tolerance level for float16 is not known. We now support fp16." },
+ { "test_bidaf", "Does not run in opset9, runs in other opsets. Tensors of type ElementType not currently supported in the LoadTensorFromFile." },
+ { "test_mnist", "Does not run in opset9, runs in other opsets. Tensors of type ElementType not currently supported in the LoadTensorFromFile" },
{ "BERT_Squad", "Could not find an implementation for the node bert / embeddings / one_hot:OneHot(9)" },
{ "mlperf_ssd_mobilenet_300", "Could not find file output_0.pb" },
{ "tf_resnet_v1_50", "result mismatch when Conv BN Fusion is applied" },
@@ -826,7 +822,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
string testDataDirNamePattern = "test_data*";
if (opset == "opset9" && modelName == "LSTM_Seq_lens_unpacked")
{
- testDataDirNamePattern = "seq_lens*"; // discrepency in data directory
+ testDataDirNamePattern = "seq_lens*"; // discrepancy in data directory
}
foreach (var testDataDir in modelDir.EnumerateDirectories(testDataDirNamePattern))
{
@@ -898,6 +894,14 @@ namespace Microsoft.ML.OnnxRuntime.Tests
{
Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer());
}
+ else if (outputMeta.ElementType == typeof(Float16))
+ {
+ Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new Float16Comparer { tolerance = 2 });
+ }
+ else if (outputMeta.ElementType == typeof(BFloat16))
+ {
+ Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new BFloat16Comparer { tolerance = 2 });
+ }
else
{
Assert.True(false, "The TestPretrainedModels does not yet support output of type " + nameof(outputMeta.ElementType));
@@ -1520,20 +1524,47 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
}
- [Fact(Skip = "FLOAT16 not available in C#")]
+ [Fact]
private void TestModelInputFLOAT16()
{
// model takes 1x5 input of fixed type, echoes back
- string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_types_FLOAT16.pb");
+ string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_types_FLOAT16.onnx");
using (var session = new InferenceSession(modelPath))
{
var container = new List();
- var tensorIn = new DenseTensor(new float[] { 1.0f, 2.0f, -3.0f, float.MinValue, float.MaxValue }, new int[] { 1, 5 });
+ var tensorIn = new DenseTensor(
+ new Float16[] { 15360, 16384, 16896, 17408, 17664 }, new int[] { 1, 5 });
var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn);
container.Add(nov);
using (var res = session.Run(container))
{
- var tensorOut = res.First().AsTensor();
+ var valueOut = res.First();
+ Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, valueOut.ValueType);
+ Assert.Equal(Tensors.TensorElementType.Float16, valueOut.ElementType);
+ var tensorOut = res.First().AsTensor();
+ Assert.True(tensorOut.SequenceEqual(tensorIn));
+ }
+ }
+ }
+
+ [Fact]
+ private void TestModelInputBFLOAT16()
+ {
+ // model takes 1x5 input of fixed type, echoes back
+ string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_types_BFLOAT16.onnx");
+ using (var session = new InferenceSession(modelPath))
+ {
+ var container = new List();
+ var tensorIn = new DenseTensor(
+ new BFloat16[] { 16256, 16384, 16448, 16512, 16544 }, new int[] { 1, 5 });
+ var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn);
+ container.Add(nov);
+ using (var res = session.Run(container))
+ {
+ var valueOut = res.First();
+ Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, valueOut.ValueType);
+ Assert.Equal(Tensors.TensorElementType.BFloat16, valueOut.ElementType);
+ var tensorOut = res.First().AsTensor();
Assert.True(tensorOut.SequenceEqual(tensorIn));
}
}
@@ -1580,7 +1611,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
var outNode0 = outputs.ElementAtOrDefault(0);
Assert.Equal("label", outNode0.Name);
Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, outNode0.ValueType);
- Assert.Equal(TensorElementType.Int64, (TensorElementType)outNode0.ElementType);
+ Assert.Equal(Tensors.TensorElementType.Int64, outNode0.ElementType);
// try-cast as a tensor
var outLabelTensor = outNode0.AsTensor();
@@ -2151,86 +2182,21 @@ namespace Microsoft.ML.OnnxRuntime.Tests
return tensorData.ToArray();
}
-
- private enum TensorElementType
+ private static void GetTypeAndWidth(Tensors.TensorElementType elemType, out Type type, out int width)
{
- Float = 1,
- UInt8 = 2,
- Int8 = 3,
- UInt16 = 4,
- Int16 = 5,
- Int32 = 6,
- Int64 = 7,
- String = 8,
- Bool = 9,
- Float16 = 10,
- Double = 11,
- UInt32 = 12,
- UInt64 = 13,
- Complex64 = 14,
- Complex128 = 15,
- BFloat16 = 16,
- DataTypeMax = 17
- }
-
- private static void GetTypeAndWidth(TensorElementType elemType, out Type type, out int width)
- {
- switch (elemType)
+ TensorElementTypeInfo result = TensorBase.GetElementTypeInfo(elemType);
+ if (result != null)
{
- case TensorElementType.Float:
- type = typeof(float);
- width = sizeof(float);
- break;
- case TensorElementType.Double:
- type = typeof(double);
- width = sizeof(double);
- break;
- case TensorElementType.Int16:
- type = typeof(short);
- width = sizeof(short);
- break;
- case TensorElementType.UInt16:
- type = typeof(ushort);
- width = sizeof(ushort);
- break;
- case TensorElementType.Int32:
- type = typeof(int);
- width = sizeof(int);
- break;
- case TensorElementType.UInt32:
- type = typeof(uint);
- width = sizeof(uint);
- break;
- case TensorElementType.Int64:
- type = typeof(long);
- width = sizeof(long);
- break;
- case TensorElementType.UInt64:
- type = typeof(ulong);
- width = sizeof(ulong);
- break;
- case TensorElementType.UInt8:
- type = typeof(byte);
- width = sizeof(byte);
- break;
- case TensorElementType.Int8:
- type = typeof(sbyte);
- width = sizeof(sbyte);
- break;
- case TensorElementType.String:
- type = typeof(byte);
- width = sizeof(byte);
- break;
- case TensorElementType.Bool:
- type = typeof(bool);
- width = sizeof(bool);
- break;
- default:
- type = null;
- width = 0;
- break;
+ type = result.TensorType;
+ width = result.TypeSize;
+ }
+ else
+ {
+ type = null;
+ width = 0;
}
}
+
static NamedOnnxValue LoadTensorFromFilePb(string filename, IReadOnlyDictionary nodeMetaDict)
{
//Set buffer size to 4MB
@@ -2243,7 +2209,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
Type tensorElemType = null;
int width = 0;
- GetTypeAndWidth((TensorElementType)tensor.DataType, out tensorElemType, out width);
+ GetTypeAndWidth((Tensors.TensorElementType)tensor.DataType, out tensorElemType, out width);
var intDims = new int[tensor.Dims.Count];
for (int i = 0; i < tensor.Dims.Count; i++)
{
@@ -2251,7 +2217,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
NodeMetadata nodeMeta = null;
- string nodeName = "";
+ string nodeName = string.Empty;
if (nodeMetaDict.Count == 1)
{
nodeMeta = nodeMetaDict.Values.First();
@@ -2259,7 +2225,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
else if (nodeMetaDict.Count > 1)
{
- if (tensor.Name != "")
+ if (tensor.Name.Length > 0)
{
nodeMeta = nodeMetaDict[tensor.Name];
nodeName = tensor.Name;
@@ -2353,6 +2319,14 @@ namespace Microsoft.ML.OnnxRuntime.Tests
{
return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(bool), intDims);
}
+ else if (nodeMeta.ElementType == typeof(Float16))
+ {
+ return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims);
+ }
+ else if (nodeMeta.ElementType == typeof(BFloat16))
+ {
+ return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims);
+ }
else
{
//TODO: Add support for remaining types
@@ -2363,9 +2337,24 @@ namespace Microsoft.ML.OnnxRuntime.Tests
static NamedOnnxValue CreateNamedOnnxValueFromRawData(string name, byte[] rawData, int elemWidth, int[] dimensions)
{
- T[] floatArr = new T[rawData.Length / elemWidth];
- Buffer.BlockCopy(rawData, 0, floatArr, 0, rawData.Length);
- var dt = new DenseTensor(floatArr, dimensions);
+ T[] typedArr = new T[rawData.Length / elemWidth];
+ var typeOf = typeof(T);
+ if(typeOf == typeof(Float16) || typeOf == typeof(BFloat16))
+ {
+ using (var memSrcHandle = new Memory(rawData).Pin())
+ using (var memDstHandle = new Memory(typedArr).Pin())
+ {
+ unsafe
+ {
+ Buffer.MemoryCopy(memSrcHandle.Pointer, memDstHandle.Pointer, typedArr.Length * elemWidth, rawData.Length);
+ }
+ }
+ }
+ else
+ {
+ Buffer.BlockCopy(rawData, 0, typedArr, 0, rawData.Length);
+ }
+ var dt = new DenseTensor(typedArr, dimensions);
return NamedOnnxValue.CreateFromTensor(name, dt);
}
@@ -2430,7 +2419,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
public int GetHashCode(float x)
{
- return 0;
+ return x.GetHashCode();
}
}
@@ -2442,7 +2431,36 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
public int GetHashCode(T x)
{
- return 0;
+ return x.GetHashCode();
+ }
+ }
+
+ ///
+ /// Use it to compare Float16 and BFloat16
+ ///
+ internal class Float16Comparer : IEqualityComparer
+ {
+ public ushort tolerance;
+ public bool Equals(Float16 x, Float16 y)
+ {
+ return Math.Abs(x - y) <= (tolerance + y);
+ }
+ public int GetHashCode(Float16 x)
+ {
+ return x.GetHashCode();
+ }
+ }
+
+ internal class BFloat16Comparer : IEqualityComparer
+ {
+ public ushort tolerance;
+ public bool Equals(BFloat16 x, BFloat16 y)
+ {
+ return Math.Abs(x - y) <= (tolerance + y);
+ }
+ public int GetHashCode(BFloat16 x)
+ {
+ return x.GetHashCode();
}
}
diff --git a/csharp/testdata/test_input_BFLOAT16.py b/csharp/testdata/test_input_BFLOAT16.py
new file mode 100644
index 0000000000..c9929898cb
--- /dev/null
+++ b/csharp/testdata/test_input_BFLOAT16.py
@@ -0,0 +1,27 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import onnx
+from onnx import helper
+from onnx.helper import make_opsetid
+from onnx import AttributeProto, TensorProto, GraphProto
+
+input_info = helper.make_tensor_value_info('input', TensorProto.BFLOAT16, [1, 5])
+output_info = helper.make_tensor_value_info('output', TensorProto.BFLOAT16, [1, 5])
+
+# Create a node (NodeProto) - This is based on Pad-11
+node_def = helper.make_node(
+ 'Identity', # node name
+ ['input'], # inputs
+ ['output'] # outputs
+)
+
+graph_def = helper.make_graph(nodes=[node_def], name='test_types_BLOAT16',
+ inputs=[input_info], outputs=[output_info])
+
+model_def = helper.make_model(graph_def, producer_name='AIInfra',
+ opset_imports=[make_opsetid('', 13)])
+
+final_model = onnx.utils.polish_model(model_def)
+onnx.save(final_model, 'test_types_BFLOAT16.onnx')
+
diff --git a/csharp/testdata/test_input_FLOAT16.py b/csharp/testdata/test_input_FLOAT16.py
new file mode 100644
index 0000000000..1de04a1a8a
--- /dev/null
+++ b/csharp/testdata/test_input_FLOAT16.py
@@ -0,0 +1,30 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import onnx
+from onnx import helper
+from onnx.helper import make_opsetid
+from onnx import AttributeProto, TensorProto, GraphProto
+
+input_info = helper.make_tensor_value_info('input', TensorProto.FLOAT16, [1, 5])
+output_info = helper.make_tensor_value_info('output', TensorProto.FLOAT16, [1, 5])
+
+# Create a node (NodeProto) - This is based on Pad-11
+node_def = helper.make_node(
+ 'Slice', # node name
+ ['input'], # inputs
+ ['output'], # outputs
+ axes=[0,1], # attributes
+ ends=[1,5],
+ starts=[0,0]
+)
+
+graph_def = helper.make_graph(nodes=[node_def], name='test_input_FLOAT16',
+ inputs=[input_info], outputs=[output_info])
+
+model_def = helper.make_model(graph_def, producer_name='AIInfra',
+ opset_imports=[make_opsetid('', 7)])
+
+final_model = onnx.utils.polish_model(model_def)
+onnx.save(final_model, 'test_types_FLOAT16.onnx')
+
diff --git a/csharp/testdata/test_types_BFLOAT16.onnx b/csharp/testdata/test_types_BFLOAT16.onnx
new file mode 100644
index 0000000000..7875cfb503
Binary files /dev/null and b/csharp/testdata/test_types_BFLOAT16.onnx differ
diff --git a/csharp/testdata/test_types_FLOAT16.onnx b/csharp/testdata/test_types_FLOAT16.onnx
new file mode 100644
index 0000000000..c8e73b3fd3
Binary files /dev/null and b/csharp/testdata/test_types_FLOAT16.onnx differ
diff --git a/csharp/testdata/test_types_FLOAT16.pb b/csharp/testdata/test_types_FLOAT16.pb
deleted file mode 100644
index f671bb7ce3..0000000000
Binary files a/csharp/testdata/test_types_FLOAT16.pb and /dev/null differ