Make NNAPI EP reject nodes with no-shape inputs (#5927)

This commit is contained in:
Guoyu Wang 2020-11-25 00:21:00 -08:00 committed by GitHub
parent fddbd8935c
commit 87368655e2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 75 additions and 18 deletions

View file

@ -80,11 +80,14 @@ class BaseOpSupportChecker : public IOpSupportChecker {
return 27;
}
virtual bool HasSupportedInputs(const Node& node) const;
virtual bool HasSupportedInputsImpl(const Node& node) const;
virtual int GetMinSupportedOpSet(const Node& /* node */) const { return 1; }
virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 13; }
private:
bool HasSupportedOpSet(const Node& node) const;
bool HasSupportedInputs(const Node& node) const;
};
/* static */ void BaseOpSupportChecker::CreateSharedOpSupportChecker(
@ -121,16 +124,23 @@ bool BaseOpSupportChecker::IsOpSupported(const InitializedTensorSet& initializer
}
bool BaseOpSupportChecker::HasSupportedInputs(const Node& node) const {
// We do not support unknown(null) input shape
for (const auto* input : node.InputDefs()) {
if (!input->Shape()) {
LOGS_DEFAULT(VERBOSE) << "Node [" << node.Name() << "] type [" << node.OpType()
<< "] Input [" << input->Name() << "] has no shape";
return false;
}
}
return HasSupportedInputsImpl(node);
}
bool BaseOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
// We only check the type of input 0 by default
// specific op builder can override this
const auto& input = *node.InputDefs()[0];
if (nullptr == input.Shape()) {
LOGS_DEFAULT(VERBOSE) << "[" << node.OpType()
<< "] Input shape is null";
return false;
}
int32_t input_type;
if (!GetType(input, input_type))
return false;
@ -170,7 +180,7 @@ class BinaryOpSupportChecker : public BaseOpSupportChecker {
int32_t GetMinSupportedSdkVer(const Node& node, const OpSupportCheckParams& params) const override;
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const OpSupportCheckParams& params) const override;
bool HasSupportedInputs(const Node& node) const override;
bool HasSupportedInputsImpl(const Node& node) const override;
int GetMinSupportedOpSet(const Node& node) const override;
};
@ -206,9 +216,9 @@ int BinaryOpSupportChecker::GetMinSupportedOpSet(const Node& node) const {
return 1;
}
bool BinaryOpSupportChecker::HasSupportedInputs(const Node& node) const {
bool BinaryOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
if (node.OpType() != "QLinearAdd")
return BaseOpSupportChecker::HasSupportedInputs(node);
return BaseOpSupportChecker::HasSupportedInputsImpl(node);
// QLinearAdd
if (!HasValidBinaryOpQuantizedInputs(node))
@ -511,7 +521,7 @@ class ConvOpSupportChecker : public BaseOpSupportChecker {
return params.use_nchw ? 29 : 28;
}
bool HasSupportedInputs(const Node& node) const override;
bool HasSupportedInputsImpl(const Node& node) const override;
};
/* static */ void ConvOpSupportChecker::CreateSharedOpSupportChecker(
@ -524,9 +534,9 @@ class ConvOpSupportChecker : public BaseOpSupportChecker {
});
}
bool ConvOpSupportChecker::HasSupportedInputs(const Node& node) const {
bool ConvOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
if (node.OpType() != "QLinearConv")
return BaseOpSupportChecker::HasSupportedInputs(node);
return BaseOpSupportChecker::HasSupportedInputsImpl(node);
// QLinearConv only supports input of uint8 for now
if (!HasValidBinaryOpQuantizedInputs(node))
@ -683,13 +693,13 @@ class GemmOpSupportChecker : public BaseOpSupportChecker {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const OpSupportCheckParams& params) const override;
bool HasSupportedInputs(const Node& node) const override;
bool HasSupportedInputsImpl(const Node& node) const override;
int GetMinSupportedOpSet(const Node& node) const override;
};
bool GemmOpSupportChecker::HasSupportedInputs(const Node& node) const {
bool GemmOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
if (node.OpType() != "QLinearMatMul")
return BaseOpSupportChecker::HasSupportedInputs(node);
return BaseOpSupportChecker::HasSupportedInputsImpl(node);
// QLinearMatMul
if (!HasValidBinaryOpQuantizedInputs(node))
@ -990,7 +1000,7 @@ class DequantizeLinearOpSupportChecker : public BaseOpSupportChecker {
int32_t GetMinSupportedSdkVer(const Node& /* node */, const OpSupportCheckParams& /* params */) const override {
return 29;
}
bool HasSupportedInputs(const Node& node) const override;
bool HasSupportedInputsImpl(const Node& node) const override;
};
bool DequantizeLinearOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
@ -1007,7 +1017,7 @@ bool DequantizeLinearOpSupportChecker::IsOpSupportedImpl(const InitializedTensor
return true;
}
bool DequantizeLinearOpSupportChecker::HasSupportedInputs(const Node& node) const {
bool DequantizeLinearOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
int32_t input_type;
if (!GetType(*node.InputDefs()[0], input_type))
return false;

View file

@ -155,6 +155,53 @@ TEST(NnapiExecutionProviderTest, FunctionTest) {
<< "Some nodes should have been taken by the NNAPI EP";
#endif
}
TEST(NnapiExecutionProviderTest, TestNoShapeInputModel) {
const ORTCHAR_T* model_file_name = ORT_TSTR("input_with_no_shape_test_graph.onnx");
{ // Create the model with 2 add nodes, the graph has 2 inputs with no shape
onnxruntime::Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
std::vector<onnxruntime::NodeArg*> inputs;
std::vector<onnxruntime::NodeArg*> outputs;
// FLOAT tensor without shape
ONNX_NAMESPACE::TypeProto float_tensor;
float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
auto& input_arg_1 = graph.GetOrCreateNodeArg("X", &float_tensor);
auto& input_arg_2 = graph.GetOrCreateNodeArg("Y", &float_tensor);
inputs.push_back(&input_arg_1);
inputs.push_back(&input_arg_2);
auto& output_arg = graph.GetOrCreateNodeArg("node_1_out_1", &float_tensor);
outputs.push_back(&output_arg);
graph.AddNode("node_1", "Add", "node 1.", inputs, outputs);
auto& input_arg_3 = graph.GetOrCreateNodeArg("Z", &float_tensor);
inputs.clear();
inputs.push_back(&output_arg);
inputs.push_back(&input_arg_3);
auto& output_arg_2 = graph.GetOrCreateNodeArg("M", &float_tensor);
outputs.clear();
outputs.push_back(&output_arg_2);
graph.AddNode("node_2", "Add", "node 2.", inputs, outputs);
ASSERT_STATUS_OK(graph.Resolve());
ASSERT_STATUS_OK(onnxruntime::Model::Save(model, model_file_name));
}
// test load only
// since we know NNAPI supports Add op, but both Add ops in the graph has no input shape
// verify the entire graph will not be assigned to NNAPI EP
SessionOptions so;
InferenceSessionWrapper session_object{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique<NnapiExecutionProvider>(0)));
ASSERT_STATUS_OK(session_object.Load(model_file_name));
ASSERT_STATUS_OK(session_object.Initialize());
ASSERT_EQ(CountAssignedNodes(session_object.GetGraph(), kNnapiExecutionProvider), 0)
<< "No node should be taken by the NNAPI EP";
}
#endif // !(ORT_MINIMAL_BUILD
TEST(NnapiExecutionProviderTest, NNAPIFlagsTest) {