Square and Cube optimization for Pow<float>() operator. (#30)

This commit is contained in:
Zhang Lei 2018-11-30 10:16:22 -08:00 committed by GitHub
parent bcc8f621ea
commit b534f9fa5f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -389,10 +389,23 @@ Status Sqrt<float>::Compute(OpKernelContext* ctx) const {
template <>
Status Pow<float>::Compute(OpKernelContext* context) const {
const Tensor& Y = *context->Input<Tensor>(1);
std::function<void(EigenVectorMap<float>, ConstEigenVectorMap<float>, float)> input1scalar =
[](EigenVectorMap<float> output, ConstEigenVectorMap<float> input0, float input1) { output = Eigen::pow(input0.array(), input1); };
if (Y.Shape().Size() == 1) {
float value = * Y.Data<float>();
if (value == 2.0) {
input1scalar = [](EigenVectorMap<float> output, ConstEigenVectorMap<float> input0, float) { output = Eigen::square(input0.array()); };
}
else if (value == 3.0) {
input1scalar = [](EigenVectorMap<float> output, ConstEigenVectorMap<float> input0, float) { output = Eigen::cube(input0.array()); };
}
}
return BroadcastTwo<float, float>(
*context,
[](EigenVectorMap<float> output, float input0, ConstEigenVectorMap<float> input1) { output = Eigen::pow(input0, input1.array()); },
[](EigenVectorMap<float> output, ConstEigenVectorMap<float> input0, float input1) { output = Eigen::pow(input0.array(), input1); },
input1scalar,
[](EigenVectorMap<float> output, ConstEigenVectorMap<float> input0, ConstEigenVectorMap<float> input1) { output = Eigen::pow(input0.array(), input1.array()); });
}