diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index bfc134b0f0..8f48195ea0 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -389,10 +389,23 @@ Status Sqrt::Compute(OpKernelContext* ctx) const { template <> Status Pow::Compute(OpKernelContext* context) const { + const Tensor& Y = *context->Input(1); + std::function, ConstEigenVectorMap, float)> input1scalar = + [](EigenVectorMap output, ConstEigenVectorMap input0, float input1) { output = Eigen::pow(input0.array(), input1); }; + if (Y.Shape().Size() == 1) { + float value = * Y.Data(); + if (value == 2.0) { + input1scalar = [](EigenVectorMap output, ConstEigenVectorMap input0, float) { output = Eigen::square(input0.array()); }; + } + else if (value == 3.0) { + input1scalar = [](EigenVectorMap output, ConstEigenVectorMap input0, float) { output = Eigen::cube(input0.array()); }; + } + } + return BroadcastTwo( *context, [](EigenVectorMap output, float input0, ConstEigenVectorMap input1) { output = Eigen::pow(input0, input1.array()); }, - [](EigenVectorMap output, ConstEigenVectorMap input0, float input1) { output = Eigen::pow(input0.array(), input1); }, + input1scalar, [](EigenVectorMap output, ConstEigenVectorMap input0, ConstEigenVectorMap input1) { output = Eigen::pow(input0.array(), input1.array()); }); }