Register kernel for Greater int64 (#1546)

Register int64 for Greater and refactor the register code
This commit is contained in:
Yufeng Li 2019-08-02 14:01:43 -07:00 committed by GitHub
parent cb71c69d5e
commit a098be12ba
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 127 additions and 336 deletions

View file

@ -59,21 +59,21 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Neg);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, int8_t, Neg);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, int32_t, Neg);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Floor);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Ceil);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Reciprocal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Floor);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Ceil);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Reciprocal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Sqrt);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, double, Sqrt);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, float, Pow);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, double, Pow);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Exp);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, double, Exp);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Log);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Sum);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Sum);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Min);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Min);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Max);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Log);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Sum);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Sum);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Min);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Min);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Max);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Max);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, double, Max);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Not);
@ -86,8 +86,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int32_t, Equal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int64_t, Equal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float, Equal);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Mean);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Mean);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Mean);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Mean);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, float, Sin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, double, Sin);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Cos);
@ -216,6 +216,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Com
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConstantOfShape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MeanVarianceNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Greater);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Greater);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Less);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, string, Cast);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike);
@ -223,7 +224,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MLFloat16, IsNaN);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sign);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Shrink);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Erf);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, Erf);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_int64_t_int64_t, OneHot);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float_int64_t_int64_t, OneHot);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_string_int64_t, OneHot);
@ -334,21 +335,21 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Neg)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, int8_t, Neg)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, int32_t, Neg)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Floor)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Ceil)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Reciprocal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Floor)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Ceil)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Reciprocal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Sqrt)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, double, Sqrt)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, float, Pow)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, double, Pow)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Exp)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, double, Exp)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Log)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Sum)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Sum)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Min)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Min)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Max)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Log)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Sum)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Sum)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Min)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Min)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Max)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Max)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, double, Max)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Not)>,
@ -361,8 +362,8 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int32_t, Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int64_t, Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float, Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Mean)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Mean)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Mean)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Mean)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, float, Sin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, double, Sin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Cos)>,
@ -491,6 +492,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConstantOfShape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MeanVarianceNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Greater)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Greater)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Less)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, string, Cast)>,
@ -498,7 +500,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MLFloat16, IsNaN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Shrink)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Erf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, Erf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_int64_t_int64_t, OneHot)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float_int64_t_int64_t, OneHot)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_string_int64_t, OneHot)>,

View file

