mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Make NNAPI EP reject nodes with no-shape inputs (#5927)
This commit is contained in:
parent
fddbd8935c
commit
87368655e2
2 changed files with 75 additions and 18 deletions
|
|
@ -80,11 +80,14 @@ class BaseOpSupportChecker : public IOpSupportChecker {
|
|||
return 27;
|
||||
}
|
||||
|
||||
virtual bool HasSupportedInputs(const Node& node) const;
|
||||
virtual bool HasSupportedInputsImpl(const Node& node) const;
|
||||
|
||||
virtual int GetMinSupportedOpSet(const Node& /* node */) const { return 1; }
|
||||
virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 13; }
|
||||
|
||||
private:
|
||||
bool HasSupportedOpSet(const Node& node) const;
|
||||
bool HasSupportedInputs(const Node& node) const;
|
||||
};
|
||||
|
||||
/* static */ void BaseOpSupportChecker::CreateSharedOpSupportChecker(
|
||||
|
|
@ -121,16 +124,23 @@ bool BaseOpSupportChecker::IsOpSupported(const InitializedTensorSet& initializer
|
|||
}
|
||||
|
||||
bool BaseOpSupportChecker::HasSupportedInputs(const Node& node) const {
|
||||
// We do not support unknown(null) input shape
|
||||
for (const auto* input : node.InputDefs()) {
|
||||
if (!input->Shape()) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Node [" << node.Name() << "] type [" << node.OpType()
|
||||
<< "] Input [" << input->Name() << "] has no shape";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return HasSupportedInputsImpl(node);
|
||||
}
|
||||
|
||||
bool BaseOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
|
||||
// We only check the type of input 0 by default
|
||||
// specific op builder can override this
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
|
||||
if (nullptr == input.Shape()) {
|
||||
LOGS_DEFAULT(VERBOSE) << "[" << node.OpType()
|
||||
<< "] Input shape is null";
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type))
|
||||
return false;
|
||||
|
|
@ -170,7 +180,7 @@ class BinaryOpSupportChecker : public BaseOpSupportChecker {
|
|||
int32_t GetMinSupportedSdkVer(const Node& node, const OpSupportCheckParams& params) const override;
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const OpSupportCheckParams& params) const override;
|
||||
bool HasSupportedInputs(const Node& node) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node) const override;
|
||||
int GetMinSupportedOpSet(const Node& node) const override;
|
||||
};
|
||||
|
||||
|
|
@ -206,9 +216,9 @@ int BinaryOpSupportChecker::GetMinSupportedOpSet(const Node& node) const {
|
|||
return 1;
|
||||
}
|
||||
|
||||
bool BinaryOpSupportChecker::HasSupportedInputs(const Node& node) const {
|
||||
bool BinaryOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
|
||||
if (node.OpType() != "QLinearAdd")
|
||||
return BaseOpSupportChecker::HasSupportedInputs(node);
|
||||
return BaseOpSupportChecker::HasSupportedInputsImpl(node);
|
||||
|
||||
// QLinearAdd
|
||||
if (!HasValidBinaryOpQuantizedInputs(node))
|
||||
|
|
@ -511,7 +521,7 @@ class ConvOpSupportChecker : public BaseOpSupportChecker {
|
|||
return params.use_nchw ? 29 : 28;
|
||||
}
|
||||
|
||||
bool HasSupportedInputs(const Node& node) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node) const override;
|
||||
};
|
||||
|
||||
/* static */ void ConvOpSupportChecker::CreateSharedOpSupportChecker(
|
||||
|
|
@ -524,9 +534,9 @@ class ConvOpSupportChecker : public BaseOpSupportChecker {
|
|||
});
|
||||
}
|
||||
|
||||
bool ConvOpSupportChecker::HasSupportedInputs(const Node& node) const {
|
||||
bool ConvOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
|
||||
if (node.OpType() != "QLinearConv")
|
||||
return BaseOpSupportChecker::HasSupportedInputs(node);
|
||||
return BaseOpSupportChecker::HasSupportedInputsImpl(node);
|
||||
|
||||
// QLinearConv only supports input of uint8 for now
|
||||
if (!HasValidBinaryOpQuantizedInputs(node))
|
||||
|
|
@ -683,13 +693,13 @@ class GemmOpSupportChecker : public BaseOpSupportChecker {
|
|||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const OpSupportCheckParams& params) const override;
|
||||
bool HasSupportedInputs(const Node& node) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node) const override;
|
||||
int GetMinSupportedOpSet(const Node& node) const override;
|
||||
};
|
||||
|
||||
bool GemmOpSupportChecker::HasSupportedInputs(const Node& node) const {
|
||||
bool GemmOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
|
||||
if (node.OpType() != "QLinearMatMul")
|
||||
return BaseOpSupportChecker::HasSupportedInputs(node);
|
||||
return BaseOpSupportChecker::HasSupportedInputsImpl(node);
|
||||
|
||||
// QLinearMatMul
|
||||
if (!HasValidBinaryOpQuantizedInputs(node))
|
||||
|
|
@ -990,7 +1000,7 @@ class DequantizeLinearOpSupportChecker : public BaseOpSupportChecker {
|
|||
int32_t GetMinSupportedSdkVer(const Node& /* node */, const OpSupportCheckParams& /* params */) const override {
|
||||
return 29;
|
||||
}
|
||||
bool HasSupportedInputs(const Node& node) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node) const override;
|
||||
};
|
||||
|
||||
bool DequantizeLinearOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
|
|
@ -1007,7 +1017,7 @@ bool DequantizeLinearOpSupportChecker::IsOpSupportedImpl(const InitializedTensor
|
|||
return true;
|
||||
}
|
||||
|
||||
bool DequantizeLinearOpSupportChecker::HasSupportedInputs(const Node& node) const {
|
||||
bool DequantizeLinearOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
|
||||
int32_t input_type;
|
||||
if (!GetType(*node.InputDefs()[0], input_type))
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -155,6 +155,53 @@ TEST(NnapiExecutionProviderTest, FunctionTest) {
|
|||
<< "Some nodes should have been taken by the NNAPI EP";
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(NnapiExecutionProviderTest, TestNoShapeInputModel) {
|
||||
const ORTCHAR_T* model_file_name = ORT_TSTR("input_with_no_shape_test_graph.onnx");
|
||||
|
||||
{ // Create the model with 2 add nodes, the graph has 2 inputs with no shape
|
||||
onnxruntime::Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
std::vector<onnxruntime::NodeArg*> inputs;
|
||||
std::vector<onnxruntime::NodeArg*> outputs;
|
||||
|
||||
// FLOAT tensor without shape
|
||||
ONNX_NAMESPACE::TypeProto float_tensor;
|
||||
float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
|
||||
auto& input_arg_1 = graph.GetOrCreateNodeArg("X", &float_tensor);
|
||||
auto& input_arg_2 = graph.GetOrCreateNodeArg("Y", &float_tensor);
|
||||
inputs.push_back(&input_arg_1);
|
||||
inputs.push_back(&input_arg_2);
|
||||
auto& output_arg = graph.GetOrCreateNodeArg("node_1_out_1", &float_tensor);
|
||||
outputs.push_back(&output_arg);
|
||||
graph.AddNode("node_1", "Add", "node 1.", inputs, outputs);
|
||||
|
||||
auto& input_arg_3 = graph.GetOrCreateNodeArg("Z", &float_tensor);
|
||||
inputs.clear();
|
||||
inputs.push_back(&output_arg);
|
||||
inputs.push_back(&input_arg_3);
|
||||
auto& output_arg_2 = graph.GetOrCreateNodeArg("M", &float_tensor);
|
||||
outputs.clear();
|
||||
outputs.push_back(&output_arg_2);
|
||||
graph.AddNode("node_2", "Add", "node 2.", inputs, outputs);
|
||||
|
||||
ASSERT_STATUS_OK(graph.Resolve());
|
||||
ASSERT_STATUS_OK(onnxruntime::Model::Save(model, model_file_name));
|
||||
}
|
||||
|
||||
// test load only
|
||||
// since we know NNAPI supports Add op, but both Add ops in the graph has no input shape
|
||||
// verify the entire graph will not be assigned to NNAPI EP
|
||||
SessionOptions so;
|
||||
InferenceSessionWrapper session_object{so, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique<NnapiExecutionProvider>(0)));
|
||||
ASSERT_STATUS_OK(session_object.Load(model_file_name));
|
||||
ASSERT_STATUS_OK(session_object.Initialize());
|
||||
ASSERT_EQ(CountAssignedNodes(session_object.GetGraph(), kNnapiExecutionProvider), 0)
|
||||
<< "No node should be taken by the NNAPI EP";
|
||||
}
|
||||
|
||||
#endif // !(ORT_MINIMAL_BUILD
|
||||
|
||||
TEST(NnapiExecutionProviderTest, NNAPIFlagsTest) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue