Add numeric_limits for MLFloat16 and BFloat16 (#22197)

### Description
* Add std::numeric_limits for MLFloat16 and BFloat16.
* Update some comments in csharp ORTFloat16.shared.cs.
* Add unit tests (including Clip)

Note that the canonical NaN is not consistent in C++ and C#. C# uses
negative quiet NaN as canonical NaN, while C++ uses positive quiet NaN.
The choice of CSharp Float16.NaN is to be consistent with
System.Half.NaN.

FP16 data returns from CUDA might have 7FFF as NaN; FP16 data from CPU
provider might have 0x7E00 as NaN. Anyway there is no consistent
canonical NaN in ORT right now. Because all these NaNs are aligned with
IEEE spec, there shall not an issue in downstream.

### Motivation and Context
std::numeric_limits is used in codebase but not defined for MLFloat16
and BFloat16. It causes some bugs like
https://github.com/microsoft/onnxruntime/issues/21957 introduced by
https://github.com/microsoft/onnxruntime/pull/21493.
This commit is contained in:
Tianlei Wu 2024-09-25 17:10:05 -07:00 committed by GitHub
parent 72b0979e8a
commit 7880342e5e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 268 additions and 57 deletions

View file

@ -60,9 +60,9 @@ namespace Microsoft.ML.OnnxRuntime
/// <summary>
/// Extracts single precision number bit representation as uint
/// so its bits can be manipulated.
///
///
/// This API is the reverse of UInt32BitsToSingle().
///
///
/// </summary>
/// <param name="single">float value</param>
/// <returns></returns>
@ -79,11 +79,11 @@ namespace Microsoft.ML.OnnxRuntime
/// <summary>
/// Needed because BitConverter impl is not available until
/// later versions. This API is the reverse of SingleToUInt32Bits().
///
///
/// For the exact bit representation of float see IEEE 754 standard for single precision.
///
///
/// </summary>
/// <param name="singleBits">bit representation of float either obtained from
/// <param name="singleBits">bit representation of float either obtained from
/// SingleToUInt32Bits or assembled using bitwise operators</param>
/// <returns></returns>
internal static float UInt32BitsToSingle(uint singleBits)
@ -99,7 +99,7 @@ namespace Microsoft.ML.OnnxRuntime
/// <summary>
/// Converts single precision bits representation which can be obtained using
/// SingleToUInt32Bits() or manually constructed according to IEEE 754 standard.
///
///
/// </summary>
/// <param name="singleBits">bits representation of a single precision number (float)</param>
/// <returns></returns>
@ -177,8 +177,8 @@ namespace Microsoft.ML.OnnxRuntime
/// do not have to be copied to be passed to native memory but simply pinned and read by native code. Thus,
/// one can create a Tensor on top of an array of these structures and feed it directly to Onnxruntime library.
/// Binary wise, it is the same as ushort[] (uint16_t in C++). However, we would like a separate type for type dispatching.
///
/// The implementation is derived from
///
/// The implementation is derived from
/// https://source.dot.net/#System.Private.CoreLib/src/libraries/System.Private.CoreLib/src/System/Half.cs,7895d5942d33f974
/// </summary>
[StructLayout(LayoutKind.Sequential)]
@ -215,6 +215,7 @@ namespace Microsoft.ML.OnnxRuntime
private const ushort OneBits = 0x3C00;
// Minimum positive normalized value. It is corresponding to numeric_limits<float16>::min() in C++.
private const ushort EpsilonBits = 0x0400;
private const ushort PositiveInfinityBits = 0x7C00;
@ -238,7 +239,7 @@ namespace Microsoft.ML.OnnxRuntime
/// <summary>
/// Float16 Epsilon value
/// </summary>
public static Float16 Epsilon => new Float16(EpsilonBits); // 5.9604645E-08
public static Float16 Epsilon => new Float16(EpsilonBits); // 0.00006103515625
/// <summary>
/// Float16 Pi value
@ -248,17 +249,17 @@ namespace Microsoft.ML.OnnxRuntime
/// <summary>
/// Float16 Positive Infinity value
/// </summary>
public static Float16 PositiveInfinity => new Float16(PositiveInfinityBits); // 1.0 / 0.0;
public static Float16 PositiveInfinity => new Float16(PositiveInfinityBits);
/// <summary>
/// Float16 Negative Infinity value
/// </summary>
public static Float16 NegativeInfinity => new Float16(NegativeInfinityBits); // -1.0 / 0.0
public static Float16 NegativeInfinity => new Float16(NegativeInfinityBits);
/// <summary>
/// Float16 NaN
/// </summary>
public static Float16 NaN => new Float16(NegativeQNaNBits); // 0.0 / 0.0
public static Float16 NaN => new Float16(NegativeQNaNBits); // Same as System.Half.NaN
/// <summary>
/// Float16 Zero value
@ -276,14 +277,14 @@ namespace Microsoft.ML.OnnxRuntime
public static Float16 NegativeZero => new Float16(NegativeZeroBits); // -0.0
/// <summary>
/// Float16 Min value
/// Float16 Lowest value
/// </summary>
public static Float16 MinValue => new Float16(MinValueBits); // 64,511
public static Float16 MinValue => new Float16(MinValueBits); // -65504.0
/// <summary>
/// Float16 Max value
/// </summary>
public static Float16 MaxValue => new Float16(MaxValueBits); // 31,743
public static Float16 MaxValue => new Float16(MaxValueBits); // 65504.0
/// <summary>
/// float16 representation bits
@ -348,7 +349,7 @@ namespace Microsoft.ML.OnnxRuntime
/// <summary>
/// Compares values of two Float16
///
///
/// </summary>
/// <param name="left">left hand side</param>
/// <param name="right">right hand side</param>
@ -376,7 +377,7 @@ namespace Microsoft.ML.OnnxRuntime
/// <summary>
/// Compares values of two Float16
///
///
/// </summary>
/// <param name="left">left hand side</param>
/// <param name="right">right hand side</param>
@ -388,7 +389,7 @@ namespace Microsoft.ML.OnnxRuntime
/// <summary>
/// Compares values of two Float16
///
///
/// </summary>
/// <param name="left">left hand side</param>
/// <param name="right">right hand side</param>
@ -429,7 +430,7 @@ namespace Microsoft.ML.OnnxRuntime
/// <summary>
/// Compares values of two Float16 for binary equality.
/// If either of the values is NaN, this will return false.
///
///
/// </summary>
/// <param name="left">left hand side</param>
/// <param name="right">right hand side</param>
@ -479,7 +480,7 @@ namespace Microsoft.ML.OnnxRuntime
/// <summary>
/// Determines whether the specified value is NaN.
/// </summary>
///
///
/// <param name="value">Float16 instance</param>
/// <returns>true if the value is not a number</returns>
public static bool IsNaN(Float16 value)
@ -500,7 +501,7 @@ namespace Microsoft.ML.OnnxRuntime
/// <summary>
/// Determines whether the specified value is negative infinity.
/// </summary>
///
///
/// <param name="value">Float16 instance</param>
/// <returns>true if the value is negative infinity</returns>
public static bool IsNegativeInfinity(Float16 value)
@ -549,7 +550,7 @@ namespace Microsoft.ML.OnnxRuntime
/// <summary>
/// Compares this object to another object, returning an integer that indicates the relationship.
/// </summary>
///
///
/// <param name="obj">Object to compare to</param>
/// <returns>A value less than zero if this is less than <paramref name="obj"/>,
/// zero if this is equal to <paramref name="obj"/>, or a value greater than zero
@ -570,7 +571,7 @@ namespace Microsoft.ML.OnnxRuntime
/// </summary>
/// <param name="other">Object to compare to</param>
/// <returns>A value less than zero if this is less than <paramref name="other"/>,
/// zero if this is equal to <paramref name="other"/>,
/// zero if this is equal to <paramref name="other"/>,
/// or a value greater than zero if this is greater than <paramref name="other"/>.</returns>
public int CompareTo(Float16 other)
{
@ -864,10 +865,13 @@ namespace Microsoft.ML.OnnxRuntime
private const ushort PositiveQNaNBits = 0x7FC1;
private const ushort NegativeQNaNBits = 0xFFC1;
// Lowest finite value. It is corresponding to numeric_limits<BFloat16>::lowest() in C++.
private const ushort MinValueBits = 0xFF7F; // 1b0_11111110_1111111
private const ushort MaxValueBits = 0x7F7F; // 0b0_11111110_1111111
private const ushort EpsilonBits = 0x0080; // the smallest positive normal value
// Minimum positive normalized value. It is corresponding to numeric_limits<BFloat16>::min() in C++.
private const ushort EpsilonBits = 0x0080;
private const ushort PiBits = 0x4049; // 0b0_10000000_1001001
@ -899,7 +903,7 @@ namespace Microsoft.ML.OnnxRuntime
/// <summary>
/// BFloat16 NaN
/// </summary>
public static BFloat16 NaN => new BFloat16(NegativeQNaNBits);
public static BFloat16 NaN => new BFloat16(NegativeQNaNBits); // .Net has no BFloat16. Follow Float16 style.
/// <summary>
/// BFloat16 Positive Zero
@ -919,13 +923,13 @@ namespace Microsoft.ML.OnnxRuntime
/// <summary>
/// BFloat16 Min value
/// </summary>
public static BFloat16 MinValue => new BFloat16(MinValueBits); // 65,407
public static BFloat16 MinValue => new BFloat16(MinValueBits); // -3.38953139e38
/// <summary>
/// BFloat16 Max value
/// </summary>
public static BFloat16 MaxValue => new BFloat16(MaxValueBits); // 32,639
public static BFloat16 MaxValue => new BFloat16(MaxValueBits); // 3.38953139e38
/// <summary>
/// bfloat16 representation bits
@ -1051,7 +1055,7 @@ namespace Microsoft.ML.OnnxRuntime
/// <summary>
/// Compares values of two BFloat16 for binary equality.
/// If either of the values is NaN, this will return false.
///
///
/// </summary>
/// <param name="left">left hand side</param>
/// <param name="right">right hand side</param>
@ -1102,7 +1106,7 @@ namespace Microsoft.ML.OnnxRuntime
/// <summary>
/// Determines whether the specified value is NaN.
/// </summary>
///
///
/// <param name="value">BFloat16 instance</param>
/// <returns>true if the value is not a number</returns>
public static bool IsNaN(BFloat16 value)
@ -1123,7 +1127,7 @@ namespace Microsoft.ML.OnnxRuntime
/// <summary>
/// Determines whether the specified value is negative infinity.
/// </summary>
///
///
/// <param name="value">BFloat16 instance</param>
/// <returns>true if the value is negative infinity</returns>
public static bool IsNegativeInfinity(BFloat16 value)
@ -1170,7 +1174,7 @@ namespace Microsoft.ML.OnnxRuntime
/// <summary>
/// Compares this object to another object, returning an integer that indicates the relationship.
/// </summary>
///
///
/// <param name="obj">Object to compare to</param>
/// <returns>A value less than zero if this is less than <paramref name="obj"/>,
/// zero if this is equal to <paramref name="obj"/>, or a value greater than zero
@ -1191,7 +1195,7 @@ namespace Microsoft.ML.OnnxRuntime
/// </summary>
/// <param name="other">Object to compare to</param>
/// <returns>A value less than zero if this is less than <paramref name="other"/>,
/// zero if this is equal to <paramref name="other"/>,
/// zero if this is equal to <paramref name="other"/>,
/// or a value greater than zero if this is greater than <paramref name="other"/>.</returns>
public int CompareTo(BFloat16 other)
{
@ -1368,4 +1372,4 @@ namespace Microsoft.ML.OnnxRuntime
#endregion
}
}
}

