Askhade/implement erf (#137)

* erf implementation for op9

* enable erf node tests + review comment fixes

* update CMAKE flag

* plus erf to execution provider
This commit is contained in:
Ashwini Khade 2018-12-10 17:26:01 -08:00 committed by GitHub
parent 7d79bfef71
commit b054646ddd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 38 additions and 1 deletions

View file

@ -285,6 +285,7 @@ set_target_properties(onnx_proto PROPERTIES FOLDER "External/ONNX")
# fix a warning in onnx code we can't do anything about
if (MSVC)
target_compile_options(onnx_proto PRIVATE /wd4146) # unary minus operator applied to unsigned type
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DEIGEN_HAS_C99_MATH") # required to be set explicitly to enable Eigen-Unsupported SpecialFunctions
endif()
set(onnxruntime_EXTERNAL_DEPENDENCIES gsl onnx_proto)

View file

@ -195,6 +195,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Con
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MLFloat16, IsNaN);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Erf);
void RegisterOnnxOperatorKernels(std::function<void(KernelCreateInfo&&)> fn) {
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Clip)>());
@ -382,6 +383,7 @@ void RegisterOnnxOperatorKernels(std::function<void(KernelCreateInfo&&)> fn) {
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, IsNaN)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MLFloat16, IsNaN)>());
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Erf)>());
}
// Forward declarations of ml op kernels

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "core/providers/cpu/math/element_wise_ops.h"
#include <unsupported/Eigen/SpecialFunctions>
namespace onnxruntime {
@ -311,6 +312,12 @@ ONNX_CPU_OPERATOR_KERNEL(
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Scale<float>);
ONNX_CPU_OPERATOR_KERNEL(
Erf,
9,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Erf<float>);
template <typename T>
Status Add<T>::Compute(OpKernelContext* context) const {
return BroadcastTwo<T, T>(
@ -874,4 +881,16 @@ Status Scale<float>::Compute(OpKernelContext* ctx) const {
return Status::OK();
}
template <>
Status Erf<float>::Compute(OpKernelContext* context) const {
auto X_ptr = context->Input<Tensor>(0);
ONNXRUNTIME_ENFORCE(X_ptr != nullptr);
auto& X = *X_ptr;
auto& Y = *context->Output(0, X.Shape());
EigenMap<float>(Y) = EigenMap<float>(X).array().erf();
return Status::OK();
}
} // namespace onnxruntime

View file

@ -317,6 +317,15 @@ class Scale final : public OpKernel {
float scale_;
};
template <typename T>
class Erf final : public OpKernel {
public:
Erf(const OpKernelInfo& info) : OpKernel(info) {
}
Status Compute(OpKernelContext* context) const override;
};
template <typename T>
auto MakeEigenArrayMap(Tensor& t) { return EigenVectorArrayMap<T>(t.template MutableData<T>(), t.Shape().Size()); }
template <typename T>

View file

@ -326,7 +326,6 @@ int real_main(int argc, char* argv[]) {
{"acosh_example", "opset 9 not supported yet"},
{"atanh_example", "opset 9 not supported yet"},
{"sign_model", "opset 9 not supported yet"},
{"erf", "opset 9 not supported yet"},
{"sign", "opset 9 not supported yet"},
{"scatter_with_axis", "opset 9 not supported yet"},
{"scatter_without_axis", "opset 9 not supported yet"},

View file

@ -869,6 +869,13 @@ TEST(MathOpTest, Scale_Default) {
test.Run();
}
TEST(MathOpTest, Erf) {
OpTester test("Erf", 9);
std::vector<int64_t> dims{2, 2};
test.AddInput<float>("A", dims, {0.5f, 1.0f, 0.7f, 2.0f});
test.AddOutput<float>("B", dims, {0.5204999f, 0.8427008f, 0.6778012f, 0.9953223f});
test.Run();
}
} // namespace test
} // namespace onnxruntime