Fix float 8 rounding on CPU (#16940)

### Description
Fix float 8 rounding issues discovered in issue #16938 (only CPU
provider).
This commit is contained in:
Xavier Dupré 2023-09-07 20:48:25 +02:00 committed by GitHub
parent 0a3eb60b01
commit 024f1dd72b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 187 additions and 72 deletions

View file

@ -44,35 +44,36 @@ struct Float8E4M3FN {
std::memcpy(&b, &v, sizeof(b));
val = static_cast<uint8_t>((b & 0x80000000) >> 24); // sign
if ((b & 0x7fc00000) == 0x7fc00000) {
val |= 0x7f;
} else if ((b & 0x7fffffff) == 0x7f800000) {
if ((b & 0x7fffffff) == 0x7f800000) { // infinity
if (saturate) {
val |= 126;
} else {
val |= 0x7f;
}
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN
val |= 0x7f;
} else {
uint8_t e = static_cast<uint8_t>((b & 0x7F800000) >> 23); // exponent
uint32_t m = static_cast<uint32_t>(b & 0x007FFFFF); // mantissa
if (e != 0) {
if (e < 117) { // 0b1110101
} else if (e < 118) { // 0b1110110
val |= 1;
if ((m >> 23) & 1) {
if (e < 117) {
} else if (e < 121) {
// denormalized number
auto d = 120 - e;
if (d < 3) {
val |= 1 << (2 - d);
val |= m >> (21 + d);
} else if (m > 0) {
val |= 1;
}
auto mask = 1 << (20 + d);
if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
// rounding
val += 1;
}
} else if (e < 121) { // 127 - 7 + 1 // 0b1111001
auto d = 120 - e; // 0b1111000
val |= 1 << (2 - d);
val |= m >> (21 + d);
if ((m >> (20 + d)) & 1) {
// rounding
val += 1;
}
} else if (e < 136) { // 127 + 8 + 1 // 0b10001000
auto ex = e - 120; // 127 - 7
} else if (e < 136) {
// normalized number
auto ex = e - 120;
if (ex == 0) {
val |= 0x4;
val |= m >> 21;
@ -83,7 +84,7 @@ struct Float8E4M3FN {
val &= 0xFE;
}
}
if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7C000))) {
if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) {
if ((val & 0x7F) < 0x7E) {
// rounding
val += 1;
@ -205,36 +206,37 @@ struct Float8E4M3FNUZ {
std::memcpy(&b, &v, sizeof(b));
val = static_cast<uint8_t>((b & 0x80000000) >> 24); // sign
if ((b & 0x7fc00000) == 0x7fc00000) {
val = 0x80;
} else if ((b & 0x7fffffff) == 0x7f800000) {
if ((b & 0x7fffffff) == 0x7f800000) { // infinity
if (saturate) {
val |= 0x7F;
} else {
// infinity
val = 0x80;
}
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN
val = 0x80;
} else {
uint8_t e = static_cast<uint8_t>((b & 0x7F800000) >> 23); // exponent
uint32_t m = static_cast<uint32_t>(b & 0x007FFFFF); // mantissa
if (e != 0) {
if (e < 116) {
} else if (e < 117) {
val |= 1;
if ((m >> 23) & 1) {
// rounding
val += 1;
}
} else if (e < 120) { // 127 - 8 + 1
} else if (e < 120) {
// denormalized number
auto d = 119 - e;
val |= 1 << (2 - d);
val |= m >> (21 + d);
if ((m >> (20 + d)) & 1) {
if (d < 3) {
val |= 1 << (2 - d);
val |= m >> (21 + d);
} else if (m > 0) {
val |= 1;
}
auto mask = 1 << (20 + d);
if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
// rounding
val += 1;
}
} else if (e < 135) { // 127 + 8
auto ex = e - 119; // 127 - 7
} else if (e < 135) {
// normalized number
auto ex = e - 119;
if (ex == 0) {
val |= 0x4;
val |= m >> 21;
@ -242,7 +244,7 @@ struct Float8E4M3FNUZ {
val |= ex << 3;
val |= m >> 20;
}
if (m & 0x80000) {
if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) {
if ((val & 0x7F) < 0x7F) {
// rounding
val += 1;
@ -357,32 +359,32 @@ struct Float8E5M2 {
uint32_t b;
std::memcpy(&b, &v, sizeof(b));
val = (b & 0x80000000) >> 24; // sign
if ((b & 0x7fc00000) == 0x7fc00000) {
val |= 0x7f;
} else if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf
val = (b & 0x80000000) >> 24; // sign
if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf
if (saturate) {
val |= 0x7B;
} else {
val |= 0x7C;
}
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN
val |= 0x7f;
} else {
uint32_t e = (b & 0x7F800000) >> 23; // exponent
uint32_t m = b & 0x007FFFFF; // mantissa
if (e != 0) {
if (e < 110) {
} else if (e < 111) {
val |= 1;
if ((m >> 23) & 1) {
// rounding
val += 1;
}
} else if (e < 113) { // 127 - 15 + 1
} else if (e < 113) {
// denormalized number
auto d = 112 - e;
val |= 1 << (1 - d);
val |= m >> (22 + d);
if ((m >> (21 + d)) & 1) {
if (d < 2) {
val |= 1 << (1 - d);
val |= m >> (22 + d);
} else if (m > 0) {
val |= 1;
}
auto mask = 1 << (21 + d);
if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
// rounding
val += 1;
}
@ -513,40 +515,41 @@ struct Float8E5M2FNUZ {
uint32_t b;
std::memcpy(&b, &v, sizeof(b));
val = (b & 0x80000000) >> 24; // sign
if ((b & 0x7fc00000) == 0x7fc00000) {
val = 0x80;
} else if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf
val = (b & 0x80000000) >> 24; // sign
if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf
if (saturate) {
val |= 0x7F;
} else {
val = 0x80;
}
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN
val = 0x80;
} else {
uint32_t e = (b & 0x7F800000) >> 23; // exponent
uint32_t m = b & 0x007FFFFF; // mantissa
if (e != 0) {
if (e < 109) {
} else if (e < 110) {
val |= 1;
if ((m >> 23) & 1) {
// rounding
val += 1;
}
} else if (e < 112) { // 127 - 16 + 1
} else if (e < 112) {
// denormalized number
auto d = 111 - e;
val |= 1 << (1 - d);
val |= m >> (22 + d);
if ((m >> (21 + d)) & 1) {
if (d < 2) {
val |= 1 << (1 - d);
val |= m >> (22 + d);
} else if (m > 0) {
val |= 1;
}
auto mask = 1 << (21 + d);
if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
// rounding
val += 1;
}
} else if (e < 143) { // 127 + 15 + 1
} else if (e < 143) {
// normalized number
auto ex = e - 111;
val |= ex << 2;
val |= m >> 21;
if (m & 0x100000) {
if ((m & 0x100000) && ((m & 0xFFFFF) || (m & 0x200000))) {
if ((val & 0x7F) < 0x7F) {
// rounding
val += 1;
@ -554,7 +557,7 @@ struct Float8E5M2FNUZ {
val = 0x80;
}
}
} else if ((e == 255) && (m == 0)) { // inf
} else if ((e == 255) && (m == 0)) {
val = 0x80;
} else if (saturate) {
val |= 0x7F;

View file

@ -0,0 +1,96 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#if !defined(DISABLE_FLOAT8_TYPES)
#include <vector>
#include "core/framework/float8.h"
#include "test/capturing_sink.h"
#include "test/test_environment.h"
#include "test_utils.h"
#include "gtest/gtest.h"
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
namespace onnxruntime {
namespace test {
TEST(Float8_Tests, CastE4M3FN) {
std::vector<std::pair<float, float>> cases{
std::pair<float, float>(0.00439453125, 0.00390625),
std::pair<float, float>(0.005859375, 0.005859375),
std::pair<float, float>(0.005759375, 0.005859375),
std::pair<float, float>(0.0046875, 0.00390625),
std::pair<float, float>(0.001953125, 0.001953125),
std::pair<float, float>(0.0029296875, 0.00390625),
std::pair<float, float>(0.002053125, 0.001953125),
std::pair<float, float>(0.00234375, 0.001953125),
std::pair<float, float>(0.0087890625, 0.0078125),
std::pair<float, float>(0.001171875, 0.001953125),
std::pair<float, float>(1.8131605, 1.875)};
for (auto it : cases) {
auto f8 = onnxruntime::Float8E4M3FN(it.first);
auto f8_32 = f8.ToFloat();
EXPECT_EQ(it.second, f8_32);
}
}
union float_bits {
uint32_t bits;
float val;
};
TEST(Float8_Tests, NanE4M3FN) {
EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0x7F800000}).val).val, static_cast<uint8_t>(0x7E));
EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0xFF800000}).val).val, static_cast<uint8_t>(0xFE));
EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0x7F800000}).val, false).val, static_cast<uint8_t>(0x7F));
EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0xFF800000}).val, false).val, static_cast<uint8_t>(0xFF));
EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0x7F800001}).val).val, static_cast<uint8_t>(0x7F));
EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0xFF800001}).val).val, static_cast<uint8_t>(0xFF));
// 0x7FC00000 is the value used by numpy.
EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0x7FC00000}).val).val, static_cast<uint8_t>(0x7F));
EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0xFFC00000}).val).val, static_cast<uint8_t>(0xFF));
}
TEST(Float8_Tests, NanE4M3FNUZ) {
EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0x7F800000}).val).val, static_cast<uint8_t>(0x7F));
EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0xFF800000}).val).val, static_cast<uint8_t>(0xFF));
EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0x7F800000}).val, false).val, static_cast<uint8_t>(0x80));
EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0xFF800000}).val, false).val, static_cast<uint8_t>(0x80));
EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0x7F800001}).val).val, static_cast<uint8_t>(0x80));
EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0xFF800001}).val).val, static_cast<uint8_t>(0x80));
// 0x7FC00000 is the value used by numpy.
EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0x7FC00000}).val).val, static_cast<uint8_t>(0x80));
EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0xFFC00000}).val).val, static_cast<uint8_t>(0x80));
}
TEST(Float8_Tests, NanE5M2) {
EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0x7F800000}).val).val, static_cast<uint8_t>(0x7B));
EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0xFF800000}).val).val, static_cast<uint8_t>(0xFB));
EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0x7F800000}).val, false).val, static_cast<uint8_t>(0x7C));
EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0xFF800000}).val, false).val, static_cast<uint8_t>(0xFC));
EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0x7F800001}).val).val, static_cast<uint8_t>(0x7F));
EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0xFF800001}).val).val, static_cast<uint8_t>(0xFF));
// 0x7FC00000 is the value used by numpy.
EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0x7FC00000}).val).val, static_cast<uint8_t>(0x7F));
EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0xFFC00000}).val).val, static_cast<uint8_t>(0xFF));
}
TEST(Float8_Tests, NanE5M2FNUZ) {
EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0x7F800000}).val).val, static_cast<uint8_t>(0x7F));
EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0xFF800000}).val).val, static_cast<uint8_t>(0xFF));
EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0x7F800000}).val, false).val, static_cast<uint8_t>(0x80));
EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0xFF800000}).val, false).val, static_cast<uint8_t>(0x80));
EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0x7F800001}).val).val, static_cast<uint8_t>(0x80));
EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0xFF800001}).val).val, static_cast<uint8_t>(0x80));
// 0x7FC00000 is the value used by numpy.
EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0x7FC00000}).val).val, static_cast<uint8_t>(0x80));
EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0xFFC00000}).val).val, static_cast<uint8_t>(0x80));
}
} // namespace test
} // namespace onnxruntime
#endif // DISABLE_FLOAT8_TYPES

