// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// Portions of this code are from System.Half struct dotnet runtime.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using System.Diagnostics;
using System.Runtime.InteropServices;
namespace Microsoft.ML.OnnxRuntime
{
// Utilities class created to fill in the gaps
// of functionality that is absent in BitConverter class in NETSTANDARD 2.0
// as well as some Single precision bit constants.
internal class BitOpsUtils
{
// Lifted from .NET source code internal code
// Constants for Single precision format
// https://source.dot.net/#System.Private.CoreLib/src/libraries/System.Private.CoreLib/src/System/Single.cs,dda909df0f8d2fd0
internal const uint SingleBiasedExponentMask = 0x7F80_0000;
internal const int SingleBiasedExponentShift = 23;
internal const uint SingleSignMask = 0x8000_0000;
internal const int SingleSignShift = 31;
// Most significant significand bit
internal const uint SingleMostSignificantSigBit = 0x400000;
internal const uint SingleTrailingSignificandMask = 0x007F_FFFF;
///
/// Required because BitOperations are not available in NETSTANDARD 2.0.
/// There are more efficient ways with bit twiddling, but this one has clarity.
///
/// value
/// number of leading zeros. Useful to compute log2 as well.
internal static int LeadingZeroCount(uint num)
{
if (num == 0)
{
return 32;
}
int count = 0;
while ((num & 0xF000_0000) == 0)
{
count += 4;
num <<= 4;
}
while ((num & 0x8000_0000) == 0)
{
count += 1;
num <<= 1;
}
return count;
}
///
/// Extracts single precision number bit representation as uint
/// so its bits can be manipulated.
///
/// This API is the reverse of UInt32BitsToSingle().
///
///
/// float value
///
internal static uint SingleToUInt32Bits(float single)
{
uint result;
unsafe
{
Buffer.MemoryCopy(&single, &result, sizeof(uint), sizeof(uint));
}
return result;
}
///
/// Needed because BitConverter impl is not available until
/// later versions. This API is the reverse of SingleToUInt32Bits().
///
/// For the exact bit representation of float see IEEE 754 standard for single precision.
///
///
/// bit representation of float either obtained from
/// SingleToUInt32Bits or assembled using bitwise operators
///
internal static float UInt32BitsToSingle(uint singleBits)
{
float result;
unsafe
{
Buffer.MemoryCopy(&singleBits, &result, sizeof(uint), sizeof(uint));
}
return result;
}
///
/// Converts single precision bits representation which can be obtained using
/// SingleToUInt32Bits() or manually constructed according to IEEE 754 standard.
///
///
/// bits representation of a single precision number (float)
///
internal static ushort SingleBitsToBFloat16Bits(uint singleBits)
{
if (!BitConverter.IsLittleEndian)
{
return (ushort)(singleBits & 0xFFFF);
}
else
{
return (ushort)(singleBits >> 16);
}
}
///
/// Converts bfloat16 ushort bits representation to single precision bits which then in turn can be
/// manipulated or converted to float using UInt32BitsToSingle()
///
/// ushort bits representation of bfloat16
///
internal static uint BFloat16BitsToSingleBits(ushort bfloatBits)
{
if (!BitConverter.IsLittleEndian)
{
return bfloatBits;
}
else
{
return (uint)bfloatBits << 16;
}
}
///
/// Creates float NaN with the given sign and fp16 significand shifted << 54
///
/// true for negative
/// should be shifted 54 bits left before calling the function
/// so only 8 bits of signidicand remains
///
internal static float CreateSingleNaN(bool sign, ulong significand)
{
// We need to set at least on bit in NaN significant
const uint NaNBits = SingleBiasedExponentMask | SingleMostSignificantSigBit;
uint signInt = (sign ? 1U : 0U) << SingleSignShift;
uint sigInt = (uint)(significand >> 41);
uint singleBits = signInt | NaNBits | sigInt;
return UInt32BitsToSingle(singleBits);
}
///
/// Creates float from sign, exponent and significand
///
/// true if negative
/// exponent
/// significand
///
internal static float CreateSingle(bool sign, byte exponent, uint significand)
{
uint signInt = (sign ? 1U : 0U) << SingleSignShift;
uint expInt = ((uint)exponent << SingleBiasedExponentShift) + significand;
uint singleBits = signInt + expInt;
return UInt32BitsToSingle(singleBits);
}
}
///
/// This value type represents A Float16 value
/// it is blittable as defined in https://docs.microsoft.com/en-us/dotnet/framework/interop/blittable-and-non-blittable-types
/// and as such, represented the same way in managed and native memories. This means that arrays of this type
/// do not have to be copied to be passed to native memory but simply pinned and read by native code. Thus,
/// one can create a Tensor on top of an array of these structures and feed it directly to Onnxruntime library.
/// Binary wise, it is the same as ushort[] (uint16_t in C++). However, we would like a separate type for type dispatching.
///
/// The implementation is derived from
/// https://source.dot.net/#System.Private.CoreLib/src/libraries/System.Private.CoreLib/src/System/Half.cs,7895d5942d33f974
///
[StructLayout(LayoutKind.Sequential)]
public readonly struct Float16 :
IComparable,
IComparable,
IEquatable
{
internal const ushort SignMask = 0x8000;
internal const int SignShift = 15;
internal const byte ShiftedSignMask = SignMask >> SignShift;
internal const ushort BiasedExponentMask = 0x7C00; // 0b0_111_1100_0000_0000;
internal const int BiasedExponentShift = 10;
internal const byte ShiftedBiasedExponentMask = BiasedExponentMask >> BiasedExponentShift;
internal const ushort TrailingSignificandMask = 0x03FF; // 0b0_000_0011_1111_1111;
internal const byte MinSign = 0;
internal const byte MaxSign = 1;
internal const byte MinBiasedExponent = 0x00;
internal const byte MaxBiasedExponent = 0x1F;
internal const byte ExponentBias = 15;
internal const sbyte MinExponent = -14;
internal const sbyte MaxExponent = +15;
// Constants representing the private bit-representation for various default values
private const ushort PositiveZeroBits = 0x0000;
private const ushort NegativeZeroBits = 0x8000;
private const ushort OneBits = 0x3C00;
private const ushort EpsilonBits = 0x0400;
private const ushort PositiveInfinityBits = 0x7C00;
private const ushort NegativeInfinityBits = 0xFC00;
private const ushort PositiveQNaNBits = 0x7E00;
private const ushort NegativeQNaNBits = 0xFE00;
private const ushort MinValueBits = 0xFBFF;
private const ushort MaxValueBits = 0x7BFF;
private const ushort PositiveOneBits = 0x3C00;
private const ushort NegativeOneBits = 0xBC00;
private const ushort EBits = 0x4170;
private const ushort PiBits = 0x4248;
private const ushort TauBits = 0x4648;
// Well-defined and commonly used values
///
/// Float16 Epsilon value
///
public static Float16 Epsilon => new Float16(EpsilonBits); // 5.9604645E-08
///
/// Float16 Pi value
///
public static Float16 Pi => new Float16(PiBits); // 3.14159265358979323846
///
/// Float16 Positive Infinity value
///
public static Float16 PositiveInfinity => new Float16(PositiveInfinityBits); // 1.0 / 0.0;
///
/// Float16 Negative Infinity value
///
public static Float16 NegativeInfinity => new Float16(NegativeInfinityBits); // -1.0 / 0.0
///
/// Float16 NaN
///
public static Float16 NaN => new Float16(NegativeQNaNBits); // 0.0 / 0.0
///
/// Float16 Zero value
///
public static Float16 Zero => new Float16(PositiveZeroBits); // 0.0
///
/// Float16 One value
///
public static Float16 One => new Float16(OneBits); // 1.0
///
/// Float16 Negative Zero value
///
public static Float16 NegativeZero => new Float16(NegativeZeroBits); // -0.0
///
/// Float16 Min value
///
public static Float16 MinValue => new Float16(MinValueBits); // 64,511
///
/// Float16 Max value
///
public static Float16 MaxValue => new Float16(MaxValueBits); // 31,743
///
/// float16 representation bits
///
public readonly ushort value;
///
/// Ctor from ushort bits, no conversion is done
///
///
public Float16(ushort v)
{
value = v;
}
private Float16(bool sign, ushort exp, ushort sig) =>
value = (ushort)(((sign ? 1 : 0) << SignShift) + (exp << BiasedExponentShift) + sig);
internal byte BiasedExponent
{
get
{
ushort bits = value;
return ExtractBiasedExponentFromBits(bits);
}
}
internal sbyte Exponent
{
get
{
return (sbyte)(BiasedExponent - ExponentBias);
}
}
internal ushort Significand
{
get
{
return (ushort)(TrailingSignificand | ((BiasedExponent != 0) ? (1U << BiasedExponentShift) : 0U));
}
}
internal ushort TrailingSignificand
{
get
{
ushort bits = value;
return ExtractTrailingSignificandFromBits(bits);
}
}
internal static byte ExtractBiasedExponentFromBits(ushort bits)
{
return (byte)((bits >> BiasedExponentShift) & ShiftedBiasedExponentMask);
}
internal static ushort ExtractTrailingSignificandFromBits(ushort bits)
{
return (ushort)(bits & TrailingSignificandMask);
}
///
/// Compares values of two Float16
///
///
/// left hand side
/// right hand side
/// returns true if left is less than right according to IEEE
public static bool operator <(Float16 left, Float16 right)
{
if (IsNaN(left) || IsNaN(right))
{
// IEEE defines that NaN is unordered with respect to everything, including itself.
return false;
}
bool leftIsNegative = IsNegative(left);
if (leftIsNegative != IsNegative(right))
{
// When the signs of left and right differ, we know that left is less than right if it is
// the negative value. The exception to this is if both values are zero, in which case IEEE
// says they should be equal, even if the signs differ.
return leftIsNegative && !AreZero(left, right);
}
return (left.value != right.value) && ((left.value < right.value) ^ leftIsNegative);
}
///
/// Compares values of two Float16
///
///
/// left hand side
/// right hand side
/// returns true if left is greater than right according to IEEE
public static bool operator >(Float16 left, Float16 right)
{
return right < left;
}
///
/// Compares values of two Float16
///
///
/// left hand side
/// right hand side
/// returns true if left is less or equal than right according to IEEE
public static bool operator <=(Float16 left, Float16 right)
{
if (IsNaN(left) || IsNaN(right))
{
// IEEE defines that NaN is unordered with respect to everything, including itself.
return false;
}
bool leftIsNegative = IsNegative(left);
if (leftIsNegative != IsNegative(right))
{
// When the signs of left and right differ, we know that left is less than right if it is
// the negative value. The exception to this is if both values are zero, in which case IEEE
// says they should be equal, even if the signs differ.
return leftIsNegative || AreZero(left, right);
}
return (left.value == right.value) || ((left.value < right.value) ^ leftIsNegative);
}
///
/// Compares values of two Float16
///
///
/// left hand side
/// right hand side
/// returns true if left is greater or equal than right according to IEEE
///
public static bool operator >=(Float16 left, Float16 right)
{
return right <= left;
}
///
/// Compares values of two Float16 for binary equality.
/// If either of the values is NaN, this will return false.
///
///
/// left hand side
/// right hand side
/// true if values are equal according to IEEE
public static bool operator ==(Float16 left, Float16 right)
{
if (IsNaN(left) || IsNaN(right))
{
// IEEE defines that NaN is not equal to anything, including itself.
return false;
}
return left.value == right.value;
}
///
/// Compares values of two Float16 for binary inequality
///
///
///
/// true if values are not equal according to IEEE
public static bool operator !=(Float16 left, Float16 right)
{
return !(left == right);
}
///
/// Determines whether the specified value is finite (zero, subnormal, or normal).
///
/// Float16 instance.
/// true if the value is finite
public static bool IsFinite(Float16 value)
{
return StripSign(value) < PositiveInfinityBits;
}
///
/// Determines whether the specified value is infinite.
///
/// Float16 instance.
/// true if the value is infinite
public static bool IsInfinity(Float16 value)
{
return StripSign(value) == PositiveInfinityBits;
}
///
/// Determines whether the specified value is NaN.
///
///
/// Float16 instance
/// true if the value is not a number
public static bool IsNaN(Float16 value)
{
return StripSign(value) > PositiveInfinityBits;
}
///
/// Determines whether the specified value is negative.
///
/// Float16 instance
/// true if the value is negative
public static bool IsNegative(Float16 value)
{
return (short)(value.value) < 0;
}
///
/// Determines whether the specified value is negative infinity.
///
///
/// Float16 instance
/// true if the value is negative infinity
public static bool IsNegativeInfinity(Float16 value)
{
return value.value == NegativeInfinityBits;
}
///
/// Determines whether the specified value is normal
///
///
/// true or false
public static bool IsNormal(Float16 value)
{
uint absValue = StripSign(value);
return (absValue < PositiveInfinityBits) // is finite
&& (absValue != 0) // is not zero
&& ((absValue & BiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
}
///
/// Determines whether the specified value is positive infinity.
///
/// Float16 instance
///
public static bool IsPositiveInfinity(Float16 value)
{
return value.value == PositiveInfinityBits;
}
///
/// Determines whether the specified value is subnormal.
///
/// Float16 instance
/// true if the value is subnormal
public static bool IsSubnormal(Float16 value)
{
uint absValue = StripSign(value);
return (absValue < PositiveInfinityBits) // is finite
&& (absValue != 0) // is not zero
&& ((absValue & BiasedExponentMask) == 0); // is subnormal (has a zero exponent)
}
///
/// Compares this object to another object, returning an integer that indicates the relationship.
///
///
/// Object to compare to
/// A value less than zero if this is less than ,
/// zero if this is equal to , or a value greater than zero
/// if this is greater than .
///
/// Thrown when is not of type .
public int CompareTo(object obj)
{
if (!(obj is Float16))
{
return (obj is null) ? 1 : throw new ArgumentException("Object must be of type Float16");
}
return CompareTo((Float16)(obj));
}
///
/// Compares this object to another object, returning an integer that indicates the relationship.
///
/// Object to compare to
/// A value less than zero if this is less than ,
/// zero if this is equal to ,
/// or a value greater than zero if this is greater than .
public int CompareTo(Float16 other)
{
if (this < other)
{
return -1;
}
if (this > other)
{
return 1;
}
if (this == other)
{
return 0;
}
if (IsNaN(this))
{
return IsNaN(other) ? 0 : -1;
}
Debug.Assert(IsNaN(other));
return 1;
}
///
/// Returns a value indicating whether this instance and other Float16 represent the same value.
///
/// A Float16 object to compare to this instance.
/// true if other.value is equal to this instance; otherwise, false.
public bool Equals(Float16 other)
{
return value == other.value
|| AreZero(this, other)
|| (IsNaN(this) && IsNaN(other));
}
///
/// Returns a value indicating whether this instance and a specified System.Object
/// represent the same type and value.
///
/// An System.Object.
/// true if obj is Float16 and its value is equal to this instance; otherwise, false.
public override bool Equals(object obj)
{
return (obj is Float16 other) && Equals(other);
}
///
/// Returns the hash code for this instance.
///
/// A 32-bit signed integer hash code.
public override int GetHashCode()
{
if (IsNaNOrZero(this))
{
// All NaNs should have the same hash code, as should both Zeros.
return value & PositiveInfinityBits;
}
return value;
}
///
/// Returns a string representation of the current value.
///
/// Text representation of Float16
public override string ToString()
{
return $"{value} : {ToFloat()}";
}
///
/// Explicit conversion
///
/// single precision value converted from Float16
public float ToFloat()
{
return (float)this;
}
/// Explicitly converts a value to its nearest representable half-precision floating-point value.
/// The value to convert.
/// converted to its nearest representable half-precision floating-point value.
public static explicit operator Float16(float value)
{
const int SingleMaxExponent = 0xFF;
uint floatInt = BitOpsUtils.SingleToUInt32Bits(value);
bool sign = (floatInt & BitOpsUtils.SingleSignMask) >> BitOpsUtils.SingleSignShift != 0;
int exp = (int)(floatInt & BitOpsUtils.SingleBiasedExponentMask) >> BitOpsUtils.SingleBiasedExponentShift;
uint sig = floatInt & BitOpsUtils.SingleTrailingSignificandMask;
if (exp == SingleMaxExponent)
{
if (sig != 0) // NaN
{
return CreateFloat16NaN(sign, (ulong)sig << 41); // Shift the significand bits to the left end
}
return sign ? NegativeInfinity : PositiveInfinity;
}
uint sigHalf = sig >> 9 | ((sig & 0x1FFU) != 0 ? 1U : 0U); // RightShiftJam
if ((exp | (int)sigHalf) == 0)
{
return new Float16(sign, 0, 0);
}
return new Float16(RoundPackToFloat16(sign, (short)(exp - 0x71), (ushort)(sigHalf | 0x4000)));
}
/// Explicitly converts a half-precision floating-point value to its nearest representable value.
/// The value to convert.
/// converted to its nearest representable value.
public static explicit operator float(Float16 value)
{
bool sign = IsNegative(value);
int exp = value.BiasedExponent;
uint sig = value.TrailingSignificand;
if (exp == MaxBiasedExponent)
{
if (sig != 0)
{
// Shift sig left so only 8 bits of it remains
return BitOpsUtils.CreateSingleNaN(sign, (ulong)sig << 54);
}
return sign ? float.NegativeInfinity : float.PositiveInfinity;
}
if (exp == 0)
{
if (sig == 0)
{
// Positive / Negative zero
return (sign) ? -0.0f : 0.0f;
}
(exp, sig) = NormSubnormalF16Sig(sig);
exp -= 1;
}
return BitOpsUtils.CreateSingle(sign, (byte)(exp + 0x70), sig << 13);
}
///
/// Flips the sign. NaNs are not affected.
/// IEEE 754 specifies NaNs to be propagated
///
///
///
public static Float16 Negate(Float16 value)
{
return IsNaN(value) ? value : new Float16((ushort)(value.value ^ SignMask));
}
#region Utilities
private static bool AreZero(Float16 left, Float16 right)
{
// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
// for two values by or'ing the private bits together and stripping the sign. They are both zero,
// and therefore equivalent, if the resulting value is still zero.
return (ushort)((left.value | right.value) & ~SignMask) == 0;
}
///
/// The function returns true if the value is either NaN or zero.
///
/// instance of Float16
/// true if NaN or zero.
public static bool IsNaNOrZero(Float16 value)
{
uint abs = StripSign(value);
return (abs == 0 || abs > PositiveInfinityBits);
}
private static uint StripSign(Float16 value)
{
return (ushort)(value.value & ~SignMask);
}
private static (int Exp, uint Sig) NormSubnormalF16Sig(uint sig)
{
int shiftDist = BitOpsUtils.LeadingZeroCount(sig) - 16 - 5;
return (1 - shiftDist, sig << shiftDist);
}
// Significand bits should be shifted towards to the left end before calling these methods
// Creates Quiet NaN if significand == 0
private static Float16 CreateFloat16NaN(bool sign, ulong significand)
{
const ushort NaNBits = BiasedExponentMask | 0x200; // Most significant significand bit
uint signInt = (sign ? 1U : 0U) << SignShift;
ushort sigInt = (ushort)(significand >> 54);
ushort ushortBits = (ushort)(signInt | NaNBits | sigInt);
return new Float16(ushortBits);
}
private static ushort RoundPackToFloat16(bool sign, short exp, ushort sig)
{
const int RoundIncrement = 0x8; // Depends on rounding mode but it's always towards closest / ties to even
int roundBits = sig & 0xF;
if ((uint)exp >= 0x1D)
{
if (exp < 0)
{
sig = (ushort)ShiftRightJam(sig, -exp);
exp = 0;
roundBits = sig & 0xF;
}
else if (exp > 0x1D || sig + RoundIncrement >= 0x8000) // Overflow
{
return sign ? NegativeInfinityBits : PositiveInfinityBits;
}
}
sig = (ushort)((sig + RoundIncrement) >> 4);
sig &= (ushort)~(((roundBits ^ 8) != 0 ? 0 : 1) & 1);
if (sig == 0)
{
exp = 0;
}
return new Float16(sign, (ushort)exp, sig).value;
}
// If any bits are lost by shifting, "jam" them into the LSB.
// if dist > bit count, Will be 1 or 0 depending on i
// (unlike bitwise operators that masks the lower 5 bits)
private static uint ShiftRightJam(uint i, int dist) => dist < 31 ? (i >> dist) | (i << (-dist & 31) != 0 ? 1U : 0U) : (i != 0 ? 1U : 0U);
private static ulong ShiftRightJam(ulong l, int dist) => dist < 63 ? (l >> dist) | (l << (-dist & 63) != 0 ? 1UL : 0UL) : (l != 0 ? 1UL : 0UL);
#endregion
}
///
/// This value type represents A BFloat16 value.
/// See https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus
/// for details.
/// it is blittable as defined in https://docs.microsoft.com/en-us/dotnet/framework/interop/blittable-and-non-blittable-types
/// and as such, represented the same way in managed and native memories. This means that arrays of this type
/// do not have to be copied to be passed to native memory but simply pinnned and read by native code. Thus,
/// one can create a Tensor on top of an array of these structures and feed it directly to Onnxruntime library.
/// Binary wise, it is the same as ushort[] (uint16_t in C++). However, we would like a separate type for type dispatching.
///
[StructLayout(LayoutKind.Sequential)]
public readonly struct BFloat16 :
IComparable,
IComparable,
IEquatable
{
internal const ushort SignMask = 0x8000;
internal const int SignShift = 15;
internal const byte ShiftedSignMask = SignMask >> SignShift;
internal const ushort BiasedExponentMask = 0x7F80; // 0b_0111_1111_1000_0000;
internal const int BiasedExponentShift = 7;
internal const byte ShiftedBiasedExponentMask = BiasedExponentMask >> BiasedExponentShift;
internal const ushort TrailingSignificandMask = 0x007F; // 0b_0000_0000_0111_1111;
internal const byte MinSign = 0;
internal const byte MaxSign = 1;
internal const byte MinBiasedExponent = 0x00;
internal const byte MaxBiasedExponent = 0xFF;
internal const byte ExponentBias = 127;
internal const sbyte MinExponent = -126;
internal const sbyte MaxExponent = +127;
// Constants representing the private bit-representation for various default values
private const ushort PositiveZeroBits = 0x0000;
private const ushort NegativeZeroBits = 0x8000;
private const ushort OneBits = 0x3F80; // 0b0_01111111_0000000
private const ushort PositiveInfinityBits = 0x7F80;
private const ushort NegativeInfinityBits = 0xFF80;
private const ushort PositiveQNaNBits = 0x7FC1;
private const ushort NegativeQNaNBits = 0xFFC1;
private const ushort MinValueBits = 0xFF7F; // 1b0_11111110_1111111
private const ushort MaxValueBits = 0x7F7F; // 0b0_11111110_1111111
private const ushort EpsilonBits = 0x0080; // the smallest positive normal value
private const ushort PiBits = 0x4049; // 0b0_10000000_1001001
// Used for rounding subnormal values
private const uint RoundingBase = 0x7FFF;
// Well-defined and commonly used values
///
/// BFloat16 Epsilon value
///
public static BFloat16 Epsilon => new BFloat16(EpsilonBits);
///
/// BFloat16 Pi value
///
public static BFloat16 Pi => new BFloat16(PiBits);
///
/// BFloat16 Positive infinity value
///
public static BFloat16 PositiveInfinity => new BFloat16(PositiveInfinityBits);
///
/// BFloat16 Negative infinity value
///
public static BFloat16 NegativeInfinity => new BFloat16(NegativeInfinityBits);
///
/// BFloat16 NaN
///
public static BFloat16 NaN => new BFloat16(NegativeQNaNBits);
///
/// BFloat16 Positive Zero
///
public static BFloat16 Zero => new BFloat16(PositiveZeroBits); // 0.0
///
/// BFloat16 One
///
public static BFloat16 One => new BFloat16(OneBits); // 1.0
///
/// BFloat16 Negative Zero
///
public static BFloat16 NegativeZero => new BFloat16(NegativeZeroBits); // -0.0
///
/// BFloat16 Min value
///
public static BFloat16 MinValue => new BFloat16(MinValueBits); // 65,407
///
/// BFloat16 Max value
///
public static BFloat16 MaxValue => new BFloat16(MaxValueBits); // 32,639
///
/// bfloat16 representation bits
///
public readonly ushort value;
///
/// Constructor from ushort, no conversion takes place. The value
/// is assumed to be converted
///
/// bfloat16 representation bits
public BFloat16(ushort v)
{
value = v;
}
// Extracts biased exponent bits
internal byte BiasedExponent
{
get
{
ushort bits = value;
return ExtractBiasedExponentFromBits(bits);
}
}
// Extracts all the Significand bits
internal ushort TrailingSignificand
{
get
{
ushort bits = value;
return ExtractTrailingSignificandFromBits(bits);
}
}
internal static byte ExtractBiasedExponentFromBits(ushort bits)
{
return (byte)((bits >> BiasedExponentShift) & ShiftedBiasedExponentMask);
}
internal static ushort ExtractTrailingSignificandFromBits(ushort bits)
{
return (ushort)(bits & TrailingSignificandMask);
}
///
/// Compares two BFloat16 instances.
///
///
///
/// true if the left is less than right according to IEEE
public static bool operator <(BFloat16 left, BFloat16 right)
{
if (IsNaN(left) || IsNaN(right))
{
// IEEE defines that NaN is unordered with respect to everything, including itself.
return false;
}
bool leftIsNegative = IsNegative(left);
if (leftIsNegative != IsNegative(right))
{
// When the signs of left and right differ, we know that left is less than right if it is
// the negative value. The exception to this is if both values are zero, in which case IEEE
// says they should be equal, even if the signs differ.
return leftIsNegative && !AreZero(left, right);
}
return (left.value != right.value) && ((left.value < right.value) ^ leftIsNegative);
}
///
/// Compares two BFloat16 instances.
///
///
///
/// true if the left is greater than right according to IEEE
public static bool operator >(BFloat16 left, BFloat16 right)
{
return right < left;
}
///
/// Compares two BFloat16 instances.
///
///
///
/// true if the left is less or equal than right according to IEEE
public static bool operator <=(BFloat16 left, BFloat16 right)
{
if (IsNaN(left) || IsNaN(right))
{
// IEEE defines that NaN is unordered with respect to everything, including itself.
return false;
}
bool leftIsNegative = IsNegative(left);
if (leftIsNegative != IsNegative(right))
{
// When the signs of left and right differ, we know that left is less than right if it is
// the negative value. The exception to this is if both values are zero, in which case IEEE
// says they should be equal, even if the signs differ.
return leftIsNegative || AreZero(left, right);
}
return (left.value == right.value) || ((left.value < right.value) ^ leftIsNegative);
}
///
/// Compares two BFloat16 instances.
///
///
///
/// true if the left is greater or equal than right according to IEEE
public static bool operator >=(BFloat16 left, BFloat16 right)
{
return right <= left;
}
///
/// Compares values of two BFloat16 for binary equality.
/// If either of the values is NaN, this will return false.
///
///
/// left hand side
/// right hand side
/// result of value comparisons
public static bool operator ==(BFloat16 left, BFloat16 right)
{
if (IsNaN(left) || IsNaN(right))
{
// IEEE defines that NaN is not equal to anything, including itself.
return false;
}
return left.value == right.value;
}
///
/// Compares values of two BFloat16 for binary inequality
/// If either of the values is NaN it would return true.
///
///
///
/// result of value comparisons
public static bool operator !=(BFloat16 left, BFloat16 right)
{
return !(left == right);
}
///
/// Determines whether the specified value is finite (zero, subnormal, or normal).
///
/// BFloat16 instance.
/// true if the value is finite
public static bool IsFinite(BFloat16 value)
{
return StripSign(value) < PositiveInfinityBits;
}
///
/// Determines whether the specified value is infinite.
///
/// BFloat16 instance.
/// true if the value is infinite
public static bool IsInfinity(BFloat16 value)
{
return StripSign(value) == PositiveInfinityBits;
}
///
/// Determines whether the specified value is NaN.
///
///
/// BFloat16 instance
/// true if the value is not a number
public static bool IsNaN(BFloat16 value)
{
return StripSign(value) > PositiveInfinityBits;
}
///
/// Determines whether the specified value is negative.
///
/// BFloat16 instance
/// true if the value is negative
public static bool IsNegative(BFloat16 value)
{
return (short)(value.value) < 0;
}
///
/// Determines whether the specified value is negative infinity.
///
///
/// BFloat16 instance
/// true if the value is negative infinity
public static bool IsNegativeInfinity(BFloat16 value)
{
return value.value == NegativeInfinityBits;
}
///
/// Determines whether the specified value is normal
///
///
/// true or false
public static bool IsNormal(BFloat16 value)
{
uint absValue = StripSign(value);
return (absValue < PositiveInfinityBits) // is finite
&& (absValue != 0) // is not zero
&& ((absValue & BiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
}
///
/// Determines whether the specified value is positive infinity.
///
/// BFloat16 instance
///
public static bool IsPositiveInfinity(BFloat16 value)
{
return value.value == PositiveInfinityBits;
}
///
/// Determines whether the specified value is subnormal.
///
/// BFloat16 instance
/// true if the value is subnormal
public static bool IsSubnormal(BFloat16 value)
{
uint absValue = StripSign(value);
return (absValue < PositiveInfinityBits) // is finite
&& (absValue != 0) // is not zero
&& ((absValue & BiasedExponentMask) == 0); // is subnormal (has a zero exponent)
}
///
/// Compares this object to another object, returning an integer that indicates the relationship.
///
///
/// Object to compare to
/// A value less than zero if this is less than ,
/// zero if this is equal to , or a value greater than zero
/// if this is greater than .
///
/// Thrown when is not of type .
public int CompareTo(object obj)
{
if (!(obj is BFloat16))
{
return (obj is null) ? 1 : throw new ArgumentException("Object must be of type BFloat16");
}
return CompareTo((BFloat16)(obj));
}
///
/// Compares this object to another object, returning an integer that indicates the relationship.
///
/// Object to compare to
/// A value less than zero if this is less than ,
/// zero if this is equal to ,
/// or a value greater than zero if this is greater than .
public int CompareTo(BFloat16 other)
{
if (this < other)
{
return -1;
}
if (this > other)
{
return 1;
}
if (this == other)
{
return 0;
}
if (IsNaN(this))
{
return IsNaN(other) ? 0 : -1;
}
Debug.Assert(IsNaN(other));
return 1;
}
///
/// Returns a value indicating whether this instance and other BFloat16 represent the same value.
///
/// A BFloat16 object to compare to this instance.
/// true if other.value is equal to this instance; otherwise, false.
public bool Equals(BFloat16 other)
{
return value == other.value
|| AreZero(this, other)
|| (IsNaN(this) && IsNaN(other));
}
///
/// Returns a value indicating whether this instance and a specified System.Object
/// represent the same type and value.
///
/// An System.Object.
/// true if obj is BFloat16 its value is equal to this instance; otherwise, false.
public override bool Equals(object obj)
{
return (obj is BFloat16 other) && Equals(other);
}
///
/// Returns the hash code for this instance.
///
/// A 32-bit signed integer hash code.
public override int GetHashCode()
{
if (IsNaNOrZero(this))
{
// All NaNs should have the same hash code, as should both Zeros.
return value & PositiveInfinityBits;
}
return value;
}
///
/// Returns a string representation of the current value.
///
/// Text representation of BFloat16
public override string ToString()
{
return $"{value} : {ToFloat()}";
}
///
/// Explicit conversion
///
/// single precision value converted from Float16
public float ToFloat()
{
return (float)this;
}
/// Explicitly converts a value to its nearest representable bfloat16 value.
/// The value to convert.
/// converted to its nearest representable half-precision floating-point value.
public static explicit operator BFloat16(float value)
{
if (float.IsNaN(value))
{
return NaN;
}
uint singleBits = BitOpsUtils.SingleToUInt32Bits(value);
ushort bfloatBits = BitOpsUtils.SingleBitsToBFloat16Bits(singleBits);
// Round this up. Implement the same logic pytorch uses for rounding.
// We use RoundingBase that is 0x7FFF + (1), so we carry the 1 to the next bit.
// either the last bfloat bit is 1 or singleBits have some bits set.
singleBits += ((uint)bfloatBits & 1) + RoundingBase;
bfloatBits = BitOpsUtils.SingleBitsToBFloat16Bits(singleBits);
return new BFloat16(bfloatBits);
}
///
/// Explicitly converts a BFloat16 value to its nearest representable value.
///
/// The value to convert.
/// converted to its nearest representable value.
public static explicit operator float(BFloat16 value)
{
bool sign = IsNegative(value);
int exp = value.BiasedExponent;
uint sig = value.TrailingSignificand;
if (exp == MaxBiasedExponent)
{
if (sig != 0)
{
// Shift sig left 54 bits to get a 64-bit integer
// to cut off all but 8 bits of the significant
return BitOpsUtils.CreateSingleNaN(sign, (ulong)sig << 56);
}
return sign ? float.NegativeInfinity : float.PositiveInfinity;
}
if (exp == 0 && sig == 0)
{
// Positive / Negative zero
return (sign) ? -0.0f : 0.0f;
}
// All subnormal numbers in BFloat16 would be also subnormal in FP32 because they
// share the exponent.
uint singleBits = BitOpsUtils.BFloat16BitsToSingleBits(value.value);
return BitOpsUtils.UInt32BitsToSingle(singleBits);
}
///
/// Flips the sign. NaNs are not affected.
/// IEEE 754 specifies NaNs to be propagated
///
///
///
public static BFloat16 Negate(BFloat16 value)
{
return IsNaN(value) ? value : new BFloat16((ushort)(value.value ^ SignMask));
}
///
/// The function returns true if the value is either NaN or zero.
///
/// instance of BFloat16
/// true if NaN or zero.
public static bool IsNaNOrZero(BFloat16 value)
{
uint abs = StripSign(value);
return (abs == 0 || abs > PositiveInfinityBits);
}
#region Utilities
private static bool AreZero(BFloat16 left, BFloat16 right)
{
// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
// for two values by or'ing the private bits together and stripping the sign. They are both zero,
// and therefore equivalent, if the resulting value is still zero.
return (ushort)((left.value | right.value) & ~SignMask) == 0;
}
private static uint StripSign(BFloat16 value)
{
return (ushort)(value.value & ~SignMask);
}
#endregion
}
}