@ -10,267 +10,100 @@
namespace onnxruntime {
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Add,
7,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Add<float>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Add,
7,
double,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
Add<double>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Add,
7,
int32_t,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int32_t>()),
Add<int32_t>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Add,
7,
int64_t,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int64_t>()),
Add<int64_t>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Sub,
7,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Sub<float>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Sub,
7,
double,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
Sub<double>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Sub,
7,
int32_t,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int32_t>()),
Sub<int32_t>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Sub,
7,
int64_t,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int64_t>()),
Sub<int64_t>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Mul,
7,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Mul<float>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Mul,
7,
double,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
Mul<double>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Mul,
7,
int32_t,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int32_t>()),
Mul<int32_t>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Mul,
7,
int64_t,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int64_t>()),
Mul<int64_t>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Div,
7,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Div<float>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Div,
7,
double,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
Div<double>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Div,
7,
int32_t,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int32_t>()),
Div<int32_t>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Div,
7,
int64_t,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int64_t>()),
Div<int64_t>);
#define REG_ABS_KERNEL(TYPE) \
#define REG_ELEMENTWISE_TYPED_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
Abs, \
6, \
OP_TYPE, \
VERSION, \
TYPE, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()), \
Abs<TYPE>);
KERNEL_CLASS<TYPE>);
REG_ABS_KERNEL(float)
REG_ABS_KERNEL(double)
REG_ABS_KERNEL(int8_t)
REG_ABS_KERNEL(int16_t)
REG_ABS_KERNEL(int32_t)
REG_ABS_KERNEL(int64_t)
REG_ABS_KERNEL(uint8_t)
REG_ABS_KERNEL(uint16_t)
REG_ABS_KERNEL(uint32_t)
REG_ABS_KERNEL(uint64_t)
#define REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, TYPE, KERNEL_CLASS) \
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \
OP_TYPE, \
VERSION_FROM, VERSION_TO, \
TYPE, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()), \
KERNEL_CLASS<TYPE>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Neg,
6,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Neg<float>);
REG_ELEMENTWISE_TYPED_KERNEL(Add, 7, float, Add);
REG_ELEMENTWISE_TYPED_KERNEL(Add, 7, double, Add);
REG_ELEMENTWISE_TYPED_KERNEL(Add, 7, int32_t, Add);
REG_ELEMENTWISE_TYPED_KERNEL(Add, 7, int64_t, Add);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Neg,
6,
int8_t,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int8_t>()),
Neg<int8_t>);
REG_ELEMENTWISE_TYPED_KERNEL(Sub, 7, float, Sub);
REG_ELEMENTWISE_TYPED_KERNEL(Sub, 7, double, Sub);
REG_ELEMENTWISE_TYPED_KERNEL(Sub, 7, int32_t, Sub);
REG_ELEMENTWISE_TYPED_KERNEL(Sub, 7, int64_t, Sub);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Neg,
6,
int32_t,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int32_t>()),
Neg<int32_t>);
REG_ELEMENTWISE_TYPED_KERNEL(Mul, 7, float, Mul);
REG_ELEMENTWISE_TYPED_KERNEL(Mul, 7, double, Mul);
REG_ELEMENTWISE_TYPED_KERNEL(Mul, 7, int32_t, Mul);
REG_ELEMENTWISE_TYPED_KERNEL(Mul, 7, int64_t, Mul);
ONNX_CPU_OPERATOR_KERNEL(
Floor,
6,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Floor<float>);
REG_ELEMENTWISE_TYPED_KERNEL(Div, 7, float, Div);
REG_ELEMENTWISE_TYPED_KERNEL(Div, 7, double, Div);
REG_ELEMENTWISE_TYPED_KERNEL(Div, 7, int32_t, Div);
REG_ELEMENTWISE_TYPED_KERNEL(Div, 7, int64_t, Div);
ONNX_CPU_OPERATOR_KERNEL(
Ceil,
6,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Ceil<float>);
REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, float, Abs);
REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, double, Abs);
REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, int8_t, Abs);
REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, int16_t, Abs);
REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, int32_t, Abs);
REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, int64_t, Abs);
REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, uint8_t, Abs);
REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, uint16_t, Abs);
REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, uint32_t, Abs);
REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, uint64_t, Abs);
ONNX_CPU_OPERATOR_KERNEL(
Reciprocal,
6,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Reciprocal<float>);
REG_ELEMENTWISE_TYPED_KERNEL(Neg, 6, float, Neg);
REG_ELEMENTWISE_TYPED_KERNEL(Neg, 6, int8_t, Neg);
REG_ELEMENTWISE_TYPED_KERNEL(Neg, 6, int32_t, Neg);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Sqrt,
6,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Sqrt<float>);
REG_ELEMENTWISE_TYPED_KERNEL(Floor, 6, float, Floor);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Sqrt,
6,
double,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
Sqrt<double>);
REG_ELEMENTWISE_TYPED_KERNEL(Ceil, 6, float, Ceil);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Pow,
7,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Pow<float>);
REG_ELEMENTWISE_TYPED_KERNEL(Reciprocal, 6, float, Reciprocal);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Pow,
7,
double,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
Pow<double>);
REG_ELEMENTWISE_TYPED_KERNEL(Sqrt, 6, float, Sqrt);
REG_ELEMENTWISE_TYPED_KERNEL(Sqrt, 6, double, Sqrt);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Exp,
6,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Exp<float>);
REG_ELEMENTWISE_TYPED_KERNEL(Pow, 7, float, Pow);
REG_ELEMENTWISE_TYPED_KERNEL(Pow, 7, double, Pow);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Exp,
6,
double,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
Exp<double>);
REG_ELEMENTWISE_TYPED_KERNEL(Exp, 6, float, Exp);
REG_ELEMENTWISE_TYPED_KERNEL(Exp, 6, double, Exp);
ONNX_CPU_OPERATOR_KERNEL(
Log,
6,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Log<float>);
REG_ELEMENTWISE_TYPED_KERNEL(Log, 6, float, Log);
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Sum,
6, 7,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Sum_6<float>);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sum, 6, 7, float, Sum_6);
REG_ELEMENTWISE_TYPED_KERNEL(Sum, 8, float, Sum_8);
ONNX_CPU_OPERATOR_KERNEL(
Sum,
8,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Sum_8<float>);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Min, 6, 7, float, Min_6);
REG_ELEMENTWISE_TYPED_KERNEL(Min, 8, float, Min_8);
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Min,
6, 7,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Min_6<float>);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Max, 6, 7, float, Max_6);
REG_ELEMENTWISE_TYPED_KERNEL(Max, 8, float, Max_8);
REG_ELEMENTWISE_TYPED_KERNEL(Max, 8, double, Max_8);
ONNX_CPU_OPERATOR_KERNEL(
Min,
8,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Min_8<float>);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Less, 7, 9, float, Less);
REG_ELEMENTWISE_TYPED_KERNEL(Less, 9, int32_t, Less);
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Max,
6, 7,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Max_6<float>);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Greater, 7, 9, float, Greater)
REG_ELEMENTWISE_TYPED_KERNEL(Greater, 9, int32_t, Greater);
REG_ELEMENTWISE_TYPED_KERNEL(Greater, 9, int64_t, Greater);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Max,
8,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Max_8<float>);
REG_ELEMENTWISE_TYPED_KERNEL(Equal, 7, bool, Equal);
REG_ELEMENTWISE_TYPED_KERNEL(Equal, 7, int32_t, Equal);
REG_ELEMENTWISE_TYPED_KERNEL(Equal, 7, int64_t, Equal);
REG_ELEMENTWISE_TYPED_KERNEL(Equal, 11, float, Equal);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Max,
8,
double,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
Max_8<double>);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mean, 6, 7, float, Mean_6);
REG_ELEMENTWISE_TYPED_KERNEL(Mean, 8, float, Mean_8);
REG_ELEMENTWISE_TYPED_KERNEL(Erf, 9, float, Erf);
ONNX_CPU_OPERATOR_KERNEL(
Not,
@ -296,80 +129,6 @@ ONNX_CPU_OPERATOR_KERNEL(
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<bool>()),
Xor);
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
Less,
7, 9,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Less<float>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Less,
9,
int32_t,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int32_t>()),
Less<int32_t>);
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
Greater,
7, 9,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Greater<float>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Greater,
9,
int32_t,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int32_t>()),
Greater<int32_t>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Equal,
7,
bool,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<bool>()),
Equal<bool>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Equal,
7,
int32_t,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int32_t>()),
Equal<int32_t>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Equal,
7,
int64_t,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int64_t>()),
Equal<int64_t>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Equal,
11,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Equal<float>);
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Mean,
6, 7,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Mean_6<float>);
ONNX_CPU_OPERATOR_KERNEL(
Mean,
8,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Mean_8<float>);
ONNX_CPU_OPERATOR_KERNEL(
Erf,
9,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Erf<float>);
template <typename T>
Status Add<T>::Compute(OpKernelContext* context) const {
return BroadcastTwo<T, T>(

View file

@ -806,7 +806,7 @@ TEST(MathOpTest, Less_Scalar1) {
test.Run();
}
TEST(MathOpTest, Greater) {
TEST(MathOpTest, Greater_7) {
OpTester test("Greater");
std::vector<int64_t> dims{4};
test.AddInput<float>("A", dims, {1.0f, 0.0f, -1.0f, -1.0f});
@ -815,6 +815,36 @@ TEST(MathOpTest, Greater) {
test.Run();
}
TEST( MathOpTest, Greater_9_float )
{
OpTester test( "Greater", 9 );
std::vector<int64_t> dims { 4 };
test.AddInput<float>( "A", dims, { 1.0f, 0.0f, -1.0f, -1.0f } );
test.AddInput<float>( "B", dims, { 1.0f, 1.0f, 2.0f, -1.0f } );
test.AddOutput<bool>( "C", dims, { false, false, false, false } );
test.Run();
}
TEST( MathOpTest, Greater_9_int32 )
{
OpTester test( "Greater", 9 );
std::vector<int64_t> dims { 4 };
test.AddInput<int32_t>( "A", dims, { 10, 11, 12, 13 } );
test.AddInput<int32_t>( "B", dims, { 15, 7, 12, 9 } );
test.AddOutput<bool>( "C", dims, { false, true, false, true } );
test.Run();
}
TEST( MathOpTest, Greater_9_int64 )
{
OpTester test( "Greater", 9 );
std::vector<int64_t> dims { 4 };
test.AddInput<int64_t>( "A", dims, { 10, 11, 12, 13 } );
test.AddInput<int64_t>( "B", dims, { 15, 7, 12, 9 } );
test.AddOutput<bool>( "C", dims, { false, true, false, true } );
test.Run();
}
TEST(MathOpTest, Equal_bool) {
OpTester test("Equal");
std::vector<int64_t> dims{4};