mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
implement isinf20 and isnan20 (#17874)
This commit is contained in:
parent
abb329179a
commit
efa0cc2562
8 changed files with 400 additions and 112 deletions
|
|
@ -156,8 +156,10 @@ Do not modify directly.*
|
|||
|||[1, 10]|**B** = tensor(bool)<br/> **V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|ImageScaler|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(float)|
|
||||
|InstanceNormalization|*in* input:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *out* output:**T**|6+|**T** = tensor(float)|
|
||||
|IsInf|*in* X:**T1**<br> *out* Y:**T2**|10+|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(bool)|
|
||||
|IsNaN|*in* X:**T1**<br> *out* Y:**T2**|13+|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|
||||
|IsInf|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
|
||||
|||[10, 19]|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(bool)|
|
||||
|IsNaN|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
|
||||
|||[13, 19]|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|
||||
|||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|
||||
|LRN|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(float)|
|
||||
|||[1, 12]|**T** = tensor(float)|
|
||||
|
|
|
|||
|
|
@ -208,9 +208,10 @@ struct Float8E4M3FNUZ {
|
|||
val = static_cast<uint8_t>((b & 0x80000000) >> 24); // sign
|
||||
if ((b & 0x7fffffff) == 0x7f800000) { // infinity
|
||||
if (saturate) {
|
||||
// the highest available value
|
||||
val |= 0x7F;
|
||||
} else {
|
||||
// infinity
|
||||
// NaN
|
||||
val = 0x80;
|
||||
}
|
||||
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN
|
||||
|
|
@ -362,8 +363,10 @@ struct Float8E5M2 {
|
|||
val = (b & 0x80000000) >> 24; // sign
|
||||
if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf
|
||||
if (saturate) {
|
||||
// the highest available value
|
||||
val |= 0x7B;
|
||||
} else {
|
||||
// the infinity
|
||||
val |= 0x7C;
|
||||
}
|
||||
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN
|
||||
|
|
|
|||
|
|
@ -365,7 +365,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 11, Dropout);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, NonMaxSuppression);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, IsInf);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 19, IsInf);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 15, float, RoiAlign);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 15, double, RoiAlign);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ReverseSequence);
|
||||
|
|
@ -682,9 +682,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Ga
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 15, ScatterND);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 15, ScatterElements);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13, Identity);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, IsNaN);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, IsNaN);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, IsNaN);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, float, IsNaN);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, double, IsNaN);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, MLFloat16, IsNaN);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool, NonZero);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, NonZero);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, NonZero);
|
||||
|
|
@ -960,6 +960,16 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Sh
|
|||
|
||||
// Opset 20
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN);
|
||||
#if !defined(DISABLE_FLOAT8_TYPES)
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FNUZ, IsNaN);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2, IsNaN);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2FNUZ, IsNaN);
|
||||
#endif
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf);
|
||||
|
||||
// !!PLEASE READ BELOW!! Following that, add new entries above this comment
|
||||
|
||||
|
|
@ -1492,7 +1502,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
Dropout)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
|
||||
NonMaxSuppression)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, IsInf)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 19, IsInf)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 15, float,
|
||||
RoiAlign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 15, double,
|
||||
|
|
@ -1981,12 +1991,12 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 15, ScatterElements)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 15, ScatterND)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13, Identity)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float,
|
||||
IsNaN)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double,
|
||||
IsNaN)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16,
|
||||
IsNaN)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, float,
|
||||
IsNaN)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, double,
|
||||
IsNaN)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, MLFloat16,
|
||||
IsNaN)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool,
|
||||
NonZero)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float,
|
||||
|
|
@ -2389,6 +2399,16 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
|
||||
// Opset 20
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN)>,
|
||||
#if !defined(DISABLE_FLOAT8_TYPES)
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FNUZ, IsNaN)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2, IsNaN)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2FNUZ, IsNaN)>,
|
||||
#endif
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
|
|
@ -14,15 +14,38 @@ namespace onnxruntime {
|
|||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#IsInf
|
||||
|
||||
namespace op_kernel_type_control {
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, IsInf, Input, 0,
|
||||
float, double);
|
||||
using IsInfTypesOpset10 = TypeList<float, double>;
|
||||
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(
|
||||
kCpuExecutionProvider, kOnnxDomain, IsInf, 10, Input, 0,
|
||||
IsInfTypesOpset10);
|
||||
|
||||
using IsInfTypesOpset20 =
|
||||
TypeList<
|
||||
float,
|
||||
double
|
||||
#if !defined(DISABLE_FLOAT8_TYPES)
|
||||
,
|
||||
Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ
|
||||
#endif
|
||||
>;
|
||||
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(
|
||||
kCpuExecutionProvider,
|
||||
kOnnxDomain,
|
||||
IsInf,
|
||||
20,
|
||||
Input,
|
||||
0,
|
||||
IsInfTypesOpset20);
|
||||
} // namespace op_kernel_type_control
|
||||
|
||||
class IsInf final : public OpKernel {
|
||||
public:
|
||||
using EnabledDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
|
||||
IsInf, Input, 0);
|
||||
using EnabledDataTypes10 = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain,
|
||||
IsInf, 10, Input, 0);
|
||||
using EnabledDataTypes20 = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain,
|
||||
IsInf, 20, Input, 0);
|
||||
|
||||
explicit IsInf(const OpKernelInfo& info);
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
|
@ -30,14 +53,25 @@ class IsInf final : public OpKernel {
|
|||
private:
|
||||
int64_t detect_positive_{1};
|
||||
int64_t detect_negative_{1};
|
||||
int opset_;
|
||||
};
|
||||
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
IsInf,
|
||||
10,
|
||||
19,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T1",
|
||||
BuildKernelDefConstraintsFromTypeList<IsInf::EnabledDataTypes10>())
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
|
||||
IsInf);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
IsInf,
|
||||
10,
|
||||
20,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T1",
|
||||
BuildKernelDefConstraintsFromTypeList<IsInf::EnabledDataTypes>())
|
||||
BuildKernelDefConstraintsFromTypeList<IsInf::EnabledDataTypes20>())
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
|
||||
IsInf);
|
||||
|
||||
|
|
@ -46,6 +80,7 @@ IsInf::IsInf(const OpKernelInfo& info) : OpKernel(info) {
|
|||
ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_positive");
|
||||
status = info.GetAttr("detect_negative", &detect_negative_);
|
||||
ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_negative");
|
||||
opset_ = info.node().SinceVersion();
|
||||
}
|
||||
|
||||
namespace isinf_internal {
|
||||
|
|
@ -78,6 +113,49 @@ struct ComputeDispatchTarget {
|
|||
}
|
||||
}
|
||||
};
|
||||
|
||||
#if !defined(DISABLE_FLOAT8_TYPES)
|
||||
template <>
|
||||
struct ComputeDispatchTarget<Float8E4M3FN> {
|
||||
void operator()(const Tensor&, Tensor& Y, bool, bool) const {
|
||||
EigenMap<bool>(Y).array() = false;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ComputeDispatchTarget<Float8E4M3FNUZ> {
|
||||
void operator()(const Tensor&, Tensor& Y, bool, bool) const {
|
||||
EigenMap<bool>(Y).array() = false;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ComputeDispatchTarget<Float8E5M2> {
|
||||
void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const {
|
||||
auto& dims = X.Shape();
|
||||
auto input = ConstEigenVectorMap<uint8_t>(static_cast<const uint8_t*>(static_cast<const void*>(X.Data<Float8E5M2>())), onnxruntime::narrow<size_t>(dims.Size()));
|
||||
auto output = EigenMap<bool>(Y);
|
||||
|
||||
// S.11111.00
|
||||
if (detect_positive && detect_negative) {
|
||||
output.array() = input.array() == 0b01111100 || input.array() == 0b11111100;
|
||||
} else if (detect_positive) {
|
||||
output.array() = input.array() == 0b01111100;
|
||||
} else if (detect_negative) {
|
||||
output.array() = input.array() == 0b11111100;
|
||||
} else {
|
||||
output.array() = false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ComputeDispatchTarget<Float8E5M2FNUZ> {
|
||||
void operator()(const Tensor&, Tensor& Y, bool, bool) const {
|
||||
EigenMap<bool>(Y).array() = false;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
} // namespace isinf_internal
|
||||
|
||||
Status IsInf::Compute(OpKernelContext* context) const {
|
||||
|
|
@ -88,8 +166,13 @@ Status IsInf::Compute(OpKernelContext* context) const {
|
|||
|
||||
using namespace isinf_internal;
|
||||
|
||||
utils::MLTypeCallDispatcherFromTypeList<EnabledDataTypes> dispatcher{X.GetElementType()};
|
||||
dispatcher.Invoke<ComputeDispatchTarget>(X, Y, detect_positive_ != 0, detect_negative_ != 0);
|
||||
if (opset_ < 20) {
|
||||
utils::MLTypeCallDispatcherFromTypeList<EnabledDataTypes10> dispatcher{X.GetElementType()};
|
||||
dispatcher.Invoke<ComputeDispatchTarget>(X, Y, detect_positive_ != 0, detect_negative_ != 0);
|
||||
} else {
|
||||
utils::MLTypeCallDispatcherFromTypeList<EnabledDataTypes20> dispatcher{X.GetElementType()};
|
||||
dispatcher.Invoke<ComputeDispatchTarget>(X, Y, detect_positive_ != 0, detect_negative_ != 0);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,10 +20,20 @@ namespace onnxruntime {
|
|||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()), \
|
||||
IsNaN<data_type>);
|
||||
|
||||
#define ADD_TYPED_ISNAN_OP_13(data_type) \
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \
|
||||
IsNaN, \
|
||||
13, 19, \
|
||||
data_type, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<data_type>()) \
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()), \
|
||||
IsNaN<data_type>);
|
||||
|
||||
#define ADD_TYPED_ISNAN_OP(data_type) \
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
|
||||
IsNaN, \
|
||||
13, \
|
||||
20, \
|
||||
data_type, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<data_type>()) \
|
||||
|
|
@ -33,10 +43,20 @@ namespace onnxruntime {
|
|||
ADD_TYPED_ISNAN_OP_9(float);
|
||||
ADD_TYPED_ISNAN_OP_9(double);
|
||||
ADD_TYPED_ISNAN_OP_9(MLFloat16);
|
||||
ADD_TYPED_ISNAN_OP_13(float);
|
||||
ADD_TYPED_ISNAN_OP_13(double);
|
||||
ADD_TYPED_ISNAN_OP_13(MLFloat16);
|
||||
ADD_TYPED_ISNAN_OP(float);
|
||||
ADD_TYPED_ISNAN_OP(double);
|
||||
ADD_TYPED_ISNAN_OP(MLFloat16);
|
||||
|
||||
#if !defined(DISABLE_FLOAT8_TYPES)
|
||||
ADD_TYPED_ISNAN_OP(Float8E4M3FN);
|
||||
ADD_TYPED_ISNAN_OP(Float8E4M3FNUZ);
|
||||
ADD_TYPED_ISNAN_OP(Float8E5M2);
|
||||
ADD_TYPED_ISNAN_OP(Float8E5M2FNUZ);
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
Status IsNaN<T>::Compute(OpKernelContext* context) const {
|
||||
const auto* X_ptr = context->Input<Tensor>(0);
|
||||
|
|
@ -70,4 +90,63 @@ Status IsNaN<MLFloat16>::Compute(OpKernelContext* context) const {
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#if !defined(DISABLE_FLOAT8_TYPES)
|
||||
template <>
|
||||
Status IsNaN<Float8E4M3FN>::Compute(OpKernelContext* context) const {
|
||||
const auto* X = context->Input<Tensor>(0);
|
||||
auto& dims = X->Shape();
|
||||
auto& Y = *context->Output(0, dims);
|
||||
|
||||
auto input = ConstEigenVectorMap<uint8_t>(static_cast<const uint8_t*>(static_cast<const void*>(X->Data<Float8E4M3FN>())), onnxruntime::narrow<size_t>(dims.Size()));
|
||||
auto output = EigenMap<bool>(Y);
|
||||
|
||||
// S.1111.111
|
||||
std::transform(input.begin(), input.end(), output.begin(), [](uint8_t c) { return (c & 0x7f) == 0x7f; });
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <>
|
||||
Status IsNaN<Float8E4M3FNUZ>::Compute(OpKernelContext* context) const {
|
||||
const auto* X = context->Input<Tensor>(0);
|
||||
auto X_data = X->Data<Float8E4M3FNUZ>();
|
||||
auto& dims = X->Shape();
|
||||
auto shape_size = dims.Size();
|
||||
auto& Y = *context->Output(0, dims);
|
||||
|
||||
// 1.0000.000
|
||||
EigenMap<bool>(Y) =
|
||||
ConstEigenVectorMap<uint8_t>(static_cast<const uint8_t*>(static_cast<const void*>(X_data)), onnxruntime::narrow<size_t>(shape_size)).array() == 0x80;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <>
|
||||
Status IsNaN<Float8E5M2>::Compute(OpKernelContext* context) const {
|
||||
const auto* X = context->Input<Tensor>(0);
|
||||
auto& dims = X->Shape();
|
||||
auto& Y = *context->Output(0, dims);
|
||||
|
||||
auto input = ConstEigenVectorMap<uint8_t>(static_cast<const uint8_t*>(static_cast<const void*>(X->Data<Float8E5M2>())), onnxruntime::narrow<size_t>(dims.Size()));
|
||||
auto output = EigenMap<bool>(Y);
|
||||
|
||||
// S.11111.{01, 10, 11}
|
||||
std::transform(input.begin(), input.end(), output.begin(), [](uint8_t c) { return ((c & 0x7c) == 0x7c) && ((c & 0x03) != 0x00); });
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <>
|
||||
Status IsNaN<Float8E5M2FNUZ>::Compute(OpKernelContext* context) const {
|
||||
const auto* X = context->Input<Tensor>(0);
|
||||
auto X_data = X->Data<Float8E5M2FNUZ>();
|
||||
auto& dims = X->Shape();
|
||||
auto shape_size = dims.Size();
|
||||
auto& Y = *context->Output(0, dims);
|
||||
|
||||
// 1.0000.000
|
||||
EigenMap<bool>(Y) = ConstEigenVectorMap<uint8_t>(static_cast<const uint8_t*>(static_cast<const void*>(X_data)), onnxruntime::narrow<size_t>(shape_size)).array() == 0x80;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -17,85 +17,137 @@ constexpr double DOUBLE_INF = std::numeric_limits<double>::infinity();
|
|||
constexpr double DOUBLE_NINF = -std::numeric_limits<double>::infinity();
|
||||
constexpr double DOUBLE_NAN = std::numeric_limits<double>::quiet_NaN();
|
||||
|
||||
TEST(IsInfTest, test_isinf_float) {
|
||||
// Defaults for detect_negative = 1
|
||||
// detect_positive = 1
|
||||
OpTester test("IsInf", 10);
|
||||
|
||||
std::vector<int64_t> input_dim{6};
|
||||
std::vector<float> input = {-1.2f, FLOAT_NAN, FLOAT_INF, 2.8f, FLOAT_NINF, FLOAT_INF};
|
||||
test.AddInput<float>("X", input_dim, input);
|
||||
|
||||
std::vector<int64_t> output_dim(input_dim);
|
||||
test.AddOutput<bool>("Y", output_dim, {false, false, true, false, true, true});
|
||||
template <typename T>
|
||||
void run_is_inf_test(int opset, int64_t detect_positive, int64_t detect_negative, const std::initializer_list<T>& input, const std::initializer_list<bool>& output) {
|
||||
OpTester test("IsInf", opset);
|
||||
test.AddAttribute<int64_t>("detect_positive", detect_positive);
|
||||
test.AddAttribute<int64_t>("detect_negative", detect_negative);
|
||||
test.AddInput<T>("X", {onnxruntime::narrow<int64_t>(input.size())}, input);
|
||||
test.AddOutput<bool>("Y", {onnxruntime::narrow<int64_t>(output.size())}, output);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(IsInfTest, test_isinf_double) {
|
||||
// Defaults for detect_negative = 1
|
||||
// detect_positive = 1
|
||||
OpTester test("IsInf", 10);
|
||||
|
||||
std::vector<int64_t> input_dim{6};
|
||||
std::vector<double> input = {-1.2, DOUBLE_NAN, DOUBLE_INF, 2.8, DOUBLE_NINF, DOUBLE_INF};
|
||||
test.AddInput<double>("X", input_dim, input);
|
||||
|
||||
std::vector<int64_t> output_dim(input_dim);
|
||||
test.AddOutput<bool>("Y", output_dim, {false, false, true, false, true, true});
|
||||
test.Run();
|
||||
TEST(IsInfTest, test_isinf_float10) {
|
||||
std::initializer_list<float> input = {-1.2f, FLOAT_NAN, FLOAT_INF, 2.8f, FLOAT_NINF, FLOAT_INF};
|
||||
std::initializer_list<bool> output = {false, false, true, false, true, true};
|
||||
run_is_inf_test(10, 1, 1, input, output);
|
||||
}
|
||||
|
||||
TEST(IsInfTest, test_isinf_positive_float) {
|
||||
OpTester test("IsInf", 10);
|
||||
test.AddAttribute<int64_t>("detect_negative", 0);
|
||||
|
||||
std::vector<int64_t> input_dim{6};
|
||||
std::vector<float> input = {-1.7f, FLOAT_NAN, FLOAT_INF, 3.6f, FLOAT_NINF, FLOAT_INF};
|
||||
test.AddInput<float>("X", input_dim, input);
|
||||
|
||||
std::vector<int64_t> output_dim(input_dim);
|
||||
test.AddOutput<bool>("Y", output_dim, {false, false, true, false, false, true});
|
||||
test.Run();
|
||||
TEST(IsInfTest, test_isinf_float20) {
|
||||
std::initializer_list<float> input = {-1.2f, FLOAT_NAN, FLOAT_INF, 2.8f, FLOAT_NINF, FLOAT_INF};
|
||||
std::initializer_list<bool> output = {false, false, true, false, true, true};
|
||||
run_is_inf_test(20, 1, 1, input, output);
|
||||
}
|
||||
|
||||
TEST(IsInfTest, test_isinf_positive_double) {
|
||||
OpTester test("IsInf", 10);
|
||||
test.AddAttribute<int64_t>("detect_negative", 0);
|
||||
|
||||
std::vector<int64_t> input_dim{6};
|
||||
std::vector<double> input = {-1.7, DOUBLE_NAN, DOUBLE_INF, 3.6, DOUBLE_NINF, DOUBLE_INF};
|
||||
test.AddInput<double>("X", input_dim, input);
|
||||
|
||||
std::vector<int64_t> output_dim(input_dim);
|
||||
test.AddOutput<bool>("Y", output_dim, {false, false, true, false, false, true});
|
||||
test.Run();
|
||||
TEST(IsInfTest, test_isinf_double10) {
|
||||
std::initializer_list<double> input = {-1.2, DOUBLE_NAN, DOUBLE_INF, 2.8, DOUBLE_NINF, DOUBLE_INF};
|
||||
std::initializer_list<bool> output = {false, false, true, false, true, true};
|
||||
run_is_inf_test(10, 1, 1, input, output);
|
||||
}
|
||||
|
||||
TEST(IsInfTest, test_isinf_negative_float) {
|
||||
OpTester test("IsInf", 10);
|
||||
test.AddAttribute<int64_t>("detect_positive", 0);
|
||||
|
||||
std::vector<int64_t> input_dim{6};
|
||||
std::vector<float> input = {-1.7f, FLOAT_NAN, FLOAT_INF, 3.6f, FLOAT_NINF, FLOAT_INF};
|
||||
test.AddInput<float>("X", input_dim, input);
|
||||
|
||||
std::vector<int64_t> output_dim(input_dim);
|
||||
test.AddOutput<bool>("Y", output_dim, {false, false, false, false, true, false});
|
||||
test.Run();
|
||||
TEST(IsInfTest, test_isinf_double20) {
|
||||
std::initializer_list<double> input = {-1.2, DOUBLE_NAN, DOUBLE_INF, 2.8, DOUBLE_NINF, DOUBLE_INF};
|
||||
std::initializer_list<bool> output = {false, false, true, false, true, true};
|
||||
run_is_inf_test(20, 1, 1, input, output);
|
||||
}
|
||||
|
||||
TEST(IsInfTest, test_isinf_negative_double) {
|
||||
OpTester test("IsInf", 10);
|
||||
test.AddAttribute<int64_t>("detect_positive", 0);
|
||||
|
||||
std::vector<int64_t> input_dim{6};
|
||||
std::vector<double> input = {-1.7, DOUBLE_NAN, DOUBLE_INF, 3.6, DOUBLE_NINF, DOUBLE_INF};
|
||||
test.AddInput<double>("X", input_dim, input);
|
||||
|
||||
std::vector<int64_t> output_dim(input_dim);
|
||||
test.AddOutput<bool>("Y", output_dim, {false, false, false, false, true, false});
|
||||
test.Run();
|
||||
TEST(IsInfTest, test_isinf_positive_float10) {
|
||||
std::initializer_list<double> input = {-1.7f, FLOAT_NAN, FLOAT_INF, 3.6f, FLOAT_NINF, FLOAT_INF};
|
||||
std::initializer_list<bool> output = {false, false, true, false, false, true};
|
||||
run_is_inf_test(10, 1, 0, input, output);
|
||||
}
|
||||
|
||||
TEST(IsInfTest, test_isinf_positive_float20) {
|
||||
std::initializer_list<double> input = {-1.7f, FLOAT_NAN, FLOAT_INF, 3.6f, FLOAT_NINF, FLOAT_INF};
|
||||
std::initializer_list<bool> output = {false, false, true, false, false, true};
|
||||
run_is_inf_test(20, 1, 0, input, output);
|
||||
}
|
||||
|
||||
TEST(IsInfTest, test_isinf_positive_double10) {
|
||||
std::initializer_list<double> input = {-1.7, DOUBLE_NAN, DOUBLE_INF, 3.6, DOUBLE_NINF, DOUBLE_INF};
|
||||
std::initializer_list<bool> output = {false, false, true, false, false, true};
|
||||
run_is_inf_test(10, 1, 0, input, output);
|
||||
}
|
||||
|
||||
TEST(IsInfTest, test_isinf_positive_double20) {
|
||||
std::initializer_list<double> input = {-1.7, DOUBLE_NAN, DOUBLE_INF, 3.6, DOUBLE_NINF, DOUBLE_INF};
|
||||
std::initializer_list<bool> output = {false, false, true, false, false, true};
|
||||
run_is_inf_test(20, 1, 0, input, output);
|
||||
}
|
||||
|
||||
TEST(IsInfTest, test_isinf_negative_float10) {
|
||||
std::initializer_list<float> input = {-1.7f, FLOAT_NAN, FLOAT_INF, 3.6f, FLOAT_NINF, FLOAT_INF};
|
||||
std::initializer_list<bool> output = {false, false, false, false, true, false};
|
||||
run_is_inf_test(10, 0, 1, input, output);
|
||||
}
|
||||
|
||||
TEST(IsInfTest, test_isinf_negative_float20) {
|
||||
std::initializer_list<float> input = {-1.7f, FLOAT_NAN, FLOAT_INF, 3.6f, FLOAT_NINF, FLOAT_INF};
|
||||
std::initializer_list<bool> output = {false, false, false, false, true, false};
|
||||
run_is_inf_test(20, 0, 1, input, output);
|
||||
}
|
||||
|
||||
TEST(IsInfTest, test_isinf_negative_double10) {
|
||||
std::initializer_list<double> input = {-1.7, DOUBLE_NAN, DOUBLE_INF, 3.6, DOUBLE_NINF, DOUBLE_INF};
|
||||
std::initializer_list<bool> output = {false, false, false, false, true, false};
|
||||
run_is_inf_test(10, 0, 1, input, output);
|
||||
}
|
||||
|
||||
TEST(IsInfTest, test_isinf_negative_double20) {
|
||||
std::initializer_list<double> input = {-1.7, DOUBLE_NAN, DOUBLE_INF, 3.6, DOUBLE_NINF, DOUBLE_INF};
|
||||
std::initializer_list<bool> output = {false, false, false, false, true, false};
|
||||
run_is_inf_test(20, 0, 1, input, output);
|
||||
}
|
||||
|
||||
#if !defined(DISABLE_FLOAT8_TYPES)
|
||||
TEST(IsInfTest, test_Float8E4M3FN) {
|
||||
std::initializer_list<Float8E4M3FN> input = {
|
||||
Float8E4M3FN(-1.0f), Float8E4M3FN(FLOAT_NAN, false), Float8E4M3FN(1.0f), Float8E4M3FN(FLOAT_NINF, false), Float8E4M3FN(FLOAT_NINF, false), Float8E4M3FN(FLOAT_INF, false)};
|
||||
std::initializer_list<bool> output = {false, false, false, false, false, false};
|
||||
run_is_inf_test(20, 1, 1, input, output);
|
||||
}
|
||||
|
||||
TEST(IsInfTest, test_Float8E4M3FNUZ) {
|
||||
std::initializer_list<Float8E4M3FNUZ> input = {
|
||||
Float8E4M3FNUZ(-1.0f), Float8E4M3FNUZ(FLOAT_NAN, false), Float8E4M3FNUZ(1.0f), Float8E4M3FNUZ(FLOAT_NINF, false), Float8E4M3FNUZ(FLOAT_NINF, false), Float8E4M3FNUZ(FLOAT_INF, false)};
|
||||
std::initializer_list<bool> output = {false, false, false, false, false, false};
|
||||
run_is_inf_test(20, 1, 1, input, output);
|
||||
}
|
||||
|
||||
TEST(IsInfTest, test_Float8E5M2_detect_both) {
|
||||
std::initializer_list<Float8E5M2> input = {
|
||||
Float8E5M2(-1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(FLOAT_NAN, false), Float8E5M2(FLOAT_INF, false)};
|
||||
std::initializer_list<bool> output = {false, true, false, true, false, true};
|
||||
run_is_inf_test(20, 1, 1, input, output);
|
||||
}
|
||||
|
||||
TEST(IsInfTest, test_Float8E5M2_detect_positive) {
|
||||
std::initializer_list<Float8E5M2> input = {
|
||||
Float8E5M2(-1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(FLOAT_NAN, false), Float8E5M2(FLOAT_INF, false)};
|
||||
std::initializer_list<bool> output = {false, false, false, false, false, true};
|
||||
run_is_inf_test(20, 1, 0, input, output);
|
||||
}
|
||||
|
||||
TEST(IsInfTest, test_Float8E5M2_detect_negative) {
|
||||
std::initializer_list<Float8E5M2> input = {
|
||||
Float8E5M2(-1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(FLOAT_NAN, false), Float8E5M2(FLOAT_INF, false)};
|
||||
std::initializer_list<bool> output = {false, true, false, true, false, false};
|
||||
run_is_inf_test(20, 0, 1, input, output);
|
||||
}
|
||||
|
||||
TEST(IsInfTest, test_Float8E5M2_none) {
|
||||
std::initializer_list<Float8E5M2> input = {
|
||||
Float8E5M2(-1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(FLOAT_NAN, false), Float8E5M2(FLOAT_INF, false)};
|
||||
std::initializer_list<bool> output = {false, false, false, false, false, false};
|
||||
run_is_inf_test(20, 0, 0, input, output);
|
||||
}
|
||||
|
||||
TEST(IsInfTest, test_Float8E5M2FNUZ) {
|
||||
std::initializer_list<Float8E5M2FNUZ> input = {
|
||||
Float8E5M2FNUZ(-1.0f), Float8E5M2FNUZ(FLOAT_NINF, false), Float8E5M2FNUZ(1.0f), Float8E5M2FNUZ(FLOAT_NINF, false), Float8E5M2FNUZ(FLOAT_NAN, false), Float8E5M2FNUZ(FLOAT_INF, false)};
|
||||
std::initializer_list<bool> output = {false, false, false, false, false, false};
|
||||
run_is_inf_test(20, 1, 1, input, output);
|
||||
}
|
||||
#endif
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -9,29 +9,84 @@
|
|||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
TEST(IsNaNOpTest, IsNaNFloat) {
|
||||
OpTester test("IsNaN", 9, kOnnxDomain);
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
test.AddInput<float>("X", dims, {1.0f, NAN, 2.0f, NAN});
|
||||
test.AddOutput<bool>("Y", dims, {false, true, false, true});
|
||||
template <typename T>
|
||||
void run_is_nan_test(int opset, const std::vector<int64_t>& dims, const std::initializer_list<T>& input, const std::initializer_list<bool>& output) {
|
||||
OpTester test("IsNaN", opset, kOnnxDomain);
|
||||
test.AddInput<T>("X", dims, input);
|
||||
test.AddOutput<bool>("Y", dims, output);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(IsNaNOpTest, IsNaNFloat16) {
|
||||
OpTester test("IsNaN", 9, kOnnxDomain);
|
||||
TEST(IsNaNOpTest, IsNaNFloat9) {
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
test.AddInput<MLFloat16>("X", dims, std::initializer_list<MLFloat16>({MLFloat16(1.0f), MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN}));
|
||||
test.AddOutput<bool>("Y", dims, {false, true, false, true});
|
||||
test.Run();
|
||||
std::initializer_list<float> input = {1.0f, NAN, 2.0f, NAN};
|
||||
std::initializer_list<bool> output = {false, true, false, true};
|
||||
run_is_nan_test(9, dims, input, output);
|
||||
}
|
||||
|
||||
TEST(IsNaNOpTest, IsNaNDouble) {
|
||||
OpTester test("IsNaN", 9, kOnnxDomain);
|
||||
TEST(IsNaNOpTest, IsNaNFloat20) {
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
test.AddInput<double>("X", dims, {1.0, NAN, 2.0, NAN});
|
||||
test.AddOutput<bool>("Y", dims, {false, true, false, true});
|
||||
test.Run();
|
||||
std::initializer_list<float> input = {1.0f, NAN, 2.0f, NAN};
|
||||
std::initializer_list<bool> output = {false, true, false, true};
|
||||
run_is_nan_test(20, dims, input, output);
|
||||
}
|
||||
|
||||
TEST(IsNaNOpTest, IsNaNFloat16_9) {
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
std::initializer_list<MLFloat16> input = {MLFloat16(1.0f), MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN};
|
||||
std::initializer_list<bool> output = {false, true, false, true};
|
||||
run_is_nan_test(9, dims, input, output);
|
||||
}
|
||||
|
||||
TEST(IsNaNOpTest, IsNaNFloat16_20) {
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
std::initializer_list<MLFloat16> input = {MLFloat16(1.0f), MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN};
|
||||
std::initializer_list<bool> output = {false, true, false, true};
|
||||
run_is_nan_test(20, dims, input, output);
|
||||
}
|
||||
|
||||
TEST(IsNaNOpTest, IsNaNDouble9) {
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
std::initializer_list<double> input = {1.0, NAN, 2.0, NAN};
|
||||
std::initializer_list<bool> output = {false, true, false, true};
|
||||
run_is_nan_test(9, dims, input, output);
|
||||
}
|
||||
|
||||
TEST(IsNaNOpTest, IsNaNDouble20) {
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
std::initializer_list<double> input = {1.0, NAN, 2.0, NAN};
|
||||
std::initializer_list<bool> output = {false, true, false, true};
|
||||
run_is_nan_test(20, dims, input, output);
|
||||
}
|
||||
|
||||
#if !defined(DISABLE_FLOAT8_TYPES)
|
||||
TEST(IsNaNOpTest, IsNaNFloat8E4M3FN) {
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
std::initializer_list<Float8E4M3FN> input = {Float8E4M3FN(1.0f), Float8E4M3FN(-NAN), Float8E4M3FN(2.0f), Float8E4M3FN(NAN)};
|
||||
std::initializer_list<bool> output = {false, true, false, true};
|
||||
run_is_nan_test(20, dims, input, output);
|
||||
}
|
||||
|
||||
TEST(IsNaNOpTest, IsNaN_Float8E4M3FNUZ) {
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
std::initializer_list<Float8E4M3FNUZ> input = {Float8E4M3FNUZ(1.0f), Float8E4M3FNUZ(-NAN), Float8E4M3FNUZ(2.0f), Float8E4M3FNUZ(-NAN)};
|
||||
std::initializer_list<bool> output = {false, true, false, true};
|
||||
run_is_nan_test(20, dims, input, output);
|
||||
}
|
||||
|
||||
TEST(IsNaNOpTest, IsNaNFloat8E5M2) {
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
std::initializer_list<Float8E5M2> input = {Float8E5M2(1.0f), Float8E5M2(-NAN), Float8E5M2(2.0f), Float8E5M2(NAN)};
|
||||
std::initializer_list<bool> output = {false, true, false, true};
|
||||
run_is_nan_test(20, dims, input, output);
|
||||
}
|
||||
|
||||
TEST(IsNaNOpTest, IsNaN_Float8E5M2FNUZ) {
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
std::initializer_list<Float8E5M2FNUZ> input = {Float8E5M2FNUZ(1.0f), Float8E5M2FNUZ(-NAN), Float8E5M2FNUZ(2.0f), Float8E5M2FNUZ(NAN)};
|
||||
std::initializer_list<bool> output = {false, true, false, true};
|
||||
run_is_nan_test(20, dims, input, output);
|
||||
}
|
||||
#endif
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -283,12 +283,6 @@
|
|||
"^test_dft_axis",
|
||||
"^test_dft",
|
||||
"^test_dft_inverse",
|
||||
"^test_isinf",
|
||||
"^test_isinf_float16",
|
||||
"^test_isinf_negative",
|
||||
"^test_isinf_positive",
|
||||
"^test_isnan",
|
||||
"^test_isnan_float16",
|
||||
"^test_reduce_max_bool_inputs",
|
||||
"^test_reduce_min_bool_inputs",
|
||||
"^test_reduce_min_empty_set",
|
||||
|
|
|
|||
Loading…
Reference in a new issue