diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 4b355501a4..321c48025b 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1494,6 +1494,25 @@ Status Graph::BuildConnections(std::unordered_set& outer_scope_node ORT_IGNORE_RETURN_VALUE(outer_scope_node_args_consumed.insert(input_arg_name)); } } + } else { + // Check all the inputs are found. + // + // Ignore a Fused node as it could have moved things like initializers to a different device + // (they're internally available to the fused node but removed from the Graph instance). + // Fusion happens after the model was loaded in full so we know the inputs were valid originally. + bool check = node.NodeType() != Node::Type::Fused; +#if defined(ENABLE_TRAINING) + // Only check initial model load for training as graph modifications there also render inputs 'invalid'. + check = check && num_resolves_ == 0; +#endif + if (check && + resolve_context_.inputs_and_initializers.find(input_arg_name) == + resolve_context_.inputs_and_initializers.cend() && + // if we're manually creating a Graph for use as a subgraph the outer scope names are manually set + outer_scope_node_arg_names_.find(input_arg_name) == outer_scope_node_arg_names_.cend()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid model. Node input '", input_arg_name, + "' is not a graph input, initializer, or output of a previous node."); + } } } } @@ -2396,7 +2415,7 @@ void Graph::InitFunctionBodyForNode(Node& node) { } auto func_ptr = std::make_unique(*this, node.Index(), onnx_function_proto, - logger_); + logger_); function_container_.emplace_back(std::move(func_ptr)); node.SetFunctionBody(*function_container_.back()); diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index 6ea5688c9d..a29f4adb95 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -446,8 +446,8 @@ TEST_F(GraphTest, LocalCustomRegistry) { } // Tests the case where function op and function body ops belong to different domains. -// Tests that such a model can be loaded successfully, function body initialization is -// successful and domain and verison mapping for each node is successful (by verifying +// Tests that such a model can be loaded successfully, function body initialization is +// successful and domain and verison mapping for each node is successful (by verifying // op schema for each of the function body nodes can be found). TEST_F(GraphTest, FunctionOpsetImportTest) { std::shared_ptr model; @@ -1880,5 +1880,33 @@ TEST_F(GraphTest, SetInputsAndSetOutputs_NewInputAndOutput) { ASSERT_TRUE(std::find(outputs.begin(), outputs.end(), sum_with_z) != outputs.end()) << "expected new output sum_with_z"; } + +TEST_F(GraphTest, LoadModelMissingInput) { + ModelProto m; + m.set_ir_version(ONNX_NAMESPACE::IR_VERSION); + ImportOpset(m, "", 13); + GraphProto& g = *m.mutable_graph(); + NodeProto* node = g.add_node(); + *node->add_input() = "x"; + *node->add_input() = "y"; + *node->add_output() = "z"; + node->set_op_type("Reshape"); + node->set_domain(""); + + // add 'x' as a graph input but not 'y' + ValueInfoProto* input1 = g.add_input(); + input1->set_name("x"); + SetTypeAndShape(input1->mutable_type()->mutable_tensor_type(), 1, {4}); + ValueInfoProto* output = g.add_output(); + output->set_name("z"); + SetTypeAndShape(output->mutable_type()->mutable_tensor_type(), 1, {2, 2}); + + std::shared_ptr model; + Status st = Model::Load(std::move(m), model, nullptr, *logger_); + ASSERT_FALSE(st.IsOK()); + ASSERT_THAT(st.ErrorMessage(), testing::HasSubstr("Invalid model. Node input 'y' is not a graph input, " + "initializer, or output of a previous node.")); +} + } // namespace test } // namespace onnxruntime