mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Cleanup macros used to register activations. (#6628)
Registrations need to either be between a start and end version, or be the current version. Having a macro that uses 3 versions will break or lead to misuse when a 4th version is released.
This commit is contained in:
parent
1916e35bea
commit
ce01c3760f
1 changed files with 28 additions and 40 deletions
|
|
@ -20,41 +20,44 @@ namespace onnxruntime {
|
|||
return Status::OK(); \
|
||||
}
|
||||
|
||||
#define REGISTER_UNARY_ELEMENTWISE_KERNEL_ALIAS(alias, x, sinceVersion) \
|
||||
ONNX_CPU_OPERATOR_KERNEL( \
|
||||
alias, sinceVersion, \
|
||||
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()), x<float>);
|
||||
#define REGISTER_VERSIONED_UNARY_ELEMENTWISE_KERNEL(op, since_version, end_version) \
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL( \
|
||||
op, since_version, end_version, \
|
||||
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()), op<float>);
|
||||
|
||||
#define REGISTER_UNARY_ELEMENTWISE_KERNEL(x, sinceVersion) REGISTER_UNARY_ELEMENTWISE_KERNEL_ALIAS(x, x, sinceVersion)
|
||||
#define REGISTER_UNARY_ELEMENTWISE_KERNEL(op, since_version) \
|
||||
ONNX_CPU_OPERATOR_KERNEL( \
|
||||
op, since_version, \
|
||||
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()), op<float>);
|
||||
|
||||
#define REGISTER_VERSIONED_UNARY_ELEMENTWISE_KERNEL_ALIAS(alias, x, sinceVersion, firstEnd, newVersion) \
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL( \
|
||||
alias, sinceVersion, firstEnd, \
|
||||
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()), x<float>); \
|
||||
ONNX_CPU_OPERATOR_KERNEL( \
|
||||
alias, newVersion, \
|
||||
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()), x<float>);
|
||||
#define REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL(op, since_version, end_version, type) \
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \
|
||||
op, since_version, end_version, type, \
|
||||
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<type>()), op<type>);
|
||||
|
||||
#define REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL_ALIAS(alias, x, sinceVersion, firstEnd, newVersion, type) \
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \
|
||||
alias, sinceVersion, firstEnd, type, \
|
||||
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<type>()), x<type>); \
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
|
||||
alias, newVersion, type, \
|
||||
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<type>()), x<type>);
|
||||
#define REGISTER_UNARY_ELEMENTWISE_TYPED_KERNEL(op, since_version, type) \
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
|
||||
op, since_version, type, \
|
||||
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<type>()), op<type>);
|
||||
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL(Elu, 6);
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL(HardSigmoid, 6);
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL(LeakyRelu, 6);
|
||||
REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL_ALIAS(Relu, Relu, 6, 12, 13, float);
|
||||
REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL_ALIAS(Relu, Relu, 6, 12, 13, double);
|
||||
REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL(Relu, 6, 12, float);
|
||||
REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL(Relu, 6, 12, double);
|
||||
REGISTER_UNARY_ELEMENTWISE_TYPED_KERNEL(Relu, 13, float);
|
||||
REGISTER_UNARY_ELEMENTWISE_TYPED_KERNEL(Relu, 13, double);
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL(Selu, 6);
|
||||
REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL_ALIAS(Sigmoid, Sigmoid, 6, 12, 13, float);
|
||||
REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL_ALIAS(Sigmoid, Sigmoid, 6, 12, 13, double);
|
||||
REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL(Sigmoid, 6, 12, float);
|
||||
REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL(Sigmoid, 6, 12, double);
|
||||
REGISTER_UNARY_ELEMENTWISE_TYPED_KERNEL(Sigmoid, 13, float);
|
||||
REGISTER_UNARY_ELEMENTWISE_TYPED_KERNEL(Sigmoid, 13, double);
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL(Softplus, 1);
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL(Softsign, 1);
|
||||
REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL_ALIAS(Tanh, Tanh, 6, 12, 13, float);
|
||||
REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL_ALIAS(Tanh, Tanh, 6, 12, 13, double);
|
||||
REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL(Tanh, 6, 12, float);
|
||||
REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL(Tanh, 6, 12, double);
|
||||
REGISTER_UNARY_ELEMENTWISE_TYPED_KERNEL(Tanh, 13, float);
|
||||
REGISTER_UNARY_ELEMENTWISE_TYPED_KERNEL(Tanh, 13, double);
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL(ThresholdedRelu, 10);
|
||||
|
||||
namespace functors {
|
||||
|
|
@ -82,21 +85,6 @@ template Status ElementWiseRangedTransform<float>::Create(const std::string& typ
|
|||
std::unique_ptr<ElementWiseRangedTransform<float>>& out);
|
||||
} // namespace functors
|
||||
|
||||
#define REGISTER_UNARY_ELEMENTWISE_KERNEL_ALIAS(alias, x, sinceVersion) \
|
||||
ONNX_CPU_OPERATOR_KERNEL( \
|
||||
alias, sinceVersion, \
|
||||
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()), x<float>);
|
||||
|
||||
#define REGISTER_UNARY_ELEMENTWISE_KERNEL(x, sinceVersion) REGISTER_UNARY_ELEMENTWISE_KERNEL_ALIAS(x, x, sinceVersion)
|
||||
|
||||
#define REGISTER_VERSIONED_UNARY_ELEMENTWISE_KERNEL_ALIAS(alias, x, sinceVersion, firstEnd, newVersion) \
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL( \
|
||||
alias, sinceVersion, firstEnd, \
|
||||
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()), x<float>); \
|
||||
ONNX_CPU_OPERATOR_KERNEL( \
|
||||
alias, newVersion, \
|
||||
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()), x<float>);
|
||||
|
||||
namespace functors {
|
||||
template <>
|
||||
void Sigmoid<float>::operator()(std::ptrdiff_t first, std::ptrdiff_t last) const {
|
||||
|
|
|
|||
Loading…
Reference in a new issue