mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
parent
07bd4ef470
commit
2ad7bcb766
3 changed files with 72 additions and 11 deletions
|
|
@ -57,10 +57,10 @@ std::vector<std::vector<int>> ModelBuilder::GetSupportedNodes() {
|
|||
std::vector<std::vector<int>> supported_node_vecs;
|
||||
int32_t android_sdk_ver = GetAndroidSdkVer();
|
||||
#ifdef __ANDROID__
|
||||
if (android_sdk_ver < 27) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Android API level "
|
||||
<< android_sdk_ver
|
||||
<< " is lower than 27";
|
||||
if (android_sdk_ver < ORT_NNAPI_MIN_API_LEVEL) {
|
||||
LOGS_DEFAULT(WARNING) << "All ops will fallback to CPU EP, because Android API level [" << android_sdk_ver
|
||||
<< "] is lower than minimal supported API level [" << ORT_NNAPI_MIN_API_LEVEL
|
||||
<< "] of this build for NNAPI";
|
||||
return supported_node_vecs;
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -10,6 +10,11 @@
|
|||
#include "core/providers/nnapi/nnapi_builtin/nnapi_lib/NeuralNetworksWrapper.h"
|
||||
#include "shaper.h"
|
||||
|
||||
// This is the minimal Android API Level required by ORT NNAPI EP to run
|
||||
#ifndef ORT_NNAPI_MIN_API_LEVEL
|
||||
#define ORT_NNAPI_MIN_API_LEVEL 27
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace nnapi {
|
||||
|
||||
|
|
|
|||
|
|
@ -681,6 +681,11 @@ class BaseOpBuilder : public IOpBuilder {
|
|||
virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) ORT_MUST_USE_RESULT = 0;
|
||||
|
||||
bool HasExternalInitializer(ModelBuilder& model_builder, const Node& node);
|
||||
|
||||
virtual int GetMinSupportedOpSet(const Node& /* node */) { return 1; }
|
||||
virtual int GetMaxSupportedOpSet(const Node& /* node */) { return 13; }
|
||||
|
||||
bool HasSupportedOpSet(const Node& node);
|
||||
};
|
||||
|
||||
bool BaseOpBuilder::IsOpSupported(ModelBuilder& model_builder, const Node& node) {
|
||||
|
|
@ -702,6 +707,9 @@ bool BaseOpBuilder::IsOpSupported(ModelBuilder& model_builder, const Node& node)
|
|||
if (HasExternalInitializer(model_builder, node))
|
||||
return false;
|
||||
|
||||
if (!HasSupportedOpSet(node))
|
||||
return false;
|
||||
|
||||
return IsOpSupportedImpl(model_builder, node);
|
||||
}
|
||||
|
||||
|
|
@ -762,6 +770,18 @@ bool BaseOpBuilder::HasExternalInitializer(ModelBuilder& model_builder, const No
|
|||
return false;
|
||||
}
|
||||
|
||||
bool BaseOpBuilder::HasSupportedOpSet(const Node& node) {
|
||||
auto since_version = node.SinceVersion();
|
||||
if (since_version < GetMinSupportedOpSet(node) || since_version > GetMaxSupportedOpSet(node)) {
|
||||
LOGS_DEFAULT(VERBOSE) << node.OpType() << "is only supported for opset ["
|
||||
<< GetMinSupportedOpSet(node) << ", "
|
||||
<< GetMaxSupportedOpSet(node) << "]";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
#pragma endregion op_base
|
||||
|
||||
#pragma region op_binary
|
||||
|
|
@ -775,6 +795,7 @@ class BinaryOpBuilder : public BaseOpBuilder {
|
|||
bool IsOpSupportedImpl(ModelBuilder& model_builder, const Node& node) override;
|
||||
bool HasSupportedInputs(const Node& node) override;
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override ORT_MUST_USE_RESULT;
|
||||
int GetMinSupportedOpSet(const Node& node) override;
|
||||
};
|
||||
|
||||
void BinaryOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) {
|
||||
|
|
@ -793,6 +814,16 @@ int32_t BinaryOpBuilder::GetMinSupportedSdkVer(ModelBuilder& /* model_builder */
|
|||
return 27;
|
||||
}
|
||||
|
||||
int BinaryOpBuilder::GetMinSupportedOpSet(const Node& node) {
|
||||
const auto& op(node.OpType());
|
||||
|
||||
// Add/Sub/Mul/Div opset 6- has broadcast attributes we do not support now
|
||||
if (op != "QLinearAdd")
|
||||
return 7;
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
bool BinaryOpBuilder::HasSupportedInputs(const Node& node) {
|
||||
if (node.OpType() != "QLinearAdd")
|
||||
return BaseOpBuilder::HasSupportedInputs(node);
|
||||
|
|
@ -1031,6 +1062,9 @@ class ReshapeOpBuilder : public BaseOpBuilder {
|
|||
bool IsOpSupportedImpl(ModelBuilder& model_builder, const Node& node) override;
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override ORT_MUST_USE_RESULT;
|
||||
static bool CanSkipReshape(const Node& node, size_t input_rank, size_t output_rank);
|
||||
|
||||
// Reshape opset 4- uses attributes for new shape which we do not support for now
|
||||
int GetMinSupportedOpSet(const Node& /* node */) override { return 5; }
|
||||
};
|
||||
|
||||
void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) {
|
||||
|
|
@ -1201,6 +1235,9 @@ class BatchNormalizationOpBuilder : public BaseOpBuilder {
|
|||
bool IsOpSupportedImpl(ModelBuilder& model_builder, const Node& node) override;
|
||||
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override ORT_MUST_USE_RESULT;
|
||||
|
||||
// BatchNormalization opset 6- has unsupported attributes
|
||||
int GetMinSupportedOpSet(const Node& /* node */) override { return 7; }
|
||||
};
|
||||
|
||||
void BatchNormalizationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) {
|
||||
|
|
@ -1849,6 +1886,9 @@ class CastOpBuilder : public BaseOpBuilder {
|
|||
}
|
||||
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override ORT_MUST_USE_RESULT;
|
||||
|
||||
// Cast opset 5- uses string attribute for to type, is not supported for now
|
||||
int GetMinSupportedOpSet(const Node& /* node */) override { return 6; }
|
||||
};
|
||||
|
||||
bool CastOpBuilder::IsOpSupportedImpl(ModelBuilder& /* model_builder */, const Node& node) {
|
||||
|
|
@ -2019,6 +2059,7 @@ class GemmOpBuilder : public BaseOpBuilder {
|
|||
bool IsOpSupportedImpl(ModelBuilder& model_builder, const Node& node) override;
|
||||
bool HasSupportedInputs(const Node& node) override;
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override ORT_MUST_USE_RESULT;
|
||||
int GetMinSupportedOpSet(const Node& node) override;
|
||||
};
|
||||
|
||||
bool GemmOpBuilder::HasSupportedInputs(const Node& node) {
|
||||
|
|
@ -2032,6 +2073,16 @@ bool GemmOpBuilder::HasSupportedInputs(const Node& node) {
|
|||
return true;
|
||||
}
|
||||
|
||||
int GemmOpBuilder::GetMinSupportedOpSet(const Node& node) {
|
||||
const auto& op(node.OpType());
|
||||
|
||||
// Gemm opset 6- has broadcast attributes we do not support now
|
||||
if (op == "Gemm")
|
||||
return 7;
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
bool GemmOpBuilder::IsOpSupportedImpl(ModelBuilder& model_builder, const Node& node) {
|
||||
const auto& op_type = node.OpType();
|
||||
const auto input_defs(node.InputDefs());
|
||||
|
|
@ -2257,6 +2308,10 @@ class UnaryOpBuilder : public BaseOpBuilder {
|
|||
int32_t GetMinSupportedSdkVer(ModelBuilder& model_builder, const Node& node) const override;
|
||||
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override ORT_MUST_USE_RESULT;
|
||||
|
||||
// All ops except "Sin" opset 5- uses consumed_inputs attribute which is not supported for now
|
||||
// "Sin" op has support from opset 7, return 6 here for all ops
|
||||
int GetMinSupportedOpSet(const Node& /* node */) override { return 6; }
|
||||
};
|
||||
|
||||
int32_t UnaryOpBuilder::GetMinSupportedSdkVer(ModelBuilder& /* model_builder */, const Node& node) const {
|
||||
|
|
@ -2418,6 +2473,10 @@ class SqueezeOpBuilder : public BaseOpBuilder {
|
|||
}
|
||||
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override ORT_MUST_USE_RESULT;
|
||||
|
||||
// Squeeze opset 13+ uses input for axes, which is not supported yet
|
||||
// TODO add support for squeeze opset 13+
|
||||
int GetMaxSupportedOpSet(const Node& /* node */) override { return 12; }
|
||||
};
|
||||
|
||||
bool SqueezeOpBuilder::IsOpSupportedImpl(ModelBuilder& /* model_builder */, const Node& node) {
|
||||
|
|
@ -2811,6 +2870,10 @@ class ResizeOpBuilder : public BaseOpBuilder {
|
|||
int32_t GetMinSupportedSdkVer(ModelBuilder& model_builder, const Node& node) const override;
|
||||
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node) override ORT_MUST_USE_RESULT;
|
||||
|
||||
// Resize opset 10- is very different than Resize opset 11+, with many key attributes missing
|
||||
// We only support Resize opset 11+ here
|
||||
int GetMinSupportedOpSet(const Node& /* node */) override { return 11; }
|
||||
};
|
||||
|
||||
int32_t ResizeOpBuilder::GetMinSupportedSdkVer(ModelBuilder& /* model_builder */, const Node& /* node */) const {
|
||||
|
|
@ -2827,13 +2890,6 @@ void ResizeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const N
|
|||
}
|
||||
|
||||
bool ResizeOpBuilder::IsOpSupportedImpl(ModelBuilder& model_builder, const Node& node) {
|
||||
// Resize opset 10- is very different than Resize opset 11+, with many key attributes missing
|
||||
// We only support Resize opset 11+ here
|
||||
if (node.SinceVersion() < 11) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Resize only supports opset 11+";
|
||||
return false;
|
||||
}
|
||||
|
||||
Shape input_shape;
|
||||
if (!GetShape(*node.InputDefs()[0], input_shape))
|
||||
return false;
|
||||
|
|
|
|||
Loading…
Reference in a new issue