From 5276bab268cffbc9fa93062a81b8a773ab56281f Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Tue, 11 May 2021 19:48:17 +1000 Subject: [PATCH] Fail model loading if node input is a missing value (#7459) * Add check for a node with an invalid input so we fail during model load. Without this we get a more cryptic failure in the allocation planner. * Add handling for manually created subgraph where outer scope node arg names are in a different place. * Update onnxruntime/core/graph/graph.cc Co-authored-by: Pranav Sharma * Ignore Fused nodes when checking inputs are valid. DML EP will remove initializers it has moved across, so they're available when the node runs but are no longer part of the ORT Graph instance. We have to fully resolve the graph before any node fusion happens, so the model was valid in the beginning (which is the main thing we are trying to validate). * Skip check in training build. Rules for allowing an 'invalid' input are unknown for those scenarios. * Only check the initial load for a training build. Co-authored-by: Pranav Sharma --- onnxruntime/core/graph/graph.cc | 21 +++++++++++++++++++- onnxruntime/test/ir/graph_test.cc | 32 +++++++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 4b355501a4..321c48025b 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1494,6 +1494,25 @@ Status Graph::BuildConnections(std::unordered_set& outer_scope_node ORT_IGNORE_RETURN_VALUE(outer_scope_node_args_consumed.insert(input_arg_name)); } } + } else { + // Check all the inputs are found. + // + // Ignore a Fused node as it could have moved things like initializers to a different device + // (they're internally available to the fused node but removed from the Graph instance). + // Fusion happens after the model was loaded in full so we know the inputs were valid originally. + bool check = node.NodeType() != Node::Type::Fused; +#if defined(ENABLE_TRAINING) + // Only check initial model load for training as graph modifications there also render inputs 'invalid'. + check = check && num_resolves_ == 0; +#endif + if (check && + resolve_context_.inputs_and_initializers.find(input_arg_name) == + resolve_context_.inputs_and_initializers.cend() && + // if we're manually creating a Graph for use as a subgraph the outer scope names are manually set + outer_scope_node_arg_names_.find(input_arg_name) == outer_scope_node_arg_names_.cend()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid model. Node input '", input_arg_name, + "' is not a graph input, initializer, or output of a previous node."); + } } } } @@ -2396,7 +2415,7 @@ void Graph::InitFunctionBodyForNode(Node& node) { } auto func_ptr = std::make_unique(*this, node.Index(), onnx_function_proto, - logger_); + logger_); function_container_.emplace_back(std::move(func_ptr)); node.SetFunctionBody(*function_container_.back()); diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index 6ea5688c9d..a29f4adb95 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -446,8 +446,8 @@ TEST_F(GraphTest, LocalCustomRegistry) { } // Tests the case where function op and function body ops belong to different domains. -// Tests that such a model can be loaded successfully, function body initialization is -// successful and domain and verison mapping for each node is successful (by verifying +// Tests that such a model can be loaded successfully, function body initialization is +// successful and domain and verison mapping for each node is successful (by verifying // op schema for each of the function body nodes can be found). TEST_F(GraphTest, FunctionOpsetImportTest) { std::shared_ptr model; @@ -1880,5 +1880,33 @@ TEST_F(GraphTest, SetInputsAndSetOutputs_NewInputAndOutput) { ASSERT_TRUE(std::find(outputs.begin(), outputs.end(), sum_with_z) != outputs.end()) << "expected new output sum_with_z"; } + +TEST_F(GraphTest, LoadModelMissingInput) { + ModelProto m; + m.set_ir_version(ONNX_NAMESPACE::IR_VERSION); + ImportOpset(m, "", 13); + GraphProto& g = *m.mutable_graph(); + NodeProto* node = g.add_node(); + *node->add_input() = "x"; + *node->add_input() = "y"; + *node->add_output() = "z"; + node->set_op_type("Reshape"); + node->set_domain(""); + + // add 'x' as a graph input but not 'y' + ValueInfoProto* input1 = g.add_input(); + input1->set_name("x"); + SetTypeAndShape(input1->mutable_type()->mutable_tensor_type(), 1, {4}); + ValueInfoProto* output = g.add_output(); + output->set_name("z"); + SetTypeAndShape(output->mutable_type()->mutable_tensor_type(), 1, {2, 2}); + + std::shared_ptr model; + Status st = Model::Load(std::move(m), model, nullptr, *logger_); + ASSERT_FALSE(st.IsOK()); + ASSERT_THAT(st.ErrorMessage(), testing::HasSubstr("Invalid model. Node input 'y' is not a graph input, " + "initializer, or output of a previous node.")); +} + } // namespace test } // namespace onnxruntime