From b534f9fa5f9e72275a8a5d4fef8710bdc43d9a05 Mon Sep 17 00:00:00 2001 From: Zhang Lei Date: Fri, 30 Nov 2018 10:16:22 -0800 Subject: [PATCH] Square and Cube optimization for Pow() operator. (#30) --- .../core/providers/cpu/math/element_wise_ops.cc | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index bfc134b0f0..8f48195ea0 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -389,10 +389,23 @@ Status Sqrt::Compute(OpKernelContext* ctx) const { template <> Status Pow::Compute(OpKernelContext* context) const { + const Tensor& Y = *context->Input(1); + std::function, ConstEigenVectorMap, float)> input1scalar = + [](EigenVectorMap output, ConstEigenVectorMap input0, float input1) { output = Eigen::pow(input0.array(), input1); }; + if (Y.Shape().Size() == 1) { + float value = * Y.Data(); + if (value == 2.0) { + input1scalar = [](EigenVectorMap output, ConstEigenVectorMap input0, float) { output = Eigen::square(input0.array()); }; + } + else if (value == 3.0) { + input1scalar = [](EigenVectorMap output, ConstEigenVectorMap input0, float) { output = Eigen::cube(input0.array()); }; + } + } + return BroadcastTwo( *context, [](EigenVectorMap output, float input0, ConstEigenVectorMap input1) { output = Eigen::pow(input0, input1.array()); }, - [](EigenVectorMap output, ConstEigenVectorMap input0, float input1) { output = Eigen::pow(input0.array(), input1); }, + input1scalar, [](EigenVectorMap output, ConstEigenVectorMap input0, ConstEigenVectorMap input1) { output = Eigen::pow(input0.array(), input1.array()); }); }