View file

@ -8,9 +8,11 @@ import sys
import unittest
import numpy as np
import packaging.version as pv
import parameterized
from numpy.testing import assert_allclose
from onnx import TensorProto
from onnx import __version__ as onnx_version
from onnx.checker import check_model
from onnx.helper import make_graph, make_model, make_node, make_opsetid, make_tensor, make_tensor_value_info
from onnx.reference import ReferenceEvaluator
@ -37,7 +39,7 @@ class TestInferenceSession(unittest.TestCase):
<https://onnx.ai/onnx/api/numpy_helper.html#onnx.numpy_helper.float8e5m2_to_float32>`_.
"""
dtypes = frozenset({"FLOAT": np.float32, "FLOAT16": np.float16})
dtypes = {"FLOAT": np.float32, "FLOAT16": np.float16} # noqa: RUF012
x = np.array(
[0.4068359375, 352, 416, 336, 304, 272, -248, -100, 1e-4, 1e-2, 416, 432, 1e5, np.inf, -np.inf, np.nan],
dtype=np.float32,
@ -76,7 +78,7 @@ class TestInferenceSession(unittest.TestCase):
240.0,
240.0,
-240.0,
-104.0,
-96.0,
0.0,
0.009765625,
240.0,
@ -113,7 +115,7 @@ class TestInferenceSession(unittest.TestCase):
[
0.4375,
384.0,
448.0,
384.0,
320.0,
320.0,
256.0,
@ -121,7 +123,7 @@ class TestInferenceSession(unittest.TestCase):
-96.0,
0.0001068115234375,
0.009765625,
448.0,
384.0,
448.0,
57344.0,
57344.0,
@ -167,7 +169,7 @@ class TestInferenceSession(unittest.TestCase):
np.nan,
np.nan,
np.nan,
-104.0,
-96.0,
0.0,
0.009765625,
np.nan,
@ -204,7 +206,7 @@ class TestInferenceSession(unittest.TestCase):
[
0.4375,
384.0,
448.0,
384.0,
320.0,
320.0,
256.0,
@ -212,7 +214,7 @@ class TestInferenceSession(unittest.TestCase):
-96.0,
0.0001068115234375,
0.009765625,
448.0,
384.0,
448.0,
np.nan,
np.nan,
@ -245,6 +247,7 @@ class TestInferenceSession(unittest.TestCase):
check_model(onnx_model)
return onnx_model
@unittest.skipIf(pv.Version(onnx_version) < pv.Version("1.15.0"), reason="needs onnx>=1.15.0")
@parameterized.parameterized.expand(
[
("FLOAT8E4M3FN", "FLOAT", 1),
@ -429,6 +432,7 @@ class TestInferenceSession(unittest.TestCase):
check_model(onnx_model)
return onnx_model
@unittest.skipIf(pv.Version(onnx_version) < pv.Version("1.15.0"), reason="needs onnx>=1.15.0")
@parameterized.parameterized.expand(
[
("FLOAT8E4M3FN", "FLOAT", 1),
@ -689,6 +693,18 @@ class TestInferenceSession(unittest.TestCase):
self.assertEqual(expect.shape, y.shape)
self.assertEqual(expect.dtype, y.dtype)
@unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running on CUDA.")
def test_compare_cpu_cuda_e4m3fn(self):
folder = os.path.join(os.path.dirname(__file__), "..", "testdata", "float8")
model = os.path.join(folder, "te.cast_fp8_1_fp32.onnx")
data = np.load(os.path.join(folder, "te.cast_fp8_1_fp32_input.npy"))
sess_cpu = onnxruntime.InferenceSession(model, providers=["CPUExecutionProvider"])
sess_cuda = onnxruntime.InferenceSession(model, providers=["CUDAExecutionProvider"])
cpu_res = sess_cpu.run(None, {"input": data})[0]
cuda_res = sess_cuda.run(None, {"input": data})[0]
self.assertEqual(cuda_res.tolist(), cpu_res.tolist())
if __name__ == "__main__":
unittest.main(verbosity=2)

Binary file not shown.

Binary file not shown.