diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 9b144b89fe..89e59c0850 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1124,6 +1124,8 @@ Status Graph::BuildConnections(std::unordered_set& outer_scope_node const std::unordered_set& outer_scope_node_args = resolve_context_.outer_scope_node_args; std::unordered_set inner_nodes; + std::unordered_set node_args_consumed_by_subgraphs; + // recurse into subgraphs first so we can update any nodes in this graph that are used by those subgraphs if (!resolve_context_.nodes_with_subgraphs.empty()) { for (auto* node : resolve_context_.nodes_with_subgraphs) { @@ -1157,6 +1159,12 @@ Status Graph::BuildConnections(std::unordered_set& outer_scope_node "This is an invalid model. Failed to find NodeArg in all parent graphs. Name=", node_arg_name, " Graph may not conform to the ONNX spec and contain initializers that are not graph inputs."); } + } else { + // this value may be produced by this graph, or it could still be coming from a parent graph if it + // is also directly consumed at this level as we create a NodeArg for all Node inputs in this graph. + // due to that we need to check the outputs from this level to determine if it is an outer scope value. + // we don't have that info yet so store and check before returning from BuildConnections + ORT_IGNORE_RETURN_VALUE(node_args_consumed_by_subgraphs.insert(node_arg_name)); } // add it to the Node's list of implicit inputs @@ -1178,8 +1186,9 @@ Status Graph::BuildConnections(std::unordered_set& outer_scope_node inner_nodes.insert(&output_node); - // If this Graph was built manually, remove the implicit input from the graph outputs if it is present there - // and not explicitly listed in the ordered graph outputs (as that implies we should leave it as an output). + // If this Graph was built manually, remove the implicit input from the graph outputs + // if it is present there and not explicitly listed in the ordered graph outputs + // (as that implies we should leave it as an output). // If the Graph was loaded from a GraphProto, honor the explicit graph outputs and leave as is. if (!is_loaded_from_model_file_) { graph_outputs_.erase(std::remove(graph_outputs_.begin(), graph_outputs_.end(), node_arg), @@ -1252,8 +1261,17 @@ Status Graph::BuildConnections(std::unordered_set& outer_scope_node } } + // finally check any node args consumed by subgraphs to see if they're available locally. + // if not we add them to the list of outer scope values consumed. + for (const auto& name : node_args_consumed_by_subgraphs) { + if (node_arg_to_producer_node_.count(name) == 0 && + resolve_context_.inputs_and_initializers.find(name) == resolve_context_.inputs_and_initializers.cend()) { + ORT_IGNORE_RETURN_VALUE(outer_scope_node_args_consumed.insert(name)); + } + } + return Status::OK(); -} // namespace onnxruntime +} void Graph::ReverseDFSFrom(const std::vector& from, const std::function& enter, diff --git a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc index 41d29d97aa..868287474f 100644 --- a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc @@ -929,6 +929,37 @@ TEST(Loop, PassThroughSubgraphInputNoTypeOrShape) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +TEST(Loop, BugFixIssue4031_implicit_input_handling) { + SessionOptions so; + so.graph_optimization_level = TransformerLevel::Level2; // we need constant folding to run + InferenceSession session_object{so, GetEnvironment()}; + static constexpr const ORTCHAR_T* MODEL_URI = ORT_TSTR("testdata/ort_github_issue_4031.onnx"); + + ASSERT_STATUS_OK(session_object.Load(MODEL_URI)); + ASSERT_STATUS_OK(session_object.Initialize()); + + onnxruntime::RunOptions run_options; + run_options.run_tag = "BugFixIssue4031_implicit_input_handling"; + + // prepare inputs + OrtValue ml_value; + CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1}, {123.f}, + &ml_value); + NameMLValMap feeds; + feeds.insert(std::make_pair("state_var_in", ml_value)); + + // prepare outputs + std::vector output_names{"state_var_out"}; + std::vector fetches; + + // Now run + ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches)); + + const auto& output = fetches[0].Get(); + ASSERT_TRUE(output.Shape().Size() == 1); + ASSERT_TRUE(output.Data()[0] == 125.f); +} + #ifdef USE_CUDA // test that when part of the subgraph run on CUDA it executes successfully TEST(Loop, MixedExecutionProviders) { diff --git a/onnxruntime/test/testdata/ort_github_issue_4031.onnx b/onnxruntime/test/testdata/ort_github_issue_4031.onnx new file mode 100644 index 0000000000..d75c755816 Binary files /dev/null and b/onnxruntime/test/testdata/ort_github_issue_4031.onnx differ diff --git a/onnxruntime/test/testdata/ort_github_issue_4031.py b/onnxruntime/test/testdata/ort_github_issue_4031.py new file mode 100644 index 0000000000..f991b41024 --- /dev/null +++ b/onnxruntime/test/testdata/ort_github_issue_4031.py @@ -0,0 +1,66 @@ +import onnx +from onnx import helper +from onnx import TensorProto + +if_body = helper.make_graph( + [ + # need to use main_graph_initializer in a way that can't be constant folded + helper.make_node("Add", ["state_var_in", "main_graph_initializer"], ["add_out"], "If_add"), + helper.make_node("Cast", ["add_out"], ["output"], to=TensorProto.BOOL), + ], + "if_branch_body", + [ + # no explicit inputs + ], + [ + helper.make_tensor_value_info('output', TensorProto.BOOL, [1]), # how is this getting a type of float? + ]) + +# Loop body graph with If node and usage of main_graph_initializer on this level +body = helper.make_graph( + [ + # Add node that can be constant folded. Creates NodeArg when created but that implicit usage of an outer scope + # value main_graph_initializer goes away after constant folding + helper.make_node("Add", ["sub_graph_initializer", "main_graph_initializer"], ["initializer_sum"], "Add1"), + helper.make_node("Add", ["initializer_sum", "loop_state_in"], ["loop_state_out"], "Add2"), + # If node to create usage of main_graph_initializer another level down + helper.make_node("If", ["subgraph_keep_going_in"], ["subgraph_keep_going_out"], "If1", + then_branch=if_body, else_branch=if_body), + ], + "Loop_body", + [ + helper.make_tensor_value_info('iteration_num', TensorProto.INT64, [1]), + helper.make_tensor_value_info('subgraph_keep_going_in', TensorProto.BOOL, [1]), + helper.make_tensor_value_info('loop_state_in', TensorProto.FLOAT, [1]) + ], + [ + helper.make_tensor_value_info('subgraph_keep_going_out', TensorProto.BOOL, [1]), + helper.make_tensor_value_info('loop_state_out', TensorProto.FLOAT, [1]), + ], + [ + helper.make_tensor('sub_graph_initializer', TensorProto.FLOAT, [1], [1.]) + ] +) + +# Create the main graph +graph_proto = helper.make_graph( + [ + helper.make_node("Loop", ["max_trip_count", "keep_going", "state_var_in"], + ["state_var_out"], "Loop1", body=body) + ], + "Main_graph", + [ + helper.make_tensor_value_info('state_var_in', TensorProto.FLOAT, [1]), + ], + [ + helper.make_tensor_value_info('state_var_out', TensorProto.FLOAT, [1]), + ], + [ + helper.make_tensor('max_trip_count', TensorProto.INT64, [1], [1]), + helper.make_tensor('main_graph_initializer', TensorProto.FLOAT, [1], [1.]), + helper.make_tensor('keep_going', TensorProto.BOOL, [1], [True]), + ] +) + +model = helper.make_model(graph_proto) +onnx.save(model, 'ort_github_issue_4031.onnx') \ No newline at end of file