mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
[NNAPI] Add pow support (#6310)
This commit is contained in:
parent
fcd9fc9b6d
commit
b220feee2f
2 changed files with 64 additions and 15 deletions
|
|
@ -178,6 +178,7 @@ static Status AddBinaryOperator(int32_t op_type,
|
|||
ModelBuilder& model_builder,
|
||||
const std::string& input1,
|
||||
const std::string& input2,
|
||||
bool add_activation,
|
||||
int32_t fuse_code,
|
||||
const std::string& output,
|
||||
bool output_is_nhwc,
|
||||
|
|
@ -187,6 +188,7 @@ static Status AddBinaryOperator(int32_t op_type,
|
|||
ModelBuilder& model_builder,
|
||||
const std::string& input1,
|
||||
const std::string& input2,
|
||||
bool add_activation,
|
||||
int32_t fuse_code,
|
||||
const std::string& output,
|
||||
bool output_is_nhwc,
|
||||
|
|
@ -199,7 +201,11 @@ static Status AddBinaryOperator(int32_t op_type,
|
|||
std::vector<uint32_t> input_indices;
|
||||
input_indices.push_back(operand_indices.at(input1)); // input 1
|
||||
input_indices.push_back(operand_indices.at(input2)); // input 2
|
||||
ADD_SCALAR_OPERAND(model_builder, input_indices, fuse_code);
|
||||
|
||||
if (add_activation) {
|
||||
ADD_SCALAR_OPERAND(model_builder, input_indices, fuse_code);
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(shaper.Eltwise(input1, input2, output));
|
||||
const OperandType output_operand_type(operand_types.at(input1).type, shaper[output],
|
||||
output_scale, output_zero_point);
|
||||
|
|
@ -752,6 +758,7 @@ void BinaryOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const N
|
|||
"Mul",
|
||||
"Div",
|
||||
"QLinearAdd",
|
||||
"Pow",
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -760,16 +767,20 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
const auto input_defs(node.InputDefs());
|
||||
|
||||
int32_t op_code;
|
||||
bool add_activation = true;
|
||||
bool op_is_qlinear = op_type == "QLinearAdd";
|
||||
if (op_type == "Add" || op_is_qlinear)
|
||||
if (op_type == "Add" || op_is_qlinear) {
|
||||
op_code = ANEURALNETWORKS_ADD;
|
||||
else if (op_type == "Sub")
|
||||
} else if (op_type == "Sub") {
|
||||
op_code = ANEURALNETWORKS_SUB;
|
||||
else if (op_type == "Mul")
|
||||
} else if (op_type == "Mul") {
|
||||
op_code = ANEURALNETWORKS_MUL;
|
||||
else if (op_type == "Div")
|
||||
} else if (op_type == "Div") {
|
||||
op_code = ANEURALNETWORKS_DIV;
|
||||
else {
|
||||
} else if (op_type == "Pow") {
|
||||
add_activation = false; // ANEURALNETWORKS_POW does not have activation
|
||||
op_code = ANEURALNETWORKS_POW;
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "UnaryOpBuilder, unknown op: ", op_type);
|
||||
}
|
||||
|
||||
|
|
@ -805,9 +816,14 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
|
|||
ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input2, b_scale, b_zero_point));
|
||||
}
|
||||
|
||||
int32_t fuse_code = model_builder.FindActivation(node, *node.OutputDefs()[0]);
|
||||
int32_t fuse_code = ANEURALNETWORKS_FUSED_NONE;
|
||||
if (add_activation) {
|
||||
fuse_code = model_builder.FindActivation(node, *node.OutputDefs()[0]);
|
||||
}
|
||||
|
||||
return AddBinaryOperator(op_code, model_builder,
|
||||
input1, input2, fuse_code,
|
||||
input1, input2,
|
||||
add_activation, fuse_code,
|
||||
output, output_is_nhwc, y_scale, y_zero_point);
|
||||
}
|
||||
|
||||
|
|
@ -1116,7 +1132,7 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu
|
|||
ORT_RETURN_IF_ERROR(AddBinaryOperator(ANEURALNETWORKS_MUL,
|
||||
model_builder,
|
||||
input, tensor_a_name,
|
||||
ANEURALNETWORKS_FUSED_NONE,
|
||||
true /* add_activation */, ANEURALNETWORKS_FUSED_NONE,
|
||||
tensor_imm_product_name,
|
||||
output_is_nhwc));
|
||||
|
||||
|
|
@ -1125,7 +1141,7 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu
|
|||
ORT_RETURN_IF_ERROR(AddBinaryOperator(ANEURALNETWORKS_ADD,
|
||||
model_builder,
|
||||
tensor_imm_product_name, tensor_b_name,
|
||||
fuse_code,
|
||||
true /* add_activation */, fuse_code,
|
||||
output,
|
||||
output_is_nhwc));
|
||||
|
||||
|
|
@ -2411,6 +2427,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
|
|||
NNAPI_EP_ADD_SHARED_OP_BUILDER("Mul", BinaryOpBuilder);
|
||||
NNAPI_EP_ADD_SHARED_OP_BUILDER("Div", BinaryOpBuilder);
|
||||
NNAPI_EP_ADD_SHARED_OP_BUILDER("QLinearAdd", BinaryOpBuilder);
|
||||
NNAPI_EP_ADD_SHARED_OP_BUILDER("Pow", BinaryOpBuilder);
|
||||
}
|
||||
|
||||
NNAPI_EP_ADD_SINGLE_OP_BUILDER("Relu", ReluOpBuilder);
|
||||
|
|
|
|||
|
|
@ -194,6 +194,7 @@ class BinaryOpSupportChecker : public BaseOpSupportChecker {
|
|||
"Mul",
|
||||
"Div",
|
||||
"QLinearAdd",
|
||||
"Pow",
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -203,13 +204,18 @@ int32_t BinaryOpSupportChecker::GetMinSupportedSdkVer(
|
|||
if (op == "Sub" || op == "Div") {
|
||||
return 28;
|
||||
}
|
||||
|
||||
if (op == "Pow") {
|
||||
return 29;
|
||||
}
|
||||
|
||||
return 27;
|
||||
}
|
||||
|
||||
int BinaryOpSupportChecker::GetMinSupportedOpSet(const Node& node) const {
|
||||
const auto& op(node.OpType());
|
||||
|
||||
// Add/Sub/Mul/Div opset 6- has broadcast attributes we do not support now
|
||||
// Add/Sub/Mul/Div/Pow opset 6- has broadcast attributes we do not support now
|
||||
if (op != "QLinearAdd")
|
||||
return 7;
|
||||
|
||||
|
|
@ -217,12 +223,37 @@ int BinaryOpSupportChecker::GetMinSupportedOpSet(const Node& node) const {
|
|||
}
|
||||
|
||||
bool BinaryOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
|
||||
if (node.OpType() != "QLinearAdd")
|
||||
bool is_qlinear_add = node.OpType() == "QLinearAdd";
|
||||
bool is_pow = node.OpType() == "Pow";
|
||||
if (!is_qlinear_add && !is_pow)
|
||||
return BaseOpSupportChecker::HasSupportedInputsImpl(node);
|
||||
|
||||
// QLinearAdd
|
||||
if (!HasValidBinaryOpQuantizedInputs(node))
|
||||
return false;
|
||||
if (is_qlinear_add) {
|
||||
// QLinearAdd
|
||||
if (!HasValidBinaryOpQuantizedInputs(node))
|
||||
return false;
|
||||
}
|
||||
|
||||
// Pow we only support both input as fp32 now
|
||||
if (is_pow) {
|
||||
const auto& input1 = *node.InputDefs()[0];
|
||||
const auto& input2 = *node.InputDefs()[1];
|
||||
|
||||
int32_t input_type_1;
|
||||
if (!GetType(input1, input_type_1))
|
||||
return false;
|
||||
|
||||
int32_t input_type_2;
|
||||
if (!GetType(input2, input_type_2))
|
||||
return false;
|
||||
|
||||
if (input_type_1 != ONNX_NAMESPACE::TensorProto_DataType_FLOAT || input_type_1 != input_type_2) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Pow only supports fp32 inputs, actual input type"
|
||||
<< ", Input type 1: " << input_type_1
|
||||
<< ", Input type 2: " << input_type_2;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
@ -1362,6 +1393,7 @@ static OpSupportCheckerRegistrations CreateOpSupportCheckerRegistrations() {
|
|||
NNAPI_EP_ADD_SHARED_OP_SUPPORT_CHECKER("Mul", BinaryOpSupportChecker);
|
||||
NNAPI_EP_ADD_SHARED_OP_SUPPORT_CHECKER("Div", BinaryOpSupportChecker);
|
||||
NNAPI_EP_ADD_SHARED_OP_SUPPORT_CHECKER("QLinearAdd", BinaryOpSupportChecker);
|
||||
NNAPI_EP_ADD_SHARED_OP_SUPPORT_CHECKER("Pow", BinaryOpSupportChecker);
|
||||
}
|
||||
|
||||
// Relu is always supported, we use BaseOpSupportChecker as default
|
||||
|
|
|
|||
Loading…
Reference in a new issue