mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
Hyperbolic inv ops (#272)
Implement Inverse for hyberbolic ops Eigen will add support for asinh, acosh and atanh in the upcoming release. But until then for completeness of opset9 we have std based implementation.
This commit is contained in:
parent
d0fa974976
commit
8fba324678
5 changed files with 118 additions and 6 deletions
|
|
@ -234,6 +234,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MaxUnpool);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sinh);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Cosh);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Asinh);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Acosh);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Atanh);
|
||||
|
||||
void RegisterOnnxOperatorKernels(std::function<void(KernelCreateInfo&&)> fn) {
|
||||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Clip)>());
|
||||
|
|
@ -461,6 +464,9 @@ void RegisterOnnxOperatorKernels(std::function<void(KernelCreateInfo&&)> fn) {
|
|||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MaxUnpool)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sinh)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Cosh)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Asinh)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Acosh)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Atanh)>());
|
||||
}
|
||||
|
||||
// Forward declarations of ml op kernels
|
||||
|
|
|
|||
|
|
@ -828,6 +828,102 @@ ONNX_CPU_OPERATOR_KERNEL(
|
|||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Cosh<float>);
|
||||
|
||||
template <typename T>
|
||||
class Asinh final : public OpKernel {
|
||||
public:
|
||||
explicit Asinh(const OpKernelInfo& info) : OpKernel(info) {
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override {
|
||||
auto& X = *context->Input<Tensor>(0);
|
||||
auto& Y = *context->Output(0, X.Shape());
|
||||
|
||||
auto X_data = X.template Data<float>();
|
||||
auto Y_data = Y.template MutableData<float>();
|
||||
|
||||
auto in = gsl::make_span(X_data, X.Shape().Size());
|
||||
auto out = gsl::make_span(Y_data, Y.Shape().Size());
|
||||
|
||||
for (int64_t index = 0; index < in.size(); ++index) {
|
||||
out[index] = std::asinh(in[index]);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Asinh);
|
||||
};
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Asinh,
|
||||
9,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Asinh<float>);
|
||||
|
||||
template <typename T>
|
||||
class Acosh final : public OpKernel {
|
||||
public:
|
||||
explicit Acosh(const OpKernelInfo& info) : OpKernel(info) {
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override {
|
||||
auto& X = *context->Input<Tensor>(0);
|
||||
auto& Y = *context->Output(0, X.Shape());
|
||||
|
||||
auto X_data = X.template Data<float>();
|
||||
auto Y_data = Y.template MutableData<float>();
|
||||
|
||||
auto in = gsl::make_span(X_data, X.Shape().Size());
|
||||
auto out = gsl::make_span(Y_data, Y.Shape().Size());
|
||||
|
||||
for (int64_t index = 0; index < in.size(); ++index) {
|
||||
out[index] = std::acosh(in[index]);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Acosh);
|
||||
};
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Acosh,
|
||||
9,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Acosh<float>);
|
||||
|
||||
template <typename T>
|
||||
class Atanh final : public OpKernel {
|
||||
public:
|
||||
explicit Atanh(const OpKernelInfo& info) : OpKernel(info) {
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override {
|
||||
auto& X = *context->Input<Tensor>(0);
|
||||
auto& Y = *context->Output(0, X.Shape());
|
||||
|
||||
auto X_data = X.template Data<float>();
|
||||
auto Y_data = Y.template MutableData<float>();
|
||||
|
||||
auto in = gsl::make_span(X_data, X.Shape().Size());
|
||||
auto out = gsl::make_span(Y_data, Y.Shape().Size());
|
||||
|
||||
for (int64_t index = 0; index < in.size(); ++index) {
|
||||
out[index] = std::atanh(in[index]);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Atanh);
|
||||
};
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Atanh,
|
||||
9,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Atanh<float>);
|
||||
|
||||
template <>
|
||||
Status PRelu<float>::Compute(OpKernelContext* context) const {
|
||||
return BroadcastTwo<float, float>(
|
||||
|
|
|
|||
|
|
@ -305,9 +305,6 @@ int real_main(int argc, char* argv[]) {
|
|||
{"prelu_broadcast", "disable reason"},
|
||||
{"prelu_example", "disable reason"},
|
||||
{"upsample_nearest", "opset 9 not supported yet"},
|
||||
{"asinh", "opset 9 not supported yet"},
|
||||
{"acosh", "opset 9 not supported yet"},
|
||||
{"atanh", "opset 9 not supported yet"},
|
||||
{"sinh_example", "opset 9 not supported yet"},
|
||||
{"cosh_example", "opset 9 not supported yet"},
|
||||
{"asinh_example", "opset 9 not supported yet"},
|
||||
|
|
|
|||
|
|
@ -829,6 +829,21 @@ TEST(MathOpTest, Cosh) {
|
|||
TrigTest<std::cosh>(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Asinh) {
|
||||
OpTester test("Asinh", 9);
|
||||
TrigTest<std::asinh>(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Acosh) {
|
||||
OpTester test("Acosh", 9);
|
||||
TrigTest<std::acosh>(test, {1.0f, 1.1f, 3.0f, 10.0f, 100.0f});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Atanh) {
|
||||
OpTester test("Atanh", 9);
|
||||
TrigTest<std::atanh>(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Expand_8_3x3) {
|
||||
OpTester test("Expand", 8);
|
||||
test.AddInput<float>("data_0", {1}, {1.0f});
|
||||
|
|
|
|||
|
|
@ -16,11 +16,9 @@ backend_test = onnx.backend.test.BackendTest(c2, __name__)
|
|||
# Type not supported
|
||||
backend_test.exclude(r'(FLOAT16)')
|
||||
|
||||
backend_test.exclude(r'(test_acosh_cpu.*'
|
||||
backend_test.exclude(r'('
|
||||
'|test_acosh_example_cpu.*'
|
||||
'|test_asinh_cpu.*'
|
||||
'|test_asinh_example_cpu.*'
|
||||
'|test_atanh_cpu.*'
|
||||
'|test_atanh_example_cpu.*'
|
||||
'|test_convtranspose_1d_cpu.*'
|
||||
'|test_convtranspose_3d_cpu.*'
|
||||
|
|
|
|||
Loading…
Reference in a new issue