mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-26 03:00:54 +00:00
support hyperbolic ops (#223)
* support hyperbolic fns This commit adds support for sinh and cosh. Support for hyperbolic inverses is not available in Eigen yet. * Make constructors explicit * remove tests from exclude list * Revert "remove tests from exclude list" This reverts commit 2112a30b57d5a899991de4847e948e700a44e85d. * remove test names from excluded list * remove tanh since its already implemented
This commit is contained in:
parent
e63572c1f3
commit
84231ba003
5 changed files with 73 additions and 25 deletions
|
|
@ -196,6 +196,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Eye
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, IsNaN);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MLFloat16, IsNaN);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Erf);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sinh);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Cosh);
|
||||
|
||||
void RegisterOnnxOperatorKernels(std::function<void(KernelCreateInfo&&)> fn) {
|
||||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Clip)>());
|
||||
|
|
@ -384,6 +386,8 @@ void RegisterOnnxOperatorKernels(std::function<void(KernelCreateInfo&&)> fn) {
|
|||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, IsNaN)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MLFloat16, IsNaN)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Erf)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sinh)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Cosh)>());
|
||||
}
|
||||
|
||||
// Forward declarations of ml op kernels
|
||||
|
|
@ -485,7 +489,7 @@ static void RegisterCPUKernels(std::function<void(KernelCreateInfo&&)> create_fn
|
|||
|
||||
std::shared_ptr<KernelRegistry> CPUExecutionProvider::GetKernelRegistry() const {
|
||||
static std::shared_ptr<KernelRegistry>
|
||||
kernel_registry = std::make_shared<KernelRegistry>(RegisterCPUKernels);
|
||||
kernel_registry = std::make_shared<KernelRegistry>(RegisterCPUKernels);
|
||||
return kernel_registry;
|
||||
}
|
||||
|
||||
|
|
@ -493,7 +497,7 @@ std::vector<std::unique_ptr<ComputeCapability>>
|
|||
CPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
|
||||
const std::vector<const KernelRegistry*>& kernel_registries) const {
|
||||
std::vector<std::unique_ptr<ComputeCapability>>
|
||||
result = IExecutionProvider::GetCapability(graph, kernel_registries);
|
||||
result = IExecutionProvider::GetCapability(graph, kernel_registries);
|
||||
|
||||
for (auto& rule : fuse_rules_) {
|
||||
rule(graph, result);
|
||||
|
|
|
|||
|
|
@ -400,11 +400,10 @@ Status Pow<float>::Compute(OpKernelContext* context) const {
|
|||
std::function<void(EigenVectorMap<float>, ConstEigenVectorMap<float>, float)> input1scalar =
|
||||
[](EigenVectorMap<float> output, ConstEigenVectorMap<float> input0, float input1) { output = Eigen::pow(input0.array(), input1); };
|
||||
if (Y.Shape().Size() == 1) {
|
||||
float value = * Y.Data<float>();
|
||||
float value = *Y.Data<float>();
|
||||
if (value == 2.0) {
|
||||
input1scalar = [](EigenVectorMap<float> output, ConstEigenVectorMap<float> input0, float) { output = Eigen::square(input0.array()); };
|
||||
}
|
||||
else if (value == 3.0) {
|
||||
} else if (value == 3.0) {
|
||||
input1scalar = [](EigenVectorMap<float> output, ConstEigenVectorMap<float> input0, float) { output = Eigen::cube(input0.array()); };
|
||||
}
|
||||
}
|
||||
|
|
@ -789,6 +788,46 @@ ONNX_CPU_OPERATOR_KERNEL(
|
|||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Atan<float>);
|
||||
|
||||
template <typename T>
|
||||
class Sinh final : public OpKernel {
|
||||
public:
|
||||
explicit Sinh(const OpKernelInfo& info) : OpKernel(info) {
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override {
|
||||
auto& X = *context->Input<Tensor>(0);
|
||||
auto& Y = *context->Output(0, X.Shape());
|
||||
MakeEigenArrayMap<float>(Y) = MakeEigenArrayMap<float>(X).sinh();
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Sinh,
|
||||
9,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Sinh<float>);
|
||||
|
||||
template <typename T>
|
||||
class Cosh final : public OpKernel {
|
||||
public:
|
||||
explicit Cosh(const OpKernelInfo& info) : OpKernel(info) {
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override {
|
||||
auto& X = *context->Input<Tensor>(0);
|
||||
auto& Y = *context->Output(0, X.Shape());
|
||||
MakeEigenArrayMap<float>(Y) = MakeEigenArrayMap<float>(X).cosh();
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Cosh,
|
||||
9,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Cosh<float>);
|
||||
|
||||
template <>
|
||||
Status PRelu<float>::Compute(OpKernelContext* context) const {
|
||||
return BroadcastTwo<float, float>(
|
||||
|
|
@ -887,7 +926,6 @@ Status Erf<float>::Compute(OpKernelContext* context) const {
|
|||
ORT_ENFORCE(X_ptr != nullptr);
|
||||
auto& X = *X_ptr;
|
||||
auto& Y = *context->Output(0, X.Shape());
|
||||
|
||||
EigenMap<float>(Y) = EigenMap<float>(X).array().erf();
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -314,8 +314,6 @@ int real_main(int argc, char* argv[]) {
|
|||
{"upsample_nearest", "opset 9 not supported yet"},
|
||||
{"onehot_with_axis", "opset 9 not supported yet"},
|
||||
{"onehot_without_axis", "opset 9 not supported yet"}, // also has bug in current test re: output type. Spandan to fix.
|
||||
{"sinh", "opset 9 not supported yet"},
|
||||
{"cosh", "opset 9 not supported yet"},
|
||||
{"asinh", "opset 9 not supported yet"},
|
||||
{"acosh", "opset 9 not supported yet"},
|
||||
{"atanh", "opset 9 not supported yet"},
|
||||
|
|
|
|||
|
|
@ -163,11 +163,11 @@ TEST(MathOpTest, Sub_int32) {
|
|||
}
|
||||
|
||||
TEST(MathOpTest, Sub_int64) {
|
||||
OpTester test("Sub");
|
||||
test.AddInput<int64_t>("A", { 3 }, { 1, 5, 6 });
|
||||
test.AddInput<int64_t>("B", { 3 }, { 4, 5, 3 });
|
||||
test.AddOutput<int64_t>("C", { 3 }, { -3, 0, 3 });
|
||||
test.Run();
|
||||
OpTester test("Sub");
|
||||
test.AddInput<int64_t>("A", {3}, {1, 5, 6});
|
||||
test.AddInput<int64_t>("B", {3}, {4, 5, 3});
|
||||
test.AddOutput<int64_t>("C", {3}, {-3, 0, 3});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Sub) {
|
||||
|
|
@ -212,11 +212,11 @@ TEST(MathOpTest, Mul_int32) {
|
|||
}
|
||||
|
||||
TEST(MathOpTest, Mul_int64) {
|
||||
OpTester test("Mul");
|
||||
test.AddInput<int64_t>("A", { 3 }, { 3, 6, -3 });
|
||||
test.AddInput<int64_t>("B", { 3 }, { 4, -3, -2 });
|
||||
test.AddOutput<int64_t>("C", { 3 }, { 12, -18, 6 });
|
||||
test.Run();
|
||||
OpTester test("Mul");
|
||||
test.AddInput<int64_t>("A", {3}, {3, 6, -3});
|
||||
test.AddInput<int64_t>("B", {3}, {4, -3, -2});
|
||||
test.AddOutput<int64_t>("C", {3}, {12, -18, 6});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Mul) {
|
||||
|
|
@ -246,11 +246,11 @@ TEST(MathOpTest, Div_int32) {
|
|||
}
|
||||
|
||||
TEST(MathOpTest, Div_int64) {
|
||||
OpTester test("Div");
|
||||
test.AddInput<int64_t>("A", { 3 }, { 4, 8, 8 });
|
||||
test.AddInput<int64_t>("B", { 3 }, { 2, 3, 4 });
|
||||
test.AddOutput<int64_t>("C", { 3 }, { 2, 2, 2 });
|
||||
test.Run();
|
||||
OpTester test("Div");
|
||||
test.AddInput<int64_t>("A", {3}, {4, 8, 8});
|
||||
test.AddInput<int64_t>("B", {3}, {2, 3, 4});
|
||||
test.AddOutput<int64_t>("C", {3}, {2, 2, 2});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Div) {
|
||||
|
|
@ -819,6 +819,16 @@ TEST(MathOpTest, Atan) {
|
|||
TrigTest<std::atan>(test, {-10.0f, -5.0f, 0.0f, 5.0f, 10.0f});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Sinh) {
|
||||
OpTester test("Sinh", 9);
|
||||
TrigTest<std::sinh>(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Cosh) {
|
||||
OpTester test("Cosh", 9);
|
||||
TrigTest<std::cosh>(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Expand_8_3x3) {
|
||||
OpTester test("Expand", 8);
|
||||
test.AddInput<float>("data_0", {1}, {1.0f});
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ backend_test.exclude(r'(test_acosh_cpu.*'
|
|||
'|test_atanh_example_cpu.*'
|
||||
'|test_convtranspose_1d_cpu.*'
|
||||
'|test_convtranspose_3d_cpu.*'
|
||||
'|test_cosh_cpu.*'
|
||||
'|test_cosh_example_cpu.*'
|
||||
'|test_dynamic_slice_cpu.*'
|
||||
'|test_dynamic_slice_default_axes_cpu.*'
|
||||
|
|
@ -43,7 +42,6 @@ backend_test.exclude(r'(test_acosh_cpu.*'
|
|||
'|test_scatter_with_axis_cpu.*'
|
||||
'|test_scatter_without_axis_cpu.*'
|
||||
'|test_sign_cpu.*'
|
||||
'|test_sinh_cpu.*'
|
||||
'|test_sinh_example_cpu.*'
|
||||
'|test_AvgPool1d_cpu.*'
|
||||
'|test_AvgPool1d_stride_cpu.*'
|
||||
|
|
|
|||
Loading…
Reference in a new issue