diff --git a/include/onnxruntime/core/framework/float8.h b/include/onnxruntime/core/framework/float8.h index 0a318dac17..0fd04f28d4 100644 --- a/include/onnxruntime/core/framework/float8.h +++ b/include/onnxruntime/core/framework/float8.h @@ -44,35 +44,36 @@ struct Float8E4M3FN { std::memcpy(&b, &v, sizeof(b)); val = static_cast((b & 0x80000000) >> 24); // sign - if ((b & 0x7fc00000) == 0x7fc00000) { - val |= 0x7f; - } else if ((b & 0x7fffffff) == 0x7f800000) { + if ((b & 0x7fffffff) == 0x7f800000) { // infinity if (saturate) { val |= 126; } else { val |= 0x7f; } + } else if ((b & 0x7F800000) == 0x7F800000) { // NaN + val |= 0x7f; } else { uint8_t e = static_cast((b & 0x7F800000) >> 23); // exponent uint32_t m = static_cast(b & 0x007FFFFF); // mantissa if (e != 0) { - if (e < 117) { // 0b1110101 - } else if (e < 118) { // 0b1110110 - val |= 1; - if ((m >> 23) & 1) { + if (e < 117) { + } else if (e < 121) { + // denormalized number + auto d = 120 - e; + if (d < 3) { + val |= 1 << (2 - d); + val |= m >> (21 + d); + } else if (m > 0) { + val |= 1; + } + auto mask = 1 << (20 + d); + if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) { // rounding val += 1; } - } else if (e < 121) { // 127 - 7 + 1 // 0b1111001 - auto d = 120 - e; // 0b1111000 - val |= 1 << (2 - d); - val |= m >> (21 + d); - if ((m >> (20 + d)) & 1) { - // rounding - val += 1; - } - } else if (e < 136) { // 127 + 8 + 1 // 0b10001000 - auto ex = e - 120; // 127 - 7 + } else if (e < 136) { + // normalized number + auto ex = e - 120; if (ex == 0) { val |= 0x4; val |= m >> 21; @@ -83,7 +84,7 @@ struct Float8E4M3FN { val &= 0xFE; } } - if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7C000))) { + if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) { if ((val & 0x7F) < 0x7E) { // rounding val += 1; @@ -205,36 +206,37 @@ struct Float8E4M3FNUZ { std::memcpy(&b, &v, sizeof(b)); val = static_cast((b & 0x80000000) >> 24); // sign - if ((b & 0x7fc00000) == 0x7fc00000) { - val = 0x80; - } else if ((b & 0x7fffffff) == 0x7f800000) { + if ((b & 0x7fffffff) == 0x7f800000) { // infinity if (saturate) { val |= 0x7F; } else { // infinity val = 0x80; } + } else if ((b & 0x7F800000) == 0x7F800000) { // NaN + val = 0x80; } else { uint8_t e = static_cast((b & 0x7F800000) >> 23); // exponent uint32_t m = static_cast(b & 0x007FFFFF); // mantissa if (e != 0) { if (e < 116) { - } else if (e < 117) { - val |= 1; - if ((m >> 23) & 1) { - // rounding - val += 1; - } - } else if (e < 120) { // 127 - 8 + 1 + } else if (e < 120) { + // denormalized number auto d = 119 - e; - val |= 1 << (2 - d); - val |= m >> (21 + d); - if ((m >> (20 + d)) & 1) { + if (d < 3) { + val |= 1 << (2 - d); + val |= m >> (21 + d); + } else if (m > 0) { + val |= 1; + } + auto mask = 1 << (20 + d); + if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) { // rounding val += 1; } - } else if (e < 135) { // 127 + 8 - auto ex = e - 119; // 127 - 7 + } else if (e < 135) { + // normalized number + auto ex = e - 119; if (ex == 0) { val |= 0x4; val |= m >> 21; @@ -242,7 +244,7 @@ struct Float8E4M3FNUZ { val |= ex << 3; val |= m >> 20; } - if (m & 0x80000) { + if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) { if ((val & 0x7F) < 0x7F) { // rounding val += 1; @@ -357,32 +359,32 @@ struct Float8E5M2 { uint32_t b; std::memcpy(&b, &v, sizeof(b)); - val = (b & 0x80000000) >> 24; // sign - if ((b & 0x7fc00000) == 0x7fc00000) { - val |= 0x7f; - } else if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf + val = (b & 0x80000000) >> 24; // sign + if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf if (saturate) { val |= 0x7B; } else { val |= 0x7C; } + } else if ((b & 0x7F800000) == 0x7F800000) { // NaN + val |= 0x7f; } else { uint32_t e = (b & 0x7F800000) >> 23; // exponent uint32_t m = b & 0x007FFFFF; // mantissa if (e != 0) { if (e < 110) { - } else if (e < 111) { - val |= 1; - if ((m >> 23) & 1) { - // rounding - val += 1; - } - } else if (e < 113) { // 127 - 15 + 1 + } else if (e < 113) { + // denormalized number auto d = 112 - e; - val |= 1 << (1 - d); - val |= m >> (22 + d); - if ((m >> (21 + d)) & 1) { + if (d < 2) { + val |= 1 << (1 - d); + val |= m >> (22 + d); + } else if (m > 0) { + val |= 1; + } + auto mask = 1 << (21 + d); + if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) { // rounding val += 1; } @@ -513,40 +515,41 @@ struct Float8E5M2FNUZ { uint32_t b; std::memcpy(&b, &v, sizeof(b)); - val = (b & 0x80000000) >> 24; // sign - if ((b & 0x7fc00000) == 0x7fc00000) { - val = 0x80; - } else if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf + val = (b & 0x80000000) >> 24; // sign + if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf if (saturate) { val |= 0x7F; } else { val = 0x80; } + } else if ((b & 0x7F800000) == 0x7F800000) { // NaN + val = 0x80; } else { uint32_t e = (b & 0x7F800000) >> 23; // exponent uint32_t m = b & 0x007FFFFF; // mantissa if (e != 0) { if (e < 109) { - } else if (e < 110) { - val |= 1; - if ((m >> 23) & 1) { - // rounding - val += 1; - } - } else if (e < 112) { // 127 - 16 + 1 + } else if (e < 112) { + // denormalized number auto d = 111 - e; - val |= 1 << (1 - d); - val |= m >> (22 + d); - if ((m >> (21 + d)) & 1) { + if (d < 2) { + val |= 1 << (1 - d); + val |= m >> (22 + d); + } else if (m > 0) { + val |= 1; + } + auto mask = 1 << (21 + d); + if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) { // rounding val += 1; } - } else if (e < 143) { // 127 + 15 + 1 + } else if (e < 143) { + // normalized number auto ex = e - 111; val |= ex << 2; val |= m >> 21; - if (m & 0x100000) { + if ((m & 0x100000) && ((m & 0xFFFFF) || (m & 0x200000))) { if ((val & 0x7F) < 0x7F) { // rounding val += 1; @@ -554,7 +557,7 @@ struct Float8E5M2FNUZ { val = 0x80; } } - } else if ((e == 255) && (m == 0)) { // inf + } else if ((e == 255) && (m == 0)) { val = 0x80; } else if (saturate) { val |= 0x7F; diff --git a/onnxruntime/test/framework/float_8_test.cc b/onnxruntime/test/framework/float_8_test.cc new file mode 100644 index 0000000000..948e0e05a9 --- /dev/null +++ b/onnxruntime/test/framework/float_8_test.cc @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(DISABLE_FLOAT8_TYPES) + +#include + +#include "core/framework/float8.h" +#include "test/capturing_sink.h" +#include "test/test_environment.h" +#include "test_utils.h" +#include "gtest/gtest.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace test { + +TEST(Float8_Tests, CastE4M3FN) { + std::vector> cases{ + std::pair(0.00439453125, 0.00390625), + std::pair(0.005859375, 0.005859375), + std::pair(0.005759375, 0.005859375), + std::pair(0.0046875, 0.00390625), + std::pair(0.001953125, 0.001953125), + std::pair(0.0029296875, 0.00390625), + std::pair(0.002053125, 0.001953125), + std::pair(0.00234375, 0.001953125), + std::pair(0.0087890625, 0.0078125), + std::pair(0.001171875, 0.001953125), + std::pair(1.8131605, 1.875)}; + for (auto it : cases) { + auto f8 = onnxruntime::Float8E4M3FN(it.first); + auto f8_32 = f8.ToFloat(); + EXPECT_EQ(it.second, f8_32); + } +} + +union float_bits { + uint32_t bits; + float val; +}; + +TEST(Float8_Tests, NanE4M3FN) { + EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0x7F800000}).val).val, static_cast(0x7E)); + EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0xFF800000}).val).val, static_cast(0xFE)); + EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0x7F800000}).val, false).val, static_cast(0x7F)); + EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0xFF800000}).val, false).val, static_cast(0xFF)); + EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0x7F800001}).val).val, static_cast(0x7F)); + EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0xFF800001}).val).val, static_cast(0xFF)); + // 0x7FC00000 is the value used by numpy. + EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0x7FC00000}).val).val, static_cast(0x7F)); + EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0xFFC00000}).val).val, static_cast(0xFF)); +} + +TEST(Float8_Tests, NanE4M3FNUZ) { + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0x7F800000}).val).val, static_cast(0x7F)); + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0xFF800000}).val).val, static_cast(0xFF)); + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0x7F800000}).val, false).val, static_cast(0x80)); + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0xFF800000}).val, false).val, static_cast(0x80)); + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0x7F800001}).val).val, static_cast(0x80)); + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0xFF800001}).val).val, static_cast(0x80)); + // 0x7FC00000 is the value used by numpy. + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0x7FC00000}).val).val, static_cast(0x80)); + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0xFFC00000}).val).val, static_cast(0x80)); +} + +TEST(Float8_Tests, NanE5M2) { + EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0x7F800000}).val).val, static_cast(0x7B)); + EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0xFF800000}).val).val, static_cast(0xFB)); + EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0x7F800000}).val, false).val, static_cast(0x7C)); + EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0xFF800000}).val, false).val, static_cast(0xFC)); + EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0x7F800001}).val).val, static_cast(0x7F)); + EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0xFF800001}).val).val, static_cast(0xFF)); + // 0x7FC00000 is the value used by numpy. + EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0x7FC00000}).val).val, static_cast(0x7F)); + EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0xFFC00000}).val).val, static_cast(0xFF)); +} + +TEST(Float8_Tests, NanE5M2FNUZ) { + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0x7F800000}).val).val, static_cast(0x7F)); + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0xFF800000}).val).val, static_cast(0xFF)); + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0x7F800000}).val, false).val, static_cast(0x80)); + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0xFF800000}).val, false).val, static_cast(0x80)); + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0x7F800001}).val).val, static_cast(0x80)); + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0xFF800001}).val).val, static_cast(0x80)); + // 0x7FC00000 is the value used by numpy. + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0x7FC00000}).val).val, static_cast(0x80)); + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0xFFC00000}).val).val, static_cast(0x80)); +} + +} // namespace test +} // namespace onnxruntime + +#endif // DISABLE_FLOAT8_TYPES diff --git a/onnxruntime/test/python/onnxruntime_test_float8.py b/onnxruntime/test/python/onnxruntime_test_float8.py index 3f3180230f..76ca5d9538 100644 --- a/onnxruntime/test/python/onnxruntime_test_float8.py +++ b/onnxruntime/test/python/onnxruntime_test_float8.py @@ -8,9 +8,11 @@ import sys import unittest import numpy as np +import packaging.version as pv import parameterized from numpy.testing import assert_allclose from onnx import TensorProto +from onnx import __version__ as onnx_version from onnx.checker import check_model from onnx.helper import make_graph, make_model, make_node, make_opsetid, make_tensor, make_tensor_value_info from onnx.reference import ReferenceEvaluator @@ -37,7 +39,7 @@ class TestInferenceSession(unittest.TestCase): `_. """ - dtypes = frozenset({"FLOAT": np.float32, "FLOAT16": np.float16}) + dtypes = {"FLOAT": np.float32, "FLOAT16": np.float16} # noqa: RUF012 x = np.array( [0.4068359375, 352, 416, 336, 304, 272, -248, -100, 1e-4, 1e-2, 416, 432, 1e5, np.inf, -np.inf, np.nan], dtype=np.float32, @@ -76,7 +78,7 @@ class TestInferenceSession(unittest.TestCase): 240.0, 240.0, -240.0, - -104.0, + -96.0, 0.0, 0.009765625, 240.0, @@ -113,7 +115,7 @@ class TestInferenceSession(unittest.TestCase): [ 0.4375, 384.0, - 448.0, + 384.0, 320.0, 320.0, 256.0, @@ -121,7 +123,7 @@ class TestInferenceSession(unittest.TestCase): -96.0, 0.0001068115234375, 0.009765625, - 448.0, + 384.0, 448.0, 57344.0, 57344.0, @@ -167,7 +169,7 @@ class TestInferenceSession(unittest.TestCase): np.nan, np.nan, np.nan, - -104.0, + -96.0, 0.0, 0.009765625, np.nan, @@ -204,7 +206,7 @@ class TestInferenceSession(unittest.TestCase): [ 0.4375, 384.0, - 448.0, + 384.0, 320.0, 320.0, 256.0, @@ -212,7 +214,7 @@ class TestInferenceSession(unittest.TestCase): -96.0, 0.0001068115234375, 0.009765625, - 448.0, + 384.0, 448.0, np.nan, np.nan, @@ -245,6 +247,7 @@ class TestInferenceSession(unittest.TestCase): check_model(onnx_model) return onnx_model + @unittest.skipIf(pv.Version(onnx_version) < pv.Version("1.15.0"), reason="needs onnx>=1.15.0") @parameterized.parameterized.expand( [ ("FLOAT8E4M3FN", "FLOAT", 1), @@ -429,6 +432,7 @@ class TestInferenceSession(unittest.TestCase): check_model(onnx_model) return onnx_model + @unittest.skipIf(pv.Version(onnx_version) < pv.Version("1.15.0"), reason="needs onnx>=1.15.0") @parameterized.parameterized.expand( [ ("FLOAT8E4M3FN", "FLOAT", 1), @@ -689,6 +693,18 @@ class TestInferenceSession(unittest.TestCase): self.assertEqual(expect.shape, y.shape) self.assertEqual(expect.dtype, y.dtype) + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running on CUDA.") + def test_compare_cpu_cuda_e4m3fn(self): + folder = os.path.join(os.path.dirname(__file__), "..", "testdata", "float8") + model = os.path.join(folder, "te.cast_fp8_1_fp32.onnx") + data = np.load(os.path.join(folder, "te.cast_fp8_1_fp32_input.npy")) + + sess_cpu = onnxruntime.InferenceSession(model, providers=["CPUExecutionProvider"]) + sess_cuda = onnxruntime.InferenceSession(model, providers=["CUDAExecutionProvider"]) + cpu_res = sess_cpu.run(None, {"input": data})[0] + cuda_res = sess_cuda.run(None, {"input": data})[0] + self.assertEqual(cuda_res.tolist(), cpu_res.tolist()) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxruntime/test/testdata/float8/te.cast_fp8_1_fp32.onnx b/onnxruntime/test/testdata/float8/te.cast_fp8_1_fp32.onnx new file mode 100644 index 0000000000..1dec991008 Binary files /dev/null and b/onnxruntime/test/testdata/float8/te.cast_fp8_1_fp32.onnx differ diff --git a/onnxruntime/test/testdata/float8/te.cast_fp8_1_fp32_input.npy b/onnxruntime/test/testdata/float8/te.cast_fp8_1_fp32_input.npy new file mode 100644 index 0000000000..706f508836 Binary files /dev/null and b/onnxruntime/test/testdata/float8/te.cast_fp8_1_fp32_input.npy differ