mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
### Description Introduce `Float16/BFloat16` support for C# and C++ APIs. User should be able to perform conversions from `float` to/from `Float16/BFloat16`, compare values and tests for `NaN, Inifnity, and whether the number is denormalized.` ### Motivation and Context User filed issues such as: https://github.com/microsoft/onnxruntime/issues/14303
1372 lines
No EOL
50 KiB
C#
1372 lines
No EOL
50 KiB
C#
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
// Portions of this code are from System.Half struct dotnet runtime.
|
|
// Licensed to the .NET Foundation under one or more agreements.
|
|
// The .NET Foundation licenses this file to you under the MIT license.
|
|
|
|
using System;
|
|
using System.Diagnostics;
|
|
using System.Runtime.InteropServices;
|
|
|
|
namespace Microsoft.ML.OnnxRuntime
|
|
{
|
|
// Utilities class created to fill in the gaps
|
|
// of functionality that is absent in BitConverter class in NETSTANDARD 2.0
|
|
// as well as some Single precision bit constants.
|
|
internal class BitOpsUtils
|
|
{
|
|
// Lifted from .NET source code internal code
|
|
// Constants for Single precision format
|
|
// https://source.dot.net/#System.Private.CoreLib/src/libraries/System.Private.CoreLib/src/System/Single.cs,dda909df0f8d2fd0
|
|
internal const uint SingleBiasedExponentMask = 0x7F80_0000;
|
|
internal const int SingleBiasedExponentShift = 23;
|
|
|
|
internal const uint SingleSignMask = 0x8000_0000;
|
|
internal const int SingleSignShift = 31;
|
|
|
|
// Most significant significand bit
|
|
internal const uint SingleMostSignificantSigBit = 0x400000;
|
|
internal const uint SingleTrailingSignificandMask = 0x007F_FFFF;
|
|
|
|
/// <summary>
|
|
/// Required because BitOperations are not available in NETSTANDARD 2.0.
|
|
/// There are more efficient ways with bit twiddling, but this one has clarity.
|
|
/// </summary>
|
|
/// <param name="num">value</param>
|
|
/// <returns>number of leading zeros. Useful to compute log2 as well.</returns>
|
|
internal static int LeadingZeroCount(uint num)
|
|
{
|
|
if (num == 0)
|
|
{
|
|
return 32;
|
|
}
|
|
|
|
int count = 0;
|
|
while ((num & 0xF000_0000) == 0)
|
|
{
|
|
count += 4;
|
|
num <<= 4;
|
|
}
|
|
|
|
while ((num & 0x8000_0000) == 0)
|
|
{
|
|
count += 1;
|
|
num <<= 1;
|
|
}
|
|
return count;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Extracts single precision number bit representation as uint
|
|
/// so its bits can be manipulated.
|
|
///
|
|
/// This API is the reverse of UInt32BitsToSingle().
|
|
///
|
|
/// </summary>
|
|
/// <param name="single">float value</param>
|
|
/// <returns></returns>
|
|
internal static uint SingleToUInt32Bits(float single)
|
|
{
|
|
uint result;
|
|
unsafe
|
|
{
|
|
Buffer.MemoryCopy(&single, &result, sizeof(uint), sizeof(uint));
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Needed because BitConverter impl is not available until
|
|
/// later versions. This API is the reverse of SingleToUInt32Bits().
|
|
///
|
|
/// For the exact bit representation of float see IEEE 754 standard for single precision.
|
|
///
|
|
/// </summary>
|
|
/// <param name="singleBits">bit representation of float either obtained from
|
|
/// SingleToUInt32Bits or assembled using bitwise operators</param>
|
|
/// <returns></returns>
|
|
internal static float UInt32BitsToSingle(uint singleBits)
|
|
{
|
|
float result;
|
|
unsafe
|
|
{
|
|
Buffer.MemoryCopy(&singleBits, &result, sizeof(uint), sizeof(uint));
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Converts single precision bits representation which can be obtained using
|
|
/// SingleToUInt32Bits() or manually constructed according to IEEE 754 standard.
|
|
///
|
|
/// </summary>
|
|
/// <param name="singleBits">bits representation of a single precision number (float)</param>
|
|
/// <returns></returns>
|
|
internal static ushort SingleBitsToBFloat16Bits(uint singleBits)
|
|
{
|
|
if (!BitConverter.IsLittleEndian)
|
|
{
|
|
return (ushort)(singleBits & 0xFFFF);
|
|
}
|
|
else
|
|
{
|
|
return (ushort)(singleBits >> 16);
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// Converts bfloat16 ushort bits representation to single precision bits which then in turn can be
|
|
/// manipulated or converted to float using UInt32BitsToSingle()
|
|
/// </summary>
|
|
/// <param name="bfloatBits">ushort bits representation of bfloat16</param>
|
|
/// <returns></returns>
|
|
internal static uint BFloat16BitsToSingleBits(ushort bfloatBits)
|
|
{
|
|
if (!BitConverter.IsLittleEndian)
|
|
{
|
|
return bfloatBits;
|
|
}
|
|
else
|
|
{
|
|
return (uint)bfloatBits << 16;
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// Creates float NaN with the given sign and fp16 significand shifted << 54
|
|
/// </summary>
|
|
/// <param name="sign">true for negative</param>
|
|
/// <param name="significand">should be shifted 54 bits left before calling the function
|
|
/// so only 8 bits of signidicand remains</param>
|
|
/// <returns></returns>
|
|
internal static float CreateSingleNaN(bool sign, ulong significand)
|
|
{
|
|
// We need to set at least on bit in NaN significant
|
|
const uint NaNBits = SingleBiasedExponentMask | SingleMostSignificantSigBit;
|
|
|
|
uint signInt = (sign ? 1U : 0U) << SingleSignShift;
|
|
uint sigInt = (uint)(significand >> 41);
|
|
uint singleBits = signInt | NaNBits | sigInt;
|
|
|
|
return UInt32BitsToSingle(singleBits);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Creates float from sign, exponent and significand
|
|
/// </summary>
|
|
/// <param name="sign">true if negative</param>
|
|
/// <param name="exponent">exponent</param>
|
|
/// <param name="significand">significand</param>
|
|
/// <returns></returns>
|
|
internal static float CreateSingle(bool sign, byte exponent, uint significand)
|
|
{
|
|
uint signInt = (sign ? 1U : 0U) << SingleSignShift;
|
|
uint expInt = ((uint)exponent << SingleBiasedExponentShift) + significand;
|
|
uint singleBits = signInt + expInt;
|
|
|
|
return UInt32BitsToSingle(singleBits);
|
|
}
|
|
}
|
|
|
|
|
|
/// <summary>
|
|
/// This value type represents A Float16 value
|
|
/// it is blittable as defined in https://docs.microsoft.com/en-us/dotnet/framework/interop/blittable-and-non-blittable-types
|
|
/// and as such, represented the same way in managed and native memories. This means that arrays of this type
|
|
/// do not have to be copied to be passed to native memory but simply pinned and read by native code. Thus,
|
|
/// one can create a Tensor on top of an array of these structures and feed it directly to Onnxruntime library.
|
|
/// Binary wise, it is the same as ushort[] (uint16_t in C++). However, we would like a separate type for type dispatching.
|
|
///
|
|
/// The implementation is derived from
|
|
/// https://source.dot.net/#System.Private.CoreLib/src/libraries/System.Private.CoreLib/src/System/Half.cs,7895d5942d33f974
|
|
/// </summary>
|
|
[StructLayout(LayoutKind.Sequential)]
|
|
public readonly struct Float16 :
|
|
IComparable,
|
|
IComparable<Float16>,
|
|
IEquatable<Float16>
|
|
{
|
|
internal const ushort SignMask = 0x8000;
|
|
internal const int SignShift = 15;
|
|
internal const byte ShiftedSignMask = SignMask >> SignShift;
|
|
|
|
internal const ushort BiasedExponentMask = 0x7C00; // 0b0_111_1100_0000_0000;
|
|
internal const int BiasedExponentShift = 10;
|
|
internal const byte ShiftedBiasedExponentMask = BiasedExponentMask >> BiasedExponentShift;
|
|
|
|
internal const ushort TrailingSignificandMask = 0x03FF; // 0b0_000_0011_1111_1111;
|
|
|
|
internal const byte MinSign = 0;
|
|
internal const byte MaxSign = 1;
|
|
|
|
internal const byte MinBiasedExponent = 0x00;
|
|
internal const byte MaxBiasedExponent = 0x1F;
|
|
|
|
internal const byte ExponentBias = 15;
|
|
|
|
internal const sbyte MinExponent = -14;
|
|
internal const sbyte MaxExponent = +15;
|
|
|
|
// Constants representing the private bit-representation for various default values
|
|
|
|
private const ushort PositiveZeroBits = 0x0000;
|
|
private const ushort NegativeZeroBits = 0x8000;
|
|
|
|
private const ushort OneBits = 0x3C00;
|
|
|
|
private const ushort EpsilonBits = 0x0400;
|
|
|
|
private const ushort PositiveInfinityBits = 0x7C00;
|
|
private const ushort NegativeInfinityBits = 0xFC00;
|
|
|
|
private const ushort PositiveQNaNBits = 0x7E00;
|
|
private const ushort NegativeQNaNBits = 0xFE00;
|
|
|
|
private const ushort MinValueBits = 0xFBFF;
|
|
private const ushort MaxValueBits = 0x7BFF;
|
|
|
|
private const ushort PositiveOneBits = 0x3C00;
|
|
private const ushort NegativeOneBits = 0xBC00;
|
|
|
|
private const ushort EBits = 0x4170;
|
|
private const ushort PiBits = 0x4248;
|
|
private const ushort TauBits = 0x4648;
|
|
|
|
// Well-defined and commonly used values
|
|
|
|
/// <summary>
|
|
/// Float16 Epsilon value
|
|
/// </summary>
|
|
public static Float16 Epsilon => new Float16(EpsilonBits); // 5.9604645E-08
|
|
|
|
/// <summary>
|
|
/// Float16 Pi value
|
|
/// </summary>
|
|
public static Float16 Pi => new Float16(PiBits); // 3.14159265358979323846
|
|
|
|
/// <summary>
|
|
/// Float16 Positive Infinity value
|
|
/// </summary>
|
|
public static Float16 PositiveInfinity => new Float16(PositiveInfinityBits); // 1.0 / 0.0;
|
|
|
|
/// <summary>
|
|
/// Float16 Negative Infinity value
|
|
/// </summary>
|
|
public static Float16 NegativeInfinity => new Float16(NegativeInfinityBits); // -1.0 / 0.0
|
|
|
|
/// <summary>
|
|
/// Float16 NaN
|
|
/// </summary>
|
|
public static Float16 NaN => new Float16(NegativeQNaNBits); // 0.0 / 0.0
|
|
|
|
/// <summary>
|
|
/// Float16 Zero value
|
|
/// </summary>
|
|
public static Float16 Zero => new Float16(PositiveZeroBits); // 0.0
|
|
|
|
/// <summary>
|
|
/// Float16 One value
|
|
/// </summary>
|
|
public static Float16 One => new Float16(OneBits); // 1.0
|
|
|
|
/// <summary>
|
|
/// Float16 Negative Zero value
|
|
/// </summary>
|
|
public static Float16 NegativeZero => new Float16(NegativeZeroBits); // -0.0
|
|
|
|
/// <summary>
|
|
/// Float16 Min value
|
|
/// </summary>
|
|
public static Float16 MinValue => new Float16(MinValueBits); // 64,511
|
|
|
|
/// <summary>
|
|
/// Float16 Max value
|
|
/// </summary>
|
|
public static Float16 MaxValue => new Float16(MaxValueBits); // 31,743
|
|
|
|
/// <summary>
|
|
/// float16 representation bits
|
|
/// </summary>
|
|
public readonly ushort value;
|
|
|
|
/// <summary>
|
|
/// Ctor from ushort bits, no conversion is done
|
|
/// </summary>
|
|
/// <param name="v"></param>
|
|
public Float16(ushort v)
|
|
{
|
|
value = v;
|
|
}
|
|
|
|
private Float16(bool sign, ushort exp, ushort sig) =>
|
|
value = (ushort)(((sign ? 1 : 0) << SignShift) + (exp << BiasedExponentShift) + sig);
|
|
|
|
internal byte BiasedExponent
|
|
{
|
|
get
|
|
{
|
|
ushort bits = value;
|
|
return ExtractBiasedExponentFromBits(bits);
|
|
}
|
|
}
|
|
|
|
internal sbyte Exponent
|
|
{
|
|
get
|
|
{
|
|
return (sbyte)(BiasedExponent - ExponentBias);
|
|
}
|
|
}
|
|
|
|
internal ushort Significand
|
|
{
|
|
get
|
|
{
|
|
return (ushort)(TrailingSignificand | ((BiasedExponent != 0) ? (1U << BiasedExponentShift) : 0U));
|
|
}
|
|
}
|
|
|
|
internal ushort TrailingSignificand
|
|
{
|
|
get
|
|
{
|
|
ushort bits = value;
|
|
return ExtractTrailingSignificandFromBits(bits);
|
|
}
|
|
}
|
|
|
|
internal static byte ExtractBiasedExponentFromBits(ushort bits)
|
|
{
|
|
return (byte)((bits >> BiasedExponentShift) & ShiftedBiasedExponentMask);
|
|
}
|
|
|
|
internal static ushort ExtractTrailingSignificandFromBits(ushort bits)
|
|
{
|
|
return (ushort)(bits & TrailingSignificandMask);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Compares values of two Float16
|
|
///
|
|
/// </summary>
|
|
/// <param name="left">left hand side</param>
|
|
/// <param name="right">right hand side</param>
|
|
/// <returns>returns true if left is less than right according to IEEE</returns>
|
|
public static bool operator <(Float16 left, Float16 right)
|
|
{
|
|
if (IsNaN(left) || IsNaN(right))
|
|
{
|
|
// IEEE defines that NaN is unordered with respect to everything, including itself.
|
|
return false;
|
|
}
|
|
|
|
bool leftIsNegative = IsNegative(left);
|
|
|
|
if (leftIsNegative != IsNegative(right))
|
|
{
|
|
// When the signs of left and right differ, we know that left is less than right if it is
|
|
// the negative value. The exception to this is if both values are zero, in which case IEEE
|
|
// says they should be equal, even if the signs differ.
|
|
return leftIsNegative && !AreZero(left, right);
|
|
}
|
|
|
|
return (left.value != right.value) && ((left.value < right.value) ^ leftIsNegative);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Compares values of two Float16
|
|
///
|
|
/// </summary>
|
|
/// <param name="left">left hand side</param>
|
|
/// <param name="right">right hand side</param>
|
|
/// <returns>returns true if left is greater than right according to IEEE</returns>
|
|
public static bool operator >(Float16 left, Float16 right)
|
|
{
|
|
return right < left;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Compares values of two Float16
|
|
///
|
|
/// </summary>
|
|
/// <param name="left">left hand side</param>
|
|
/// <param name="right">right hand side</param>
|
|
/// <returns>returns true if left is less or equal than right according to IEEE</returns>
|
|
public static bool operator <=(Float16 left, Float16 right)
|
|
{
|
|
if (IsNaN(left) || IsNaN(right))
|
|
{
|
|
// IEEE defines that NaN is unordered with respect to everything, including itself.
|
|
return false;
|
|
}
|
|
|
|
bool leftIsNegative = IsNegative(left);
|
|
|
|
if (leftIsNegative != IsNegative(right))
|
|
{
|
|
// When the signs of left and right differ, we know that left is less than right if it is
|
|
// the negative value. The exception to this is if both values are zero, in which case IEEE
|
|
// says they should be equal, even if the signs differ.
|
|
return leftIsNegative || AreZero(left, right);
|
|
}
|
|
|
|
return (left.value == right.value) || ((left.value < right.value) ^ leftIsNegative);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Compares values of two Float16
|
|
///
|
|
/// </summary>
|
|
/// <param name="left">left hand side</param>
|
|
/// <param name="right">right hand side</param>
|
|
/// <returns>returns true if left is greater or equal than right according to IEEE</returns>
|
|
/// <inheritdoc cref="IComparisonOperators{TSelf, TOther, TResult}.op_GreaterThanOrEqual(TSelf, TOther)" />
|
|
public static bool operator >=(Float16 left, Float16 right)
|
|
{
|
|
return right <= left;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Compares values of two Float16 for binary equality.
|
|
/// If either of the values is NaN, this will return false.
|
|
///
|
|
/// </summary>
|
|
/// <param name="left">left hand side</param>
|
|
/// <param name="right">right hand side</param>
|
|
/// <returns>true if values are equal according to IEEE</returns>
|
|
public static bool operator ==(Float16 left, Float16 right)
|
|
{
|
|
if (IsNaN(left) || IsNaN(right))
|
|
{
|
|
// IEEE defines that NaN is not equal to anything, including itself.
|
|
return false;
|
|
}
|
|
|
|
return left.value == right.value;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Compares values of two Float16 for binary inequality
|
|
/// </summary>
|
|
/// <param name="left"></param>
|
|
/// <param name="right"></param>
|
|
/// <returns>true if values are not equal according to IEEE</returns>
|
|
public static bool operator !=(Float16 left, Float16 right)
|
|
{
|
|
return !(left == right);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Determines whether the specified value is finite (zero, subnormal, or normal).
|
|
/// </summary>
|
|
/// <param name="value">Float16 instance.</param>
|
|
/// <returns>true if the value is finite</returns>
|
|
public static bool IsFinite(Float16 value)
|
|
{
|
|
return StripSign(value) < PositiveInfinityBits;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Determines whether the specified value is infinite.
|
|
/// </summary>
|
|
/// <param name="value">Float16 instance.</param>
|
|
/// <returns>true if the value is infinite</returns>
|
|
public static bool IsInfinity(Float16 value)
|
|
{
|
|
return StripSign(value) == PositiveInfinityBits;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Determines whether the specified value is NaN.
|
|
/// </summary>
|
|
///
|
|
/// <param name="value">Float16 instance</param>
|
|
/// <returns>true if the value is not a number</returns>
|
|
public static bool IsNaN(Float16 value)
|
|
{
|
|
return StripSign(value) > PositiveInfinityBits;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Determines whether the specified value is negative.
|
|
/// </summary>
|
|
/// <param name="value">Float16 instance</param>
|
|
/// <returns>true if the value is negative</returns></returns>
|
|
public static bool IsNegative(Float16 value)
|
|
{
|
|
return (short)(value.value) < 0;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Determines whether the specified value is negative infinity.
|
|
/// </summary>
|
|
///
|
|
/// <param name="value">Float16 instance</param>
|
|
/// <returns>true if the value is negative infinity</returns>
|
|
public static bool IsNegativeInfinity(Float16 value)
|
|
{
|
|
return value.value == NegativeInfinityBits;
|
|
}
|
|
|
|
|
|
/// <summary>
|
|
/// Determines whether the specified value is normal
|
|
/// </summary>
|
|
/// <param name="value"></param>
|
|
/// <returns>true or false</returns>
|
|
public static bool IsNormal(Float16 value)
|
|
{
|
|
uint absValue = StripSign(value);
|
|
return (absValue < PositiveInfinityBits) // is finite
|
|
&& (absValue != 0) // is not zero
|
|
&& ((absValue & BiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
|
|
}
|
|
|
|
|
|
/// <summary>
|
|
/// Determines whether the specified value is positive infinity.
|
|
/// </summary>
|
|
/// <param name="value">Float16 instance</param>
|
|
/// <returns></returns>
|
|
public static bool IsPositiveInfinity(Float16 value)
|
|
{
|
|
return value.value == PositiveInfinityBits;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Determines whether the specified value is subnormal.
|
|
/// </summary>
|
|
/// <param name="value">Float16 instance</param>
|
|
/// <returns>true if the value is subnormal</returns>
|
|
public static bool IsSubnormal(Float16 value)
|
|
{
|
|
uint absValue = StripSign(value);
|
|
return (absValue < PositiveInfinityBits) // is finite
|
|
&& (absValue != 0) // is not zero
|
|
&& ((absValue & BiasedExponentMask) == 0); // is subnormal (has a zero exponent)
|
|
}
|
|
|
|
/// <summary>
|
|
/// Compares this object to another object, returning an integer that indicates the relationship.
|
|
/// </summary>
|
|
///
|
|
/// <param name="obj">Object to compare to</param>
|
|
/// <returns>A value less than zero if this is less than <paramref name="obj"/>,
|
|
/// zero if this is equal to <paramref name="obj"/>, or a value greater than zero
|
|
/// if this is greater than <paramref name="obj"/>.
|
|
/// </returns>
|
|
/// <exception cref="ArgumentException">Thrown when <paramref name="obj"/> is not of type <see cref="Float16"/>.</exception>
|
|
public int CompareTo(object obj)
|
|
{
|
|
if (!(obj is Float16))
|
|
{
|
|
return (obj is null) ? 1 : throw new ArgumentException("Object must be of type Float16");
|
|
}
|
|
return CompareTo((Float16)(obj));
|
|
}
|
|
|
|
/// <summary>
|
|
/// Compares this object to another object, returning an integer that indicates the relationship.
|
|
/// </summary>
|
|
/// <param name="other">Object to compare to</param>
|
|
/// <returns>A value less than zero if this is less than <paramref name="other"/>,
|
|
/// zero if this is equal to <paramref name="other"/>,
|
|
/// or a value greater than zero if this is greater than <paramref name="other"/>.</returns>
|
|
public int CompareTo(Float16 other)
|
|
{
|
|
if (this < other)
|
|
{
|
|
return -1;
|
|
}
|
|
|
|
if (this > other)
|
|
{
|
|
return 1;
|
|
}
|
|
|
|
if (this == other)
|
|
{
|
|
return 0;
|
|
}
|
|
|
|
if (IsNaN(this))
|
|
{
|
|
return IsNaN(other) ? 0 : -1;
|
|
}
|
|
|
|
Debug.Assert(IsNaN(other));
|
|
return 1;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Returns a value indicating whether this instance and other Float16 represent the same value.
|
|
/// </summary>
|
|
/// <param name="other">A Float16 object to compare to this instance.</param>
|
|
/// <returns>true if other.value is equal to this instance; otherwise, false.</returns>
|
|
public bool Equals(Float16 other)
|
|
{
|
|
return value == other.value
|
|
|| AreZero(this, other)
|
|
|| (IsNaN(this) && IsNaN(other));
|
|
}
|
|
|
|
/// <summary>
|
|
/// Returns a value indicating whether this instance and a specified System.Object
|
|
/// represent the same type and value.
|
|
/// </summary>
|
|
/// <param name="obj">An System.Object.</param>
|
|
/// <returns>true if obj is Float16 and its value is equal to this instance; otherwise, false.</returns>
|
|
public override bool Equals(object obj)
|
|
{
|
|
return (obj is Float16 other) && Equals(other);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Returns the hash code for this instance.
|
|
/// </summary>
|
|
/// <returns>A 32-bit signed integer hash code.</returns>
|
|
public override int GetHashCode()
|
|
{
|
|
if (IsNaNOrZero(this))
|
|
{
|
|
// All NaNs should have the same hash code, as should both Zeros.
|
|
return value & PositiveInfinityBits;
|
|
}
|
|
return value;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Returns a string representation of the current value.
|
|
/// </summary>
|
|
/// <returns>Text representation of Float16</returns>
|
|
public override string ToString()
|
|
{
|
|
return $"{value} : {ToFloat()}";
|
|
}
|
|
|
|
/// <summary>
|
|
/// Explicit conversion
|
|
/// </summary>
|
|
/// <returns>single precision value converted from Float16</returns>
|
|
public float ToFloat()
|
|
{
|
|
return (float)this;
|
|
}
|
|
|
|
/// <summary>Explicitly converts a <see cref="float" /> value to its nearest representable half-precision floating-point value.</summary>
|
|
/// <param name="value">The value to convert.</param>
|
|
/// <returns><paramref name="value" /> converted to its nearest representable half-precision floating-point value.</returns>
|
|
public static explicit operator Float16(float value)
|
|
{
|
|
const int SingleMaxExponent = 0xFF;
|
|
|
|
uint floatInt = BitOpsUtils.SingleToUInt32Bits(value);
|
|
bool sign = (floatInt & BitOpsUtils.SingleSignMask) >> BitOpsUtils.SingleSignShift != 0;
|
|
int exp = (int)(floatInt & BitOpsUtils.SingleBiasedExponentMask) >> BitOpsUtils.SingleBiasedExponentShift;
|
|
uint sig = floatInt & BitOpsUtils.SingleTrailingSignificandMask;
|
|
|
|
if (exp == SingleMaxExponent)
|
|
{
|
|
if (sig != 0) // NaN
|
|
{
|
|
return CreateFloat16NaN(sign, (ulong)sig << 41); // Shift the significand bits to the left end
|
|
}
|
|
return sign ? NegativeInfinity : PositiveInfinity;
|
|
}
|
|
|
|
uint sigHalf = sig >> 9 | ((sig & 0x1FFU) != 0 ? 1U : 0U); // RightShiftJam
|
|
|
|
if ((exp | (int)sigHalf) == 0)
|
|
{
|
|
return new Float16(sign, 0, 0);
|
|
}
|
|
|
|
return new Float16(RoundPackToFloat16(sign, (short)(exp - 0x71), (ushort)(sigHalf | 0x4000)));
|
|
}
|
|
|
|
/// <summary>Explicitly converts a half-precision floating-point value to its nearest representable <see cref="float" /> value.</summary>
|
|
/// <param name="value">The value to convert.</param>
|
|
/// <returns><paramref name="value" /> converted to its nearest representable <see cref="float" /> value.</returns>
|
|
public static explicit operator float(Float16 value)
|
|
{
|
|
bool sign = IsNegative(value);
|
|
int exp = value.BiasedExponent;
|
|
uint sig = value.TrailingSignificand;
|
|
|
|
if (exp == MaxBiasedExponent)
|
|
{
|
|
if (sig != 0)
|
|
{
|
|
// Shift sig left so only 8 bits of it remains
|
|
return BitOpsUtils.CreateSingleNaN(sign, (ulong)sig << 54);
|
|
}
|
|
return sign ? float.NegativeInfinity : float.PositiveInfinity;
|
|
}
|
|
|
|
if (exp == 0)
|
|
{
|
|
if (sig == 0)
|
|
{
|
|
// Positive / Negative zero
|
|
return (sign) ? -0.0f : 0.0f;
|
|
}
|
|
(exp, sig) = NormSubnormalF16Sig(sig);
|
|
exp -= 1;
|
|
}
|
|
|
|
return BitOpsUtils.CreateSingle(sign, (byte)(exp + 0x70), sig << 13);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Flips the sign. NaNs are not affected.
|
|
/// IEEE 754 specifies NaNs to be propagated
|
|
/// </summary>
|
|
/// <param name="value"></param>
|
|
/// <returns></returns>
|
|
public static Float16 Negate(Float16 value)
|
|
{
|
|
return IsNaN(value) ? value : new Float16((ushort)(value.value ^ SignMask));
|
|
}
|
|
|
|
#region Utilities
|
|
|
|
private static bool AreZero(Float16 left, Float16 right)
|
|
{
|
|
// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
|
|
// for two values by or'ing the private bits together and stripping the sign. They are both zero,
|
|
// and therefore equivalent, if the resulting value is still zero.
|
|
return (ushort)((left.value | right.value) & ~SignMask) == 0;
|
|
}
|
|
|
|
/// <summary>
|
|
/// The function returns true if the value is either NaN or zero.
|
|
/// </summary>
|
|
/// <param name="value">instance of Float16</param>
|
|
/// <returns>true if NaN or zero.</returns>
|
|
public static bool IsNaNOrZero(Float16 value)
|
|
{
|
|
uint abs = StripSign(value);
|
|
return (abs == 0 || abs > PositiveInfinityBits);
|
|
}
|
|
|
|
private static uint StripSign(Float16 value)
|
|
{
|
|
return (ushort)(value.value & ~SignMask);
|
|
}
|
|
|
|
private static (int Exp, uint Sig) NormSubnormalF16Sig(uint sig)
|
|
{
|
|
int shiftDist = BitOpsUtils.LeadingZeroCount(sig) - 16 - 5;
|
|
return (1 - shiftDist, sig << shiftDist);
|
|
}
|
|
|
|
|
|
// Significand bits should be shifted towards to the left end before calling these methods
|
|
// Creates Quiet NaN if significand == 0
|
|
private static Float16 CreateFloat16NaN(bool sign, ulong significand)
|
|
{
|
|
const ushort NaNBits = BiasedExponentMask | 0x200; // Most significant significand bit
|
|
|
|
uint signInt = (sign ? 1U : 0U) << SignShift;
|
|
ushort sigInt = (ushort)(significand >> 54);
|
|
|
|
ushort ushortBits = (ushort)(signInt | NaNBits | sigInt);
|
|
return new Float16(ushortBits);
|
|
}
|
|
|
|
private static ushort RoundPackToFloat16(bool sign, short exp, ushort sig)
|
|
{
|
|
const int RoundIncrement = 0x8; // Depends on rounding mode but it's always towards closest / ties to even
|
|
int roundBits = sig & 0xF;
|
|
|
|
if ((uint)exp >= 0x1D)
|
|
{
|
|
if (exp < 0)
|
|
{
|
|
sig = (ushort)ShiftRightJam(sig, -exp);
|
|
exp = 0;
|
|
roundBits = sig & 0xF;
|
|
}
|
|
else if (exp > 0x1D || sig + RoundIncrement >= 0x8000) // Overflow
|
|
{
|
|
return sign ? NegativeInfinityBits : PositiveInfinityBits;
|
|
}
|
|
}
|
|
|
|
sig = (ushort)((sig + RoundIncrement) >> 4);
|
|
sig &= (ushort)~(((roundBits ^ 8) != 0 ? 0 : 1) & 1);
|
|
|
|
if (sig == 0)
|
|
{
|
|
exp = 0;
|
|
}
|
|
|
|
return new Float16(sign, (ushort)exp, sig).value;
|
|
}
|
|
|
|
// If any bits are lost by shifting, "jam" them into the LSB.
|
|
// if dist > bit count, Will be 1 or 0 depending on i
|
|
// (unlike bitwise operators that masks the lower 5 bits)
|
|
private static uint ShiftRightJam(uint i, int dist) => dist < 31 ? (i >> dist) | (i << (-dist & 31) != 0 ? 1U : 0U) : (i != 0 ? 1U : 0U);
|
|
|
|
private static ulong ShiftRightJam(ulong l, int dist) => dist < 63 ? (l >> dist) | (l << (-dist & 63) != 0 ? 1UL : 0UL) : (l != 0 ? 1UL : 0UL);
|
|
|
|
#endregion
|
|
}
|
|
|
|
/// <summary>
|
|
/// This value type represents A BFloat16 value.
|
|
/// See https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus
|
|
/// for details.
|
|
/// it is blittable as defined in https://docs.microsoft.com/en-us/dotnet/framework/interop/blittable-and-non-blittable-types
|
|
/// and as such, represented the same way in managed and native memories. This means that arrays of this type
|
|
/// do not have to be copied to be passed to native memory but simply pinnned and read by native code. Thus,
|
|
/// one can create a Tensor on top of an array of these structures and feed it directly to Onnxruntime library.
|
|
/// Binary wise, it is the same as ushort[] (uint16_t in C++). However, we would like a separate type for type dispatching.
|
|
/// </summary>
|
|
[StructLayout(LayoutKind.Sequential)]
|
|
public readonly struct BFloat16 :
|
|
IComparable,
|
|
IComparable<BFloat16>,
|
|
IEquatable<BFloat16>
|
|
{
|
|
internal const ushort SignMask = 0x8000;
|
|
internal const int SignShift = 15;
|
|
internal const byte ShiftedSignMask = SignMask >> SignShift;
|
|
|
|
internal const ushort BiasedExponentMask = 0x7F80; // 0b_0111_1111_1000_0000;
|
|
internal const int BiasedExponentShift = 7;
|
|
internal const byte ShiftedBiasedExponentMask = BiasedExponentMask >> BiasedExponentShift;
|
|
|
|
internal const ushort TrailingSignificandMask = 0x007F; // 0b_0000_0000_0111_1111;
|
|
|
|
internal const byte MinSign = 0;
|
|
internal const byte MaxSign = 1;
|
|
|
|
internal const byte MinBiasedExponent = 0x00;
|
|
internal const byte MaxBiasedExponent = 0xFF;
|
|
|
|
internal const byte ExponentBias = 127;
|
|
|
|
internal const sbyte MinExponent = -126;
|
|
internal const sbyte MaxExponent = +127;
|
|
|
|
// Constants representing the private bit-representation for various default values
|
|
|
|
private const ushort PositiveZeroBits = 0x0000;
|
|
private const ushort NegativeZeroBits = 0x8000;
|
|
|
|
private const ushort OneBits = 0x3F80; // 0b0_01111111_0000000
|
|
|
|
private const ushort PositiveInfinityBits = 0x7F80;
|
|
private const ushort NegativeInfinityBits = 0xFF80;
|
|
|
|
private const ushort PositiveQNaNBits = 0x7FC1;
|
|
private const ushort NegativeQNaNBits = 0xFFC1;
|
|
|
|
private const ushort MinValueBits = 0xFF7F; // 1b0_11111110_1111111
|
|
private const ushort MaxValueBits = 0x7F7F; // 0b0_11111110_1111111
|
|
|
|
private const ushort EpsilonBits = 0x0080; // the smallest positive normal value
|
|
|
|
private const ushort PiBits = 0x4049; // 0b0_10000000_1001001
|
|
|
|
// Used for rounding subnormal values
|
|
private const uint RoundingBase = 0x7FFF;
|
|
|
|
// Well-defined and commonly used values
|
|
|
|
/// <summary>
|
|
/// BFloat16 Epsilon value
|
|
/// </summary>
|
|
public static BFloat16 Epsilon => new BFloat16(EpsilonBits);
|
|
|
|
/// <summary>
|
|
/// BFloat16 Pi value
|
|
/// </summary>
|
|
public static BFloat16 Pi => new BFloat16(PiBits);
|
|
|
|
/// <summary>
|
|
/// BFloat16 Positive infinity value
|
|
/// </summary>
|
|
public static BFloat16 PositiveInfinity => new BFloat16(PositiveInfinityBits);
|
|
|
|
/// <summary>
|
|
/// BFloat16 Negative infinity value
|
|
/// </summary>
|
|
public static BFloat16 NegativeInfinity => new BFloat16(NegativeInfinityBits);
|
|
|
|
/// <summary>
|
|
/// BFloat16 NaN
|
|
/// </summary>
|
|
public static BFloat16 NaN => new BFloat16(NegativeQNaNBits);
|
|
|
|
/// <summary>
|
|
/// BFloat16 Positive Zero
|
|
/// </summary>
|
|
public static BFloat16 Zero => new BFloat16(PositiveZeroBits); // 0.0
|
|
|
|
/// <summary>
|
|
/// BFloat16 One
|
|
/// </summary>
|
|
public static BFloat16 One => new BFloat16(OneBits); // 1.0
|
|
|
|
/// <summary>
|
|
/// BFloat16 Negative Zero
|
|
/// </summary>
|
|
public static BFloat16 NegativeZero => new BFloat16(NegativeZeroBits); // -0.0
|
|
|
|
/// <summary>
|
|
/// BFloat16 Min value
|
|
/// </summary>
|
|
public static BFloat16 MinValue => new BFloat16(MinValueBits); // 65,407
|
|
|
|
/// <summary>
|
|
/// BFloat16 Max value
|
|
/// </summary>
|
|
|
|
public static BFloat16 MaxValue => new BFloat16(MaxValueBits); // 32,639
|
|
|
|
/// <summary>
|
|
/// bfloat16 representation bits
|
|
/// </summary>
|
|
public readonly ushort value;
|
|
|
|
/// <summary>
|
|
/// Constructor from ushort, no conversion takes place. The value
|
|
/// is assumed to be converted
|
|
/// </summary>
|
|
/// <param name="v">bfloat16 representation bits</param>
|
|
public BFloat16(ushort v)
|
|
{
|
|
value = v;
|
|
}
|
|
|
|
// Extracts biased exponent bits
|
|
internal byte BiasedExponent
|
|
{
|
|
get
|
|
{
|
|
ushort bits = value;
|
|
return ExtractBiasedExponentFromBits(bits);
|
|
}
|
|
}
|
|
|
|
// Extracts all the Significand bits
|
|
internal ushort TrailingSignificand
|
|
{
|
|
get
|
|
{
|
|
ushort bits = value;
|
|
return ExtractTrailingSignificandFromBits(bits);
|
|
}
|
|
}
|
|
|
|
internal static byte ExtractBiasedExponentFromBits(ushort bits)
|
|
{
|
|
return (byte)((bits >> BiasedExponentShift) & ShiftedBiasedExponentMask);
|
|
}
|
|
|
|
internal static ushort ExtractTrailingSignificandFromBits(ushort bits)
|
|
{
|
|
return (ushort)(bits & TrailingSignificandMask);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Compares two BFloat16 instances.
|
|
/// </summary>
|
|
/// <param name="left"></param>
|
|
/// <param name="right"></param>
|
|
/// <returns>true if the left is less than right according to IEEE</returns>
|
|
public static bool operator <(BFloat16 left, BFloat16 right)
|
|
{
|
|
if (IsNaN(left) || IsNaN(right))
|
|
{
|
|
// IEEE defines that NaN is unordered with respect to everything, including itself.
|
|
return false;
|
|
}
|
|
|
|
bool leftIsNegative = IsNegative(left);
|
|
|
|
if (leftIsNegative != IsNegative(right))
|
|
{
|
|
// When the signs of left and right differ, we know that left is less than right if it is
|
|
// the negative value. The exception to this is if both values are zero, in which case IEEE
|
|
// says they should be equal, even if the signs differ.
|
|
return leftIsNegative && !AreZero(left, right);
|
|
}
|
|
|
|
return (left.value != right.value) && ((left.value < right.value) ^ leftIsNegative);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Compares two BFloat16 instances.
|
|
/// </summary>
|
|
/// <param name="left"></param>
|
|
/// <param name="right"></param>
|
|
/// <returns>true if the left is greater than right according to IEEE</returns>
|
|
public static bool operator >(BFloat16 left, BFloat16 right)
|
|
{
|
|
return right < left;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Compares two BFloat16 instances.
|
|
/// </summary>
|
|
/// <param name="left"></param>
|
|
/// <param name="right"></param>
|
|
/// <returns>true if the left is less or equal than right according to IEEE</returns>
|
|
public static bool operator <=(BFloat16 left, BFloat16 right)
|
|
{
|
|
if (IsNaN(left) || IsNaN(right))
|
|
{
|
|
// IEEE defines that NaN is unordered with respect to everything, including itself.
|
|
return false;
|
|
}
|
|
|
|
bool leftIsNegative = IsNegative(left);
|
|
|
|
if (leftIsNegative != IsNegative(right))
|
|
{
|
|
// When the signs of left and right differ, we know that left is less than right if it is
|
|
// the negative value. The exception to this is if both values are zero, in which case IEEE
|
|
// says they should be equal, even if the signs differ.
|
|
return leftIsNegative || AreZero(left, right);
|
|
}
|
|
|
|
return (left.value == right.value) || ((left.value < right.value) ^ leftIsNegative);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Compares two BFloat16 instances.
|
|
/// </summary>
|
|
/// <param name="left"></param>
|
|
/// <param name="right"></param>
|
|
/// <returns>true if the left is greater or equal than right according to IEEE</returns>
|
|
public static bool operator >=(BFloat16 left, BFloat16 right)
|
|
{
|
|
return right <= left;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Compares values of two BFloat16 for binary equality.
|
|
/// If either of the values is NaN, this will return false.
|
|
///
|
|
/// </summary>
|
|
/// <param name="left">left hand side</param>
|
|
/// <param name="right">right hand side</param>
|
|
/// <returns>result of value comparisons</returns>
|
|
public static bool operator ==(BFloat16 left, BFloat16 right)
|
|
{
|
|
if (IsNaN(left) || IsNaN(right))
|
|
{
|
|
// IEEE defines that NaN is not equal to anything, including itself.
|
|
return false;
|
|
}
|
|
|
|
return left.value == right.value;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Compares values of two BFloat16 for binary inequality
|
|
/// If either of the values is NaN it would return true.
|
|
/// </summary>
|
|
/// <param name="left"></param>
|
|
/// <param name="right"></param>
|
|
/// <returns>result of value comparisons</returns>
|
|
public static bool operator !=(BFloat16 left, BFloat16 right)
|
|
{
|
|
return !(left == right);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Determines whether the specified value is finite (zero, subnormal, or normal).
|
|
/// </summary>
|
|
/// <param name="value">BFloat16 instance.</param>
|
|
/// <returns>true if the value is finite</returns>
|
|
public static bool IsFinite(BFloat16 value)
|
|
{
|
|
return StripSign(value) < PositiveInfinityBits;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Determines whether the specified value is infinite.
|
|
/// </summary>
|
|
/// <param name="value">BFloat16 instance.</param>
|
|
/// <returns>true if the value is infinite</returns>
|
|
public static bool IsInfinity(BFloat16 value)
|
|
{
|
|
return StripSign(value) == PositiveInfinityBits;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Determines whether the specified value is NaN.
|
|
/// </summary>
|
|
///
|
|
/// <param name="value">BFloat16 instance</param>
|
|
/// <returns>true if the value is not a number</returns>
|
|
public static bool IsNaN(BFloat16 value)
|
|
{
|
|
return StripSign(value) > PositiveInfinityBits;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Determines whether the specified value is negative.
|
|
/// </summary>
|
|
/// <param name="value">BFloat16 instance</param>
|
|
/// <returns>true if the value is negative</returns></returns>
|
|
public static bool IsNegative(BFloat16 value)
|
|
{
|
|
return (short)(value.value) < 0;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Determines whether the specified value is negative infinity.
|
|
/// </summary>
|
|
///
|
|
/// <param name="value">BFloat16 instance</param>
|
|
/// <returns>true if the value is negative infinity</returns>
|
|
public static bool IsNegativeInfinity(BFloat16 value)
|
|
{
|
|
return value.value == NegativeInfinityBits;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Determines whether the specified value is normal
|
|
/// </summary>
|
|
/// <param name="value"></param>
|
|
/// <returns>true or false</returns>
|
|
public static bool IsNormal(BFloat16 value)
|
|
{
|
|
uint absValue = StripSign(value);
|
|
return (absValue < PositiveInfinityBits) // is finite
|
|
&& (absValue != 0) // is not zero
|
|
&& ((absValue & BiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
|
|
}
|
|
|
|
/// <summary>
|
|
/// Determines whether the specified value is positive infinity.
|
|
/// </summary>
|
|
/// <param name="value">BFloat16 instance</param>
|
|
/// <returns></returns>
|
|
public static bool IsPositiveInfinity(BFloat16 value)
|
|
{
|
|
return value.value == PositiveInfinityBits;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Determines whether the specified value is subnormal.
|
|
/// </summary>
|
|
/// <param name="value">BFloat16 instance</param>
|
|
/// <returns>true if the value is subnormal</returns>
|
|
public static bool IsSubnormal(BFloat16 value)
|
|
{
|
|
uint absValue = StripSign(value);
|
|
return (absValue < PositiveInfinityBits) // is finite
|
|
&& (absValue != 0) // is not zero
|
|
&& ((absValue & BiasedExponentMask) == 0); // is subnormal (has a zero exponent)
|
|
}
|
|
|
|
/// <summary>
|
|
/// Compares this object to another object, returning an integer that indicates the relationship.
|
|
/// </summary>
|
|
///
|
|
/// <param name="obj">Object to compare to</param>
|
|
/// <returns>A value less than zero if this is less than <paramref name="obj"/>,
|
|
/// zero if this is equal to <paramref name="obj"/>, or a value greater than zero
|
|
/// if this is greater than <paramref name="obj"/>.
|
|
/// </returns>
|
|
/// <exception cref="ArgumentException">Thrown when <paramref name="obj"/> is not of type <see cref="BFloat16"/>.</exception>
|
|
public int CompareTo(object obj)
|
|
{
|
|
if (!(obj is BFloat16))
|
|
{
|
|
return (obj is null) ? 1 : throw new ArgumentException("Object must be of type BFloat16");
|
|
}
|
|
return CompareTo((BFloat16)(obj));
|
|
}
|
|
|
|
/// <summary>
|
|
/// Compares this object to another object, returning an integer that indicates the relationship.
|
|
/// </summary>
|
|
/// <param name="other">Object to compare to</param>
|
|
/// <returns>A value less than zero if this is less than <paramref name="other"/>,
|
|
/// zero if this is equal to <paramref name="other"/>,
|
|
/// or a value greater than zero if this is greater than <paramref name="other"/>.</returns>
|
|
public int CompareTo(BFloat16 other)
|
|
{
|
|
if (this < other)
|
|
{
|
|
return -1;
|
|
}
|
|
|
|
if (this > other)
|
|
{
|
|
return 1;
|
|
}
|
|
|
|
if (this == other)
|
|
{
|
|
return 0;
|
|
}
|
|
|
|
if (IsNaN(this))
|
|
{
|
|
return IsNaN(other) ? 0 : -1;
|
|
}
|
|
|
|
Debug.Assert(IsNaN(other));
|
|
return 1;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Returns a value indicating whether this instance and other BFloat16 represent the same value.
|
|
/// </summary>
|
|
/// <param name="other">A BFloat16 object to compare to this instance.</param>
|
|
/// <returns>true if other.value is equal to this instance; otherwise, false.</returns>
|
|
public bool Equals(BFloat16 other)
|
|
{
|
|
return value == other.value
|
|
|| AreZero(this, other)
|
|
|| (IsNaN(this) && IsNaN(other));
|
|
}
|
|
|
|
/// <summary>
|
|
/// Returns a value indicating whether this instance and a specified System.Object
|
|
/// represent the same type and value.
|
|
/// </summary>
|
|
/// <param name="obj">An System.Object.</param>
|
|
/// <returns>true if obj is BFloat16 its value is equal to this instance; otherwise, false.</returns>
|
|
public override bool Equals(object obj)
|
|
{
|
|
return (obj is BFloat16 other) && Equals(other);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Returns the hash code for this instance.
|
|
/// </summary>
|
|
/// <returns>A 32-bit signed integer hash code.</returns>
|
|
public override int GetHashCode()
|
|
{
|
|
if (IsNaNOrZero(this))
|
|
{
|
|
// All NaNs should have the same hash code, as should both Zeros.
|
|
return value & PositiveInfinityBits;
|
|
}
|
|
return value;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Returns a string representation of the current value.
|
|
/// </summary>
|
|
/// <returns>Text representation of BFloat16</returns>
|
|
public override string ToString()
|
|
{
|
|
return $"{value} : {ToFloat()}";
|
|
}
|
|
|
|
/// <summary>
|
|
/// Explicit conversion
|
|
/// </summary>
|
|
/// <returns>single precision value converted from Float16</returns>
|
|
public float ToFloat()
|
|
{
|
|
return (float)this;
|
|
}
|
|
|
|
/// <summary>Explicitly converts a <see cref="float" /> value to its nearest representable bfloat16 value.</summary>
|
|
/// <param name="value">The value to convert.</param>
|
|
/// <returns><paramref name="value" /> converted to its nearest representable half-precision floating-point value.</returns>
|
|
public static explicit operator BFloat16(float value)
|
|
{
|
|
if (float.IsNaN(value))
|
|
{
|
|
return NaN;
|
|
}
|
|
|
|
uint singleBits = BitOpsUtils.SingleToUInt32Bits(value);
|
|
ushort bfloatBits = BitOpsUtils.SingleBitsToBFloat16Bits(singleBits);
|
|
|
|
// Round this up. Implement the same logic pytorch uses for rounding.
|
|
// We use RoundingBase that is 0x7FFF + (1), so we carry the 1 to the next bit.
|
|
// either the last bfloat bit is 1 or singleBits have some bits set.
|
|
singleBits += ((uint)bfloatBits & 1) + RoundingBase;
|
|
bfloatBits = BitOpsUtils.SingleBitsToBFloat16Bits(singleBits);
|
|
return new BFloat16(bfloatBits);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Explicitly converts a BFloat16 value to its nearest representable <see cref="float" /> value.
|
|
/// </summary>
|
|
/// <param name="value">The value to convert.</param>
|
|
/// <returns><paramref name="value" /> converted to its nearest representable <see cref="float" /> value.</returns>
|
|
public static explicit operator float(BFloat16 value)
|
|
{
|
|
bool sign = IsNegative(value);
|
|
int exp = value.BiasedExponent;
|
|
uint sig = value.TrailingSignificand;
|
|
|
|
if (exp == MaxBiasedExponent)
|
|
{
|
|
if (sig != 0)
|
|
{
|
|
// Shift sig left 54 bits to get a 64-bit integer
|
|
// to cut off all but 8 bits of the significant
|
|
return BitOpsUtils.CreateSingleNaN(sign, (ulong)sig << 56);
|
|
}
|
|
return sign ? float.NegativeInfinity : float.PositiveInfinity;
|
|
}
|
|
|
|
if (exp == 0 && sig == 0)
|
|
{
|
|
// Positive / Negative zero
|
|
return (sign) ? -0.0f : 0.0f;
|
|
}
|
|
|
|
// All subnormal numbers in BFloat16 would be also subnormal in FP32 because they
|
|
// share the exponent.
|
|
uint singleBits = BitOpsUtils.BFloat16BitsToSingleBits(value.value);
|
|
return BitOpsUtils.UInt32BitsToSingle(singleBits);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Flips the sign. NaNs are not affected.
|
|
/// IEEE 754 specifies NaNs to be propagated
|
|
/// </summary>
|
|
/// <param name="value"></param>
|
|
/// <returns></returns>
|
|
public static BFloat16 Negate(BFloat16 value)
|
|
{
|
|
return IsNaN(value) ? value : new BFloat16((ushort)(value.value ^ SignMask));
|
|
}
|
|
|
|
/// <summary>
|
|
/// The function returns true if the value is either NaN or zero.
|
|
/// </summary>
|
|
/// <param name="value">instance of BFloat16</param>
|
|
/// <returns>true if NaN or zero.</returns>
|
|
public static bool IsNaNOrZero(BFloat16 value)
|
|
{
|
|
uint abs = StripSign(value);
|
|
return (abs == 0 || abs > PositiveInfinityBits);
|
|
}
|
|
|
|
#region Utilities
|
|
|
|
private static bool AreZero(BFloat16 left, BFloat16 right)
|
|
{
|
|
// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
|
|
// for two values by or'ing the private bits together and stripping the sign. They are both zero,
|
|
// and therefore equivalent, if the resulting value is still zero.
|
|
return (ushort)((left.value | right.value) & ~SignMask) == 0;
|
|
}
|
|
|
|
private static uint StripSign(BFloat16 value)
|
|
{
|
|
return (ushort)(value.value & ~SignMask);
|
|
}
|
|
|
|
#endregion
|
|
}
|
|
} |