mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
Support double for operators Relu, Tanh, Sigmoid (#6221)
This commit is contained in:
parent
111ac299cc
commit
df7e2f3c1e
5 changed files with 116 additions and 75 deletions
|
|
@ -35,15 +35,26 @@ namespace onnxruntime {
|
|||
alias, newVersion, \
|
||||
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()), x<float>);
|
||||
|
||||
#define REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL_ALIAS(alias, x, sinceVersion, firstEnd, newVersion, type) \
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \
|
||||
alias, sinceVersion, firstEnd, type, \
|
||||
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<type>()), x<type>); \
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
|
||||
alias, newVersion, type, \
|
||||
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<type>()), x<type>);
|
||||
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL(Elu, 6);
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL(HardSigmoid, 6);
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL(LeakyRelu, 6);
|
||||
REGISTER_VERSIONED_UNARY_ELEMENTWISE_KERNEL_ALIAS(Relu, Relu, 6, 12, 13);
|
||||
REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL_ALIAS(Relu, Relu, 6, 12, 13, float);
|
||||
REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL_ALIAS(Relu, Relu, 6, 12, 13, double);
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL(Selu, 6);
|
||||
REGISTER_VERSIONED_UNARY_ELEMENTWISE_KERNEL_ALIAS(Sigmoid, Sigmoid, 6, 12, 13);
|
||||
REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL_ALIAS(Sigmoid, Sigmoid, 6, 12, 13, float);
|
||||
REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL_ALIAS(Sigmoid, Sigmoid, 6, 12, 13, double);
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL(Softplus, 1);
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL(Softsign, 1);
|
||||
REGISTER_VERSIONED_UNARY_ELEMENTWISE_KERNEL_ALIAS(Tanh, Tanh, 6, 12, 13);
|
||||
REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL_ALIAS(Tanh, Tanh, 6, 12, 13, float);
|
||||
REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL_ALIAS(Tanh, Tanh, 6, 12, 13, double);
|
||||
REGISTER_UNARY_ELEMENTWISE_KERNEL(ThresholdedRelu, 10);
|
||||
|
||||
namespace functors {
|
||||
|
|
|
|||
|
|
@ -33,14 +33,15 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Elu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, HardSigmoid);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, LeakyRelu);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, Relu);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, float, Relu);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, double, Relu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Selu);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, Sigmoid);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Sigmoid);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, float, Sigmoid);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, double, Sigmoid);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Softplus);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Softsign);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, Tanh);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Tanh);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, float, Tanh);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, double, Tanh);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, PRelu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, RandomNormal);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, RandomUniform);
|
||||
|
|
@ -548,7 +549,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Ceil);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Sqrt);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Sqrt);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Relu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Relu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Relu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Sigmoid);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Sigmoid);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Tanh);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Tanh);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Exp);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Exp);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Log);
|
||||
|
|
@ -666,13 +672,15 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Elu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, HardSigmoid)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, LeakyRelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, Relu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, float, Relu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, double, Relu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Selu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, Sigmoid)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, float, Sigmoid)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, double, Sigmoid)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Softplus)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Softsign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, Tanh)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Tanh)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, float, Tanh)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, double, Tanh)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8,
|
||||
PRelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, RandomNormal)>,
|
||||
|
|
@ -1382,7 +1390,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Size)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Sum)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Sum)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Sigmoid)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Sigmoid)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Sigmoid)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint8_t,
|
||||
DequantizeLinear)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int8_t,
|
||||
|
|
@ -1499,7 +1508,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
Ceil)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Sqrt)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Sqrt)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Relu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Relu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Relu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Tanh)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Tanh)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Exp)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Exp)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Log)>,
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ namespace test{
|
|||
|
||||
TEST_F(ActivationOpTest, ThresholdedRelu_version_1_to_9) {
|
||||
float alpha = 0.1f;
|
||||
TestActivationOp(
|
||||
TestActivationOp<float>(
|
||||
"ThresholdedRelu", input_values, [alpha](float x) { return (x >= alpha) ? x : 0; }, {{"alpha", alpha}}, true, 1);
|
||||
}
|
||||
|
||||
|
|
@ -23,7 +23,7 @@ TEST_F(ActivationOpTest, ScaledTanh) {
|
|||
static constexpr float alpha = 2.0f;
|
||||
static constexpr float beta = 1.5f;
|
||||
|
||||
TestActivationOp("ScaledTanh", input_values, [](float x) { return alpha * tanh(beta * x); },
|
||||
TestActivationOp<float>("ScaledTanh", input_values, [](float x) { return alpha * tanh(beta * x); },
|
||||
{{"alpha", alpha}, {"beta", beta}});
|
||||
}
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ TEST_F(ActivationOpTest, ParametricSoftplus) {
|
|||
static constexpr float alpha = 2.0f;
|
||||
static constexpr float beta = 1.5f;
|
||||
|
||||
TestActivationOp("ParametricSoftplus", input_values,
|
||||
TestActivationOp<float>("ParametricSoftplus", input_values,
|
||||
[](float x) {
|
||||
float bx = beta * x;
|
||||
if (bx > 0)
|
||||
|
|
@ -43,7 +43,7 @@ TEST_F(ActivationOpTest, ParametricSoftplus) {
|
|||
}
|
||||
|
||||
TEST_F(ActivationOpTest, Gelu) {
|
||||
TestActivationOp(
|
||||
TestActivationOp<float>(
|
||||
"Gelu", input_values, [](float x) { return x * 0.5f * (1.0f + std::erf(x * static_cast<float>(M_SQRT1_2))); }, {},
|
||||
false, 1, kMSDomain);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,57 +8,70 @@ namespace onnxruntime {
|
|||
namespace test {
|
||||
|
||||
TEST_F(ActivationOpTest, Sigmoid) {
|
||||
TestActivationOp("Sigmoid",
|
||||
input_values,
|
||||
[](float x) {
|
||||
auto y = 1.f / (1.f + std::exp(-std::abs(x))); // safe sigmoid
|
||||
y = x > 0 ? y : 1 - y;
|
||||
return y;
|
||||
});
|
||||
TestActivationOp<float>("Sigmoid",
|
||||
input_values,
|
||||
[](float x) {
|
||||
auto y = 1.f / (1.f + std::exp(-std::abs(x))); // safe sigmoid
|
||||
y = x > 0 ? y : 1 - y;
|
||||
return y;
|
||||
});
|
||||
TestActivationOp<double>("Sigmoid",
|
||||
input_values_double,
|
||||
[](double x) {
|
||||
auto y = 1. / (1. + std::exp(-std::abs(x))); // safe sigmoid
|
||||
y = x > 0 ? y : 1 - y;
|
||||
return y;
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(ActivationOpTest, HardSigmoid) {
|
||||
float alpha = 0.2f;
|
||||
float beta = 0.5f;
|
||||
TestActivationOp("HardSigmoid",
|
||||
input_values,
|
||||
[alpha, beta](float x) {
|
||||
return std::max(std::min((alpha * x + beta), 1.0f), 0.0f);
|
||||
},
|
||||
{{"alpha", alpha}, {"beta", beta}});
|
||||
TestActivationOp<float>("HardSigmoid",
|
||||
input_values,
|
||||
[alpha, beta](float x) {
|
||||
return std::max(std::min((alpha * x + beta), 1.0f), 0.0f);
|
||||
},
|
||||
{{"alpha", alpha}, {"beta", beta}});
|
||||
}
|
||||
|
||||
TEST_F(ActivationOpTest, Tanh) {
|
||||
TestActivationOp("Tanh",
|
||||
input_values,
|
||||
[](float x) { return std::tanh(x); });
|
||||
TestActivationOp<float>("Tanh",
|
||||
input_values,
|
||||
[](float x) { return std::tanh(x); });
|
||||
TestActivationOp<double>("Tanh",
|
||||
input_values_double,
|
||||
[](double x) { return std::tanh(x); });
|
||||
}
|
||||
|
||||
TEST_F(ActivationOpTest, Relu) {
|
||||
TestActivationOp("Relu",
|
||||
input_values,
|
||||
[](float x) { return std::max(x, 0.0f); });
|
||||
TestActivationOp<float>("Relu",
|
||||
input_values,
|
||||
[](float x) { return std::max(x, 0.0f); });
|
||||
TestActivationOp<double>("Relu",
|
||||
input_values_double,
|
||||
[](double x) { return std::max(x, 0.0); });
|
||||
}
|
||||
|
||||
TEST_F(ActivationOpTest, Elu) {
|
||||
float alpha = 0.1f;
|
||||
TestActivationOp("Elu",
|
||||
input_values,
|
||||
[alpha](float x) { return (x >= 0) ? x : alpha * (exp(x) - 1); },
|
||||
{{"alpha", alpha}});
|
||||
TestActivationOp<float>("Elu",
|
||||
input_values,
|
||||
[alpha](float x) { return (x >= 0) ? x : alpha * (exp(x) - 1); },
|
||||
{{"alpha", alpha}});
|
||||
}
|
||||
|
||||
TEST_F(ActivationOpTest, LeakyRelu) {
|
||||
float alpha = 0.1f;
|
||||
TestActivationOp("LeakyRelu",
|
||||
input_values,
|
||||
[alpha](float x) { return (x >= 0) ? x : alpha * x; },
|
||||
{{"alpha", alpha}});
|
||||
TestActivationOp<float>("LeakyRelu",
|
||||
input_values,
|
||||
[alpha](float x) { return (x >= 0) ? x : alpha * x; },
|
||||
{{"alpha", alpha}});
|
||||
}
|
||||
|
||||
TEST_F(ActivationOpTest, ThresholdedRelu) {
|
||||
float alpha = 0.1f;
|
||||
TestActivationOp(
|
||||
TestActivationOp<float>(
|
||||
"ThresholdedRelu",
|
||||
input_values,
|
||||
[alpha](float x) { return (x >= alpha) ? x : 0; },
|
||||
|
|
@ -69,20 +82,20 @@ TEST_F(ActivationOpTest, Selu) {
|
|||
static constexpr float alpha = 1.6732f;
|
||||
static constexpr float gamma = 1.0507f;
|
||||
|
||||
TestActivationOp("Selu",
|
||||
input_values,
|
||||
[](float x) { return x <= 0 ? gamma * (alpha * exp(x) - alpha) : gamma * x; },
|
||||
{{"alpha", alpha}, {"gamma", gamma}});
|
||||
TestActivationOp<float>("Selu",
|
||||
input_values,
|
||||
[](float x) { return x <= 0 ? gamma * (alpha * exp(x) - alpha) : gamma * x; },
|
||||
{{"alpha", alpha}, {"gamma", gamma}});
|
||||
}
|
||||
|
||||
TEST_F(ActivationOpTest, Selu_Attributes) {
|
||||
static constexpr float alpha = 1.8f;
|
||||
static constexpr float gamma = 0.5f;
|
||||
|
||||
TestActivationOp("Selu",
|
||||
input_values,
|
||||
[](float x) { return x <= 0 ? gamma * (alpha * exp(x) - alpha) : gamma * x; },
|
||||
{{"alpha", alpha}, {"gamma", gamma}});
|
||||
TestActivationOp<float>("Selu",
|
||||
input_values,
|
||||
[](float x) { return x <= 0 ? gamma * (alpha * exp(x) - alpha) : gamma * x; },
|
||||
{{"alpha", alpha}, {"gamma", gamma}});
|
||||
}
|
||||
|
||||
TEST_F(ActivationOpTest, PRelu) {
|
||||
|
|
@ -144,18 +157,18 @@ TEST_F(ActivationOpTest, PRelu_MultiChannel) {
|
|||
}
|
||||
|
||||
TEST_F(ActivationOpTest, Softplus) {
|
||||
TestActivationOp("Softplus",
|
||||
input_values,
|
||||
[](float x) {
|
||||
if (x > 0)
|
||||
return x + logf(expf(-x) + 1);
|
||||
else
|
||||
return logf(expf(x) + 1);
|
||||
});
|
||||
TestActivationOp<float>("Softplus",
|
||||
input_values,
|
||||
[](float x) {
|
||||
if (x > 0)
|
||||
return x + logf(expf(-x) + 1);
|
||||
else
|
||||
return logf(expf(x) + 1);
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(ActivationOpNoInfTest, Softsign) {
|
||||
TestActivationOp(
|
||||
TestActivationOp<float>(
|
||||
"Softsign",
|
||||
input_values,
|
||||
[](float x) {
|
||||
|
|
|
|||
|
|
@ -13,22 +13,23 @@
|
|||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
inline void TestActivationOp(const char* szOp, const std::vector<std::vector<float>>& input_vals_vec,
|
||||
std::function<float(float)> expected_func,
|
||||
template <typename T>
|
||||
inline void TestActivationOp(const char* szOp, const std::vector<std::vector<T>>& input_vals_vec,
|
||||
std::function<T(T)> expected_func,
|
||||
const std::unordered_map<std::string, float> attribs = {},
|
||||
bool is_tensorrt_supported = true, int opset_version = 7,
|
||||
const char* domain = kOnnxDomain) {
|
||||
for (const std::vector<float>& input_vals : input_vals_vec) {
|
||||
for (const std::vector<T>& input_vals : input_vals_vec) {
|
||||
OpTester test(szOp, opset_version, domain);
|
||||
|
||||
for (auto attr : attribs) test.AddAttribute(attr.first, attr.second);
|
||||
for (auto attr : attribs) test.AddAttribute<float>(attr.first, attr.second);
|
||||
std::vector<int64_t> dims{(int64_t)input_vals.size()};
|
||||
|
||||
std::vector<float> expected_vals;
|
||||
std::vector<T> expected_vals;
|
||||
for (const auto& iv : input_vals) expected_vals.push_back(expected_func(iv));
|
||||
|
||||
test.AddInput<float>("X", dims, input_vals);
|
||||
test.AddOutput<float>("Y", dims, expected_vals);
|
||||
test.AddInput<T>("X", dims, input_vals);
|
||||
test.AddOutput<T>("Y", dims, expected_vals);
|
||||
|
||||
// Disable TensorRT on unsupported tests
|
||||
std::unordered_set<std::string> excluded_providers;
|
||||
|
|
@ -74,10 +75,14 @@ inline void TestActivationOp(const char* szOp, const std::vector<std::vector<flo
|
|||
|
||||
class ActivationOpTest : public ::testing::Test {
|
||||
protected:
|
||||
std::vector<std::vector<float>> input_values{{-1.0f, 0, 1.0f, // normal input values for activation
|
||||
100.0f, -100.0f, 1000.0f, -1000.0f, // input values that leads to exp() overflow
|
||||
FLT_MIN, FLT_MIN / 10, -FLT_MIN / 10, // min, denorm, -denorm
|
||||
FLT_MAX, -FLT_MAX, std::numeric_limits<float>::infinity()}}; // max, -max, inf
|
||||
std::vector<std::vector<float>> input_values{{-1.0f, 0, 1.0f, // normal input values for activation
|
||||
100.0f, -100.0f, 1000.0f, -1000.0f, // input values that leads to exp() overflow
|
||||
FLT_MIN, FLT_MIN / 10, -FLT_MIN / 10, // min, denorm, -denorm
|
||||
FLT_MAX, -FLT_MAX, std::numeric_limits<float>::infinity()}}; // max, -max, inf
|
||||
std::vector<std::vector<double>> input_values_double{{-1.0, 0, 1.0, // normal input values for activation
|
||||
100.0, -100.0, 1000.0, -1000.0, // input values that leads to exp() overflow
|
||||
DBL_MIN, DBL_MIN / 10, -DBL_MIN / 10, // min, denorm, -denorm
|
||||
DBL_MAX, -DBL_MAX, std::numeric_limits<double>::infinity()}}; // max, -max, inf
|
||||
|
||||
void SetUp() override {
|
||||
float low = -1.0f, high = 1.0f;
|
||||
|
|
|
|||
Loading…
Reference in a new issue