View file

@ -295,3 +295,147 @@ inline void FloatToBFloat16(const float* flt, BFloat16* blf, size_t size) {
}
} // namespace onnxruntime
namespace std {
template <>
class numeric_limits<onnxruntime::MLFloat16> {
public:
static constexpr onnxruntime::MLFloat16 min() noexcept {
return onnxruntime::MLFloat16::FromBits(0x0400U); // Minimum positive normalized value: 0.00006103515625
}
static constexpr onnxruntime::MLFloat16 max() noexcept {
return onnxruntime::MLFloat16::FromBits(0x7BFFU); // Largest representable value: 65504
}
static constexpr onnxruntime::MLFloat16 lowest() noexcept {
return onnxruntime::MLFloat16::FromBits(0xFBFFU); // Smallest representable value: -65504
}
static constexpr onnxruntime::MLFloat16 infinity() noexcept {
return onnxruntime::MLFloat16::FromBits(0x7C00U); // Bits: sign(0), exponent(111,11), fraction(00,0000,0000)
}
static constexpr onnxruntime::MLFloat16 quiet_NaN() noexcept {
// The most significant fraction bit shall be 1, and no limitation on other fraction bits.
// Note that most frameworks use 0x7E00; while CUDA uses 0x7FFF; .Net System.Half.NaN uses 0xFE00;
return onnxruntime::MLFloat16::FromBits(0x7E00U); // Bits: sign(0), exponent(111,11), fraction(10,0000,0000)
}
static constexpr onnxruntime::MLFloat16 signaling_NaN() noexcept {
return onnxruntime::MLFloat16::FromBits(0x7D00U); // Bits: sign(0), exponent(111,11), fraction(01,0000,0000)
}
static constexpr onnxruntime::MLFloat16 denorm_min() noexcept {
return onnxruntime::MLFloat16::FromBits(0x0001U); // Minimum subnormal value: 0.000000059604645
}
static constexpr onnxruntime::MLFloat16 epsilon() noexcept {
return onnxruntime::MLFloat16::FromBits(0x1400U); // Difference between 1.0 and the next value: 2^-10 = 0.0009765625
}
static constexpr onnxruntime::MLFloat16 round_error() noexcept {
return onnxruntime::MLFloat16::FromBits(0x3800U); // 0.5
}
static constexpr bool is_specialized = true;
static constexpr bool is_signed = true;
static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = true;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = true;
static constexpr float_denorm_style has_denorm = denorm_present;
static constexpr bool has_denorm_loss = false;
static constexpr bool is_bounded = true;
static constexpr bool is_iec559 = true;
static constexpr bool is_modulo = false;
static constexpr int digits = 11; // Number of significant digits (mantissa)
static constexpr int digits10 = 3; // Decimal digits of precision
static constexpr int max_digits10 = 5; // Max decimal digits required for precision
static constexpr int radix = 2;
static constexpr int min_exponent = -13;
static constexpr int min_exponent10 = -4;
static constexpr int max_exponent = 16;
static constexpr int max_exponent10 = 4;
static constexpr bool traps = false;
static constexpr bool tinyness_before = false;
static constexpr std::float_round_style round_style = std::round_to_nearest;
};
template <>
class numeric_limits<onnxruntime::BFloat16> {
public:
static constexpr onnxruntime::BFloat16 min() noexcept {
return onnxruntime::BFloat16::FromBits(0x0080U); // Minimum positive normalized value: 1.175494e-38
}
static constexpr onnxruntime::BFloat16 max() noexcept {
return onnxruntime::BFloat16::FromBits(0x7F7FU); // Largest representable value: 3.38953139e38
}
static constexpr onnxruntime::BFloat16 lowest() noexcept {
return onnxruntime::BFloat16::FromBits(0xFF7FU); // Smallest representable value: -3.38953139e38
}
static constexpr onnxruntime::BFloat16 infinity() noexcept {
return onnxruntime::BFloat16::FromBits(0x7F80U); // Bits: sign(0), exponent(111,1111,1), fraction(000,0000)
}
static constexpr onnxruntime::BFloat16 quiet_NaN() noexcept {
// The most significant fraction bit shall be 1, and no limitation on other fraction bits.
// Note that Torch, Tensorflow, OpenVino, nGraph uses 0x7FC0; Paddle uses 0x7FC1; CUDA uses 0x7FFF.
return onnxruntime::BFloat16::FromBits(0x7FC1U); // Bits: sign(0), exponent(111,1111,1), fraction(100,0001)
}
static constexpr onnxruntime::BFloat16 signaling_NaN() noexcept {
// The most significant fraction bit shall be 0, and there is at least one 1 in other fraction bits.
return onnxruntime::BFloat16::FromBits(0x7F81U); // Bits: sign(0), exponent(111,1111,1), fraction(000,0001)
}
static constexpr onnxruntime::BFloat16 denorm_min() noexcept {
return onnxruntime::BFloat16::FromBits(0x0001U); // Minimum subnormal value: 9.1835e-41
}
static constexpr onnxruntime::BFloat16 epsilon() noexcept {
return onnxruntime::BFloat16::FromBits(0x3C00U); // Difference between 1.0 and the next value: 2^-7 = 0.0078125
}
static constexpr onnxruntime::BFloat16 round_error() noexcept {
return onnxruntime::BFloat16::FromBits(0x3F00U); // 0.5
}
static constexpr bool is_specialized = true;
static constexpr bool is_signed = true;
static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = true;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = true;
static constexpr float_denorm_style has_denorm = denorm_present;
static constexpr bool has_denorm_loss = false;
static constexpr bool is_bounded = true;
static constexpr bool is_iec559 = false;
static constexpr bool is_modulo = false;
static constexpr int digits = 8;
static constexpr int digits10 = 2;
static constexpr int max_digits10 = 4;
static constexpr int radix = 2;
static constexpr int min_exponent = -125;
static constexpr int min_exponent10 = -37;
static constexpr int max_exponent = 128;
static constexpr int max_exponent10 = 38;
static constexpr bool traps = false;
static constexpr bool tinyness_before = false;
static constexpr float_round_style round_style = round_to_nearest;
};
} // namespace std

