diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 66b626ec06..595384e0f5 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -234,6 +234,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MaxUnpool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sinh); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Cosh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Asinh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Acosh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Atanh); void RegisterOnnxOperatorKernels(std::function fn) { fn(BuildKernel()); @@ -461,6 +464,9 @@ void RegisterOnnxOperatorKernels(std::function fn) { fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); } // Forward declarations of ml op kernels diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 1f8324fbef..5ebbd9512d 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -828,6 +828,102 @@ ONNX_CPU_OPERATOR_KERNEL( KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Cosh); +template +class Asinh final : public OpKernel { + public: + explicit Asinh(const OpKernelInfo& info) : OpKernel(info) { + } + + Status Compute(OpKernelContext* context) const override { + auto& X = *context->Input(0); + auto& Y = *context->Output(0, X.Shape()); + + auto X_data = X.template Data(); + auto Y_data = Y.template MutableData(); + + auto in = gsl::make_span(X_data, X.Shape().Size()); + auto out = gsl::make_span(Y_data, Y.Shape().Size()); + + for (int64_t index = 0; index < in.size(); ++index) { + out[index] = std::asinh(in[index]); + } + return Status::OK(); + } + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Asinh); +}; + +ONNX_CPU_OPERATOR_KERNEL( + Asinh, + 9, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Asinh); + +template +class Acosh final : public OpKernel { + public: + explicit Acosh(const OpKernelInfo& info) : OpKernel(info) { + } + + Status Compute(OpKernelContext* context) const override { + auto& X = *context->Input(0); + auto& Y = *context->Output(0, X.Shape()); + + auto X_data = X.template Data(); + auto Y_data = Y.template MutableData(); + + auto in = gsl::make_span(X_data, X.Shape().Size()); + auto out = gsl::make_span(Y_data, Y.Shape().Size()); + + for (int64_t index = 0; index < in.size(); ++index) { + out[index] = std::acosh(in[index]); + } + return Status::OK(); + } + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Acosh); +}; + +ONNX_CPU_OPERATOR_KERNEL( + Acosh, + 9, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Acosh); + +template +class Atanh final : public OpKernel { + public: + explicit Atanh(const OpKernelInfo& info) : OpKernel(info) { + } + + Status Compute(OpKernelContext* context) const override { + auto& X = *context->Input(0); + auto& Y = *context->Output(0, X.Shape()); + + auto X_data = X.template Data(); + auto Y_data = Y.template MutableData(); + + auto in = gsl::make_span(X_data, X.Shape().Size()); + auto out = gsl::make_span(Y_data, Y.Shape().Size()); + + for (int64_t index = 0; index < in.size(); ++index) { + out[index] = std::atanh(in[index]); + } + return Status::OK(); + } + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Atanh); +}; + +ONNX_CPU_OPERATOR_KERNEL( + Atanh, + 9, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Atanh); + template <> Status PRelu::Compute(OpKernelContext* context) const { return BroadcastTwo( diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index fac6f4a823..89c99cdb08 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -305,9 +305,6 @@ int real_main(int argc, char* argv[]) { {"prelu_broadcast", "disable reason"}, {"prelu_example", "disable reason"}, {"upsample_nearest", "opset 9 not supported yet"}, - {"asinh", "opset 9 not supported yet"}, - {"acosh", "opset 9 not supported yet"}, - {"atanh", "opset 9 not supported yet"}, {"sinh_example", "opset 9 not supported yet"}, {"cosh_example", "opset 9 not supported yet"}, {"asinh_example", "opset 9 not supported yet"}, diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index fa1b369b2d..90650e6df4 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -829,6 +829,21 @@ TEST(MathOpTest, Cosh) { TrigTest(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f}); } +TEST(MathOpTest, Asinh) { + OpTester test("Asinh", 9); + TrigTest(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f}); +} + +TEST(MathOpTest, Acosh) { + OpTester test("Acosh", 9); + TrigTest(test, {1.0f, 1.1f, 3.0f, 10.0f, 100.0f}); +} + +TEST(MathOpTest, Atanh) { + OpTester test("Atanh", 9); + TrigTest(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f}); +} + TEST(MathOpTest, Expand_8_3x3) { OpTester test("Expand", 8); test.AddInput("data_0", {1}, {1.0f}); diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index 70f5e99fad..992a56bcbb 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -16,11 +16,9 @@ backend_test = onnx.backend.test.BackendTest(c2, __name__) # Type not supported backend_test.exclude(r'(FLOAT16)') -backend_test.exclude(r'(test_acosh_cpu.*' +backend_test.exclude(r'(' '|test_acosh_example_cpu.*' -'|test_asinh_cpu.*' '|test_asinh_example_cpu.*' -'|test_atanh_cpu.*' '|test_atanh_example_cpu.*' '|test_convtranspose_1d_cpu.*' '|test_convtranspose_3d_cpu.*'