diff --git a/onnxruntime/core/providers/cpu/activation/activations.cc b/onnxruntime/core/providers/cpu/activation/activations.cc index ffea82c2c7..1c2dfb79dc 100644 --- a/onnxruntime/core/providers/cpu/activation/activations.cc +++ b/onnxruntime/core/providers/cpu/activation/activations.cc @@ -56,6 +56,7 @@ REGISTER_UNARY_ELEMENTWISE_KERNEL(Softplus, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL(Softsign, 1); REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL(Tanh, 6, 12, float); REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL(Tanh, 6, 12, double); +REGISTER_UNARY_ELEMENTWISE_KERNEL(Celu, 12); REGISTER_UNARY_ELEMENTWISE_TYPED_KERNEL(Tanh, 13, float); REGISTER_UNARY_ELEMENTWISE_TYPED_KERNEL(Tanh, 13, double); REGISTER_UNARY_ELEMENTWISE_KERNEL(ThresholdedRelu, 10); @@ -64,6 +65,7 @@ namespace functors { template Status ElementWiseRangedTransform::Create(const std::string& type, const NodeAttributes& attributes, std::unique_ptr>& out) { + CREATE_ELE_KERNEL(Celu); CREATE_ELE_KERNEL(Elu); CREATE_ELE_KERNEL(HardSigmoid); CREATE_ELE_KERNEL(LeakyRelu); diff --git a/onnxruntime/core/providers/cpu/activation/activations.h b/onnxruntime/core/providers/cpu/activation/activations.h index 92a7478013..75a73a818d 100644 --- a/onnxruntime/core/providers/cpu/activation/activations.h +++ b/onnxruntime/core/providers/cpu/activation/activations.h @@ -13,6 +13,23 @@ namespace onnxruntime { namespace functors { +template +struct Celu : public ElementWiseRangedTransform { + ORT_GET_FLOAT_ATTR_AND_RETURN(alpha); + + float Cost() const final { + // TODO: Tune the cost + return 1.0f; + } + void operator()(std::ptrdiff_t first, std::ptrdiff_t last) const final { + ptrdiff_t len = last - first; + T* output_ptr = this->output + first; + ConstEigenVectorArrayMap xm(this->input + first, len); + EigenVectorArrayMap ym(output_ptr, len); + ym = xm.cwiseMax(0.0f) + (((T)alpha * ((xm / (T)alpha).exp() - 1)).cwiseMin(0.0f)); + } +}; + template struct Elu : public ElementWiseRangedTransform { ORT_GET_FLOAT_ATTR_AND_RETURN(alpha); @@ -89,9 +106,9 @@ struct Relu : public ElementWiseRangedTransform { Status Init(const onnxruntime::NodeAttributes&) { return Status::OK(); } - ElementWiseRangedTransform* Copy() const { // replace it with a macro. why this? + ElementWiseRangedTransform* Copy() const { // replace it with a macro. why this? using T1 = typename std::remove_pointer::type; - using T2 = typename std::remove_const::type; //redundant? + using T2 = typename std::remove_const::type; //redundant? return new T2(*this); } float Cost() const final { @@ -212,6 +229,7 @@ struct Selu : public ElementWiseRangedTransform { } // namespace functors +DEFINE_ELE_KERNEL(Celu); DEFINE_ELE_KERNEL(Elu); DEFINE_ELE_KERNEL(HardSigmoid); DEFINE_ELE_KERNEL(LeakyRelu); diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index e7ce121a8d..5f4188c763 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -465,6 +465,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, float_double, Dropout); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, double_float, Dropout); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, double_double, Dropout); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, Celu); // opset 13 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Erf); @@ -1093,7 +1094,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + float, Gemm)>, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // opset 13 BuildKernelCreateInfo, diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc index f8b5be5930..104e62e62e 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc @@ -61,6 +61,16 @@ TEST_F(ActivationOpTest, Elu) { {{"alpha", alpha}}); } +TEST_F(ActivationOpTest, Celu) { + float alpha = -0.5f; + TestActivationOp( + "Celu", + input_values, + // TODO: Investigate why gcc 4 fails to compile without the explicit cast + [alpha](float x) { return std::max(0.0f, x) + std::min(0.0f, alpha * (static_cast(exp(x / alpha)) - 1)); }, + // Disable on TensorRT as it seems like it doesn't yet support Celu + {{"alpha", alpha}}, false, 12); +} TEST_F(ActivationOpTest, LeakyRelu) { float alpha = 0.1f; TestActivationOp("LeakyRelu",