diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc index 58d8c9b2a4..523402ed6b 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc @@ -178,6 +178,7 @@ static Status AddBinaryOperator(int32_t op_type, ModelBuilder& model_builder, const std::string& input1, const std::string& input2, + bool add_activation, int32_t fuse_code, const std::string& output, bool output_is_nhwc, @@ -187,6 +188,7 @@ static Status AddBinaryOperator(int32_t op_type, ModelBuilder& model_builder, const std::string& input1, const std::string& input2, + bool add_activation, int32_t fuse_code, const std::string& output, bool output_is_nhwc, @@ -199,7 +201,11 @@ static Status AddBinaryOperator(int32_t op_type, std::vector input_indices; input_indices.push_back(operand_indices.at(input1)); // input 1 input_indices.push_back(operand_indices.at(input2)); // input 2 - ADD_SCALAR_OPERAND(model_builder, input_indices, fuse_code); + + if (add_activation) { + ADD_SCALAR_OPERAND(model_builder, input_indices, fuse_code); + } + ORT_RETURN_IF_ERROR(shaper.Eltwise(input1, input2, output)); const OperandType output_operand_type(operand_types.at(input1).type, shaper[output], output_scale, output_zero_point); @@ -752,6 +758,7 @@ void BinaryOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const N "Mul", "Div", "QLinearAdd", + "Pow", }); } @@ -760,16 +767,20 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const const auto input_defs(node.InputDefs()); int32_t op_code; + bool add_activation = true; bool op_is_qlinear = op_type == "QLinearAdd"; - if (op_type == "Add" || op_is_qlinear) + if (op_type == "Add" || op_is_qlinear) { op_code = ANEURALNETWORKS_ADD; - else if (op_type == "Sub") + } else if (op_type == "Sub") { op_code = ANEURALNETWORKS_SUB; - else if (op_type == "Mul") + } else if (op_type == "Mul") { op_code = ANEURALNETWORKS_MUL; - else if (op_type == "Div") + } else if (op_type == "Div") { op_code = ANEURALNETWORKS_DIV; - else { + } else if (op_type == "Pow") { + add_activation = false; // ANEURALNETWORKS_POW does not have activation + op_code = ANEURALNETWORKS_POW; + } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "UnaryOpBuilder, unknown op: ", op_type); } @@ -805,9 +816,14 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input2, b_scale, b_zero_point)); } - int32_t fuse_code = model_builder.FindActivation(node, *node.OutputDefs()[0]); + int32_t fuse_code = ANEURALNETWORKS_FUSED_NONE; + if (add_activation) { + fuse_code = model_builder.FindActivation(node, *node.OutputDefs()[0]); + } + return AddBinaryOperator(op_code, model_builder, - input1, input2, fuse_code, + input1, input2, + add_activation, fuse_code, output, output_is_nhwc, y_scale, y_zero_point); } @@ -1116,7 +1132,7 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu ORT_RETURN_IF_ERROR(AddBinaryOperator(ANEURALNETWORKS_MUL, model_builder, input, tensor_a_name, - ANEURALNETWORKS_FUSED_NONE, + true /* add_activation */, ANEURALNETWORKS_FUSED_NONE, tensor_imm_product_name, output_is_nhwc)); @@ -1125,7 +1141,7 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu ORT_RETURN_IF_ERROR(AddBinaryOperator(ANEURALNETWORKS_ADD, model_builder, tensor_imm_product_name, tensor_b_name, - fuse_code, + true /* add_activation */, fuse_code, output, output_is_nhwc)); @@ -2411,6 +2427,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { NNAPI_EP_ADD_SHARED_OP_BUILDER("Mul", BinaryOpBuilder); NNAPI_EP_ADD_SHARED_OP_BUILDER("Div", BinaryOpBuilder); NNAPI_EP_ADD_SHARED_OP_BUILDER("QLinearAdd", BinaryOpBuilder); + NNAPI_EP_ADD_SHARED_OP_BUILDER("Pow", BinaryOpBuilder); } NNAPI_EP_ADD_SINGLE_OP_BUILDER("Relu", ReluOpBuilder); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc index 47d2be430b..ebd0f8a339 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc @@ -194,6 +194,7 @@ class BinaryOpSupportChecker : public BaseOpSupportChecker { "Mul", "Div", "QLinearAdd", + "Pow", }); } @@ -203,13 +204,18 @@ int32_t BinaryOpSupportChecker::GetMinSupportedSdkVer( if (op == "Sub" || op == "Div") { return 28; } + + if (op == "Pow") { + return 29; + } + return 27; } int BinaryOpSupportChecker::GetMinSupportedOpSet(const Node& node) const { const auto& op(node.OpType()); - // Add/Sub/Mul/Div opset 6- has broadcast attributes we do not support now + // Add/Sub/Mul/Div/Pow opset 6- has broadcast attributes we do not support now if (op != "QLinearAdd") return 7; @@ -217,12 +223,37 @@ int BinaryOpSupportChecker::GetMinSupportedOpSet(const Node& node) const { } bool BinaryOpSupportChecker::HasSupportedInputsImpl(const Node& node) const { - if (node.OpType() != "QLinearAdd") + bool is_qlinear_add = node.OpType() == "QLinearAdd"; + bool is_pow = node.OpType() == "Pow"; + if (!is_qlinear_add && !is_pow) return BaseOpSupportChecker::HasSupportedInputsImpl(node); - // QLinearAdd - if (!HasValidBinaryOpQuantizedInputs(node)) - return false; + if (is_qlinear_add) { + // QLinearAdd + if (!HasValidBinaryOpQuantizedInputs(node)) + return false; + } + + // Pow we only support both input as fp32 now + if (is_pow) { + const auto& input1 = *node.InputDefs()[0]; + const auto& input2 = *node.InputDefs()[1]; + + int32_t input_type_1; + if (!GetType(input1, input_type_1)) + return false; + + int32_t input_type_2; + if (!GetType(input2, input_type_2)) + return false; + + if (input_type_1 != ONNX_NAMESPACE::TensorProto_DataType_FLOAT || input_type_1 != input_type_2) { + LOGS_DEFAULT(VERBOSE) << "Pow only supports fp32 inputs, actual input type" + << ", Input type 1: " << input_type_1 + << ", Input type 2: " << input_type_2; + return false; + } + } return true; } @@ -1362,6 +1393,7 @@ static OpSupportCheckerRegistrations CreateOpSupportCheckerRegistrations() { NNAPI_EP_ADD_SHARED_OP_SUPPORT_CHECKER("Mul", BinaryOpSupportChecker); NNAPI_EP_ADD_SHARED_OP_SUPPORT_CHECKER("Div", BinaryOpSupportChecker); NNAPI_EP_ADD_SHARED_OP_SUPPORT_CHECKER("QLinearAdd", BinaryOpSupportChecker); + NNAPI_EP_ADD_SHARED_OP_SUPPORT_CHECKER("Pow", BinaryOpSupportChecker); } // Relu is always supported, we use BaseOpSupportChecker as default