diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc index e465ba2b1d..5ada0499fa 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc @@ -80,11 +80,14 @@ class BaseOpSupportChecker : public IOpSupportChecker { return 27; } - virtual bool HasSupportedInputs(const Node& node) const; + virtual bool HasSupportedInputsImpl(const Node& node) const; virtual int GetMinSupportedOpSet(const Node& /* node */) const { return 1; } virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 13; } + + private: bool HasSupportedOpSet(const Node& node) const; + bool HasSupportedInputs(const Node& node) const; }; /* static */ void BaseOpSupportChecker::CreateSharedOpSupportChecker( @@ -121,16 +124,23 @@ bool BaseOpSupportChecker::IsOpSupported(const InitializedTensorSet& initializer } bool BaseOpSupportChecker::HasSupportedInputs(const Node& node) const { + // We do not support unknown(null) input shape + for (const auto* input : node.InputDefs()) { + if (!input->Shape()) { + LOGS_DEFAULT(VERBOSE) << "Node [" << node.Name() << "] type [" << node.OpType() + << "] Input [" << input->Name() << "] has no shape"; + return false; + } + } + + return HasSupportedInputsImpl(node); +} + +bool BaseOpSupportChecker::HasSupportedInputsImpl(const Node& node) const { // We only check the type of input 0 by default // specific op builder can override this const auto& input = *node.InputDefs()[0]; - if (nullptr == input.Shape()) { - LOGS_DEFAULT(VERBOSE) << "[" << node.OpType() - << "] Input shape is null"; - return false; - } - int32_t input_type; if (!GetType(input, input_type)) return false; @@ -170,7 +180,7 @@ class BinaryOpSupportChecker : public BaseOpSupportChecker { int32_t GetMinSupportedSdkVer(const Node& node, const OpSupportCheckParams& params) const override; bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const OpSupportCheckParams& params) const override; - bool HasSupportedInputs(const Node& node) const override; + bool HasSupportedInputsImpl(const Node& node) const override; int GetMinSupportedOpSet(const Node& node) const override; }; @@ -206,9 +216,9 @@ int BinaryOpSupportChecker::GetMinSupportedOpSet(const Node& node) const { return 1; } -bool BinaryOpSupportChecker::HasSupportedInputs(const Node& node) const { +bool BinaryOpSupportChecker::HasSupportedInputsImpl(const Node& node) const { if (node.OpType() != "QLinearAdd") - return BaseOpSupportChecker::HasSupportedInputs(node); + return BaseOpSupportChecker::HasSupportedInputsImpl(node); // QLinearAdd if (!HasValidBinaryOpQuantizedInputs(node)) @@ -511,7 +521,7 @@ class ConvOpSupportChecker : public BaseOpSupportChecker { return params.use_nchw ? 29 : 28; } - bool HasSupportedInputs(const Node& node) const override; + bool HasSupportedInputsImpl(const Node& node) const override; }; /* static */ void ConvOpSupportChecker::CreateSharedOpSupportChecker( @@ -524,9 +534,9 @@ class ConvOpSupportChecker : public BaseOpSupportChecker { }); } -bool ConvOpSupportChecker::HasSupportedInputs(const Node& node) const { +bool ConvOpSupportChecker::HasSupportedInputsImpl(const Node& node) const { if (node.OpType() != "QLinearConv") - return BaseOpSupportChecker::HasSupportedInputs(node); + return BaseOpSupportChecker::HasSupportedInputsImpl(node); // QLinearConv only supports input of uint8 for now if (!HasValidBinaryOpQuantizedInputs(node)) @@ -683,13 +693,13 @@ class GemmOpSupportChecker : public BaseOpSupportChecker { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const OpSupportCheckParams& params) const override; - bool HasSupportedInputs(const Node& node) const override; + bool HasSupportedInputsImpl(const Node& node) const override; int GetMinSupportedOpSet(const Node& node) const override; }; -bool GemmOpSupportChecker::HasSupportedInputs(const Node& node) const { +bool GemmOpSupportChecker::HasSupportedInputsImpl(const Node& node) const { if (node.OpType() != "QLinearMatMul") - return BaseOpSupportChecker::HasSupportedInputs(node); + return BaseOpSupportChecker::HasSupportedInputsImpl(node); // QLinearMatMul if (!HasValidBinaryOpQuantizedInputs(node)) @@ -990,7 +1000,7 @@ class DequantizeLinearOpSupportChecker : public BaseOpSupportChecker { int32_t GetMinSupportedSdkVer(const Node& /* node */, const OpSupportCheckParams& /* params */) const override { return 29; } - bool HasSupportedInputs(const Node& node) const override; + bool HasSupportedInputsImpl(const Node& node) const override; }; bool DequantizeLinearOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, @@ -1007,7 +1017,7 @@ bool DequantizeLinearOpSupportChecker::IsOpSupportedImpl(const InitializedTensor return true; } -bool DequantizeLinearOpSupportChecker::HasSupportedInputs(const Node& node) const { +bool DequantizeLinearOpSupportChecker::HasSupportedInputsImpl(const Node& node) const { int32_t input_type; if (!GetType(*node.InputDefs()[0], input_type)) return false; diff --git a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc index c1cb1ffac3..1b5c5343c3 100644 --- a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc +++ b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc @@ -155,6 +155,53 @@ TEST(NnapiExecutionProviderTest, FunctionTest) { << "Some nodes should have been taken by the NNAPI EP"; #endif } + +TEST(NnapiExecutionProviderTest, TestNoShapeInputModel) { + const ORTCHAR_T* model_file_name = ORT_TSTR("input_with_no_shape_test_graph.onnx"); + + { // Create the model with 2 add nodes, the graph has 2 inputs with no shape + onnxruntime::Model model("graph_1", false, DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + std::vector inputs; + std::vector outputs; + + // FLOAT tensor without shape + ONNX_NAMESPACE::TypeProto float_tensor; + float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + + auto& input_arg_1 = graph.GetOrCreateNodeArg("X", &float_tensor); + auto& input_arg_2 = graph.GetOrCreateNodeArg("Y", &float_tensor); + inputs.push_back(&input_arg_1); + inputs.push_back(&input_arg_2); + auto& output_arg = graph.GetOrCreateNodeArg("node_1_out_1", &float_tensor); + outputs.push_back(&output_arg); + graph.AddNode("node_1", "Add", "node 1.", inputs, outputs); + + auto& input_arg_3 = graph.GetOrCreateNodeArg("Z", &float_tensor); + inputs.clear(); + inputs.push_back(&output_arg); + inputs.push_back(&input_arg_3); + auto& output_arg_2 = graph.GetOrCreateNodeArg("M", &float_tensor); + outputs.clear(); + outputs.push_back(&output_arg_2); + graph.AddNode("node_2", "Add", "node 2.", inputs, outputs); + + ASSERT_STATUS_OK(graph.Resolve()); + ASSERT_STATUS_OK(onnxruntime::Model::Save(model, model_file_name)); + } + + // test load only + // since we know NNAPI supports Add op, but both Add ops in the graph has no input shape + // verify the entire graph will not be assigned to NNAPI EP + SessionOptions so; + InferenceSessionWrapper session_object{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique(0))); + ASSERT_STATUS_OK(session_object.Load(model_file_name)); + ASSERT_STATUS_OK(session_object.Initialize()); + ASSERT_EQ(CountAssignedNodes(session_object.GetGraph(), kNnapiExecutionProvider), 0) + << "No node should be taken by the NNAPI EP"; +} + #endif // !(ORT_MINIMAL_BUILD TEST(NnapiExecutionProviderTest, NNAPIFlagsTest) {