View file

@ -59,33 +59,11 @@ Status Clip_6<T>::ComputeInternal(OpKernelContext* ctx) const {
return Status::OK();
}
namespace clip_internal {
template <typename T>
struct LowMax {
constexpr static T low() {
return std::numeric_limits<T>::lowest();
}
constexpr static T max() {
return std::numeric_limits<T>::max();
}
};
template <>
struct LowMax<MLFloat16> {
static MLFloat16 low() {
return MLFloat16::FromBits(math::floatToHalf(std::numeric_limits<float>::lowest()));
}
static MLFloat16 max() {
return MLFloat16::FromBits(math::floatToHalf(std::numeric_limits<float>::max()));
}
};
} // namespace clip_internal
template <typename T>
struct Clip::ComputeImpl {
void operator()(cudaStream_t stream, const Tensor* X, const Tensor* min, const Tensor* max, Tensor* Y) const {
auto min_default = clip_internal::LowMax<T>::low();
auto max_default = clip_internal::LowMax<T>::max();
auto min_default = std::numeric_limits<T>::lowest();
auto max_default = std::numeric_limits<T>::max();
const T* min_data = nullptr;
const T* max_data = nullptr;

View file

@ -494,6 +494,25 @@ TEST_F(DataTypeTest, MLFloat16Comparision) {
}
TEST_F(DataTypeTest, MLFloat16TestNAN) {
const MLFloat16 quiet_NaN = std::numeric_limits<MLFloat16>::quiet_NaN();
EXPECT_TRUE(quiet_NaN.IsNaN());
EXPECT_TRUE(quiet_NaN.IsNaNOrZero());
EXPECT_NE(MLFloat16::NaN, quiet_NaN); // NaN are not equal to each other
EXPECT_TRUE(std::isnan(quiet_NaN.ToFloat()));
const MLFloat16 signaling_NaN = std::numeric_limits<MLFloat16>::signaling_NaN();
EXPECT_TRUE(signaling_NaN.IsNaN());
EXPECT_TRUE(signaling_NaN.IsNaNOrZero());
EXPECT_NE(MLFloat16::NaN, signaling_NaN); // NaN are not equal to each other
EXPECT_TRUE(std::isnan(signaling_NaN.ToFloat()));
// NaN used in C# has negative sign
const MLFloat16 csharp_NaN = MLFloat16::FromBits(0xFE00U);
EXPECT_TRUE(csharp_NaN.IsNaN());
EXPECT_TRUE(csharp_NaN.IsNaNOrZero());
EXPECT_NE(BFloat16::NaN, csharp_NaN);
EXPECT_TRUE(std::isnan(csharp_NaN.ToFloat()));
const MLFloat16 fp16NANFromSingle(std::numeric_limits<float>::quiet_NaN());
EXPECT_TRUE(fp16NANFromSingle.IsNaN());
EXPECT_TRUE(fp16NANFromSingle.IsNaNOrZero());
@ -520,6 +539,11 @@ TEST_F(DataTypeTest, MLFloat16NaNComparision) {
}
TEST_F(DataTypeTest, MLFloat16Infinity) {
const MLFloat16 fp16_infinity(std::numeric_limits<MLFloat16>::infinity());
EXPECT_TRUE(fp16_infinity.IsInfinity());
EXPECT_FALSE(fp16_infinity.IsFinite());
EXPECT_FALSE(fp16_infinity.IsNegative());
EXPECT_FALSE(MLFloat16::MaxValue.Negate().IsInfinity());
EXPECT_FALSE(MLFloat16::MaxValue.IsInfinity());
EXPECT_TRUE(MLFloat16::MaxValue.IsFinite());
@ -550,6 +574,8 @@ TEST_F(DataTypeTest, MLFloat16NormalSubnormal) {
EXPECT_TRUE(smallest_subnormal.IsSubnormal());
EXPECT_FALSE(smallest_subnormal.IsNormal());
EXPECT_EQ(smallest_subnormal, std::numeric_limits<MLFloat16>::denorm_min());
// float smallest positive subnormal is ~1.40129846432481707092E-45, and
// in float the same number above would be normal
const float float_from_smallest_subnormal = static_cast<float>(smallest_subnormal);
@ -639,6 +665,24 @@ TEST_F(DataTypeTest, BFloat16Comparision) {
}
TEST_F(DataTypeTest, BFloat16TestNAN) {
const BFloat16 quiet_NaN = std::numeric_limits<BFloat16>::quiet_NaN();
EXPECT_TRUE(quiet_NaN.IsNaN());
EXPECT_TRUE(quiet_NaN.IsNaNOrZero());
EXPECT_NE(BFloat16::NaN, quiet_NaN);
EXPECT_TRUE(std::isnan(quiet_NaN.ToFloat()));
const BFloat16 signaling_NaN = std::numeric_limits<BFloat16>::signaling_NaN();
EXPECT_TRUE(signaling_NaN.IsNaN());
EXPECT_TRUE(signaling_NaN.IsNaNOrZero());
EXPECT_NE(BFloat16::NaN, signaling_NaN);
EXPECT_TRUE(std::isnan(signaling_NaN.ToFloat()));
const BFloat16 csharp_NaN = BFloat16::FromBits(0xFFC1U);
EXPECT_TRUE(csharp_NaN.IsNaN());
EXPECT_TRUE(csharp_NaN.IsNaNOrZero());
EXPECT_NE(BFloat16::NaN, csharp_NaN);
EXPECT_TRUE(std::isnan(csharp_NaN.ToFloat()));
const BFloat16 fp16NANFromSingle = std::numeric_limits<float>::quiet_NaN();
EXPECT_TRUE(fp16NANFromSingle.IsNaN());
EXPECT_TRUE(fp16NANFromSingle.IsNaNOrZero());
@ -695,6 +739,8 @@ TEST_F(DataTypeTest, BFloat16NormalSubnormal) {
EXPECT_TRUE(smallest_subnormal.IsSubnormal());
EXPECT_FALSE(smallest_subnormal.IsNormal());
EXPECT_EQ(smallest_subnormal, std::numeric_limits<BFloat16>::denorm_min());
const float float_from_smallest_subnormal = (float)smallest_subnormal;
EXPECT_FALSE(std::isnormal(float_from_smallest_subnormal));

View file

@ -137,6 +137,45 @@ TEST(MathOpTest, Clip_MLFloat16) {
test.Run();
}
TEST(MathOpTest, Clip_MLFloat16_NoMin_NoMax) {
OpTester test("Clip", 12);
std::vector<int64_t> dims{3};
test.AddInput<MLFloat16>("X", dims,
{MLFloat16(-1.0f), MLFloat16(-2.0f), MLFloat16(3.0f)});
test.AddOutput<MLFloat16>("Y", dims,
{MLFloat16(-1.0f), MLFloat16(-2.0f), MLFloat16(3.0f)});
test.Run();
}
TEST(MathOpTest, Clip_MLFloat16_NoMax) {
OpTester test("Clip", 12);
std::vector<int64_t> dims{3};
test.AddInput<MLFloat16>("X", dims,
{MLFloat16(-1.0f), MLFloat16(-2.0f), MLFloat16(3.0f)});
test.AddInput<MLFloat16>("min", {}, {MLFloat16(0.0f)});
test.AddOutput<MLFloat16>("Y", dims,
{MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(3.0f)});
test.Run();
}
TEST(MathOpTest, Clip_MLFloat16_NoMin) {
OpTester test("Clip", 12);
std::vector<int64_t> dims{3};
test.AddInput<MLFloat16>("X", dims,
{MLFloat16(-1.0f), MLFloat16(-2.0f), MLFloat16(3.0f)});
test.AddOptionalInputEdge<MLFloat16>(); // no min
test.AddInput<MLFloat16>("max", {}, {MLFloat16(0.0f)});
test.AddOutput<MLFloat16>("Y", dims,
{MLFloat16(-1.0f), MLFloat16(-2.0f), MLFloat16(0.0f)});
test.Run();
}
TEST(MathOpTest, Clip_int32) {
OpTester test("Clip", 12);