From b054646ddd9cccaee90d43dfae7620fa2d3f2c2b Mon Sep 17 00:00:00 2001 From: Ashwini Khade Date: Mon, 10 Dec 2018 17:26:01 -0800 Subject: [PATCH] Askhade/implement erf (#137) * erf implementation for op9 * enable erf node tests + review comment fixes * update CMAKE flag * plus erf to execution provider --- cmake/CMakeLists.txt | 1 + .../providers/cpu/cpu_execution_provider.cc | 2 ++ .../providers/cpu/math/element_wise_ops.cc | 19 +++++++++++++++++++ .../providers/cpu/math/element_wise_ops.h | 9 +++++++++ onnxruntime/test/onnx/main.cc | 1 - .../cpu/math/element_wise_ops_test.cc | 7 +++++++ 6 files changed, 38 insertions(+), 1 deletion(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 5f4c104009..a35671c461 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -285,6 +285,7 @@ set_target_properties(onnx_proto PROPERTIES FOLDER "External/ONNX") # fix a warning in onnx code we can't do anything about if (MSVC) target_compile_options(onnx_proto PRIVATE /wd4146) # unary minus operator applied to unsigned type + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DEIGEN_HAS_C99_MATH") # required to be set explicitly to enable Eigen-Unsupported SpecialFunctions endif() set(onnxruntime_EXTERNAL_DEPENDENCIES gsl onnx_proto) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index f5303df02c..cdcb09e523 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -195,6 +195,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Con class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MLFloat16, IsNaN); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Erf); void RegisterOnnxOperatorKernels(std::function fn) { fn(BuildKernel()); @@ -382,6 +383,7 @@ void RegisterOnnxOperatorKernels(std::function fn) { fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel()); + fn(BuildKernel()); } // Forward declarations of ml op kernels diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 8f48195ea0..ab05ed82be 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/providers/cpu/math/element_wise_ops.h" +#include namespace onnxruntime { @@ -311,6 +312,12 @@ ONNX_CPU_OPERATOR_KERNEL( KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Scale); +ONNX_CPU_OPERATOR_KERNEL( + Erf, + 9, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Erf); + template Status Add::Compute(OpKernelContext* context) const { return BroadcastTwo( @@ -874,4 +881,16 @@ Status Scale::Compute(OpKernelContext* ctx) const { return Status::OK(); } +template <> +Status Erf::Compute(OpKernelContext* context) const { + auto X_ptr = context->Input(0); + ONNXRUNTIME_ENFORCE(X_ptr != nullptr); + auto& X = *X_ptr; + auto& Y = *context->Output(0, X.Shape()); + + EigenMap(Y) = EigenMap(X).array().erf(); + + return Status::OK(); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.h b/onnxruntime/core/providers/cpu/math/element_wise_ops.h index 912d352f39..feaa5c0a82 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.h +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.h @@ -317,6 +317,15 @@ class Scale final : public OpKernel { float scale_; }; +template +class Erf final : public OpKernel { + public: + Erf(const OpKernelInfo& info) : OpKernel(info) { + } + + Status Compute(OpKernelContext* context) const override; +}; + template auto MakeEigenArrayMap(Tensor& t) { return EigenVectorArrayMap(t.template MutableData(), t.Shape().Size()); } template diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 4871e1102c..77f717987e 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -326,7 +326,6 @@ int real_main(int argc, char* argv[]) { {"acosh_example", "opset 9 not supported yet"}, {"atanh_example", "opset 9 not supported yet"}, {"sign_model", "opset 9 not supported yet"}, - {"erf", "opset 9 not supported yet"}, {"sign", "opset 9 not supported yet"}, {"scatter_with_axis", "opset 9 not supported yet"}, {"scatter_without_axis", "opset 9 not supported yet"}, diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index e6eaf6c99e..e884554239 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -869,6 +869,13 @@ TEST(MathOpTest, Scale_Default) { test.Run(); } +TEST(MathOpTest, Erf) { + OpTester test("Erf", 9); + std::vector dims{2, 2}; + test.AddInput("A", dims, {0.5f, 1.0f, 0.7f, 2.0f}); + test.AddOutput("B", dims, {0.5204999f, 0.8427008f, 0.6778012f, 0.9953223f}); + test.Run(); +} } // namespace test } // namespace onnxruntime