Fail model loading if node input is a missing value (#7459)

* Add check for a node with an invalid input so we fail during model load. Without this we get a more cryptic failure in the allocation planner.

* Add handling for manually created subgraph where outer scope node arg names are in a different place.

* Update onnxruntime/core/graph/graph.cc

Co-authored-by: Pranav Sharma <prs@microsoft.com>

* Ignore Fused nodes when checking inputs are valid. DML EP will remove initializers it has moved across, so they're available when the node runs but are no longer part of the ORT Graph instance.

We have to fully resolve the graph before any node fusion happens, so the model was valid in the beginning (which is the main thing we are trying to validate).

* Skip check in training build. Rules for allowing an 'invalid' input are unknown for those scenarios.

* Only check the initial load for a training build.

Co-authored-by: Pranav Sharma <prs@microsoft.com>
This commit is contained in:
Scott McKay 2021-05-11 19:48:17 +10:00 committed by GitHub
parent 90c65ac171
commit 5276bab268
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 3 deletions

View file

@ -1494,6 +1494,25 @@ Status Graph::BuildConnections(std::unordered_set<std::string>& outer_scope_node
ORT_IGNORE_RETURN_VALUE(outer_scope_node_args_consumed.insert(input_arg_name));
}
}
} else {
// Check all the inputs are found.
//
// Ignore a Fused node as it could have moved things like initializers to a different device
// (they're internally available to the fused node but removed from the Graph instance).
// Fusion happens after the model was loaded in full so we know the inputs were valid originally.
bool check = node.NodeType() != Node::Type::Fused;
#if defined(ENABLE_TRAINING)
// Only check initial model load for training as graph modifications there also render inputs 'invalid'.
check = check && num_resolves_ == 0;
#endif
if (check &&
resolve_context_.inputs_and_initializers.find(input_arg_name) ==
resolve_context_.inputs_and_initializers.cend() &&
// if we're manually creating a Graph for use as a subgraph the outer scope names are manually set
outer_scope_node_arg_names_.find(input_arg_name) == outer_scope_node_arg_names_.cend()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid model. Node input '", input_arg_name,
"' is not a graph input, initializer, or output of a previous node.");
}
}
}
}
@ -2396,7 +2415,7 @@ void Graph::InitFunctionBodyForNode(Node& node) {
}
auto func_ptr = std::make_unique<onnxruntime::FunctionImpl>(*this, node.Index(), onnx_function_proto,
logger_);
logger_);
function_container_.emplace_back(std::move(func_ptr));
node.SetFunctionBody(*function_container_.back());

View file

@ -446,8 +446,8 @@ TEST_F(GraphTest, LocalCustomRegistry) {
}
// Tests the case where function op and function body ops belong to different domains.
// Tests that such a model can be loaded successfully, function body initialization is
// successful and domain and verison mapping for each node is successful (by verifying
// Tests that such a model can be loaded successfully, function body initialization is
// successful and domain and verison mapping for each node is successful (by verifying
// op schema for each of the function body nodes can be found).
TEST_F(GraphTest, FunctionOpsetImportTest) {
std::shared_ptr<Model> model;
@ -1880,5 +1880,33 @@ TEST_F(GraphTest, SetInputsAndSetOutputs_NewInputAndOutput) {
ASSERT_TRUE(std::find(outputs.begin(), outputs.end(), sum_with_z) != outputs.end())
<< "expected new output sum_with_z";
}
TEST_F(GraphTest, LoadModelMissingInput) {
ModelProto m;
m.set_ir_version(ONNX_NAMESPACE::IR_VERSION);
ImportOpset(m, "", 13);
GraphProto& g = *m.mutable_graph();
NodeProto* node = g.add_node();
*node->add_input() = "x";
*node->add_input() = "y";
*node->add_output() = "z";
node->set_op_type("Reshape");
node->set_domain("");
// add 'x' as a graph input but not 'y'
ValueInfoProto* input1 = g.add_input();
input1->set_name("x");
SetTypeAndShape(input1->mutable_type()->mutable_tensor_type(), 1, {4});
ValueInfoProto* output = g.add_output();
output->set_name("z");
SetTypeAndShape(output->mutable_type()->mutable_tensor_type(), 1, {2, 2});
std::shared_ptr<Model> model;
Status st = Model::Load(std::move(m), model, nullptr, *logger_);
ASSERT_FALSE(st.IsOK());
ASSERT_THAT(st.ErrorMessage(), testing::HasSubstr("Invalid model. Node input 'y' is not a graph input, "
"initializer, or output of a previous node."));
}
} // namespace test
} // namespace onnxruntime