mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Handle edge case with implicit input and multiple levels of subgraphs (#4031)
* Handle edge case where an implicit input for a subgraph may not get wired in correctly.
Conditions required:
- two or more levels of nested subgraph
- an implicit input from above the bottom two levels is used in both levels of subgraph
- this creates a NodeArg for the implicit input at both levels
- something changes to the first level subgraph to no longer use the implicit input
- could be constant folding, could be partitioning of nodes results in a copy of the implicit input being made to a different device
When that occurs we lose the wiring through to the second level of nested subgraph as there's a NodeArg in the first level but the implicit input is no longer used there. Fix that by doing a final check for outer scope values once we know all the outputs produced by the current graph.
Found by commenting out the CUDA implementations of the control flow nodes and running ssd_mobilenet_300 from the mlperf models.
* Add test case.
This commit is contained in:
parent
c331d8cffc
commit
b85805ed01
4 changed files with 118 additions and 3 deletions
|
|
@ -1124,6 +1124,8 @@ Status Graph::BuildConnections(std::unordered_set<std::string>& outer_scope_node
|
|||
const std::unordered_set<std::string>& outer_scope_node_args = resolve_context_.outer_scope_node_args;
|
||||
std::unordered_set<Node*> inner_nodes;
|
||||
|
||||
std::unordered_set<std::string> node_args_consumed_by_subgraphs;
|
||||
|
||||
// recurse into subgraphs first so we can update any nodes in this graph that are used by those subgraphs
|
||||
if (!resolve_context_.nodes_with_subgraphs.empty()) {
|
||||
for (auto* node : resolve_context_.nodes_with_subgraphs) {
|
||||
|
|
@ -1157,6 +1159,12 @@ Status Graph::BuildConnections(std::unordered_set<std::string>& outer_scope_node
|
|||
"This is an invalid model. Failed to find NodeArg in all parent graphs. Name=", node_arg_name,
|
||||
" Graph may not conform to the ONNX spec and contain initializers that are not graph inputs.");
|
||||
}
|
||||
} else {
|
||||
// this value may be produced by this graph, or it could still be coming from a parent graph if it
|
||||
// is also directly consumed at this level as we create a NodeArg for all Node inputs in this graph.
|
||||
// due to that we need to check the outputs from this level to determine if it is an outer scope value.
|
||||
// we don't have that info yet so store and check before returning from BuildConnections
|
||||
ORT_IGNORE_RETURN_VALUE(node_args_consumed_by_subgraphs.insert(node_arg_name));
|
||||
}
|
||||
|
||||
// add it to the Node's list of implicit inputs
|
||||
|
|
@ -1178,8 +1186,9 @@ Status Graph::BuildConnections(std::unordered_set<std::string>& outer_scope_node
|
|||
|
||||
inner_nodes.insert(&output_node);
|
||||
|
||||
// If this Graph was built manually, remove the implicit input from the graph outputs if it is present there
|
||||
// and not explicitly listed in the ordered graph outputs (as that implies we should leave it as an output).
|
||||
// If this Graph was built manually, remove the implicit input from the graph outputs
|
||||
// if it is present there and not explicitly listed in the ordered graph outputs
|
||||
// (as that implies we should leave it as an output).
|
||||
// If the Graph was loaded from a GraphProto, honor the explicit graph outputs and leave as is.
|
||||
if (!is_loaded_from_model_file_) {
|
||||
graph_outputs_.erase(std::remove(graph_outputs_.begin(), graph_outputs_.end(), node_arg),
|
||||
|
|
@ -1252,8 +1261,17 @@ Status Graph::BuildConnections(std::unordered_set<std::string>& outer_scope_node
|
|||
}
|
||||
}
|
||||
|
||||
// finally check any node args consumed by subgraphs to see if they're available locally.
|
||||
// if not we add them to the list of outer scope values consumed.
|
||||
for (const auto& name : node_args_consumed_by_subgraphs) {
|
||||
if (node_arg_to_producer_node_.count(name) == 0 &&
|
||||
resolve_context_.inputs_and_initializers.find(name) == resolve_context_.inputs_and_initializers.cend()) {
|
||||
ORT_IGNORE_RETURN_VALUE(outer_scope_node_args_consumed.insert(name));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
} // namespace onnxruntime
|
||||
}
|
||||
|
||||
void Graph::ReverseDFSFrom(const std::vector<NodeIndex>& from,
|
||||
const std::function<void(const Node*)>& enter,
|
||||
|
|
|
|||
|
|
@ -929,6 +929,37 @@ TEST(Loop, PassThroughSubgraphInputNoTypeOrShape) {
|
|||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(Loop, BugFixIssue4031_implicit_input_handling) {
|
||||
SessionOptions so;
|
||||
so.graph_optimization_level = TransformerLevel::Level2; // we need constant folding to run
|
||||
InferenceSession session_object{so, GetEnvironment()};
|
||||
static constexpr const ORTCHAR_T* MODEL_URI = ORT_TSTR("testdata/ort_github_issue_4031.onnx");
|
||||
|
||||
ASSERT_STATUS_OK(session_object.Load(MODEL_URI));
|
||||
ASSERT_STATUS_OK(session_object.Initialize());
|
||||
|
||||
onnxruntime::RunOptions run_options;
|
||||
run_options.run_tag = "BugFixIssue4031_implicit_input_handling";
|
||||
|
||||
// prepare inputs
|
||||
OrtValue ml_value;
|
||||
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1}, {123.f},
|
||||
&ml_value);
|
||||
NameMLValMap feeds;
|
||||
feeds.insert(std::make_pair("state_var_in", ml_value));
|
||||
|
||||
// prepare outputs
|
||||
std::vector<std::string> output_names{"state_var_out"};
|
||||
std::vector<OrtValue> fetches;
|
||||
|
||||
// Now run
|
||||
ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches));
|
||||
|
||||
const auto& output = fetches[0].Get<Tensor>();
|
||||
ASSERT_TRUE(output.Shape().Size() == 1);
|
||||
ASSERT_TRUE(output.Data<float>()[0] == 125.f);
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
// test that when part of the subgraph run on CUDA it executes successfully
|
||||
TEST(Loop, MixedExecutionProviders) {
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/ort_github_issue_4031.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/ort_github_issue_4031.onnx
vendored
Normal file
Binary file not shown.
66
onnxruntime/test/testdata/ort_github_issue_4031.py
vendored
Normal file
66
onnxruntime/test/testdata/ort_github_issue_4031.py
vendored
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import TensorProto
|
||||
|
||||
if_body = helper.make_graph(
|
||||
[
|
||||
# need to use main_graph_initializer in a way that can't be constant folded
|
||||
helper.make_node("Add", ["state_var_in", "main_graph_initializer"], ["add_out"], "If_add"),
|
||||
helper.make_node("Cast", ["add_out"], ["output"], to=TensorProto.BOOL),
|
||||
],
|
||||
"if_branch_body",
|
||||
[
|
||||
# no explicit inputs
|
||||
],
|
||||
[
|
||||
helper.make_tensor_value_info('output', TensorProto.BOOL, [1]), # how is this getting a type of float?
|
||||
])
|
||||
|
||||
# Loop body graph with If node and usage of main_graph_initializer on this level
|
||||
body = helper.make_graph(
|
||||
[
|
||||
# Add node that can be constant folded. Creates NodeArg when created but that implicit usage of an outer scope
|
||||
# value main_graph_initializer goes away after constant folding
|
||||
helper.make_node("Add", ["sub_graph_initializer", "main_graph_initializer"], ["initializer_sum"], "Add1"),
|
||||
helper.make_node("Add", ["initializer_sum", "loop_state_in"], ["loop_state_out"], "Add2"),
|
||||
# If node to create usage of main_graph_initializer another level down
|
||||
helper.make_node("If", ["subgraph_keep_going_in"], ["subgraph_keep_going_out"], "If1",
|
||||
then_branch=if_body, else_branch=if_body),
|
||||
],
|
||||
"Loop_body",
|
||||
[
|
||||
helper.make_tensor_value_info('iteration_num', TensorProto.INT64, [1]),
|
||||
helper.make_tensor_value_info('subgraph_keep_going_in', TensorProto.BOOL, [1]),
|
||||
helper.make_tensor_value_info('loop_state_in', TensorProto.FLOAT, [1])
|
||||
],
|
||||
[
|
||||
helper.make_tensor_value_info('subgraph_keep_going_out', TensorProto.BOOL, [1]),
|
||||
helper.make_tensor_value_info('loop_state_out', TensorProto.FLOAT, [1]),
|
||||
],
|
||||
[
|
||||
helper.make_tensor('sub_graph_initializer', TensorProto.FLOAT, [1], [1.])
|
||||
]
|
||||
)
|
||||
|
||||
# Create the main graph
|
||||
graph_proto = helper.make_graph(
|
||||
[
|
||||
helper.make_node("Loop", ["max_trip_count", "keep_going", "state_var_in"],
|
||||
["state_var_out"], "Loop1", body=body)
|
||||
],
|
||||
"Main_graph",
|
||||
[
|
||||
helper.make_tensor_value_info('state_var_in', TensorProto.FLOAT, [1]),
|
||||
],
|
||||
[
|
||||
helper.make_tensor_value_info('state_var_out', TensorProto.FLOAT, [1]),
|
||||
],
|
||||
[
|
||||
helper.make_tensor('max_trip_count', TensorProto.INT64, [1], [1]),
|
||||
helper.make_tensor('main_graph_initializer', TensorProto.FLOAT, [1], [1.]),
|
||||
helper.make_tensor('keep_going', TensorProto.BOOL, [1], [True]),
|
||||
]
|
||||
)
|
||||
|
||||
model = helper.make_model(graph_proto)
|
||||
onnx.save(model, 'ort_github_issue_4031.onnx')
|
||||
Loading…
Reference in a new issue