mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Support a CPU kernel for Celu (#6995)
This commit is contained in:
parent
d0cca35308
commit
27ac88201a
4 changed files with 35 additions and 3 deletions
|
|
@ -56,6 +56,7 @@ REGISTER_UNARY_ELEMENTWISE_KERNEL(Softplus, 1);
|
|||
REGISTER_UNARY_ELEMENTWISE_KERNEL(Softsign, 1);
|
||||
REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL(Tanh, 6, 12, float);
|
||||
REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL(Tanh, 6, 12, double);
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL(Celu, 12);
|
||||
REGISTER_UNARY_ELEMENTWISE_TYPED_KERNEL(Tanh, 13, float);
|
||||
REGISTER_UNARY_ELEMENTWISE_TYPED_KERNEL(Tanh, 13, double);
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL(ThresholdedRelu, 10);
|
||||
|
|
@ -64,6 +65,7 @@ namespace functors {
|
|||
template <typename T>
|
||||
Status ElementWiseRangedTransform<T>::Create(const std::string& type, const NodeAttributes& attributes,
|
||||
std::unique_ptr<ElementWiseRangedTransform<T>>& out) {
|
||||
CREATE_ELE_KERNEL(Celu);
|
||||
CREATE_ELE_KERNEL(Elu);
|
||||
CREATE_ELE_KERNEL(HardSigmoid);
|
||||
CREATE_ELE_KERNEL(LeakyRelu);
|
||||
|
|
|
|||
|
|
@ -13,6 +13,23 @@ namespace onnxruntime {
|
|||
|
||||
namespace functors {
|
||||
|
||||
template <typename T>
|
||||
struct Celu : public ElementWiseRangedTransform<T> {
|
||||
ORT_GET_FLOAT_ATTR_AND_RETURN(alpha);
|
||||
|
||||
float Cost() const final {
|
||||
// TODO: Tune the cost
|
||||
return 1.0f;
|
||||
}
|
||||
void operator()(std::ptrdiff_t first, std::ptrdiff_t last) const final {
|
||||
ptrdiff_t len = last - first;
|
||||
T* output_ptr = this->output + first;
|
||||
ConstEigenVectorArrayMap<T> xm(this->input + first, len);
|
||||
EigenVectorArrayMap<T> ym(output_ptr, len);
|
||||
ym = xm.cwiseMax(0.0f) + (((T)alpha * ((xm / (T)alpha).exp() - 1)).cwiseMin(0.0f));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Elu : public ElementWiseRangedTransform<T> {
|
||||
ORT_GET_FLOAT_ATTR_AND_RETURN(alpha);
|
||||
|
|
@ -89,9 +106,9 @@ struct Relu : public ElementWiseRangedTransform<T> {
|
|||
Status Init(const onnxruntime::NodeAttributes&) {
|
||||
return Status::OK();
|
||||
}
|
||||
ElementWiseRangedTransform<T>* Copy() const { // replace it with a macro. why this?
|
||||
ElementWiseRangedTransform<T>* Copy() const { // replace it with a macro. why this?
|
||||
using T1 = typename std::remove_pointer<decltype(this)>::type;
|
||||
using T2 = typename std::remove_const<T1>::type; //redundant?
|
||||
using T2 = typename std::remove_const<T1>::type; //redundant?
|
||||
return new T2(*this);
|
||||
}
|
||||
float Cost() const final {
|
||||
|
|
@ -212,6 +229,7 @@ struct Selu : public ElementWiseRangedTransform<T> {
|
|||
|
||||
} // namespace functors
|
||||
|
||||
DEFINE_ELE_KERNEL(Celu);
|
||||
DEFINE_ELE_KERNEL(Elu);
|
||||
DEFINE_ELE_KERNEL(HardSigmoid);
|
||||
DEFINE_ELE_KERNEL(LeakyRelu);
|
||||
|
|
|
|||
|
|
@ -465,6 +465,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn
|
|||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, float_double, Dropout);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, double_float, Dropout);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, double_double, Dropout);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, Celu);
|
||||
|
||||
// opset 13
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Erf);
|
||||
|
|
@ -1093,7 +1094,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10,
|
||||
Flatten)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10,
|
||||
float, Gemm)>,
|
||||
float, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10,
|
||||
double, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float,
|
||||
|
|
@ -1427,6 +1428,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, float_double, Dropout)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, double_float, Dropout)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, double_double, Dropout)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, Celu)>,
|
||||
|
||||
// opset 13
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Cast)>,
|
||||
|
|
|
|||
|
|
@ -61,6 +61,16 @@ TEST_F(ActivationOpTest, Elu) {
|
|||
{{"alpha", alpha}});
|
||||
}
|
||||
|
||||
TEST_F(ActivationOpTest, Celu) {
|
||||
float alpha = -0.5f;
|
||||
TestActivationOp<float>(
|
||||
"Celu",
|
||||
input_values,
|
||||
// TODO: Investigate why gcc 4 fails to compile without the explicit cast
|
||||
[alpha](float x) { return std::max(0.0f, x) + std::min(0.0f, alpha * (static_cast<float>(exp(x / alpha)) - 1)); },
|
||||
// Disable on TensorRT as it seems like it doesn't yet support Celu
|
||||
{{"alpha", alpha}}, false, 12);
|
||||
}
|
||||
TEST_F(ActivationOpTest, LeakyRelu) {
|
||||
float alpha = 0.1f;
|
||||
TestActivationOp<float>("LeakyRelu",
|
||||
|
|
|
|||
Loading…
Reference in a new issue