mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
Fix float 8 rounding on CPU (#16940)
### Description Fix float 8 rounding issues discovered in issue #16938 (only CPU provider).
This commit is contained in:
parent
0a3eb60b01
commit
024f1dd72b
5 changed files with 187 additions and 72 deletions
|
|
@ -44,35 +44,36 @@ struct Float8E4M3FN {
|
|||
std::memcpy(&b, &v, sizeof(b));
|
||||
|
||||
val = static_cast<uint8_t>((b & 0x80000000) >> 24); // sign
|
||||
if ((b & 0x7fc00000) == 0x7fc00000) {
|
||||
val |= 0x7f;
|
||||
} else if ((b & 0x7fffffff) == 0x7f800000) {
|
||||
if ((b & 0x7fffffff) == 0x7f800000) { // infinity
|
||||
if (saturate) {
|
||||
val |= 126;
|
||||
} else {
|
||||
val |= 0x7f;
|
||||
}
|
||||
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN
|
||||
val |= 0x7f;
|
||||
} else {
|
||||
uint8_t e = static_cast<uint8_t>((b & 0x7F800000) >> 23); // exponent
|
||||
uint32_t m = static_cast<uint32_t>(b & 0x007FFFFF); // mantissa
|
||||
if (e != 0) {
|
||||
if (e < 117) { // 0b1110101
|
||||
} else if (e < 118) { // 0b1110110
|
||||
val |= 1;
|
||||
if ((m >> 23) & 1) {
|
||||
if (e < 117) {
|
||||
} else if (e < 121) {
|
||||
// denormalized number
|
||||
auto d = 120 - e;
|
||||
if (d < 3) {
|
||||
val |= 1 << (2 - d);
|
||||
val |= m >> (21 + d);
|
||||
} else if (m > 0) {
|
||||
val |= 1;
|
||||
}
|
||||
auto mask = 1 << (20 + d);
|
||||
if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
|
||||
// rounding
|
||||
val += 1;
|
||||
}
|
||||
} else if (e < 121) { // 127 - 7 + 1 // 0b1111001
|
||||
auto d = 120 - e; // 0b1111000
|
||||
val |= 1 << (2 - d);
|
||||
val |= m >> (21 + d);
|
||||
if ((m >> (20 + d)) & 1) {
|
||||
// rounding
|
||||
val += 1;
|
||||
}
|
||||
} else if (e < 136) { // 127 + 8 + 1 // 0b10001000
|
||||
auto ex = e - 120; // 127 - 7
|
||||
} else if (e < 136) {
|
||||
// normalized number
|
||||
auto ex = e - 120;
|
||||
if (ex == 0) {
|
||||
val |= 0x4;
|
||||
val |= m >> 21;
|
||||
|
|
@ -83,7 +84,7 @@ struct Float8E4M3FN {
|
|||
val &= 0xFE;
|
||||
}
|
||||
}
|
||||
if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7C000))) {
|
||||
if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) {
|
||||
if ((val & 0x7F) < 0x7E) {
|
||||
// rounding
|
||||
val += 1;
|
||||
|
|
@ -205,36 +206,37 @@ struct Float8E4M3FNUZ {
|
|||
std::memcpy(&b, &v, sizeof(b));
|
||||
|
||||
val = static_cast<uint8_t>((b & 0x80000000) >> 24); // sign
|
||||
if ((b & 0x7fc00000) == 0x7fc00000) {
|
||||
val = 0x80;
|
||||
} else if ((b & 0x7fffffff) == 0x7f800000) {
|
||||
if ((b & 0x7fffffff) == 0x7f800000) { // infinity
|
||||
if (saturate) {
|
||||
val |= 0x7F;
|
||||
} else {
|
||||
// infinity
|
||||
val = 0x80;
|
||||
}
|
||||
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN
|
||||
val = 0x80;
|
||||
} else {
|
||||
uint8_t e = static_cast<uint8_t>((b & 0x7F800000) >> 23); // exponent
|
||||
uint32_t m = static_cast<uint32_t>(b & 0x007FFFFF); // mantissa
|
||||
if (e != 0) {
|
||||
if (e < 116) {
|
||||
} else if (e < 117) {
|
||||
val |= 1;
|
||||
if ((m >> 23) & 1) {
|
||||
// rounding
|
||||
val += 1;
|
||||
}
|
||||
} else if (e < 120) { // 127 - 8 + 1
|
||||
} else if (e < 120) {
|
||||
// denormalized number
|
||||
auto d = 119 - e;
|
||||
val |= 1 << (2 - d);
|
||||
val |= m >> (21 + d);
|
||||
if ((m >> (20 + d)) & 1) {
|
||||
if (d < 3) {
|
||||
val |= 1 << (2 - d);
|
||||
val |= m >> (21 + d);
|
||||
} else if (m > 0) {
|
||||
val |= 1;
|
||||
}
|
||||
auto mask = 1 << (20 + d);
|
||||
if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
|
||||
// rounding
|
||||
val += 1;
|
||||
}
|
||||
} else if (e < 135) { // 127 + 8
|
||||
auto ex = e - 119; // 127 - 7
|
||||
} else if (e < 135) {
|
||||
// normalized number
|
||||
auto ex = e - 119;
|
||||
if (ex == 0) {
|
||||
val |= 0x4;
|
||||
val |= m >> 21;
|
||||
|
|
@ -242,7 +244,7 @@ struct Float8E4M3FNUZ {
|
|||
val |= ex << 3;
|
||||
val |= m >> 20;
|
||||
}
|
||||
if (m & 0x80000) {
|
||||
if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) {
|
||||
if ((val & 0x7F) < 0x7F) {
|
||||
// rounding
|
||||
val += 1;
|
||||
|
|
@ -357,32 +359,32 @@ struct Float8E5M2 {
|
|||
uint32_t b;
|
||||
std::memcpy(&b, &v, sizeof(b));
|
||||
|
||||
val = (b & 0x80000000) >> 24; // sign
|
||||
if ((b & 0x7fc00000) == 0x7fc00000) {
|
||||
val |= 0x7f;
|
||||
} else if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf
|
||||
val = (b & 0x80000000) >> 24; // sign
|
||||
if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf
|
||||
if (saturate) {
|
||||
val |= 0x7B;
|
||||
} else {
|
||||
val |= 0x7C;
|
||||
}
|
||||
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN
|
||||
val |= 0x7f;
|
||||
} else {
|
||||
uint32_t e = (b & 0x7F800000) >> 23; // exponent
|
||||
uint32_t m = b & 0x007FFFFF; // mantissa
|
||||
|
||||
if (e != 0) {
|
||||
if (e < 110) {
|
||||
} else if (e < 111) {
|
||||
val |= 1;
|
||||
if ((m >> 23) & 1) {
|
||||
// rounding
|
||||
val += 1;
|
||||
}
|
||||
} else if (e < 113) { // 127 - 15 + 1
|
||||
} else if (e < 113) {
|
||||
// denormalized number
|
||||
auto d = 112 - e;
|
||||
val |= 1 << (1 - d);
|
||||
val |= m >> (22 + d);
|
||||
if ((m >> (21 + d)) & 1) {
|
||||
if (d < 2) {
|
||||
val |= 1 << (1 - d);
|
||||
val |= m >> (22 + d);
|
||||
} else if (m > 0) {
|
||||
val |= 1;
|
||||
}
|
||||
auto mask = 1 << (21 + d);
|
||||
if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
|
||||
// rounding
|
||||
val += 1;
|
||||
}
|
||||
|
|
@ -513,40 +515,41 @@ struct Float8E5M2FNUZ {
|
|||
uint32_t b;
|
||||
std::memcpy(&b, &v, sizeof(b));
|
||||
|
||||
val = (b & 0x80000000) >> 24; // sign
|
||||
if ((b & 0x7fc00000) == 0x7fc00000) {
|
||||
val = 0x80;
|
||||
} else if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf
|
||||
val = (b & 0x80000000) >> 24; // sign
|
||||
if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf
|
||||
if (saturate) {
|
||||
val |= 0x7F;
|
||||
} else {
|
||||
val = 0x80;
|
||||
}
|
||||
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN
|
||||
val = 0x80;
|
||||
} else {
|
||||
uint32_t e = (b & 0x7F800000) >> 23; // exponent
|
||||
uint32_t m = b & 0x007FFFFF; // mantissa
|
||||
|
||||
if (e != 0) {
|
||||
if (e < 109) {
|
||||
} else if (e < 110) {
|
||||
val |= 1;
|
||||
if ((m >> 23) & 1) {
|
||||
// rounding
|
||||
val += 1;
|
||||
}
|
||||
} else if (e < 112) { // 127 - 16 + 1
|
||||
} else if (e < 112) {
|
||||
// denormalized number
|
||||
auto d = 111 - e;
|
||||
val |= 1 << (1 - d);
|
||||
val |= m >> (22 + d);
|
||||
if ((m >> (21 + d)) & 1) {
|
||||
if (d < 2) {
|
||||
val |= 1 << (1 - d);
|
||||
val |= m >> (22 + d);
|
||||
} else if (m > 0) {
|
||||
val |= 1;
|
||||
}
|
||||
auto mask = 1 << (21 + d);
|
||||
if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
|
||||
// rounding
|
||||
val += 1;
|
||||
}
|
||||
} else if (e < 143) { // 127 + 15 + 1
|
||||
} else if (e < 143) {
|
||||
// normalized number
|
||||
auto ex = e - 111;
|
||||
val |= ex << 2;
|
||||
val |= m >> 21;
|
||||
if (m & 0x100000) {
|
||||
if ((m & 0x100000) && ((m & 0xFFFFF) || (m & 0x200000))) {
|
||||
if ((val & 0x7F) < 0x7F) {
|
||||
// rounding
|
||||
val += 1;
|
||||
|
|
@ -554,7 +557,7 @@ struct Float8E5M2FNUZ {
|
|||
val = 0x80;
|
||||
}
|
||||
}
|
||||
} else if ((e == 255) && (m == 0)) { // inf
|
||||
} else if ((e == 255) && (m == 0)) {
|
||||
val = 0x80;
|
||||
} else if (saturate) {
|
||||
val |= 0x7F;
|
||||
|
|
|
|||
96
onnxruntime/test/framework/float_8_test.cc
Normal file
96
onnxruntime/test/framework/float_8_test.cc
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#if !defined(DISABLE_FLOAT8_TYPES)
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "core/framework/float8.h"
|
||||
#include "test/capturing_sink.h"
|
||||
#include "test/test_environment.h"
|
||||
#include "test_utils.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace onnxruntime::common;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
TEST(Float8_Tests, CastE4M3FN) {
|
||||
std::vector<std::pair<float, float>> cases{
|
||||
std::pair<float, float>(0.00439453125, 0.00390625),
|
||||
std::pair<float, float>(0.005859375, 0.005859375),
|
||||
std::pair<float, float>(0.005759375, 0.005859375),
|
||||
std::pair<float, float>(0.0046875, 0.00390625),
|
||||
std::pair<float, float>(0.001953125, 0.001953125),
|
||||
std::pair<float, float>(0.0029296875, 0.00390625),
|
||||
std::pair<float, float>(0.002053125, 0.001953125),
|
||||
std::pair<float, float>(0.00234375, 0.001953125),
|
||||
std::pair<float, float>(0.0087890625, 0.0078125),
|
||||
std::pair<float, float>(0.001171875, 0.001953125),
|
||||
std::pair<float, float>(1.8131605, 1.875)};
|
||||
for (auto it : cases) {
|
||||
auto f8 = onnxruntime::Float8E4M3FN(it.first);
|
||||
auto f8_32 = f8.ToFloat();
|
||||
EXPECT_EQ(it.second, f8_32);
|
||||
}
|
||||
}
|
||||
|
||||
union float_bits {
|
||||
uint32_t bits;
|
||||
float val;
|
||||
};
|
||||
|
||||
TEST(Float8_Tests, NanE4M3FN) {
|
||||
EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0x7F800000}).val).val, static_cast<uint8_t>(0x7E));
|
||||
EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0xFF800000}).val).val, static_cast<uint8_t>(0xFE));
|
||||
EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0x7F800000}).val, false).val, static_cast<uint8_t>(0x7F));
|
||||
EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0xFF800000}).val, false).val, static_cast<uint8_t>(0xFF));
|
||||
EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0x7F800001}).val).val, static_cast<uint8_t>(0x7F));
|
||||
EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0xFF800001}).val).val, static_cast<uint8_t>(0xFF));
|
||||
// 0x7FC00000 is the value used by numpy.
|
||||
EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0x7FC00000}).val).val, static_cast<uint8_t>(0x7F));
|
||||
EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0xFFC00000}).val).val, static_cast<uint8_t>(0xFF));
|
||||
}
|
||||
|
||||
TEST(Float8_Tests, NanE4M3FNUZ) {
|
||||
EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0x7F800000}).val).val, static_cast<uint8_t>(0x7F));
|
||||
EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0xFF800000}).val).val, static_cast<uint8_t>(0xFF));
|
||||
EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0x7F800000}).val, false).val, static_cast<uint8_t>(0x80));
|
||||
EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0xFF800000}).val, false).val, static_cast<uint8_t>(0x80));
|
||||
EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0x7F800001}).val).val, static_cast<uint8_t>(0x80));
|
||||
EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0xFF800001}).val).val, static_cast<uint8_t>(0x80));
|
||||
// 0x7FC00000 is the value used by numpy.
|
||||
EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0x7FC00000}).val).val, static_cast<uint8_t>(0x80));
|
||||
EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0xFFC00000}).val).val, static_cast<uint8_t>(0x80));
|
||||
}
|
||||
|
||||
TEST(Float8_Tests, NanE5M2) {
|
||||
EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0x7F800000}).val).val, static_cast<uint8_t>(0x7B));
|
||||
EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0xFF800000}).val).val, static_cast<uint8_t>(0xFB));
|
||||
EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0x7F800000}).val, false).val, static_cast<uint8_t>(0x7C));
|
||||
EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0xFF800000}).val, false).val, static_cast<uint8_t>(0xFC));
|
||||
EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0x7F800001}).val).val, static_cast<uint8_t>(0x7F));
|
||||
EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0xFF800001}).val).val, static_cast<uint8_t>(0xFF));
|
||||
// 0x7FC00000 is the value used by numpy.
|
||||
EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0x7FC00000}).val).val, static_cast<uint8_t>(0x7F));
|
||||
EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0xFFC00000}).val).val, static_cast<uint8_t>(0xFF));
|
||||
}
|
||||
|
||||
TEST(Float8_Tests, NanE5M2FNUZ) {
|
||||
EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0x7F800000}).val).val, static_cast<uint8_t>(0x7F));
|
||||
EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0xFF800000}).val).val, static_cast<uint8_t>(0xFF));
|
||||
EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0x7F800000}).val, false).val, static_cast<uint8_t>(0x80));
|
||||
EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0xFF800000}).val, false).val, static_cast<uint8_t>(0x80));
|
||||
EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0x7F800001}).val).val, static_cast<uint8_t>(0x80));
|
||||
EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0xFF800001}).val).val, static_cast<uint8_t>(0x80));
|
||||
// 0x7FC00000 is the value used by numpy.
|
||||
EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0x7FC00000}).val).val, static_cast<uint8_t>(0x80));
|
||||
EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0xFFC00000}).val).val, static_cast<uint8_t>(0x80));
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif // DISABLE_FLOAT8_TYPES
|
||||
|
|
@ -8,9 +8,11 @@ import sys
|
|||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import packaging.version as pv
|
||||
import parameterized
|
||||
from numpy.testing import assert_allclose
|
||||
from onnx import TensorProto
|
||||
from onnx import __version__ as onnx_version
|
||||
from onnx.checker import check_model
|
||||
from onnx.helper import make_graph, make_model, make_node, make_opsetid, make_tensor, make_tensor_value_info
|
||||
from onnx.reference import ReferenceEvaluator
|
||||
|
|
@ -37,7 +39,7 @@ class TestInferenceSession(unittest.TestCase):
|
|||
<https://onnx.ai/onnx/api/numpy_helper.html#onnx.numpy_helper.float8e5m2_to_float32>`_.
|
||||
"""
|
||||
|
||||
dtypes = frozenset({"FLOAT": np.float32, "FLOAT16": np.float16})
|
||||
dtypes = {"FLOAT": np.float32, "FLOAT16": np.float16} # noqa: RUF012
|
||||
x = np.array(
|
||||
[0.4068359375, 352, 416, 336, 304, 272, -248, -100, 1e-4, 1e-2, 416, 432, 1e5, np.inf, -np.inf, np.nan],
|
||||
dtype=np.float32,
|
||||
|
|
@ -76,7 +78,7 @@ class TestInferenceSession(unittest.TestCase):
|
|||
240.0,
|
||||
240.0,
|
||||
-240.0,
|
||||
-104.0,
|
||||
-96.0,
|
||||
0.0,
|
||||
0.009765625,
|
||||
240.0,
|
||||
|
|
@ -113,7 +115,7 @@ class TestInferenceSession(unittest.TestCase):
|
|||
[
|
||||
0.4375,
|
||||
384.0,
|
||||
448.0,
|
||||
384.0,
|
||||
320.0,
|
||||
320.0,
|
||||
256.0,
|
||||
|
|
@ -121,7 +123,7 @@ class TestInferenceSession(unittest.TestCase):
|
|||
-96.0,
|
||||
0.0001068115234375,
|
||||
0.009765625,
|
||||
448.0,
|
||||
384.0,
|
||||
448.0,
|
||||
57344.0,
|
||||
57344.0,
|
||||
|
|
@ -167,7 +169,7 @@ class TestInferenceSession(unittest.TestCase):
|
|||
np.nan,
|
||||
np.nan,
|
||||
np.nan,
|
||||
-104.0,
|
||||
-96.0,
|
||||
0.0,
|
||||
0.009765625,
|
||||
np.nan,
|
||||
|
|
@ -204,7 +206,7 @@ class TestInferenceSession(unittest.TestCase):
|
|||
[
|
||||
0.4375,
|
||||
384.0,
|
||||
448.0,
|
||||
384.0,
|
||||
320.0,
|
||||
320.0,
|
||||
256.0,
|
||||
|
|
@ -212,7 +214,7 @@ class TestInferenceSession(unittest.TestCase):
|
|||
-96.0,
|
||||
0.0001068115234375,
|
||||
0.009765625,
|
||||
448.0,
|
||||
384.0,
|
||||
448.0,
|
||||
np.nan,
|
||||
np.nan,
|
||||
|
|
@ -245,6 +247,7 @@ class TestInferenceSession(unittest.TestCase):
|
|||
check_model(onnx_model)
|
||||
return onnx_model
|
||||
|
||||
@unittest.skipIf(pv.Version(onnx_version) < pv.Version("1.15.0"), reason="needs onnx>=1.15.0")
|
||||
@parameterized.parameterized.expand(
|
||||
[
|
||||
("FLOAT8E4M3FN", "FLOAT", 1),
|
||||
|
|
@ -429,6 +432,7 @@ class TestInferenceSession(unittest.TestCase):
|
|||
check_model(onnx_model)
|
||||
return onnx_model
|
||||
|
||||
@unittest.skipIf(pv.Version(onnx_version) < pv.Version("1.15.0"), reason="needs onnx>=1.15.0")
|
||||
@parameterized.parameterized.expand(
|
||||
[
|
||||
("FLOAT8E4M3FN", "FLOAT", 1),
|
||||
|
|
@ -689,6 +693,18 @@ class TestInferenceSession(unittest.TestCase):
|
|||
self.assertEqual(expect.shape, y.shape)
|
||||
self.assertEqual(expect.dtype, y.dtype)
|
||||
|
||||
@unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running on CUDA.")
|
||||
def test_compare_cpu_cuda_e4m3fn(self):
|
||||
folder = os.path.join(os.path.dirname(__file__), "..", "testdata", "float8")
|
||||
model = os.path.join(folder, "te.cast_fp8_1_fp32.onnx")
|
||||
data = np.load(os.path.join(folder, "te.cast_fp8_1_fp32_input.npy"))
|
||||
|
||||
sess_cpu = onnxruntime.InferenceSession(model, providers=["CPUExecutionProvider"])
|
||||
sess_cuda = onnxruntime.InferenceSession(model, providers=["CUDAExecutionProvider"])
|
||||
cpu_res = sess_cpu.run(None, {"input": data})[0]
|
||||
cuda_res = sess_cuda.run(None, {"input": data})[0]
|
||||
self.assertEqual(cuda_res.tolist(), cpu_res.tolist())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/float8/te.cast_fp8_1_fp32.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/float8/te.cast_fp8_1_fp32.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/float8/te.cast_fp8_1_fp32_input.npy
vendored
Normal file
BIN
onnxruntime/test/testdata/float8/te.cast_fp8_1_fp32_input.npy
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue