// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. // Portions of this code are from System.Half struct dotnet runtime. // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. using System; using System.Diagnostics; using System.Runtime.InteropServices; namespace Microsoft.ML.OnnxRuntime { // Utilities class created to fill in the gaps // of functionality that is absent in BitConverter class in NETSTANDARD 2.0 // as well as some Single precision bit constants. internal class BitOpsUtils { // Lifted from .NET source code internal code // Constants for Single precision format // https://source.dot.net/#System.Private.CoreLib/src/libraries/System.Private.CoreLib/src/System/Single.cs,dda909df0f8d2fd0 internal const uint SingleBiasedExponentMask = 0x7F80_0000; internal const int SingleBiasedExponentShift = 23; internal const uint SingleSignMask = 0x8000_0000; internal const int SingleSignShift = 31; // Most significant significand bit internal const uint SingleMostSignificantSigBit = 0x400000; internal const uint SingleTrailingSignificandMask = 0x007F_FFFF; /// /// Required because BitOperations are not available in NETSTANDARD 2.0. /// There are more efficient ways with bit twiddling, but this one has clarity. /// /// value /// number of leading zeros. Useful to compute log2 as well. internal static int LeadingZeroCount(uint num) { if (num == 0) { return 32; } int count = 0; while ((num & 0xF000_0000) == 0) { count += 4; num <<= 4; } while ((num & 0x8000_0000) == 0) { count += 1; num <<= 1; } return count; } /// /// Extracts single precision number bit representation as uint /// so its bits can be manipulated. /// /// This API is the reverse of UInt32BitsToSingle(). /// /// /// float value /// internal static uint SingleToUInt32Bits(float single) { uint result; unsafe { Buffer.MemoryCopy(&single, &result, sizeof(uint), sizeof(uint)); } return result; } /// /// Needed because BitConverter impl is not available until /// later versions. This API is the reverse of SingleToUInt32Bits(). /// /// For the exact bit representation of float see IEEE 754 standard for single precision. /// /// /// bit representation of float either obtained from /// SingleToUInt32Bits or assembled using bitwise operators /// internal static float UInt32BitsToSingle(uint singleBits) { float result; unsafe { Buffer.MemoryCopy(&singleBits, &result, sizeof(uint), sizeof(uint)); } return result; } /// /// Converts single precision bits representation which can be obtained using /// SingleToUInt32Bits() or manually constructed according to IEEE 754 standard. /// /// /// bits representation of a single precision number (float) /// internal static ushort SingleBitsToBFloat16Bits(uint singleBits) { if (!BitConverter.IsLittleEndian) { return (ushort)(singleBits & 0xFFFF); } else { return (ushort)(singleBits >> 16); } } /// /// Converts bfloat16 ushort bits representation to single precision bits which then in turn can be /// manipulated or converted to float using UInt32BitsToSingle() /// /// ushort bits representation of bfloat16 /// internal static uint BFloat16BitsToSingleBits(ushort bfloatBits) { if (!BitConverter.IsLittleEndian) { return bfloatBits; } else { return (uint)bfloatBits << 16; } } /// /// Creates float NaN with the given sign and fp16 significand shifted << 54 /// /// true for negative /// should be shifted 54 bits left before calling the function /// so only 8 bits of signidicand remains /// internal static float CreateSingleNaN(bool sign, ulong significand) { // We need to set at least on bit in NaN significant const uint NaNBits = SingleBiasedExponentMask | SingleMostSignificantSigBit; uint signInt = (sign ? 1U : 0U) << SingleSignShift; uint sigInt = (uint)(significand >> 41); uint singleBits = signInt | NaNBits | sigInt; return UInt32BitsToSingle(singleBits); } /// /// Creates float from sign, exponent and significand /// /// true if negative /// exponent /// significand /// internal static float CreateSingle(bool sign, byte exponent, uint significand) { uint signInt = (sign ? 1U : 0U) << SingleSignShift; uint expInt = ((uint)exponent << SingleBiasedExponentShift) + significand; uint singleBits = signInt + expInt; return UInt32BitsToSingle(singleBits); } } /// /// This value type represents A Float16 value /// it is blittable as defined in https://docs.microsoft.com/en-us/dotnet/framework/interop/blittable-and-non-blittable-types /// and as such, represented the same way in managed and native memories. This means that arrays of this type /// do not have to be copied to be passed to native memory but simply pinned and read by native code. Thus, /// one can create a Tensor on top of an array of these structures and feed it directly to Onnxruntime library. /// Binary wise, it is the same as ushort[] (uint16_t in C++). However, we would like a separate type for type dispatching. /// /// The implementation is derived from /// https://source.dot.net/#System.Private.CoreLib/src/libraries/System.Private.CoreLib/src/System/Half.cs,7895d5942d33f974 /// [StructLayout(LayoutKind.Sequential)] public readonly struct Float16 : IComparable, IComparable, IEquatable { internal const ushort SignMask = 0x8000; internal const int SignShift = 15; internal const byte ShiftedSignMask = SignMask >> SignShift; internal const ushort BiasedExponentMask = 0x7C00; // 0b0_111_1100_0000_0000; internal const int BiasedExponentShift = 10; internal const byte ShiftedBiasedExponentMask = BiasedExponentMask >> BiasedExponentShift; internal const ushort TrailingSignificandMask = 0x03FF; // 0b0_000_0011_1111_1111; internal const byte MinSign = 0; internal const byte MaxSign = 1; internal const byte MinBiasedExponent = 0x00; internal const byte MaxBiasedExponent = 0x1F; internal const byte ExponentBias = 15; internal const sbyte MinExponent = -14; internal const sbyte MaxExponent = +15; // Constants representing the private bit-representation for various default values private const ushort PositiveZeroBits = 0x0000; private const ushort NegativeZeroBits = 0x8000; private const ushort OneBits = 0x3C00; private const ushort EpsilonBits = 0x0400; private const ushort PositiveInfinityBits = 0x7C00; private const ushort NegativeInfinityBits = 0xFC00; private const ushort PositiveQNaNBits = 0x7E00; private const ushort NegativeQNaNBits = 0xFE00; private const ushort MinValueBits = 0xFBFF; private const ushort MaxValueBits = 0x7BFF; private const ushort PositiveOneBits = 0x3C00; private const ushort NegativeOneBits = 0xBC00; private const ushort EBits = 0x4170; private const ushort PiBits = 0x4248; private const ushort TauBits = 0x4648; // Well-defined and commonly used values /// /// Float16 Epsilon value /// public static Float16 Epsilon => new Float16(EpsilonBits); // 5.9604645E-08 /// /// Float16 Pi value /// public static Float16 Pi => new Float16(PiBits); // 3.14159265358979323846 /// /// Float16 Positive Infinity value /// public static Float16 PositiveInfinity => new Float16(PositiveInfinityBits); // 1.0 / 0.0; /// /// Float16 Negative Infinity value /// public static Float16 NegativeInfinity => new Float16(NegativeInfinityBits); // -1.0 / 0.0 /// /// Float16 NaN /// public static Float16 NaN => new Float16(NegativeQNaNBits); // 0.0 / 0.0 /// /// Float16 Zero value /// public static Float16 Zero => new Float16(PositiveZeroBits); // 0.0 /// /// Float16 One value /// public static Float16 One => new Float16(OneBits); // 1.0 /// /// Float16 Negative Zero value /// public static Float16 NegativeZero => new Float16(NegativeZeroBits); // -0.0 /// /// Float16 Min value /// public static Float16 MinValue => new Float16(MinValueBits); // 64,511 /// /// Float16 Max value /// public static Float16 MaxValue => new Float16(MaxValueBits); // 31,743 /// /// float16 representation bits /// public readonly ushort value; /// /// Ctor from ushort bits, no conversion is done /// /// public Float16(ushort v) { value = v; } private Float16(bool sign, ushort exp, ushort sig) => value = (ushort)(((sign ? 1 : 0) << SignShift) + (exp << BiasedExponentShift) + sig); internal byte BiasedExponent { get { ushort bits = value; return ExtractBiasedExponentFromBits(bits); } } internal sbyte Exponent { get { return (sbyte)(BiasedExponent - ExponentBias); } } internal ushort Significand { get { return (ushort)(TrailingSignificand | ((BiasedExponent != 0) ? (1U << BiasedExponentShift) : 0U)); } } internal ushort TrailingSignificand { get { ushort bits = value; return ExtractTrailingSignificandFromBits(bits); } } internal static byte ExtractBiasedExponentFromBits(ushort bits) { return (byte)((bits >> BiasedExponentShift) & ShiftedBiasedExponentMask); } internal static ushort ExtractTrailingSignificandFromBits(ushort bits) { return (ushort)(bits & TrailingSignificandMask); } /// /// Compares values of two Float16 /// /// /// left hand side /// right hand side /// returns true if left is less than right according to IEEE public static bool operator <(Float16 left, Float16 right) { if (IsNaN(left) || IsNaN(right)) { // IEEE defines that NaN is unordered with respect to everything, including itself. return false; } bool leftIsNegative = IsNegative(left); if (leftIsNegative != IsNegative(right)) { // When the signs of left and right differ, we know that left is less than right if it is // the negative value. The exception to this is if both values are zero, in which case IEEE // says they should be equal, even if the signs differ. return leftIsNegative && !AreZero(left, right); } return (left.value != right.value) && ((left.value < right.value) ^ leftIsNegative); } /// /// Compares values of two Float16 /// /// /// left hand side /// right hand side /// returns true if left is greater than right according to IEEE public static bool operator >(Float16 left, Float16 right) { return right < left; } /// /// Compares values of two Float16 /// /// /// left hand side /// right hand side /// returns true if left is less or equal than right according to IEEE public static bool operator <=(Float16 left, Float16 right) { if (IsNaN(left) || IsNaN(right)) { // IEEE defines that NaN is unordered with respect to everything, including itself. return false; } bool leftIsNegative = IsNegative(left); if (leftIsNegative != IsNegative(right)) { // When the signs of left and right differ, we know that left is less than right if it is // the negative value. The exception to this is if both values are zero, in which case IEEE // says they should be equal, even if the signs differ. return leftIsNegative || AreZero(left, right); } return (left.value == right.value) || ((left.value < right.value) ^ leftIsNegative); } /// /// Compares values of two Float16 /// /// /// left hand side /// right hand side /// returns true if left is greater or equal than right according to IEEE /// public static bool operator >=(Float16 left, Float16 right) { return right <= left; } /// /// Compares values of two Float16 for binary equality. /// If either of the values is NaN, this will return false. /// /// /// left hand side /// right hand side /// true if values are equal according to IEEE public static bool operator ==(Float16 left, Float16 right) { if (IsNaN(left) || IsNaN(right)) { // IEEE defines that NaN is not equal to anything, including itself. return false; } return left.value == right.value; } /// /// Compares values of two Float16 for binary inequality /// /// /// /// true if values are not equal according to IEEE public static bool operator !=(Float16 left, Float16 right) { return !(left == right); } /// /// Determines whether the specified value is finite (zero, subnormal, or normal). /// /// Float16 instance. /// true if the value is finite public static bool IsFinite(Float16 value) { return StripSign(value) < PositiveInfinityBits; } /// /// Determines whether the specified value is infinite. /// /// Float16 instance. /// true if the value is infinite public static bool IsInfinity(Float16 value) { return StripSign(value) == PositiveInfinityBits; } /// /// Determines whether the specified value is NaN. /// /// /// Float16 instance /// true if the value is not a number public static bool IsNaN(Float16 value) { return StripSign(value) > PositiveInfinityBits; } /// /// Determines whether the specified value is negative. /// /// Float16 instance /// true if the value is negative public static bool IsNegative(Float16 value) { return (short)(value.value) < 0; } /// /// Determines whether the specified value is negative infinity. /// /// /// Float16 instance /// true if the value is negative infinity public static bool IsNegativeInfinity(Float16 value) { return value.value == NegativeInfinityBits; } /// /// Determines whether the specified value is normal /// /// /// true or false public static bool IsNormal(Float16 value) { uint absValue = StripSign(value); return (absValue < PositiveInfinityBits) // is finite && (absValue != 0) // is not zero && ((absValue & BiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent) } /// /// Determines whether the specified value is positive infinity. /// /// Float16 instance /// public static bool IsPositiveInfinity(Float16 value) { return value.value == PositiveInfinityBits; } /// /// Determines whether the specified value is subnormal. /// /// Float16 instance /// true if the value is subnormal public static bool IsSubnormal(Float16 value) { uint absValue = StripSign(value); return (absValue < PositiveInfinityBits) // is finite && (absValue != 0) // is not zero && ((absValue & BiasedExponentMask) == 0); // is subnormal (has a zero exponent) } /// /// Compares this object to another object, returning an integer that indicates the relationship. /// /// /// Object to compare to /// A value less than zero if this is less than , /// zero if this is equal to , or a value greater than zero /// if this is greater than . /// /// Thrown when is not of type . public int CompareTo(object obj) { if (!(obj is Float16)) { return (obj is null) ? 1 : throw new ArgumentException("Object must be of type Float16"); } return CompareTo((Float16)(obj)); } /// /// Compares this object to another object, returning an integer that indicates the relationship. /// /// Object to compare to /// A value less than zero if this is less than , /// zero if this is equal to , /// or a value greater than zero if this is greater than . public int CompareTo(Float16 other) { if (this < other) { return -1; } if (this > other) { return 1; } if (this == other) { return 0; } if (IsNaN(this)) { return IsNaN(other) ? 0 : -1; } Debug.Assert(IsNaN(other)); return 1; } /// /// Returns a value indicating whether this instance and other Float16 represent the same value. /// /// A Float16 object to compare to this instance. /// true if other.value is equal to this instance; otherwise, false. public bool Equals(Float16 other) { return value == other.value || AreZero(this, other) || (IsNaN(this) && IsNaN(other)); } /// /// Returns a value indicating whether this instance and a specified System.Object /// represent the same type and value. /// /// An System.Object. /// true if obj is Float16 and its value is equal to this instance; otherwise, false. public override bool Equals(object obj) { return (obj is Float16 other) && Equals(other); } /// /// Returns the hash code for this instance. /// /// A 32-bit signed integer hash code. public override int GetHashCode() { if (IsNaNOrZero(this)) { // All NaNs should have the same hash code, as should both Zeros. return value & PositiveInfinityBits; } return value; } /// /// Returns a string representation of the current value. /// /// Text representation of Float16 public override string ToString() { return $"{value} : {ToFloat()}"; } /// /// Explicit conversion /// /// single precision value converted from Float16 public float ToFloat() { return (float)this; } /// Explicitly converts a value to its nearest representable half-precision floating-point value. /// The value to convert. /// converted to its nearest representable half-precision floating-point value. public static explicit operator Float16(float value) { const int SingleMaxExponent = 0xFF; uint floatInt = BitOpsUtils.SingleToUInt32Bits(value); bool sign = (floatInt & BitOpsUtils.SingleSignMask) >> BitOpsUtils.SingleSignShift != 0; int exp = (int)(floatInt & BitOpsUtils.SingleBiasedExponentMask) >> BitOpsUtils.SingleBiasedExponentShift; uint sig = floatInt & BitOpsUtils.SingleTrailingSignificandMask; if (exp == SingleMaxExponent) { if (sig != 0) // NaN { return CreateFloat16NaN(sign, (ulong)sig << 41); // Shift the significand bits to the left end } return sign ? NegativeInfinity : PositiveInfinity; } uint sigHalf = sig >> 9 | ((sig & 0x1FFU) != 0 ? 1U : 0U); // RightShiftJam if ((exp | (int)sigHalf) == 0) { return new Float16(sign, 0, 0); } return new Float16(RoundPackToFloat16(sign, (short)(exp - 0x71), (ushort)(sigHalf | 0x4000))); } /// Explicitly converts a half-precision floating-point value to its nearest representable value. /// The value to convert. /// converted to its nearest representable value. public static explicit operator float(Float16 value) { bool sign = IsNegative(value); int exp = value.BiasedExponent; uint sig = value.TrailingSignificand; if (exp == MaxBiasedExponent) { if (sig != 0) { // Shift sig left so only 8 bits of it remains return BitOpsUtils.CreateSingleNaN(sign, (ulong)sig << 54); } return sign ? float.NegativeInfinity : float.PositiveInfinity; } if (exp == 0) { if (sig == 0) { // Positive / Negative zero return (sign) ? -0.0f : 0.0f; } (exp, sig) = NormSubnormalF16Sig(sig); exp -= 1; } return BitOpsUtils.CreateSingle(sign, (byte)(exp + 0x70), sig << 13); } /// /// Flips the sign. NaNs are not affected. /// IEEE 754 specifies NaNs to be propagated /// /// /// public static Float16 Negate(Float16 value) { return IsNaN(value) ? value : new Float16((ushort)(value.value ^ SignMask)); } #region Utilities private static bool AreZero(Float16 left, Float16 right) { // IEEE defines that positive and negative zero are equal, this gives us a quick equality check // for two values by or'ing the private bits together and stripping the sign. They are both zero, // and therefore equivalent, if the resulting value is still zero. return (ushort)((left.value | right.value) & ~SignMask) == 0; } /// /// The function returns true if the value is either NaN or zero. /// /// instance of Float16 /// true if NaN or zero. public static bool IsNaNOrZero(Float16 value) { uint abs = StripSign(value); return (abs == 0 || abs > PositiveInfinityBits); } private static uint StripSign(Float16 value) { return (ushort)(value.value & ~SignMask); } private static (int Exp, uint Sig) NormSubnormalF16Sig(uint sig) { int shiftDist = BitOpsUtils.LeadingZeroCount(sig) - 16 - 5; return (1 - shiftDist, sig << shiftDist); } // Significand bits should be shifted towards to the left end before calling these methods // Creates Quiet NaN if significand == 0 private static Float16 CreateFloat16NaN(bool sign, ulong significand) { const ushort NaNBits = BiasedExponentMask | 0x200; // Most significant significand bit uint signInt = (sign ? 1U : 0U) << SignShift; ushort sigInt = (ushort)(significand >> 54); ushort ushortBits = (ushort)(signInt | NaNBits | sigInt); return new Float16(ushortBits); } private static ushort RoundPackToFloat16(bool sign, short exp, ushort sig) { const int RoundIncrement = 0x8; // Depends on rounding mode but it's always towards closest / ties to even int roundBits = sig & 0xF; if ((uint)exp >= 0x1D) { if (exp < 0) { sig = (ushort)ShiftRightJam(sig, -exp); exp = 0; roundBits = sig & 0xF; } else if (exp > 0x1D || sig + RoundIncrement >= 0x8000) // Overflow { return sign ? NegativeInfinityBits : PositiveInfinityBits; } } sig = (ushort)((sig + RoundIncrement) >> 4); sig &= (ushort)~(((roundBits ^ 8) != 0 ? 0 : 1) & 1); if (sig == 0) { exp = 0; } return new Float16(sign, (ushort)exp, sig).value; } // If any bits are lost by shifting, "jam" them into the LSB. // if dist > bit count, Will be 1 or 0 depending on i // (unlike bitwise operators that masks the lower 5 bits) private static uint ShiftRightJam(uint i, int dist) => dist < 31 ? (i >> dist) | (i << (-dist & 31) != 0 ? 1U : 0U) : (i != 0 ? 1U : 0U); private static ulong ShiftRightJam(ulong l, int dist) => dist < 63 ? (l >> dist) | (l << (-dist & 63) != 0 ? 1UL : 0UL) : (l != 0 ? 1UL : 0UL); #endregion } /// /// This value type represents A BFloat16 value. /// See https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus /// for details. /// it is blittable as defined in https://docs.microsoft.com/en-us/dotnet/framework/interop/blittable-and-non-blittable-types /// and as such, represented the same way in managed and native memories. This means that arrays of this type /// do not have to be copied to be passed to native memory but simply pinnned and read by native code. Thus, /// one can create a Tensor on top of an array of these structures and feed it directly to Onnxruntime library. /// Binary wise, it is the same as ushort[] (uint16_t in C++). However, we would like a separate type for type dispatching. /// [StructLayout(LayoutKind.Sequential)] public readonly struct BFloat16 : IComparable, IComparable, IEquatable { internal const ushort SignMask = 0x8000; internal const int SignShift = 15; internal const byte ShiftedSignMask = SignMask >> SignShift; internal const ushort BiasedExponentMask = 0x7F80; // 0b_0111_1111_1000_0000; internal const int BiasedExponentShift = 7; internal const byte ShiftedBiasedExponentMask = BiasedExponentMask >> BiasedExponentShift; internal const ushort TrailingSignificandMask = 0x007F; // 0b_0000_0000_0111_1111; internal const byte MinSign = 0; internal const byte MaxSign = 1; internal const byte MinBiasedExponent = 0x00; internal const byte MaxBiasedExponent = 0xFF; internal const byte ExponentBias = 127; internal const sbyte MinExponent = -126; internal const sbyte MaxExponent = +127; // Constants representing the private bit-representation for various default values private const ushort PositiveZeroBits = 0x0000; private const ushort NegativeZeroBits = 0x8000; private const ushort OneBits = 0x3F80; // 0b0_01111111_0000000 private const ushort PositiveInfinityBits = 0x7F80; private const ushort NegativeInfinityBits = 0xFF80; private const ushort PositiveQNaNBits = 0x7FC1; private const ushort NegativeQNaNBits = 0xFFC1; private const ushort MinValueBits = 0xFF7F; // 1b0_11111110_1111111 private const ushort MaxValueBits = 0x7F7F; // 0b0_11111110_1111111 private const ushort EpsilonBits = 0x0080; // the smallest positive normal value private const ushort PiBits = 0x4049; // 0b0_10000000_1001001 // Used for rounding subnormal values private const uint RoundingBase = 0x7FFF; // Well-defined and commonly used values /// /// BFloat16 Epsilon value /// public static BFloat16 Epsilon => new BFloat16(EpsilonBits); /// /// BFloat16 Pi value /// public static BFloat16 Pi => new BFloat16(PiBits); /// /// BFloat16 Positive infinity value /// public static BFloat16 PositiveInfinity => new BFloat16(PositiveInfinityBits); /// /// BFloat16 Negative infinity value /// public static BFloat16 NegativeInfinity => new BFloat16(NegativeInfinityBits); /// /// BFloat16 NaN /// public static BFloat16 NaN => new BFloat16(NegativeQNaNBits); /// /// BFloat16 Positive Zero /// public static BFloat16 Zero => new BFloat16(PositiveZeroBits); // 0.0 /// /// BFloat16 One /// public static BFloat16 One => new BFloat16(OneBits); // 1.0 /// /// BFloat16 Negative Zero /// public static BFloat16 NegativeZero => new BFloat16(NegativeZeroBits); // -0.0 /// /// BFloat16 Min value /// public static BFloat16 MinValue => new BFloat16(MinValueBits); // 65,407 /// /// BFloat16 Max value /// public static BFloat16 MaxValue => new BFloat16(MaxValueBits); // 32,639 /// /// bfloat16 representation bits /// public readonly ushort value; /// /// Constructor from ushort, no conversion takes place. The value /// is assumed to be converted /// /// bfloat16 representation bits public BFloat16(ushort v) { value = v; } // Extracts biased exponent bits internal byte BiasedExponent { get { ushort bits = value; return ExtractBiasedExponentFromBits(bits); } } // Extracts all the Significand bits internal ushort TrailingSignificand { get { ushort bits = value; return ExtractTrailingSignificandFromBits(bits); } } internal static byte ExtractBiasedExponentFromBits(ushort bits) { return (byte)((bits >> BiasedExponentShift) & ShiftedBiasedExponentMask); } internal static ushort ExtractTrailingSignificandFromBits(ushort bits) { return (ushort)(bits & TrailingSignificandMask); } /// /// Compares two BFloat16 instances. /// /// /// /// true if the left is less than right according to IEEE public static bool operator <(BFloat16 left, BFloat16 right) { if (IsNaN(left) || IsNaN(right)) { // IEEE defines that NaN is unordered with respect to everything, including itself. return false; } bool leftIsNegative = IsNegative(left); if (leftIsNegative != IsNegative(right)) { // When the signs of left and right differ, we know that left is less than right if it is // the negative value. The exception to this is if both values are zero, in which case IEEE // says they should be equal, even if the signs differ. return leftIsNegative && !AreZero(left, right); } return (left.value != right.value) && ((left.value < right.value) ^ leftIsNegative); } /// /// Compares two BFloat16 instances. /// /// /// /// true if the left is greater than right according to IEEE public static bool operator >(BFloat16 left, BFloat16 right) { return right < left; } /// /// Compares two BFloat16 instances. /// /// /// /// true if the left is less or equal than right according to IEEE public static bool operator <=(BFloat16 left, BFloat16 right) { if (IsNaN(left) || IsNaN(right)) { // IEEE defines that NaN is unordered with respect to everything, including itself. return false; } bool leftIsNegative = IsNegative(left); if (leftIsNegative != IsNegative(right)) { // When the signs of left and right differ, we know that left is less than right if it is // the negative value. The exception to this is if both values are zero, in which case IEEE // says they should be equal, even if the signs differ. return leftIsNegative || AreZero(left, right); } return (left.value == right.value) || ((left.value < right.value) ^ leftIsNegative); } /// /// Compares two BFloat16 instances. /// /// /// /// true if the left is greater or equal than right according to IEEE public static bool operator >=(BFloat16 left, BFloat16 right) { return right <= left; } /// /// Compares values of two BFloat16 for binary equality. /// If either of the values is NaN, this will return false. /// /// /// left hand side /// right hand side /// result of value comparisons public static bool operator ==(BFloat16 left, BFloat16 right) { if (IsNaN(left) || IsNaN(right)) { // IEEE defines that NaN is not equal to anything, including itself. return false; } return left.value == right.value; } /// /// Compares values of two BFloat16 for binary inequality /// If either of the values is NaN it would return true. /// /// /// /// result of value comparisons public static bool operator !=(BFloat16 left, BFloat16 right) { return !(left == right); } /// /// Determines whether the specified value is finite (zero, subnormal, or normal). /// /// BFloat16 instance. /// true if the value is finite public static bool IsFinite(BFloat16 value) { return StripSign(value) < PositiveInfinityBits; } /// /// Determines whether the specified value is infinite. /// /// BFloat16 instance. /// true if the value is infinite public static bool IsInfinity(BFloat16 value) { return StripSign(value) == PositiveInfinityBits; } /// /// Determines whether the specified value is NaN. /// /// /// BFloat16 instance /// true if the value is not a number public static bool IsNaN(BFloat16 value) { return StripSign(value) > PositiveInfinityBits; } /// /// Determines whether the specified value is negative. /// /// BFloat16 instance /// true if the value is negative public static bool IsNegative(BFloat16 value) { return (short)(value.value) < 0; } /// /// Determines whether the specified value is negative infinity. /// /// /// BFloat16 instance /// true if the value is negative infinity public static bool IsNegativeInfinity(BFloat16 value) { return value.value == NegativeInfinityBits; } /// /// Determines whether the specified value is normal /// /// /// true or false public static bool IsNormal(BFloat16 value) { uint absValue = StripSign(value); return (absValue < PositiveInfinityBits) // is finite && (absValue != 0) // is not zero && ((absValue & BiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent) } /// /// Determines whether the specified value is positive infinity. /// /// BFloat16 instance /// public static bool IsPositiveInfinity(BFloat16 value) { return value.value == PositiveInfinityBits; } /// /// Determines whether the specified value is subnormal. /// /// BFloat16 instance /// true if the value is subnormal public static bool IsSubnormal(BFloat16 value) { uint absValue = StripSign(value); return (absValue < PositiveInfinityBits) // is finite && (absValue != 0) // is not zero && ((absValue & BiasedExponentMask) == 0); // is subnormal (has a zero exponent) } /// /// Compares this object to another object, returning an integer that indicates the relationship. /// /// /// Object to compare to /// A value less than zero if this is less than , /// zero if this is equal to , or a value greater than zero /// if this is greater than . /// /// Thrown when is not of type . public int CompareTo(object obj) { if (!(obj is BFloat16)) { return (obj is null) ? 1 : throw new ArgumentException("Object must be of type BFloat16"); } return CompareTo((BFloat16)(obj)); } /// /// Compares this object to another object, returning an integer that indicates the relationship. /// /// Object to compare to /// A value less than zero if this is less than , /// zero if this is equal to , /// or a value greater than zero if this is greater than . public int CompareTo(BFloat16 other) { if (this < other) { return -1; } if (this > other) { return 1; } if (this == other) { return 0; } if (IsNaN(this)) { return IsNaN(other) ? 0 : -1; } Debug.Assert(IsNaN(other)); return 1; } /// /// Returns a value indicating whether this instance and other BFloat16 represent the same value. /// /// A BFloat16 object to compare to this instance. /// true if other.value is equal to this instance; otherwise, false. public bool Equals(BFloat16 other) { return value == other.value || AreZero(this, other) || (IsNaN(this) && IsNaN(other)); } /// /// Returns a value indicating whether this instance and a specified System.Object /// represent the same type and value. /// /// An System.Object. /// true if obj is BFloat16 its value is equal to this instance; otherwise, false. public override bool Equals(object obj) { return (obj is BFloat16 other) && Equals(other); } /// /// Returns the hash code for this instance. /// /// A 32-bit signed integer hash code. public override int GetHashCode() { if (IsNaNOrZero(this)) { // All NaNs should have the same hash code, as should both Zeros. return value & PositiveInfinityBits; } return value; } /// /// Returns a string representation of the current value. /// /// Text representation of BFloat16 public override string ToString() { return $"{value} : {ToFloat()}"; } /// /// Explicit conversion /// /// single precision value converted from Float16 public float ToFloat() { return (float)this; } /// Explicitly converts a value to its nearest representable bfloat16 value. /// The value to convert. /// converted to its nearest representable half-precision floating-point value. public static explicit operator BFloat16(float value) { if (float.IsNaN(value)) { return NaN; } uint singleBits = BitOpsUtils.SingleToUInt32Bits(value); ushort bfloatBits = BitOpsUtils.SingleBitsToBFloat16Bits(singleBits); // Round this up. Implement the same logic pytorch uses for rounding. // We use RoundingBase that is 0x7FFF + (1), so we carry the 1 to the next bit. // either the last bfloat bit is 1 or singleBits have some bits set. singleBits += ((uint)bfloatBits & 1) + RoundingBase; bfloatBits = BitOpsUtils.SingleBitsToBFloat16Bits(singleBits); return new BFloat16(bfloatBits); } /// /// Explicitly converts a BFloat16 value to its nearest representable value. /// /// The value to convert. /// converted to its nearest representable value. public static explicit operator float(BFloat16 value) { bool sign = IsNegative(value); int exp = value.BiasedExponent; uint sig = value.TrailingSignificand; if (exp == MaxBiasedExponent) { if (sig != 0) { // Shift sig left 54 bits to get a 64-bit integer // to cut off all but 8 bits of the significant return BitOpsUtils.CreateSingleNaN(sign, (ulong)sig << 56); } return sign ? float.NegativeInfinity : float.PositiveInfinity; } if (exp == 0 && sig == 0) { // Positive / Negative zero return (sign) ? -0.0f : 0.0f; } // All subnormal numbers in BFloat16 would be also subnormal in FP32 because they // share the exponent. uint singleBits = BitOpsUtils.BFloat16BitsToSingleBits(value.value); return BitOpsUtils.UInt32BitsToSingle(singleBits); } /// /// Flips the sign. NaNs are not affected. /// IEEE 754 specifies NaNs to be propagated /// /// /// public static BFloat16 Negate(BFloat16 value) { return IsNaN(value) ? value : new BFloat16((ushort)(value.value ^ SignMask)); } /// /// The function returns true if the value is either NaN or zero. /// /// instance of BFloat16 /// true if NaN or zero. public static bool IsNaNOrZero(BFloat16 value) { uint abs = StripSign(value); return (abs == 0 || abs > PositiveInfinityBits); } #region Utilities private static bool AreZero(BFloat16 left, BFloat16 right) { // IEEE defines that positive and negative zero are equal, this gives us a quick equality check // for two values by or'ing the private bits together and stripping the sign. They are both zero, // and therefore equivalent, if the resulting value is still zero. return (ushort)((left.value | right.value) & ~SignMask) == 0; } private static uint StripSign(BFloat16 value) { return (ushort)(value.value & ~SignMask); } #endregion } }