NNAPI add opset version check (#5687)

* nnapi add opset support
This commit is contained in:
Guoyu Wang 2020-11-04 03:48:00 -08:00 committed by GitHub
parent 07bd4ef470
commit 2ad7bcb766
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 72 additions and 11 deletions

View file

@ -57,10 +57,10 @@ std::vector<std::vector<int>> ModelBuilder::GetSupportedNodes() {
std::vector<std::vector<int>> supported_node_vecs;
int32_t android_sdk_ver = GetAndroidSdkVer();
#ifdef __ANDROID__
if (android_sdk_ver < 27) {
LOGS_DEFAULT(VERBOSE) << "Android API level "
<< android_sdk_ver
<< " is lower than 27";
if (android_sdk_ver < ORT_NNAPI_MIN_API_LEVEL) {
LOGS_DEFAULT(WARNING) << "All ops will fallback to CPU EP, because Android API level [" << android_sdk_ver
<< "] is lower than minimal supported API level [" << ORT_NNAPI_MIN_API_LEVEL
<< "] of this build for NNAPI";
return supported_node_vecs;
}
#endif

View file

@ -10,6 +10,11 @@
#include "core/providers/nnapi/nnapi_builtin/nnapi_lib/NeuralNetworksWrapper.h"
#include "shaper.h"
// This is the minimal Android API Level required by ORT NNAPI EP to run
#ifndef ORT_NNAPI_MIN_API_LEVEL
#define ORT_NNAPI_MIN_API_LEVEL 27
#endif
namespace onnxruntime {
namespace nnapi {

View file

@ -681,6 +681,11 @@ class BaseOpBuilder : public IOpBuilder {
virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) ORT_MUST_USE_RESULT = 0;
bool HasExternalInitializer(ModelBuilder& model_builder, const Node& node);
virtual int GetMinSupportedOpSet(const Node& /* node */) { return 1; }
virtual int GetMaxSupportedOpSet(const Node& /* node */) { return 13; }
bool HasSupportedOpSet(const Node& node);
};
bool BaseOpBuilder::IsOpSupported(ModelBuilder& model_builder, const Node& node) {
@ -702,6 +707,9 @@ bool BaseOpBuilder::IsOpSupported(ModelBuilder& model_builder, const Node& node)
if (HasExternalInitializer(model_builder, node))
return false;
if (!HasSupportedOpSet(node))
return false;
return IsOpSupportedImpl(model_builder, node);
}
@ -762,6 +770,18 @@ bool BaseOpBuilder::HasExternalInitializer(ModelBuilder& model_builder, const No
return false;
}
bool BaseOpBuilder::HasSupportedOpSet(const Node& node) {
auto since_version = node.SinceVersion();
if (since_version < GetMinSupportedOpSet(node) || since_version > GetMaxSupportedOpSet(node)) {
LOGS_DEFAULT(VERBOSE) << node.OpType() << "is only supported for opset ["
<< GetMinSupportedOpSet(node) << ", "
<< GetMaxSupportedOpSet(node) << "]";
return false;
}
return true;
}
#pragma endregion op_base
#pragma region op_binary
@ -775,6 +795,7 @@ class BinaryOpBuilder : public BaseOpBuilder {
bool IsOpSupportedImpl(ModelBuilder& model_builder, const Node& node) override;
bool HasSupportedInputs(const Node& node) override;
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override ORT_MUST_USE_RESULT;
int GetMinSupportedOpSet(const Node& node) override;
};
void BinaryOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) {
@ -793,6 +814,16 @@ int32_t BinaryOpBuilder::GetMinSupportedSdkVer(ModelBuilder& /* model_builder */
return 27;
}
int BinaryOpBuilder::GetMinSupportedOpSet(const Node& node) {
const auto& op(node.OpType());
// Add/Sub/Mul/Div opset 6- has broadcast attributes we do not support now
if (op != "QLinearAdd")
return 7;
return 1;
}
bool BinaryOpBuilder::HasSupportedInputs(const Node& node) {
if (node.OpType() != "QLinearAdd")
return BaseOpBuilder::HasSupportedInputs(node);
@ -1031,6 +1062,9 @@ class ReshapeOpBuilder : public BaseOpBuilder {
bool IsOpSupportedImpl(ModelBuilder& model_builder, const Node& node) override;
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override ORT_MUST_USE_RESULT;
static bool CanSkipReshape(const Node& node, size_t input_rank, size_t output_rank);
// Reshape opset 4- uses attributes for new shape which we do not support for now
int GetMinSupportedOpSet(const Node& /* node */) override { return 5; }
};
void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) {
@ -1201,6 +1235,9 @@ class BatchNormalizationOpBuilder : public BaseOpBuilder {
bool IsOpSupportedImpl(ModelBuilder& model_builder, const Node& node) override;
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override ORT_MUST_USE_RESULT;
// BatchNormalization opset 6- has unsupported attributes
int GetMinSupportedOpSet(const Node& /* node */) override { return 7; }
};
void BatchNormalizationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) {
@ -1849,6 +1886,9 @@ class CastOpBuilder : public BaseOpBuilder {
}
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override ORT_MUST_USE_RESULT;
// Cast opset 5- uses string attribute for to type, is not supported for now
int GetMinSupportedOpSet(const Node& /* node */) override { return 6; }
};
bool CastOpBuilder::IsOpSupportedImpl(ModelBuilder& /* model_builder */, const Node& node) {
@ -2019,6 +2059,7 @@ class GemmOpBuilder : public BaseOpBuilder {
bool IsOpSupportedImpl(ModelBuilder& model_builder, const Node& node) override;
bool HasSupportedInputs(const Node& node) override;
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override ORT_MUST_USE_RESULT;
int GetMinSupportedOpSet(const Node& node) override;
};
bool GemmOpBuilder::HasSupportedInputs(const Node& node) {
@ -2032,6 +2073,16 @@ bool GemmOpBuilder::HasSupportedInputs(const Node& node) {
return true;
}
int GemmOpBuilder::GetMinSupportedOpSet(const Node& node) {
const auto& op(node.OpType());
// Gemm opset 6- has broadcast attributes we do not support now
if (op == "Gemm")
return 7;
return 1;
}
bool GemmOpBuilder::IsOpSupportedImpl(ModelBuilder& model_builder, const Node& node) {
const auto& op_type = node.OpType();
const auto input_defs(node.InputDefs());
@ -2257,6 +2308,10 @@ class UnaryOpBuilder : public BaseOpBuilder {
int32_t GetMinSupportedSdkVer(ModelBuilder& model_builder, const Node& node) const override;
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override ORT_MUST_USE_RESULT;
// All ops except "Sin" opset 5- uses consumed_inputs attribute which is not supported for now
// "Sin" op has support from opset 7, return 6 here for all ops
int GetMinSupportedOpSet(const Node& /* node */) override { return 6; }
};
int32_t UnaryOpBuilder::GetMinSupportedSdkVer(ModelBuilder& /* model_builder */, const Node& node) const {
@ -2418,6 +2473,10 @@ class SqueezeOpBuilder : public BaseOpBuilder {
}
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override ORT_MUST_USE_RESULT;
// Squeeze opset 13+ uses input for axes, which is not supported yet
// TODO add support for squeeze opset 13+
int GetMaxSupportedOpSet(const Node& /* node */) override { return 12; }
};
bool SqueezeOpBuilder::IsOpSupportedImpl(ModelBuilder& /* model_builder */, const Node& node) {
@ -2811,6 +2870,10 @@ class ResizeOpBuilder : public BaseOpBuilder {
int32_t GetMinSupportedSdkVer(ModelBuilder& model_builder, const Node& node) const override;
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override ORT_MUST_USE_RESULT;
// Resize opset 10- is very different than Resize opset 11+, with many key attributes missing
// We only support Resize opset 11+ here
int GetMinSupportedOpSet(const Node& /* node */) override { return 11; }
};
int32_t ResizeOpBuilder::GetMinSupportedSdkVer(ModelBuilder& /* model_builder */, const Node& /* node */) const {
@ -2827,13 +2890,6 @@ void ResizeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const N
}
bool ResizeOpBuilder::IsOpSupportedImpl(ModelBuilder& model_builder, const Node& node) {
// Resize opset 10- is very different than Resize opset 11+, with many key attributes missing
// We only support Resize opset 11+ here
if (node.SinceVersion() < 11) {
LOGS_DEFAULT(VERBOSE) << "Resize only supports opset 11+";
return false;
}
Shape input_shape;
if (!GetShape(*node.InputDefs()[0], input_shape))
return false;