diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc index 99684a0264..b8072c462b 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc @@ -57,10 +57,10 @@ std::vector> ModelBuilder::GetSupportedNodes() { std::vector> supported_node_vecs; int32_t android_sdk_ver = GetAndroidSdkVer(); #ifdef __ANDROID__ - if (android_sdk_ver < 27) { - LOGS_DEFAULT(VERBOSE) << "Android API level " - << android_sdk_ver - << " is lower than 27"; + if (android_sdk_ver < ORT_NNAPI_MIN_API_LEVEL) { + LOGS_DEFAULT(WARNING) << "All ops will fallback to CPU EP, because Android API level [" << android_sdk_ver + << "] is lower than minimal supported API level [" << ORT_NNAPI_MIN_API_LEVEL + << "] of this build for NNAPI"; return supported_node_vecs; } #endif diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h index 9f67fafb88..3046cb0a26 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h @@ -10,6 +10,11 @@ #include "core/providers/nnapi/nnapi_builtin/nnapi_lib/NeuralNetworksWrapper.h" #include "shaper.h" +// This is the minimal Android API Level required by ORT NNAPI EP to run +#ifndef ORT_NNAPI_MIN_API_LEVEL +#define ORT_NNAPI_MIN_API_LEVEL 27 +#endif + namespace onnxruntime { namespace nnapi { diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc index c56a50afa6..4646053c60 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc @@ -681,6 +681,11 @@ class BaseOpBuilder : public IOpBuilder { virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) ORT_MUST_USE_RESULT = 0; bool HasExternalInitializer(ModelBuilder& model_builder, const Node& node); + + virtual int GetMinSupportedOpSet(const Node& /* node */) { return 1; } + virtual int GetMaxSupportedOpSet(const Node& /* node */) { return 13; } + + bool HasSupportedOpSet(const Node& node); }; bool BaseOpBuilder::IsOpSupported(ModelBuilder& model_builder, const Node& node) { @@ -702,6 +707,9 @@ bool BaseOpBuilder::IsOpSupported(ModelBuilder& model_builder, const Node& node) if (HasExternalInitializer(model_builder, node)) return false; + if (!HasSupportedOpSet(node)) + return false; + return IsOpSupportedImpl(model_builder, node); } @@ -762,6 +770,18 @@ bool BaseOpBuilder::HasExternalInitializer(ModelBuilder& model_builder, const No return false; } +bool BaseOpBuilder::HasSupportedOpSet(const Node& node) { + auto since_version = node.SinceVersion(); + if (since_version < GetMinSupportedOpSet(node) || since_version > GetMaxSupportedOpSet(node)) { + LOGS_DEFAULT(VERBOSE) << node.OpType() << "is only supported for opset [" + << GetMinSupportedOpSet(node) << ", " + << GetMaxSupportedOpSet(node) << "]"; + return false; + } + + return true; +} + #pragma endregion op_base #pragma region op_binary @@ -775,6 +795,7 @@ class BinaryOpBuilder : public BaseOpBuilder { bool IsOpSupportedImpl(ModelBuilder& model_builder, const Node& node) override; bool HasSupportedInputs(const Node& node) override; Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override ORT_MUST_USE_RESULT; + int GetMinSupportedOpSet(const Node& node) override; }; void BinaryOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) { @@ -793,6 +814,16 @@ int32_t BinaryOpBuilder::GetMinSupportedSdkVer(ModelBuilder& /* model_builder */ return 27; } +int BinaryOpBuilder::GetMinSupportedOpSet(const Node& node) { + const auto& op(node.OpType()); + + // Add/Sub/Mul/Div opset 6- has broadcast attributes we do not support now + if (op != "QLinearAdd") + return 7; + + return 1; +} + bool BinaryOpBuilder::HasSupportedInputs(const Node& node) { if (node.OpType() != "QLinearAdd") return BaseOpBuilder::HasSupportedInputs(node); @@ -1031,6 +1062,9 @@ class ReshapeOpBuilder : public BaseOpBuilder { bool IsOpSupportedImpl(ModelBuilder& model_builder, const Node& node) override; Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override ORT_MUST_USE_RESULT; static bool CanSkipReshape(const Node& node, size_t input_rank, size_t output_rank); + + // Reshape opset 4- uses attributes for new shape which we do not support for now + int GetMinSupportedOpSet(const Node& /* node */) override { return 5; } }; void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) { @@ -1201,6 +1235,9 @@ class BatchNormalizationOpBuilder : public BaseOpBuilder { bool IsOpSupportedImpl(ModelBuilder& model_builder, const Node& node) override; Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override ORT_MUST_USE_RESULT; + + // BatchNormalization opset 6- has unsupported attributes + int GetMinSupportedOpSet(const Node& /* node */) override { return 7; } }; void BatchNormalizationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) { @@ -1849,6 +1886,9 @@ class CastOpBuilder : public BaseOpBuilder { } Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override ORT_MUST_USE_RESULT; + + // Cast opset 5- uses string attribute for to type, is not supported for now + int GetMinSupportedOpSet(const Node& /* node */) override { return 6; } }; bool CastOpBuilder::IsOpSupportedImpl(ModelBuilder& /* model_builder */, const Node& node) { @@ -2019,6 +2059,7 @@ class GemmOpBuilder : public BaseOpBuilder { bool IsOpSupportedImpl(ModelBuilder& model_builder, const Node& node) override; bool HasSupportedInputs(const Node& node) override; Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override ORT_MUST_USE_RESULT; + int GetMinSupportedOpSet(const Node& node) override; }; bool GemmOpBuilder::HasSupportedInputs(const Node& node) { @@ -2032,6 +2073,16 @@ bool GemmOpBuilder::HasSupportedInputs(const Node& node) { return true; } +int GemmOpBuilder::GetMinSupportedOpSet(const Node& node) { + const auto& op(node.OpType()); + + // Gemm opset 6- has broadcast attributes we do not support now + if (op == "Gemm") + return 7; + + return 1; +} + bool GemmOpBuilder::IsOpSupportedImpl(ModelBuilder& model_builder, const Node& node) { const auto& op_type = node.OpType(); const auto input_defs(node.InputDefs()); @@ -2257,6 +2308,10 @@ class UnaryOpBuilder : public BaseOpBuilder { int32_t GetMinSupportedSdkVer(ModelBuilder& model_builder, const Node& node) const override; Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override ORT_MUST_USE_RESULT; + + // All ops except "Sin" opset 5- uses consumed_inputs attribute which is not supported for now + // "Sin" op has support from opset 7, return 6 here for all ops + int GetMinSupportedOpSet(const Node& /* node */) override { return 6; } }; int32_t UnaryOpBuilder::GetMinSupportedSdkVer(ModelBuilder& /* model_builder */, const Node& node) const { @@ -2418,6 +2473,10 @@ class SqueezeOpBuilder : public BaseOpBuilder { } Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override ORT_MUST_USE_RESULT; + + // Squeeze opset 13+ uses input for axes, which is not supported yet + // TODO add support for squeeze opset 13+ + int GetMaxSupportedOpSet(const Node& /* node */) override { return 12; } }; bool SqueezeOpBuilder::IsOpSupportedImpl(ModelBuilder& /* model_builder */, const Node& node) { @@ -2811,6 +2870,10 @@ class ResizeOpBuilder : public BaseOpBuilder { int32_t GetMinSupportedSdkVer(ModelBuilder& model_builder, const Node& node) const override; Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override ORT_MUST_USE_RESULT; + + // Resize opset 10- is very different than Resize opset 11+, with many key attributes missing + // We only support Resize opset 11+ here + int GetMinSupportedOpSet(const Node& /* node */) override { return 11; } }; int32_t ResizeOpBuilder::GetMinSupportedSdkVer(ModelBuilder& /* model_builder */, const Node& /* node */) const { @@ -2827,13 +2890,6 @@ void ResizeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const N } bool ResizeOpBuilder::IsOpSupportedImpl(ModelBuilder& model_builder, const Node& node) { - // Resize opset 10- is very different than Resize opset 11+, with many key attributes missing - // We only support Resize opset 11+ here - if (node.SinceVersion() < 11) { - LOGS_DEFAULT(VERBOSE) << "Resize only supports opset 11+"; - return false; - } - Shape input_shape; if (!GetShape(*node.InputDefs()[0], input_shape)) return false;