diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index dea71d81f8..ba610515ac 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -156,8 +156,10 @@ Do not modify directly.*
|||[1, 10]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|ImageScaler|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)|
|InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(float)|
-|IsInf|*in* X:**T1**
*out* Y:**T2**|10+|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)|
-|IsNaN|*in* X:**T1**
*out* Y:**T2**|13+|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)|
+|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)|
+|||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)|
+|IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)|
+|||[13, 19]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)|
|||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)|
|LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float)|
|||[1, 12]|**T** = tensor(float)|
diff --git a/include/onnxruntime/core/framework/float8.h b/include/onnxruntime/core/framework/float8.h
index 0fd04f28d4..dd607cbbc6 100644
--- a/include/onnxruntime/core/framework/float8.h
+++ b/include/onnxruntime/core/framework/float8.h
@@ -208,9 +208,10 @@ struct Float8E4M3FNUZ {
val = static_cast((b & 0x80000000) >> 24); // sign
if ((b & 0x7fffffff) == 0x7f800000) { // infinity
if (saturate) {
+ // the highest available value
val |= 0x7F;
} else {
- // infinity
+ // NaN
val = 0x80;
}
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN
@@ -362,8 +363,10 @@ struct Float8E5M2 {
val = (b & 0x80000000) >> 24; // sign
if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf
if (saturate) {
+ // the highest available value
val |= 0x7B;
} else {
+ // the infinity
val |= 0x7C;
}
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN
diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
index 3d03abf5b7..a54d999a10 100644
--- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
+++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
@@ -365,7 +365,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, Slice);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 11, Dropout);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, NonMaxSuppression);
-class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, IsInf);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 19, IsInf);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 15, float, RoiAlign);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 15, double, RoiAlign);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ReverseSequence);
@@ -682,9 +682,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Ga
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 15, ScatterND);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 15, ScatterElements);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13, Identity);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, IsNaN);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, IsNaN);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, IsNaN);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, float, IsNaN);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, double, IsNaN);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, MLFloat16, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool, NonZero);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, NonZero);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, NonZero);
@@ -960,6 +960,16 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Sh
// Opset 20
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN);
+#if !defined(DISABLE_FLOAT8_TYPES)
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FNUZ, IsNaN);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2, IsNaN);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2FNUZ, IsNaN);
+#endif
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf);
// !!PLEASE READ BELOW!! Following that, add new entries above this comment
@@ -1492,7 +1502,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
Dropout)>,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+#if !defined(DISABLE_FLOAT8_TYPES)
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+#endif
+ BuildKernelCreateInfo,
};
for (auto& function_table_entry : function_table) {
diff --git a/onnxruntime/core/providers/cpu/tensor/isinf.cc b/onnxruntime/core/providers/cpu/tensor/isinf.cc
index bc99caa803..1b449f4692 100644
--- a/onnxruntime/core/providers/cpu/tensor/isinf.cc
+++ b/onnxruntime/core/providers/cpu/tensor/isinf.cc
@@ -14,15 +14,38 @@ namespace onnxruntime {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#IsInf
namespace op_kernel_type_control {
-ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(
- kCpuExecutionProvider, kOnnxDomain, IsInf, Input, 0,
- float, double);
+using IsInfTypesOpset10 = TypeList;
+
+ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(
+ kCpuExecutionProvider, kOnnxDomain, IsInf, 10, Input, 0,
+ IsInfTypesOpset10);
+
+using IsInfTypesOpset20 =
+ TypeList<
+ float,
+ double
+#if !defined(DISABLE_FLOAT8_TYPES)
+ ,
+ Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ
+#endif
+ >;
+
+ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(
+ kCpuExecutionProvider,
+ kOnnxDomain,
+ IsInf,
+ 20,
+ Input,
+ 0,
+ IsInfTypesOpset20);
} // namespace op_kernel_type_control
class IsInf final : public OpKernel {
public:
- using EnabledDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
- IsInf, Input, 0);
+ using EnabledDataTypes10 = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain,
+ IsInf, 10, Input, 0);
+ using EnabledDataTypes20 = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain,
+ IsInf, 20, Input, 0);
explicit IsInf(const OpKernelInfo& info);
Status Compute(OpKernelContext* context) const override;
@@ -30,14 +53,25 @@ class IsInf final : public OpKernel {
private:
int64_t detect_positive_{1};
int64_t detect_negative_{1};
+ int opset_;
};
+ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
+ IsInf,
+ 10,
+ 19,
+ KernelDefBuilder()
+ .TypeConstraint("T1",
+ BuildKernelDefConstraintsFromTypeList())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ IsInf);
+
ONNX_CPU_OPERATOR_KERNEL(
IsInf,
- 10,
+ 20,
KernelDefBuilder()
.TypeConstraint("T1",
- BuildKernelDefConstraintsFromTypeList())
+ BuildKernelDefConstraintsFromTypeList())
.TypeConstraint("T2", DataTypeImpl::GetTensorType()),
IsInf);
@@ -46,6 +80,7 @@ IsInf::IsInf(const OpKernelInfo& info) : OpKernel(info) {
ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_positive");
status = info.GetAttr("detect_negative", &detect_negative_);
ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_negative");
+ opset_ = info.node().SinceVersion();
}
namespace isinf_internal {
@@ -78,6 +113,49 @@ struct ComputeDispatchTarget {
}
}
};
+
+#if !defined(DISABLE_FLOAT8_TYPES)
+template <>
+struct ComputeDispatchTarget {
+ void operator()(const Tensor&, Tensor& Y, bool, bool) const {
+ EigenMap(Y).array() = false;
+ }
+};
+
+template <>
+struct ComputeDispatchTarget {
+ void operator()(const Tensor&, Tensor& Y, bool, bool) const {
+ EigenMap(Y).array() = false;
+ }
+};
+
+template <>
+struct ComputeDispatchTarget {
+ void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const {
+ auto& dims = X.Shape();
+ auto input = ConstEigenVectorMap(static_cast(static_cast(X.Data())), onnxruntime::narrow(dims.Size()));
+ auto output = EigenMap(Y);
+
+ // S.11111.00
+ if (detect_positive && detect_negative) {
+ output.array() = input.array() == 0b01111100 || input.array() == 0b11111100;
+ } else if (detect_positive) {
+ output.array() = input.array() == 0b01111100;
+ } else if (detect_negative) {
+ output.array() = input.array() == 0b11111100;
+ } else {
+ output.array() = false;
+ }
+ }
+};
+
+template <>
+struct ComputeDispatchTarget {
+ void operator()(const Tensor&, Tensor& Y, bool, bool) const {
+ EigenMap(Y).array() = false;
+ }
+};
+#endif
} // namespace isinf_internal
Status IsInf::Compute(OpKernelContext* context) const {
@@ -88,8 +166,13 @@ Status IsInf::Compute(OpKernelContext* context) const {
using namespace isinf_internal;
- utils::MLTypeCallDispatcherFromTypeList dispatcher{X.GetElementType()};
- dispatcher.Invoke(X, Y, detect_positive_ != 0, detect_negative_ != 0);
+ if (opset_ < 20) {
+ utils::MLTypeCallDispatcherFromTypeList dispatcher{X.GetElementType()};
+ dispatcher.Invoke(X, Y, detect_positive_ != 0, detect_negative_ != 0);
+ } else {
+ utils::MLTypeCallDispatcherFromTypeList dispatcher{X.GetElementType()};
+ dispatcher.Invoke(X, Y, detect_positive_ != 0, detect_negative_ != 0);
+ }
return Status::OK();
}
diff --git a/onnxruntime/core/providers/cpu/tensor/isnan.cc b/onnxruntime/core/providers/cpu/tensor/isnan.cc
index 33d0f8eb6c..34495e3822 100644
--- a/onnxruntime/core/providers/cpu/tensor/isnan.cc
+++ b/onnxruntime/core/providers/cpu/tensor/isnan.cc
@@ -20,10 +20,20 @@ namespace onnxruntime {
.TypeConstraint("T2", DataTypeImpl::GetTensorType()), \
IsNaN);
+#define ADD_TYPED_ISNAN_OP_13(data_type) \
+ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \
+ IsNaN, \
+ 13, 19, \
+ data_type, \
+ KernelDefBuilder() \
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \
+ IsNaN);
+
#define ADD_TYPED_ISNAN_OP(data_type) \
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
IsNaN, \
- 13, \
+ 20, \
data_type, \
KernelDefBuilder() \
.TypeConstraint("T1", DataTypeImpl::GetTensorType()) \
@@ -33,10 +43,20 @@ namespace onnxruntime {
ADD_TYPED_ISNAN_OP_9(float);
ADD_TYPED_ISNAN_OP_9(double);
ADD_TYPED_ISNAN_OP_9(MLFloat16);
+ADD_TYPED_ISNAN_OP_13(float);
+ADD_TYPED_ISNAN_OP_13(double);
+ADD_TYPED_ISNAN_OP_13(MLFloat16);
ADD_TYPED_ISNAN_OP(float);
ADD_TYPED_ISNAN_OP(double);
ADD_TYPED_ISNAN_OP(MLFloat16);
+#if !defined(DISABLE_FLOAT8_TYPES)
+ADD_TYPED_ISNAN_OP(Float8E4M3FN);
+ADD_TYPED_ISNAN_OP(Float8E4M3FNUZ);
+ADD_TYPED_ISNAN_OP(Float8E5M2);
+ADD_TYPED_ISNAN_OP(Float8E5M2FNUZ);
+#endif
+
template
Status IsNaN::Compute(OpKernelContext* context) const {
const auto* X_ptr = context->Input(0);
@@ -70,4 +90,63 @@ Status IsNaN::Compute(OpKernelContext* context) const {
return Status::OK();
}
+
+#if !defined(DISABLE_FLOAT8_TYPES)
+template <>
+Status IsNaN::Compute(OpKernelContext* context) const {
+ const auto* X = context->Input(0);
+ auto& dims = X->Shape();
+ auto& Y = *context->Output(0, dims);
+
+ auto input = ConstEigenVectorMap(static_cast(static_cast(X->Data())), onnxruntime::narrow(dims.Size()));
+ auto output = EigenMap(Y);
+
+ // S.1111.111
+ std::transform(input.begin(), input.end(), output.begin(), [](uint8_t c) { return (c & 0x7f) == 0x7f; });
+ return Status::OK();
+}
+
+template <>
+Status IsNaN::Compute(OpKernelContext* context) const {
+ const auto* X = context->Input(0);
+ auto X_data = X->Data();
+ auto& dims = X->Shape();
+ auto shape_size = dims.Size();
+ auto& Y = *context->Output(0, dims);
+
+ // 1.0000.000
+ EigenMap(Y) =
+ ConstEigenVectorMap(static_cast(static_cast(X_data)), onnxruntime::narrow(shape_size)).array() == 0x80;
+
+ return Status::OK();
+}
+
+template <>
+Status IsNaN::Compute(OpKernelContext* context) const {
+ const auto* X = context->Input(0);
+ auto& dims = X->Shape();
+ auto& Y = *context->Output(0, dims);
+
+ auto input = ConstEigenVectorMap(static_cast(static_cast(X->Data())), onnxruntime::narrow(dims.Size()));
+ auto output = EigenMap(Y);
+
+ // S.11111.{01, 10, 11}
+ std::transform(input.begin(), input.end(), output.begin(), [](uint8_t c) { return ((c & 0x7c) == 0x7c) && ((c & 0x03) != 0x00); });
+ return Status::OK();
+}
+
+template <>
+Status IsNaN::Compute(OpKernelContext* context) const {
+ const auto* X = context->Input(0);
+ auto X_data = X->Data();
+ auto& dims = X->Shape();
+ auto shape_size = dims.Size();
+ auto& Y = *context->Output(0, dims);
+
+ // 1.0000.000
+ EigenMap(Y) = ConstEigenVectorMap(static_cast(static_cast(X_data)), onnxruntime::narrow(shape_size)).array() == 0x80;
+
+ return Status::OK();
+}
+#endif
} // namespace onnxruntime
diff --git a/onnxruntime/test/providers/cpu/tensor/isinf_test.cc b/onnxruntime/test/providers/cpu/tensor/isinf_test.cc
index ddb392eb82..2e583c5d25 100644
--- a/onnxruntime/test/providers/cpu/tensor/isinf_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/isinf_test.cc
@@ -17,85 +17,137 @@ constexpr double DOUBLE_INF = std::numeric_limits::infinity();
constexpr double DOUBLE_NINF = -std::numeric_limits::infinity();
constexpr double DOUBLE_NAN = std::numeric_limits::quiet_NaN();
-TEST(IsInfTest, test_isinf_float) {
- // Defaults for detect_negative = 1
- // detect_positive = 1
- OpTester test("IsInf", 10);
-
- std::vector input_dim{6};
- std::vector input = {-1.2f, FLOAT_NAN, FLOAT_INF, 2.8f, FLOAT_NINF, FLOAT_INF};
- test.AddInput("X", input_dim, input);
-
- std::vector output_dim(input_dim);
- test.AddOutput("Y", output_dim, {false, false, true, false, true, true});
+template
+void run_is_inf_test(int opset, int64_t detect_positive, int64_t detect_negative, const std::initializer_list& input, const std::initializer_list& output) {
+ OpTester test("IsInf", opset);
+ test.AddAttribute("detect_positive", detect_positive);
+ test.AddAttribute("detect_negative", detect_negative);
+ test.AddInput("X", {onnxruntime::narrow(input.size())}, input);
+ test.AddOutput("Y", {onnxruntime::narrow(output.size())}, output);
test.Run();
}
-TEST(IsInfTest, test_isinf_double) {
- // Defaults for detect_negative = 1
- // detect_positive = 1
- OpTester test("IsInf", 10);
-
- std::vector input_dim{6};
- std::vector input = {-1.2, DOUBLE_NAN, DOUBLE_INF, 2.8, DOUBLE_NINF, DOUBLE_INF};
- test.AddInput("X", input_dim, input);
-
- std::vector output_dim(input_dim);
- test.AddOutput("Y", output_dim, {false, false, true, false, true, true});
- test.Run();
+TEST(IsInfTest, test_isinf_float10) {
+ std::initializer_list input = {-1.2f, FLOAT_NAN, FLOAT_INF, 2.8f, FLOAT_NINF, FLOAT_INF};
+ std::initializer_list output = {false, false, true, false, true, true};
+ run_is_inf_test(10, 1, 1, input, output);
}
-TEST(IsInfTest, test_isinf_positive_float) {
- OpTester test("IsInf", 10);
- test.AddAttribute("detect_negative", 0);
-
- std::vector input_dim{6};
- std::vector input = {-1.7f, FLOAT_NAN, FLOAT_INF, 3.6f, FLOAT_NINF, FLOAT_INF};
- test.AddInput("X", input_dim, input);
-
- std::vector output_dim(input_dim);
- test.AddOutput("Y", output_dim, {false, false, true, false, false, true});
- test.Run();
+TEST(IsInfTest, test_isinf_float20) {
+ std::initializer_list input = {-1.2f, FLOAT_NAN, FLOAT_INF, 2.8f, FLOAT_NINF, FLOAT_INF};
+ std::initializer_list output = {false, false, true, false, true, true};
+ run_is_inf_test(20, 1, 1, input, output);
}
-TEST(IsInfTest, test_isinf_positive_double) {
- OpTester test("IsInf", 10);
- test.AddAttribute("detect_negative", 0);
-
- std::vector input_dim{6};
- std::vector input = {-1.7, DOUBLE_NAN, DOUBLE_INF, 3.6, DOUBLE_NINF, DOUBLE_INF};
- test.AddInput("X", input_dim, input);
-
- std::vector output_dim(input_dim);
- test.AddOutput("Y", output_dim, {false, false, true, false, false, true});
- test.Run();
+TEST(IsInfTest, test_isinf_double10) {
+ std::initializer_list input = {-1.2, DOUBLE_NAN, DOUBLE_INF, 2.8, DOUBLE_NINF, DOUBLE_INF};
+ std::initializer_list output = {false, false, true, false, true, true};
+ run_is_inf_test(10, 1, 1, input, output);
}
-TEST(IsInfTest, test_isinf_negative_float) {
- OpTester test("IsInf", 10);
- test.AddAttribute("detect_positive", 0);
-
- std::vector input_dim{6};
- std::vector input = {-1.7f, FLOAT_NAN, FLOAT_INF, 3.6f, FLOAT_NINF, FLOAT_INF};
- test.AddInput("X", input_dim, input);
-
- std::vector output_dim(input_dim);
- test.AddOutput("Y", output_dim, {false, false, false, false, true, false});
- test.Run();
+TEST(IsInfTest, test_isinf_double20) {
+ std::initializer_list input = {-1.2, DOUBLE_NAN, DOUBLE_INF, 2.8, DOUBLE_NINF, DOUBLE_INF};
+ std::initializer_list output = {false, false, true, false, true, true};
+ run_is_inf_test(20, 1, 1, input, output);
}
-TEST(IsInfTest, test_isinf_negative_double) {
- OpTester test("IsInf", 10);
- test.AddAttribute("detect_positive", 0);
-
- std::vector input_dim{6};
- std::vector input = {-1.7, DOUBLE_NAN, DOUBLE_INF, 3.6, DOUBLE_NINF, DOUBLE_INF};
- test.AddInput("X", input_dim, input);
-
- std::vector output_dim(input_dim);
- test.AddOutput("Y", output_dim, {false, false, false, false, true, false});
- test.Run();
+TEST(IsInfTest, test_isinf_positive_float10) {
+ std::initializer_list input = {-1.7f, FLOAT_NAN, FLOAT_INF, 3.6f, FLOAT_NINF, FLOAT_INF};
+ std::initializer_list output = {false, false, true, false, false, true};
+ run_is_inf_test(10, 1, 0, input, output);
}
+TEST(IsInfTest, test_isinf_positive_float20) {
+ std::initializer_list input = {-1.7f, FLOAT_NAN, FLOAT_INF, 3.6f, FLOAT_NINF, FLOAT_INF};
+ std::initializer_list output = {false, false, true, false, false, true};
+ run_is_inf_test(20, 1, 0, input, output);
+}
+
+TEST(IsInfTest, test_isinf_positive_double10) {
+ std::initializer_list input = {-1.7, DOUBLE_NAN, DOUBLE_INF, 3.6, DOUBLE_NINF, DOUBLE_INF};
+ std::initializer_list output = {false, false, true, false, false, true};
+ run_is_inf_test(10, 1, 0, input, output);
+}
+
+TEST(IsInfTest, test_isinf_positive_double20) {
+ std::initializer_list input = {-1.7, DOUBLE_NAN, DOUBLE_INF, 3.6, DOUBLE_NINF, DOUBLE_INF};
+ std::initializer_list output = {false, false, true, false, false, true};
+ run_is_inf_test(20, 1, 0, input, output);
+}
+
+TEST(IsInfTest, test_isinf_negative_float10) {
+ std::initializer_list input = {-1.7f, FLOAT_NAN, FLOAT_INF, 3.6f, FLOAT_NINF, FLOAT_INF};
+ std::initializer_list output = {false, false, false, false, true, false};
+ run_is_inf_test(10, 0, 1, input, output);
+}
+
+TEST(IsInfTest, test_isinf_negative_float20) {
+ std::initializer_list input = {-1.7f, FLOAT_NAN, FLOAT_INF, 3.6f, FLOAT_NINF, FLOAT_INF};
+ std::initializer_list output = {false, false, false, false, true, false};
+ run_is_inf_test(20, 0, 1, input, output);
+}
+
+TEST(IsInfTest, test_isinf_negative_double10) {
+ std::initializer_list input = {-1.7, DOUBLE_NAN, DOUBLE_INF, 3.6, DOUBLE_NINF, DOUBLE_INF};
+ std::initializer_list output = {false, false, false, false, true, false};
+ run_is_inf_test(10, 0, 1, input, output);
+}
+
+TEST(IsInfTest, test_isinf_negative_double20) {
+ std::initializer_list input = {-1.7, DOUBLE_NAN, DOUBLE_INF, 3.6, DOUBLE_NINF, DOUBLE_INF};
+ std::initializer_list output = {false, false, false, false, true, false};
+ run_is_inf_test(20, 0, 1, input, output);
+}
+
+#if !defined(DISABLE_FLOAT8_TYPES)
+TEST(IsInfTest, test_Float8E4M3FN) {
+ std::initializer_list input = {
+ Float8E4M3FN(-1.0f), Float8E4M3FN(FLOAT_NAN, false), Float8E4M3FN(1.0f), Float8E4M3FN(FLOAT_NINF, false), Float8E4M3FN(FLOAT_NINF, false), Float8E4M3FN(FLOAT_INF, false)};
+ std::initializer_list output = {false, false, false, false, false, false};
+ run_is_inf_test(20, 1, 1, input, output);
+}
+
+TEST(IsInfTest, test_Float8E4M3FNUZ) {
+ std::initializer_list input = {
+ Float8E4M3FNUZ(-1.0f), Float8E4M3FNUZ(FLOAT_NAN, false), Float8E4M3FNUZ(1.0f), Float8E4M3FNUZ(FLOAT_NINF, false), Float8E4M3FNUZ(FLOAT_NINF, false), Float8E4M3FNUZ(FLOAT_INF, false)};
+ std::initializer_list output = {false, false, false, false, false, false};
+ run_is_inf_test(20, 1, 1, input, output);
+}
+
+TEST(IsInfTest, test_Float8E5M2_detect_both) {
+ std::initializer_list input = {
+ Float8E5M2(-1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(FLOAT_NAN, false), Float8E5M2(FLOAT_INF, false)};
+ std::initializer_list output = {false, true, false, true, false, true};
+ run_is_inf_test(20, 1, 1, input, output);
+}
+
+TEST(IsInfTest, test_Float8E5M2_detect_positive) {
+ std::initializer_list input = {
+ Float8E5M2(-1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(FLOAT_NAN, false), Float8E5M2(FLOAT_INF, false)};
+ std::initializer_list output = {false, false, false, false, false, true};
+ run_is_inf_test(20, 1, 0, input, output);
+}
+
+TEST(IsInfTest, test_Float8E5M2_detect_negative) {
+ std::initializer_list input = {
+ Float8E5M2(-1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(FLOAT_NAN, false), Float8E5M2(FLOAT_INF, false)};
+ std::initializer_list output = {false, true, false, true, false, false};
+ run_is_inf_test(20, 0, 1, input, output);
+}
+
+TEST(IsInfTest, test_Float8E5M2_none) {
+ std::initializer_list input = {
+ Float8E5M2(-1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(FLOAT_NAN, false), Float8E5M2(FLOAT_INF, false)};
+ std::initializer_list output = {false, false, false, false, false, false};
+ run_is_inf_test(20, 0, 0, input, output);
+}
+
+TEST(IsInfTest, test_Float8E5M2FNUZ) {
+ std::initializer_list input = {
+ Float8E5M2FNUZ(-1.0f), Float8E5M2FNUZ(FLOAT_NINF, false), Float8E5M2FNUZ(1.0f), Float8E5M2FNUZ(FLOAT_NINF, false), Float8E5M2FNUZ(FLOAT_NAN, false), Float8E5M2FNUZ(FLOAT_INF, false)};
+ std::initializer_list output = {false, false, false, false, false, false};
+ run_is_inf_test(20, 1, 1, input, output);
+}
+#endif
} // namespace test
} // namespace onnxruntime
diff --git a/onnxruntime/test/providers/cpu/tensor/isnan_test.cc b/onnxruntime/test/providers/cpu/tensor/isnan_test.cc
index 0dffc452b5..0f1e5c07cd 100644
--- a/onnxruntime/test/providers/cpu/tensor/isnan_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/isnan_test.cc
@@ -9,29 +9,84 @@
namespace onnxruntime {
namespace test {
-TEST(IsNaNOpTest, IsNaNFloat) {
- OpTester test("IsNaN", 9, kOnnxDomain);
- std::vector dims{2, 2};
- test.AddInput("X", dims, {1.0f, NAN, 2.0f, NAN});
- test.AddOutput("Y", dims, {false, true, false, true});
+template
+void run_is_nan_test(int opset, const std::vector& dims, const std::initializer_list& input, const std::initializer_list& output) {
+ OpTester test("IsNaN", opset, kOnnxDomain);
+ test.AddInput("X", dims, input);
+ test.AddOutput("Y", dims, output);
test.Run();
}
-TEST(IsNaNOpTest, IsNaNFloat16) {
- OpTester test("IsNaN", 9, kOnnxDomain);
+TEST(IsNaNOpTest, IsNaNFloat9) {
std::vector dims{2, 2};
- test.AddInput("X", dims, std::initializer_list({MLFloat16(1.0f), MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN}));
- test.AddOutput("Y", dims, {false, true, false, true});
- test.Run();
+ std::initializer_list input = {1.0f, NAN, 2.0f, NAN};
+ std::initializer_list output = {false, true, false, true};
+ run_is_nan_test(9, dims, input, output);
}
-TEST(IsNaNOpTest, IsNaNDouble) {
- OpTester test("IsNaN", 9, kOnnxDomain);
+TEST(IsNaNOpTest, IsNaNFloat20) {
std::vector dims{2, 2};
- test.AddInput("X", dims, {1.0, NAN, 2.0, NAN});
- test.AddOutput("Y", dims, {false, true, false, true});
- test.Run();
+ std::initializer_list input = {1.0f, NAN, 2.0f, NAN};
+ std::initializer_list output = {false, true, false, true};
+ run_is_nan_test(20, dims, input, output);
}
+TEST(IsNaNOpTest, IsNaNFloat16_9) {
+ std::vector dims{2, 2};
+ std::initializer_list input = {MLFloat16(1.0f), MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN};
+ std::initializer_list output = {false, true, false, true};
+ run_is_nan_test(9, dims, input, output);
+}
+
+TEST(IsNaNOpTest, IsNaNFloat16_20) {
+ std::vector dims{2, 2};
+ std::initializer_list input = {MLFloat16(1.0f), MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN};
+ std::initializer_list output = {false, true, false, true};
+ run_is_nan_test(20, dims, input, output);
+}
+
+TEST(IsNaNOpTest, IsNaNDouble9) {
+ std::vector dims{2, 2};
+ std::initializer_list input = {1.0, NAN, 2.0, NAN};
+ std::initializer_list output = {false, true, false, true};
+ run_is_nan_test(9, dims, input, output);
+}
+
+TEST(IsNaNOpTest, IsNaNDouble20) {
+ std::vector dims{2, 2};
+ std::initializer_list input = {1.0, NAN, 2.0, NAN};
+ std::initializer_list output = {false, true, false, true};
+ run_is_nan_test(20, dims, input, output);
+}
+
+#if !defined(DISABLE_FLOAT8_TYPES)
+TEST(IsNaNOpTest, IsNaNFloat8E4M3FN) {
+ std::vector dims{2, 2};
+ std::initializer_list input = {Float8E4M3FN(1.0f), Float8E4M3FN(-NAN), Float8E4M3FN(2.0f), Float8E4M3FN(NAN)};
+ std::initializer_list output = {false, true, false, true};
+ run_is_nan_test(20, dims, input, output);
+}
+
+TEST(IsNaNOpTest, IsNaN_Float8E4M3FNUZ) {
+ std::vector dims{2, 2};
+ std::initializer_list input = {Float8E4M3FNUZ(1.0f), Float8E4M3FNUZ(-NAN), Float8E4M3FNUZ(2.0f), Float8E4M3FNUZ(-NAN)};
+ std::initializer_list output = {false, true, false, true};
+ run_is_nan_test(20, dims, input, output);
+}
+
+TEST(IsNaNOpTest, IsNaNFloat8E5M2) {
+ std::vector dims{2, 2};
+ std::initializer_list input = {Float8E5M2(1.0f), Float8E5M2(-NAN), Float8E5M2(2.0f), Float8E5M2(NAN)};
+ std::initializer_list output = {false, true, false, true};
+ run_is_nan_test(20, dims, input, output);
+}
+
+TEST(IsNaNOpTest, IsNaN_Float8E5M2FNUZ) {
+ std::vector dims{2, 2};
+ std::initializer_list input = {Float8E5M2FNUZ(1.0f), Float8E5M2FNUZ(-NAN), Float8E5M2FNUZ(2.0f), Float8E5M2FNUZ(NAN)};
+ std::initializer_list output = {false, true, false, true};
+ run_is_nan_test(20, dims, input, output);
+}
+#endif
} // namespace test
} // namespace onnxruntime
diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
index b3161a42bb..44db7c0078 100644
--- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
+++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
@@ -283,12 +283,6 @@
"^test_dft_axis",
"^test_dft",
"^test_dft_inverse",
- "^test_isinf",
- "^test_isinf_float16",
- "^test_isinf_negative",
- "^test_isinf_positive",
- "^test_isnan",
- "^test_isnan_float16",
"^test_reduce_max_bool_inputs",
"^test_reduce_min_bool_inputs",
"^test_reduce_min_empty_set",