[NNAPI] Add pow support (#6310)

This commit is contained in:
Guoyu Wang 2021-01-13 17:15:05 -08:00 committed by GitHub
parent fcd9fc9b6d
commit b220feee2f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 64 additions and 15 deletions

View file

@ -178,6 +178,7 @@ static Status AddBinaryOperator(int32_t op_type,
ModelBuilder& model_builder,
const std::string& input1,
const std::string& input2,
bool add_activation,
int32_t fuse_code,
const std::string& output,
bool output_is_nhwc,
@ -187,6 +188,7 @@ static Status AddBinaryOperator(int32_t op_type,
ModelBuilder& model_builder,
const std::string& input1,
const std::string& input2,
bool add_activation,
int32_t fuse_code,
const std::string& output,
bool output_is_nhwc,
@ -199,7 +201,11 @@ static Status AddBinaryOperator(int32_t op_type,
std::vector<uint32_t> input_indices;
input_indices.push_back(operand_indices.at(input1)); // input 1
input_indices.push_back(operand_indices.at(input2)); // input 2
ADD_SCALAR_OPERAND(model_builder, input_indices, fuse_code);
if (add_activation) {
ADD_SCALAR_OPERAND(model_builder, input_indices, fuse_code);
}
ORT_RETURN_IF_ERROR(shaper.Eltwise(input1, input2, output));
const OperandType output_operand_type(operand_types.at(input1).type, shaper[output],
output_scale, output_zero_point);
@ -752,6 +758,7 @@ void BinaryOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const N
"Mul",
"Div",
"QLinearAdd",
"Pow",
});
}
@ -760,16 +767,20 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
const auto input_defs(node.InputDefs());
int32_t op_code;
bool add_activation = true;
bool op_is_qlinear = op_type == "QLinearAdd";
if (op_type == "Add" || op_is_qlinear)
if (op_type == "Add" || op_is_qlinear) {
op_code = ANEURALNETWORKS_ADD;
else if (op_type == "Sub")
} else if (op_type == "Sub") {
op_code = ANEURALNETWORKS_SUB;
else if (op_type == "Mul")
} else if (op_type == "Mul") {
op_code = ANEURALNETWORKS_MUL;
else if (op_type == "Div")
} else if (op_type == "Div") {
op_code = ANEURALNETWORKS_DIV;
else {
} else if (op_type == "Pow") {
add_activation = false; // ANEURALNETWORKS_POW does not have activation
op_code = ANEURALNETWORKS_POW;
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "UnaryOpBuilder, unknown op: ", op_type);
}
@ -805,9 +816,14 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input2, b_scale, b_zero_point));
}
int32_t fuse_code = model_builder.FindActivation(node, *node.OutputDefs()[0]);
int32_t fuse_code = ANEURALNETWORKS_FUSED_NONE;
if (add_activation) {
fuse_code = model_builder.FindActivation(node, *node.OutputDefs()[0]);
}
return AddBinaryOperator(op_code, model_builder,
input1, input2, fuse_code,
input1, input2,
add_activation, fuse_code,
output, output_is_nhwc, y_scale, y_zero_point);
}
@ -1116,7 +1132,7 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu
ORT_RETURN_IF_ERROR(AddBinaryOperator(ANEURALNETWORKS_MUL,
model_builder,
input, tensor_a_name,
ANEURALNETWORKS_FUSED_NONE,
true /* add_activation */, ANEURALNETWORKS_FUSED_NONE,
tensor_imm_product_name,
output_is_nhwc));
@ -1125,7 +1141,7 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu
ORT_RETURN_IF_ERROR(AddBinaryOperator(ANEURALNETWORKS_ADD,
model_builder,
tensor_imm_product_name, tensor_b_name,
fuse_code,
true /* add_activation */, fuse_code,
output,
output_is_nhwc));
@ -2411,6 +2427,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
NNAPI_EP_ADD_SHARED_OP_BUILDER("Mul", BinaryOpBuilder);
NNAPI_EP_ADD_SHARED_OP_BUILDER("Div", BinaryOpBuilder);
NNAPI_EP_ADD_SHARED_OP_BUILDER("QLinearAdd", BinaryOpBuilder);
NNAPI_EP_ADD_SHARED_OP_BUILDER("Pow", BinaryOpBuilder);
}
NNAPI_EP_ADD_SINGLE_OP_BUILDER("Relu", ReluOpBuilder);

View file

@ -194,6 +194,7 @@ class BinaryOpSupportChecker : public BaseOpSupportChecker {
"Mul",
"Div",
"QLinearAdd",
"Pow",
});
}
@ -203,13 +204,18 @@ int32_t BinaryOpSupportChecker::GetMinSupportedSdkVer(
if (op == "Sub" || op == "Div") {
return 28;
}
if (op == "Pow") {
return 29;
}
return 27;
}
int BinaryOpSupportChecker::GetMinSupportedOpSet(const Node& node) const {
const auto& op(node.OpType());
// Add/Sub/Mul/Div opset 6- has broadcast attributes we do not support now
// Add/Sub/Mul/Div/Pow opset 6- has broadcast attributes we do not support now
if (op != "QLinearAdd")
return 7;
@ -217,12 +223,37 @@ int BinaryOpSupportChecker::GetMinSupportedOpSet(const Node& node) const {
}
bool BinaryOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
if (node.OpType() != "QLinearAdd")
bool is_qlinear_add = node.OpType() == "QLinearAdd";
bool is_pow = node.OpType() == "Pow";
if (!is_qlinear_add && !is_pow)
return BaseOpSupportChecker::HasSupportedInputsImpl(node);
// QLinearAdd
if (!HasValidBinaryOpQuantizedInputs(node))
return false;
if (is_qlinear_add) {
// QLinearAdd
if (!HasValidBinaryOpQuantizedInputs(node))
return false;
}
// Pow we only support both input as fp32 now
if (is_pow) {
const auto& input1 = *node.InputDefs()[0];
const auto& input2 = *node.InputDefs()[1];
int32_t input_type_1;
if (!GetType(input1, input_type_1))
return false;
int32_t input_type_2;
if (!GetType(input2, input_type_2))
return false;
if (input_type_1 != ONNX_NAMESPACE::TensorProto_DataType_FLOAT || input_type_1 != input_type_2) {
LOGS_DEFAULT(VERBOSE) << "Pow only supports fp32 inputs, actual input type"
<< ", Input type 1: " << input_type_1
<< ", Input type 2: " << input_type_2;
return false;
}
}
return true;
}
@ -1362,6 +1393,7 @@ static OpSupportCheckerRegistrations CreateOpSupportCheckerRegistrations() {
NNAPI_EP_ADD_SHARED_OP_SUPPORT_CHECKER("Mul", BinaryOpSupportChecker);
NNAPI_EP_ADD_SHARED_OP_SUPPORT_CHECKER("Div", BinaryOpSupportChecker);
NNAPI_EP_ADD_SHARED_OP_SUPPORT_CHECKER("QLinearAdd", BinaryOpSupportChecker);
NNAPI_EP_ADD_SHARED_OP_SUPPORT_CHECKER("Pow", BinaryOpSupportChecker);
}
// Relu is always supported, we use BaseOpSupportChecker as default