mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
Register kernel for Greater int64 (#1546)
Register int64 for Greater and refactor the register code
This commit is contained in:
parent
cb71c69d5e
commit
a098be12ba
3 changed files with 127 additions and 336 deletions
|
|
@ -59,21 +59,21 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Neg);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, int8_t, Neg);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, int32_t, Neg);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Floor);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Ceil);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Reciprocal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Floor);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Ceil);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Reciprocal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Sqrt);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, double, Sqrt);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, float, Pow);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, double, Pow);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Exp);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, double, Exp);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Log);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Sum);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Sum);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Min);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Min);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Max);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Log);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Sum);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Sum);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Min);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Min);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Max);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Max);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, double, Max);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Not);
|
||||
|
|
@ -86,8 +86,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int32_t, Equal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int64_t, Equal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float, Equal);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Mean);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Mean);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Mean);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Mean);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, float, Sin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, double, Sin);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Cos);
|
||||
|
|
@ -216,6 +216,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Com
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConstantOfShape);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MeanVarianceNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Greater);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Greater);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Less);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, string, Cast);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike);
|
||||
|
|
@ -223,7 +224,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MLFloat16, IsNaN);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sign);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Shrink);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Erf);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, Erf);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_int64_t_int64_t, OneHot);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float_int64_t_int64_t, OneHot);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_string_int64_t, OneHot);
|
||||
|
|
@ -334,21 +335,21 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Neg)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, int8_t, Neg)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, int32_t, Neg)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Floor)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Ceil)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Reciprocal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Floor)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Ceil)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Reciprocal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Sqrt)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, double, Sqrt)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, float, Pow)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, double, Pow)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Exp)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, double, Exp)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Log)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Sum)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Sum)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Min)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Min)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Max)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Log)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Sum)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Sum)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Min)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Min)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Max)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Max)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, double, Max)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Not)>,
|
||||
|
|
@ -361,8 +362,8 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int32_t, Equal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int64_t, Equal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float, Equal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Mean)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Mean)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Mean)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Mean)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, float, Sin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, double, Sin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Cos)>,
|
||||
|
|
@ -491,6 +492,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConstantOfShape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MeanVarianceNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Greater)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Greater)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Less)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, string, Cast)>,
|
||||
|
|
@ -498,7 +500,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MLFloat16, IsNaN)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Shrink)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Erf)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, Erf)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_int64_t_int64_t, OneHot)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float_int64_t_int64_t, OneHot)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_string_int64_t, OneHot)>,
|
||||
|
|
|
|||
|
|
@ -10,267 +10,100 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Add,
|
||||
7,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Add<float>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Add,
|
||||
7,
|
||||
double,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
|
||||
Add<double>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Add,
|
||||
7,
|
||||
int32_t,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int32_t>()),
|
||||
Add<int32_t>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Add,
|
||||
7,
|
||||
int64_t,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int64_t>()),
|
||||
Add<int64_t>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Sub,
|
||||
7,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Sub<float>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Sub,
|
||||
7,
|
||||
double,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
|
||||
Sub<double>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Sub,
|
||||
7,
|
||||
int32_t,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int32_t>()),
|
||||
Sub<int32_t>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Sub,
|
||||
7,
|
||||
int64_t,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int64_t>()),
|
||||
Sub<int64_t>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Mul,
|
||||
7,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Mul<float>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Mul,
|
||||
7,
|
||||
double,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
|
||||
Mul<double>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Mul,
|
||||
7,
|
||||
int32_t,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int32_t>()),
|
||||
Mul<int32_t>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Mul,
|
||||
7,
|
||||
int64_t,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int64_t>()),
|
||||
Mul<int64_t>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Div,
|
||||
7,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Div<float>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Div,
|
||||
7,
|
||||
double,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
|
||||
Div<double>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Div,
|
||||
7,
|
||||
int32_t,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int32_t>()),
|
||||
Div<int32_t>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Div,
|
||||
7,
|
||||
int64_t,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int64_t>()),
|
||||
Div<int64_t>);
|
||||
|
||||
#define REG_ABS_KERNEL(TYPE) \
|
||||
#define REG_ELEMENTWISE_TYPED_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
|
||||
Abs, \
|
||||
6, \
|
||||
OP_TYPE, \
|
||||
VERSION, \
|
||||
TYPE, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()), \
|
||||
Abs<TYPE>);
|
||||
KERNEL_CLASS<TYPE>);
|
||||
|
||||
REG_ABS_KERNEL(float)
|
||||
REG_ABS_KERNEL(double)
|
||||
REG_ABS_KERNEL(int8_t)
|
||||
REG_ABS_KERNEL(int16_t)
|
||||
REG_ABS_KERNEL(int32_t)
|
||||
REG_ABS_KERNEL(int64_t)
|
||||
REG_ABS_KERNEL(uint8_t)
|
||||
REG_ABS_KERNEL(uint16_t)
|
||||
REG_ABS_KERNEL(uint32_t)
|
||||
REG_ABS_KERNEL(uint64_t)
|
||||
#define REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, TYPE, KERNEL_CLASS) \
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \
|
||||
OP_TYPE, \
|
||||
VERSION_FROM, VERSION_TO, \
|
||||
TYPE, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()), \
|
||||
KERNEL_CLASS<TYPE>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Neg,
|
||||
6,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Neg<float>);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Add, 7, float, Add);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Add, 7, double, Add);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Add, 7, int32_t, Add);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Add, 7, int64_t, Add);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Neg,
|
||||
6,
|
||||
int8_t,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int8_t>()),
|
||||
Neg<int8_t>);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Sub, 7, float, Sub);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Sub, 7, double, Sub);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Sub, 7, int32_t, Sub);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Sub, 7, int64_t, Sub);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Neg,
|
||||
6,
|
||||
int32_t,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int32_t>()),
|
||||
Neg<int32_t>);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Mul, 7, float, Mul);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Mul, 7, double, Mul);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Mul, 7, int32_t, Mul);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Mul, 7, int64_t, Mul);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Floor,
|
||||
6,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Floor<float>);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Div, 7, float, Div);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Div, 7, double, Div);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Div, 7, int32_t, Div);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Div, 7, int64_t, Div);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Ceil,
|
||||
6,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Ceil<float>);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, float, Abs);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, double, Abs);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, int8_t, Abs);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, int16_t, Abs);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, int32_t, Abs);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, int64_t, Abs);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, uint8_t, Abs);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, uint16_t, Abs);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, uint32_t, Abs);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, uint64_t, Abs);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Reciprocal,
|
||||
6,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Reciprocal<float>);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Neg, 6, float, Neg);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Neg, 6, int8_t, Neg);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Neg, 6, int32_t, Neg);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Sqrt,
|
||||
6,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Sqrt<float>);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Floor, 6, float, Floor);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Sqrt,
|
||||
6,
|
||||
double,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
|
||||
Sqrt<double>);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Ceil, 6, float, Ceil);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Pow,
|
||||
7,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Pow<float>);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Reciprocal, 6, float, Reciprocal);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Pow,
|
||||
7,
|
||||
double,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
|
||||
Pow<double>);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Sqrt, 6, float, Sqrt);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Sqrt, 6, double, Sqrt);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Exp,
|
||||
6,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Exp<float>);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Pow, 7, float, Pow);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Pow, 7, double, Pow);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Exp,
|
||||
6,
|
||||
double,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
|
||||
Exp<double>);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Exp, 6, float, Exp);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Exp, 6, double, Exp);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Log,
|
||||
6,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Log<float>);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Log, 6, float, Log);
|
||||
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
Sum,
|
||||
6, 7,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Sum_6<float>);
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sum, 6, 7, float, Sum_6);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Sum, 8, float, Sum_8);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Sum,
|
||||
8,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Sum_8<float>);
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Min, 6, 7, float, Min_6);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Min, 8, float, Min_8);
|
||||
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
Min,
|
||||
6, 7,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Min_6<float>);
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Max, 6, 7, float, Max_6);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Max, 8, float, Max_8);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Max, 8, double, Max_8);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Min,
|
||||
8,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Min_8<float>);
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Less, 7, 9, float, Less);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Less, 9, int32_t, Less);
|
||||
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
Max,
|
||||
6, 7,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Max_6<float>);
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Greater, 7, 9, float, Greater)
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Greater, 9, int32_t, Greater);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Greater, 9, int64_t, Greater);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Max,
|
||||
8,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Max_8<float>);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Equal, 7, bool, Equal);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Equal, 7, int32_t, Equal);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Equal, 7, int64_t, Equal);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Equal, 11, float, Equal);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Max,
|
||||
8,
|
||||
double,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
|
||||
Max_8<double>);
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mean, 6, 7, float, Mean_6);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Mean, 8, float, Mean_8);
|
||||
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Erf, 9, float, Erf);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Not,
|
||||
|
|
@ -296,80 +129,6 @@ ONNX_CPU_OPERATOR_KERNEL(
|
|||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<bool>()),
|
||||
Xor);
|
||||
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
|
||||
Less,
|
||||
7, 9,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Less<float>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Less,
|
||||
9,
|
||||
int32_t,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int32_t>()),
|
||||
Less<int32_t>);
|
||||
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
|
||||
Greater,
|
||||
7, 9,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Greater<float>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Greater,
|
||||
9,
|
||||
int32_t,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int32_t>()),
|
||||
Greater<int32_t>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Equal,
|
||||
7,
|
||||
bool,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<bool>()),
|
||||
Equal<bool>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Equal,
|
||||
7,
|
||||
int32_t,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int32_t>()),
|
||||
Equal<int32_t>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Equal,
|
||||
7,
|
||||
int64_t,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int64_t>()),
|
||||
Equal<int64_t>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Equal,
|
||||
11,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Equal<float>);
|
||||
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
Mean,
|
||||
6, 7,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Mean_6<float>);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Mean,
|
||||
8,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Mean_8<float>);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Erf,
|
||||
9,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Erf<float>);
|
||||
|
||||
template <typename T>
|
||||
Status Add<T>::Compute(OpKernelContext* context) const {
|
||||
return BroadcastTwo<T, T>(
|
||||
|
|
|
|||
|
|
@ -806,7 +806,7 @@ TEST(MathOpTest, Less_Scalar1) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Greater) {
|
||||
TEST(MathOpTest, Greater_7) {
|
||||
OpTester test("Greater");
|
||||
std::vector<int64_t> dims{4};
|
||||
test.AddInput<float>("A", dims, {1.0f, 0.0f, -1.0f, -1.0f});
|
||||
|
|
@ -815,6 +815,36 @@ TEST(MathOpTest, Greater) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST( MathOpTest, Greater_9_float )
|
||||
{
|
||||
OpTester test( "Greater", 9 );
|
||||
std::vector<int64_t> dims { 4 };
|
||||
test.AddInput<float>( "A", dims, { 1.0f, 0.0f, -1.0f, -1.0f } );
|
||||
test.AddInput<float>( "B", dims, { 1.0f, 1.0f, 2.0f, -1.0f } );
|
||||
test.AddOutput<bool>( "C", dims, { false, false, false, false } );
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST( MathOpTest, Greater_9_int32 )
|
||||
{
|
||||
OpTester test( "Greater", 9 );
|
||||
std::vector<int64_t> dims { 4 };
|
||||
test.AddInput<int32_t>( "A", dims, { 10, 11, 12, 13 } );
|
||||
test.AddInput<int32_t>( "B", dims, { 15, 7, 12, 9 } );
|
||||
test.AddOutput<bool>( "C", dims, { false, true, false, true } );
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST( MathOpTest, Greater_9_int64 )
|
||||
{
|
||||
OpTester test( "Greater", 9 );
|
||||
std::vector<int64_t> dims { 4 };
|
||||
test.AddInput<int64_t>( "A", dims, { 10, 11, 12, 13 } );
|
||||
test.AddInput<int64_t>( "B", dims, { 15, 7, 12, 9 } );
|
||||
test.AddOutput<bool>( "C", dims, { false, true, false, true } );
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Equal_bool) {
|
||||
OpTester test("Equal");
|
||||
std::vector<int64_t> dims{4};
|
||||
|
|
|
|||
Loading…
Reference in a new issue