mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-02 23:39:58 +00:00
Support double type for a few ops (#1450)
* Initial commit * More ops * fix missing declarations for ReduceSum and ReduceSumSquare * Add tests for new ops supporting double * isable Add_dobule for OpenVINO EP
This commit is contained in:
parent
9d67292c8c
commit
1fc6f8ee5b
6 changed files with 298 additions and 59 deletions
|
|
@ -31,6 +31,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Ran
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, RandomUniformLike);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Multinomial);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, float, Add);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, double, Add);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int32_t, Add);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int64_t, Add);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, float, Sub);
|
||||
|
|
@ -61,16 +62,20 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Floor);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Ceil);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Reciprocal);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Sqrt);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Pow);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Exp);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Sqrt);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, double, Sqrt);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, float, Pow);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, double, Pow);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Exp);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, double, Exp);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Log);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Sum);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Sum);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Min);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Min);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Max);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Max);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Max);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, double, Max);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Not);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, And);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Or);
|
||||
|
|
@ -83,7 +88,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float, Equal);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Mean);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Mean);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Sin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, float, Sin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, double, Sin);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Cos);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Tan);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Asin);
|
||||
|
|
@ -133,8 +139,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, ReduceProd);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, ReduceProd);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, ReduceSum);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, ReduceSum);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, ReduceSum);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, ReduceSumSquare);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, ReduceSumSquare);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, ReduceSumSquare);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, ArgMax);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, ArgMax);
|
||||
|
|
@ -298,6 +306,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, RandomUniformLike)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Multinomial)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, float, Add)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, double, Add)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int32_t, Add)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int64_t, Add)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, float, Sub)>,
|
||||
|
|
@ -328,16 +337,20 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Floor)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Ceil)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Reciprocal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Sqrt)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Pow)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Exp)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Sqrt)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, double, Sqrt)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, float, Pow)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, double, Pow)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Exp)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, double, Exp)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Log)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Sum)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Sum)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Min)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Min)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Max)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Max)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Max)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, double, Max)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Not)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, And)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Or)>,
|
||||
|
|
@ -350,7 +363,8 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float, Equal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Mean)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Mean)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Sin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, float, Sin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, double, Sin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Cos)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Tan)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Asin)>,
|
||||
|
|
@ -401,8 +415,10 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, ReduceProd)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, ReduceSum)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, ReduceSum)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, ReduceSum)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, ReduceSumSquare)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, ReduceSumSquare)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, ReduceSumSquare)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, ArgMin)>,
|
||||
|
|
|
|||
|
|
@ -17,6 +17,13 @@ ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
|||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Add<float>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Add,
|
||||
7,
|
||||
double,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
|
||||
Add<double>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Add,
|
||||
7,
|
||||
|
|
@ -173,24 +180,48 @@ ONNX_CPU_OPERATOR_KERNEL(
|
|||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Reciprocal<float>);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Sqrt,
|
||||
6,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Sqrt<float>);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Sqrt,
|
||||
6,
|
||||
double,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
|
||||
Sqrt<double>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Pow,
|
||||
7,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Pow<float>);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Pow,
|
||||
7,
|
||||
double,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
|
||||
Pow<double>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Exp,
|
||||
6,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Exp<float>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Exp,
|
||||
6,
|
||||
double,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
|
||||
Exp<double>);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Log,
|
||||
6,
|
||||
|
|
@ -227,12 +258,20 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
|||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Max_6<float>);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Max,
|
||||
8,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Max_8<float>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Max,
|
||||
8,
|
||||
double,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
|
||||
Max_8<double>);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Not,
|
||||
1,
|
||||
|
|
@ -397,43 +436,43 @@ Status Reciprocal<float>::Compute(OpKernelContext* ctx) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template <>
|
||||
Status Sqrt<float>::Compute(OpKernelContext* ctx) const {
|
||||
template <typename T>
|
||||
Status Sqrt<T>::Compute(OpKernelContext* ctx) const {
|
||||
auto& X = *ctx->Input<Tensor>(0);
|
||||
auto& Y = *ctx->Output(0, X.Shape());
|
||||
|
||||
EigenMap<float>(Y) = EigenMap<float>(X).cwiseSqrt();
|
||||
EigenMap<T>(Y) = EigenMap<T>(X).cwiseSqrt();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <>
|
||||
Status Pow<float>::Compute(OpKernelContext* context) const {
|
||||
template <typename T>
|
||||
Status Pow<T>::Compute(OpKernelContext* context) const {
|
||||
const Tensor& Y = *context->Input<Tensor>(1);
|
||||
std::function<void(EigenVectorMap<float>, ConstEigenVectorMap<float>, float)> input1scalar =
|
||||
[](EigenVectorMap<float> output, ConstEigenVectorMap<float> input0, float input1) { output = Eigen::pow(input0.array(), input1); };
|
||||
std::function<void(EigenVectorMap<T>, ConstEigenVectorMap<T>, T)> input1scalar =
|
||||
[](EigenVectorMap<T> output, ConstEigenVectorMap<T> input0, T input1) { output = Eigen::pow(input0.array(), input1); };
|
||||
if (Y.Shape().Size() == 1) {
|
||||
float value = *Y.Data<float>();
|
||||
T value = *Y.Data<T>();
|
||||
if (value == 2.0) {
|
||||
input1scalar = [](EigenVectorMap<float> output, ConstEigenVectorMap<float> input0, float) { output = Eigen::square(input0.array()); };
|
||||
input1scalar = [](EigenVectorMap<T> output, ConstEigenVectorMap<T> input0, T) { output = Eigen::square(input0.array()); };
|
||||
} else if (value == 3.0) {
|
||||
input1scalar = [](EigenVectorMap<float> output, ConstEigenVectorMap<float> input0, float) { output = Eigen::cube(input0.array()); };
|
||||
input1scalar = [](EigenVectorMap<T> output, ConstEigenVectorMap<T> input0, T) { output = Eigen::cube(input0.array()); };
|
||||
}
|
||||
}
|
||||
|
||||
return BroadcastTwo<float, float>(
|
||||
return BroadcastTwo<T, T>(
|
||||
*context,
|
||||
[](EigenVectorMap<float> output, float input0, ConstEigenVectorMap<float> input1) { output = Eigen::pow(input0, input1.array()); },
|
||||
[](EigenVectorMap<T> output, T input0, ConstEigenVectorMap<T> input1) { output = Eigen::pow(input0, input1.array()); },
|
||||
input1scalar,
|
||||
[](EigenVectorMap<float> output, ConstEigenVectorMap<float> input0, ConstEigenVectorMap<float> input1) { output = Eigen::pow(input0.array(), input1.array()); });
|
||||
[](EigenVectorMap<T> output, ConstEigenVectorMap<T> input0, ConstEigenVectorMap<T> input1) { output = Eigen::pow(input0.array(), input1.array()); });
|
||||
}
|
||||
|
||||
template <>
|
||||
Status Exp<float>::Compute(OpKernelContext* ctx) const {
|
||||
template <typename T>
|
||||
Status Exp<T>::Compute(OpKernelContext* ctx) const {
|
||||
auto& X = *ctx->Input<Tensor>(0);
|
||||
auto& Y = *ctx->Output(0, X.Shape());
|
||||
|
||||
EigenMap<float>(Y) = EigenMap<float>(X).array().exp();
|
||||
EigenMap<T>(Y) = EigenMap<T>(X).array().exp();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -527,13 +566,13 @@ Status Max_6<float>::Compute(OpKernelContext* ctx) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template <>
|
||||
Status Max_8<float>::Compute(OpKernelContext* context) const {
|
||||
return BroadcastVariadic<float, float>(
|
||||
template <typename T>
|
||||
Status Max_8<T>::Compute(OpKernelContext* context) const {
|
||||
return BroadcastVariadic<T, T>(
|
||||
Node(), *context,
|
||||
[](EigenVectorMap<float> output, float input0, ConstEigenVectorMap<float> input1) { output = input1.array().max(input0); },
|
||||
[](EigenVectorMap<float> output, ConstEigenVectorMap<float> input0, float input1) { output = input0.array().max(input1); },
|
||||
[](EigenVectorMap<float> output, ConstEigenVectorMap<float> input0, ConstEigenVectorMap<float> input1) { output = input0.array().max(input1.array()); });
|
||||
[](EigenVectorMap<T> output, T input0, ConstEigenVectorMap<T> input1) { output = input1.array().max(input0); },
|
||||
[](EigenVectorMap<T> output, ConstEigenVectorMap<T> input0, T input1) { output = input0.array().max(input1); },
|
||||
[](EigenVectorMap<T> output, ConstEigenVectorMap<T> input0, ConstEigenVectorMap<T> input1) { output = input0.array().max(input1.array()); });
|
||||
}
|
||||
|
||||
Status Not::Compute(OpKernelContext* context) const {
|
||||
|
|
@ -682,17 +721,25 @@ class Sin final : public OpKernel {
|
|||
Status Compute(OpKernelContext* context) const override {
|
||||
auto& X = *context->Input<Tensor>(0);
|
||||
auto& Y = *context->Output(0, X.Shape());
|
||||
MakeEigenArrayMap<float>(Y) = MakeEigenArrayMap<float>(X).sin();
|
||||
MakeEigenArrayMap<T>(Y) = MakeEigenArrayMap<T>(X).sin();
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Sin,
|
||||
7,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Sin<float>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Sin,
|
||||
7,
|
||||
double,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
|
||||
Sin<double>);
|
||||
|
||||
template <typename T>
|
||||
class Cos final : public OpKernel {
|
||||
public:
|
||||
|
|
|
|||
|
|
@ -22,6 +22,14 @@ namespace onnxruntime {
|
|||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int32_t>()), \
|
||||
x<int32_t>);
|
||||
|
||||
#define REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(x, sinceVersion) \
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
|
||||
x, \
|
||||
sinceVersion, \
|
||||
double, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()), \
|
||||
x<double>);
|
||||
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceL1, 1);
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceL2, 1);
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceLogSum, 1);
|
||||
|
|
@ -31,7 +39,9 @@ REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMean, 1);
|
|||
REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMin, 1);
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceProd, 1);
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceSum, 1);
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ReduceSum, 1);
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceSumSquare, 1);
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ReduceSumSquare, 1);
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL(ArgMax, 1);
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL(ArgMin, 1);
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ TEST(MathOpTest, Add_int64) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Add) {
|
||||
TEST(MathOpTest, Add_float) {
|
||||
OpTester test("Add");
|
||||
std::vector<int64_t> dims{3, 3};
|
||||
test.AddInput<float>("A", dims,
|
||||
|
|
@ -49,6 +49,25 @@ TEST(MathOpTest, Add) {
|
|||
#endif
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Add_double) {
|
||||
OpTester test("Add");
|
||||
std::vector<int64_t> dims{3, 3};
|
||||
test.AddInput<double>("A", dims,
|
||||
{1.0, 2.0, -1.0,
|
||||
0.0, 1.5, -100.0,
|
||||
-5.4, 9.3, -10'000.0});
|
||||
test.AddInput<double>("B", dims,
|
||||
{-1.0, 4.4, 432.3,
|
||||
0.0, 3.5, 64.0,
|
||||
-5.4, 9.3, 10'000.0});
|
||||
test.AddOutput<double>("C", dims,
|
||||
{0.0, 6.4, 431.3,
|
||||
0.0, 5.0, -36.0,
|
||||
-10.8, 18.6, 0.0});
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); // Disabling OpenVINO as this type is not supported
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Add_Broadcast_Axis) {
|
||||
OpTester test("Add");
|
||||
|
||||
|
|
@ -384,7 +403,7 @@ TEST(MathOpTest, Reciprocal) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Sqrt) {
|
||||
TEST(MathOpTest, Sqrt_Float) {
|
||||
OpTester test("Sqrt");
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
test.AddInput<float>("X", dims,
|
||||
|
|
@ -396,7 +415,19 @@ TEST(MathOpTest, Sqrt) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Pow) {
|
||||
TEST(MathOpTest, Sqrt_Double) {
|
||||
OpTester test("Sqrt");
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
test.AddInput<double>("X", dims,
|
||||
{1.0, 4.0,
|
||||
0.0, 9.0});
|
||||
test.AddOutput<double>("Y", dims,
|
||||
{1.0, 2.0,
|
||||
0.0, 3.0});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Pow_Float) {
|
||||
OpTester test("Pow");
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
test.AddInput<float>("X", dims,
|
||||
|
|
@ -411,6 +442,21 @@ TEST(MathOpTest, Pow) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Pow_Double) {
|
||||
OpTester test("Pow");
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
test.AddInput<double>("X", dims,
|
||||
{2.0, 2.0,
|
||||
std::sqrt(2.0), 1.0});
|
||||
test.AddInput<double>("Y", dims,
|
||||
{0.0, 8.0,
|
||||
2.0, 9.0});
|
||||
test.AddOutput<double>("Z", dims,
|
||||
{1.0, 256.0,
|
||||
2.0, 1.0});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Pow_Broadcast_Scalar0) {
|
||||
OpTester test("Pow");
|
||||
|
||||
|
|
@ -431,7 +477,7 @@ TEST(MathOpTest, Pow_Broadcast_Scalar1) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Exp) {
|
||||
TEST(MathOpTest, Exp_float) {
|
||||
OpTester test("Exp");
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
test.AddInput<float>("X", dims,
|
||||
|
|
@ -444,6 +490,21 @@ TEST(MathOpTest, Exp) {
|
|||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: result differs
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Exp_double) {
|
||||
OpTester test("Exp");
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
test.AddInput<double>("X", dims,
|
||||
{0.0, 1.0,
|
||||
2.0, 10.0});
|
||||
test.AddOutput<double>("Y", dims,
|
||||
{1.0, std::exp(1.0),
|
||||
std::exp(2.0), std::exp(10.0)});
|
||||
test.SetOutputRelErr("Y", 1e-7f);
|
||||
// TODO: Check if this test's result really differs for tensorRT
|
||||
// For now basing this exclusion based on this test's float counterpart - Exp_float
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Log) {
|
||||
OpTester test("Log");
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
|
|
@ -610,7 +671,7 @@ TEST(MathOpTest, Max_6) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Max_8) {
|
||||
TEST(MathOpTest, Max_8_Float) {
|
||||
OpTester test("Max", 8);
|
||||
test.AddInput<float>("data_0", {1, 3},
|
||||
{1.0f, 2.0f, 3.0f});
|
||||
|
|
@ -627,6 +688,23 @@ TEST(MathOpTest, Max_8) {
|
|||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: Input batch size is inconsistent
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Max_8_Double) {
|
||||
OpTester test("Max", 8);
|
||||
test.AddInput<double>("data_0", {1, 3},
|
||||
{1.0, 2.0, 3.0});
|
||||
test.AddInput<double>("data_2", {3, 3},
|
||||
{10.0, 20.0, 30.0,
|
||||
40.0, 50.0, 60.0,
|
||||
70.0, 80.0, 90.0});
|
||||
test.AddInput<double>("data_1", {3, 1},
|
||||
{-1.0, -2.0, 300.0});
|
||||
test.AddOutput<double>("max", {3, 3},
|
||||
{10.0, 20.0, 30.0,
|
||||
40.0, 50.0, 60.0,
|
||||
300.0, 300.0, 300.0});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: Input batch size is inconsistent
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Max_8_2inputbroadcast) {
|
||||
OpTester test("Max", 8);
|
||||
test.AddInput<float>("data_0", {1, 3},
|
||||
|
|
@ -828,7 +906,7 @@ TEST(MathOpTest, Mean_8) {
|
|||
}
|
||||
|
||||
template <float (&op)(float value)>
|
||||
void TrigTest(OpTester& test, std::initializer_list<float> input) {
|
||||
void TrigFloatTest(OpTester& test, std::initializer_list<float> input) {
|
||||
std::vector<int64_t> dims{static_cast<int64_t>(input.size())};
|
||||
|
||||
std::vector<float> output;
|
||||
|
|
@ -840,59 +918,77 @@ void TrigTest(OpTester& test, std::initializer_list<float> input) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Sin) {
|
||||
template <double (&op)(double value)>
|
||||
void TrigDoubleTest(OpTester& test, std::initializer_list<double> input) {
|
||||
std::vector<int64_t> dims{static_cast<int64_t>(input.size())};
|
||||
|
||||
std::vector<double> output;
|
||||
for (auto v : input)
|
||||
output.push_back(op(v));
|
||||
|
||||
test.AddInput<double>("X", dims, input);
|
||||
test.AddOutput<double>("Y", dims, output);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, SinFloat) {
|
||||
OpTester test("Sin");
|
||||
TrigTest<std::sin>(test, {1.1f, -1.1f, 2.2f, -2.2f});
|
||||
TrigFloatTest<std::sin>(test, {1.1f, -1.1f, 2.2f, -2.2f});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, SinDouble) {
|
||||
OpTester test("Sin");
|
||||
TrigDoubleTest<std::sin>(test, {1.1, -1.1, 2.2, -2.2});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Cos) {
|
||||
OpTester test("Cos");
|
||||
TrigTest<std::cos>(test, {1.1f, -1.1f, 2.2f, -2.2f});
|
||||
TrigFloatTest<std::cos>(test, {1.1f, -1.1f, 2.2f, -2.2f});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Tan) {
|
||||
OpTester test("Tan");
|
||||
TrigTest<std::tan>(test, {-100.0f, -50.0f, 0.0f, 50.0f, 100.0f});
|
||||
TrigFloatTest<std::tan>(test, {-100.0f, -50.0f, 0.0f, 50.0f, 100.0f});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Asin) {
|
||||
OpTester test("Asin");
|
||||
TrigTest<std::asin>(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f});
|
||||
TrigFloatTest<std::asin>(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Acos) {
|
||||
OpTester test("Acos");
|
||||
TrigTest<std::acos>(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f});
|
||||
TrigFloatTest<std::acos>(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Atan) {
|
||||
OpTester test("Atan");
|
||||
TrigTest<std::atan>(test, {-10.0f, -5.0f, 0.0f, 5.0f, 10.0f});
|
||||
TrigFloatTest<std::atan>(test, {-10.0f, -5.0f, 0.0f, 5.0f, 10.0f});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Sinh) {
|
||||
OpTester test("Sinh", 9);
|
||||
TrigTest<std::sinh>(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f});
|
||||
TrigFloatTest<std::sinh>(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Cosh) {
|
||||
OpTester test("Cosh", 9);
|
||||
TrigTest<std::cosh>(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f});
|
||||
TrigFloatTest<std::cosh>(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Asinh) {
|
||||
OpTester test("Asinh", 9);
|
||||
TrigTest<std::asinh>(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f});
|
||||
TrigFloatTest<std::asinh>(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Acosh) {
|
||||
OpTester test("Acosh", 9);
|
||||
TrigTest<std::acosh>(test, {1.0f, 1.1f, 3.0f, 10.0f, 100.0f});
|
||||
TrigFloatTest<std::acosh>(test, {1.0f, 1.1f, 3.0f, 10.0f, 100.0f});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Atanh) {
|
||||
OpTester test("Atanh", 9);
|
||||
TrigTest<std::atanh>(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f});
|
||||
TrigFloatTest<std::atanh>(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Expand_8_3x3) {
|
||||
|
|
@ -999,9 +1095,9 @@ TEST(MathOpTest, Expand_8_3x1x3x1_int64) {
|
|||
test.AddInput<int64_t>("data_0", {1, 3, 1, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
test.AddInput<int64_t>("data_1", {4}, {3, 1, 3, 1});
|
||||
test.AddOutput<int64_t>("result", {3, 3, 3, 3},
|
||||
{1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 7, 8, 9, 7, 8, 9, 7, 8, 9,
|
||||
1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 7, 8, 9, 7, 8, 9, 7, 8, 9,
|
||||
1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 7, 8, 9, 7, 8, 9, 7, 8, 9,});
|
||||
{1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 7, 8, 9, 7, 8, 9, 7, 8, 9,
|
||||
1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 7, 8, 9, 7, 8, 9, 7, 8, 9,
|
||||
1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 7, 8, 9, 7, 8, 9, 7, 8, 9,});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -676,6 +676,23 @@ TEST(ReductionOpTest, ReduceSum) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ReductionOpTest, ReduceSum_double) {
|
||||
OpTester test("ReduceSum");
|
||||
test.AddAttribute("axes", std::vector<int64_t>{0, 2});
|
||||
test.AddAttribute("keepdims", (int64_t)1);
|
||||
test.AddInput<double>("data", {3, 2, 2},
|
||||
{1.0, 2.0,
|
||||
3.0, 4.0,
|
||||
|
||||
5.0, 6.0,
|
||||
7.0, 8.0,
|
||||
|
||||
9.0, 10.0,
|
||||
11.0, 12.0});
|
||||
test.AddOutput<double>("reduced", {1, 2, 1}, {33.0, 45.0});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ReductionOpTest, ReduceSum_axes01) {
|
||||
OpTester test("ReduceSum");
|
||||
test.AddAttribute("axes", std::vector<int64_t>{2});
|
||||
|
|
@ -798,6 +815,23 @@ TEST(ReductionOpTest, ReduceSumSquare) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ReductionOpTest, ReduceSumSquare_double) {
|
||||
OpTester test("ReduceSumSquare");
|
||||
test.AddAttribute("axes", std::vector<int64_t>{0, 2});
|
||||
test.AddAttribute("keepdims", (int64_t)1);
|
||||
test.AddInput<double>("data", {3, 2, 2},
|
||||
{1.0, 2.0,
|
||||
3.0, 4.0,
|
||||
|
||||
5.0, 6.0,
|
||||
7.0, 8.0,
|
||||
|
||||
9.0, 10.0,
|
||||
11.0, 12.0});
|
||||
test.AddOutput<double>("reduced", {1, 2, 1}, {247.0, 403.});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ReductionOpTest, ReduceSumSquare_int32) {
|
||||
OpTester test("ReduceSumSquare");
|
||||
test.AddAttribute("axes", std::vector<int64_t>{0, 2});
|
||||
|
|
|
|||
|
|
@ -34,6 +34,42 @@ void Check(const OpTester::Data& expected_data, const Tensor& output_tensor, con
|
|||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void Check<double>(const OpTester::Data& expected_data, const Tensor& output_tensor, const std::string& provider_type) {
|
||||
auto& expected_tensor = expected_data.data_.Get<Tensor>();
|
||||
auto* expected = expected_tensor.template Data<double>();
|
||||
auto* output = output_tensor.template Data<double>();
|
||||
auto size = output_tensor.Shape().Size();
|
||||
|
||||
bool has_abs_err = expected_data.absolute_error_.has_value();
|
||||
bool has_rel_err = expected_data.relative_error_.has_value();
|
||||
|
||||
double threshold = 0.001;
|
||||
#ifdef USE_CUDA
|
||||
threshold = 0.005;
|
||||
#endif
|
||||
|
||||
for (int i = 0; i < size; ++i) {
|
||||
if (std::isinf(expected[i])) { // Test infinity for equality
|
||||
EXPECT_EQ(expected[i], output[i]);
|
||||
} else if (std::isnan(expected[i])) {
|
||||
EXPECT_TRUE(std::isnan(output[i])) << "Expected output " << i << " to be NaN";
|
||||
} else {
|
||||
if (!has_abs_err && !has_rel_err) {
|
||||
// the default for existing tests
|
||||
EXPECT_NEAR(expected[i], output[i], threshold) << "provider_type: " << provider_type;
|
||||
} else {
|
||||
if (has_abs_err) {
|
||||
EXPECT_NEAR(expected[i], output[i], expected_data.absolute_error_.value()) << "provider_type: " << provider_type;
|
||||
}
|
||||
if (has_rel_err) {
|
||||
EXPECT_NEAR(expected[i], output[i], expected_data.relative_error_.value() * std::abs(expected[i])) << "provider_type: " << provider_type;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void Check<float>(const OpTester::Data& expected_data, const Tensor& output_tensor, const std::string& provider_type) {
|
||||
auto& expected_tensor = expected_data.data_.Get<Tensor>();
|
||||
|
|
|
|||
Loading…
Reference in a new issue