mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-10 00:38:54 +00:00
Askhade/implement erf (#137)
* erf implementation for op9 * enable erf node tests + review comment fixes * update CMAKE flag * plus erf to execution provider
This commit is contained in:
parent
7d79bfef71
commit
b054646ddd
6 changed files with 38 additions and 1 deletions
|
|
@ -285,6 +285,7 @@ set_target_properties(onnx_proto PROPERTIES FOLDER "External/ONNX")
|
|||
# fix a warning in onnx code we can't do anything about
|
||||
if (MSVC)
|
||||
target_compile_options(onnx_proto PRIVATE /wd4146) # unary minus operator applied to unsigned type
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DEIGEN_HAS_C99_MATH") # required to be set explicitly to enable Eigen-Unsupported SpecialFunctions
|
||||
endif()
|
||||
set(onnxruntime_EXTERNAL_DEPENDENCIES gsl onnx_proto)
|
||||
|
||||
|
|
|
|||
|
|
@ -195,6 +195,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Con
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, IsNaN);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MLFloat16, IsNaN);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Erf);
|
||||
|
||||
void RegisterOnnxOperatorKernels(std::function<void(KernelCreateInfo&&)> fn) {
|
||||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Clip)>());
|
||||
|
|
@ -382,6 +383,7 @@ void RegisterOnnxOperatorKernels(std::function<void(KernelCreateInfo&&)> fn) {
|
|||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, IsNaN)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MLFloat16, IsNaN)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Erf)>());
|
||||
}
|
||||
|
||||
// Forward declarations of ml op kernels
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cpu/math/element_wise_ops.h"
|
||||
#include <unsupported/Eigen/SpecialFunctions>
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -311,6 +312,12 @@ ONNX_CPU_OPERATOR_KERNEL(
|
|||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Scale<float>);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Erf,
|
||||
9,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Erf<float>);
|
||||
|
||||
template <typename T>
|
||||
Status Add<T>::Compute(OpKernelContext* context) const {
|
||||
return BroadcastTwo<T, T>(
|
||||
|
|
@ -874,4 +881,16 @@ Status Scale<float>::Compute(OpKernelContext* ctx) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template <>
|
||||
Status Erf<float>::Compute(OpKernelContext* context) const {
|
||||
auto X_ptr = context->Input<Tensor>(0);
|
||||
ONNXRUNTIME_ENFORCE(X_ptr != nullptr);
|
||||
auto& X = *X_ptr;
|
||||
auto& Y = *context->Output(0, X.Shape());
|
||||
|
||||
EigenMap<float>(Y) = EigenMap<float>(X).array().erf();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -317,6 +317,15 @@ class Scale final : public OpKernel {
|
|||
float scale_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Erf final : public OpKernel {
|
||||
public:
|
||||
Erf(const OpKernelInfo& info) : OpKernel(info) {
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
auto MakeEigenArrayMap(Tensor& t) { return EigenVectorArrayMap<T>(t.template MutableData<T>(), t.Shape().Size()); }
|
||||
template <typename T>
|
||||
|
|
|
|||
|
|
@ -326,7 +326,6 @@ int real_main(int argc, char* argv[]) {
|
|||
{"acosh_example", "opset 9 not supported yet"},
|
||||
{"atanh_example", "opset 9 not supported yet"},
|
||||
{"sign_model", "opset 9 not supported yet"},
|
||||
{"erf", "opset 9 not supported yet"},
|
||||
{"sign", "opset 9 not supported yet"},
|
||||
{"scatter_with_axis", "opset 9 not supported yet"},
|
||||
{"scatter_without_axis", "opset 9 not supported yet"},
|
||||
|
|
|
|||
|
|
@ -869,6 +869,13 @@ TEST(MathOpTest, Scale_Default) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Erf) {
|
||||
OpTester test("Erf", 9);
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
test.AddInput<float>("A", dims, {0.5f, 1.0f, 0.7f, 2.0f});
|
||||
test.AddOutput<float>("B", dims, {0.5204999f, 0.8427008f, 0.6778012f, 0.9953223f});
|
||||
test.Run();
|
||||
}
|
||||
} // namespace test
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue