diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 1a67e50e0f..4a0153b86d 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -727,27 +727,15 @@ class Graph { ORT_IGNORE_RETURN_VALUE(outer_scope_node_arg_names_.insert(name)); } - /** When programmatically constructing a Graph, explicitly set the order to use for graph inputs when the graph is - resolved. - This will determine the graph input order when the Graph is converted to a GraphProto by Graph::ToGraphProto. - @param inputs NodeArgs that represent graph inputs which need to be explicitly ordered. - Any graph inputs not in this list will be appended to the ordered graph input list, in the order that they were first - used by Nodes (i.e. the order of Node creation implicitly determines the ordering). + /** When programmatically constructing a Graph, explicitly set graph inputs. + @param inputs NodeArgs that represent complete graph inputs which need to be explicitly ordered. @remarks If the Graph was loaded from a GraphProto this has no effect.*/ - void SetInputOrder(const std::vector inputs) { - graph_input_order_ = inputs; - } + void SetInputs(const std::vector inputs); - /** When programmatically constructing a Graph, explicitly set the order to use for graph outputs when the graph is - resolved. - This will determine the graph output order when the Graph is converted to a GraphProto by Graph::ToGraphProto. - @param outputs NodeArgs that represent graph outputs which need to be explicitly ordered. - Any graph outputs not in this list will be appended to the ordered graph output list, in the order that they were first - produced by Nodes (i.e. the order of Node creation implicitly determines the ordering). + /** When programmatically constructing a Graph, explicitly set graph outputs. + @param outputs NodeArgs that represent complete graph outputs which need to be explicitly ordered. @remarks If the Graph was loaded from a GraphProto this has no effect.*/ - void SetOutputOrder(const std::vector outputs) { - graph_output_order_ = outputs; - } + void SetOutputs(const std::vector outputs); /** Returns true if this is a subgraph or fase if it is a high-level graph. */ bool IsSubgraph() const { return parent_graph_ != nullptr; } @@ -945,12 +933,14 @@ class Graph { // Full list of graph inputs. Matches number and order of inputs in the GraphProto. std::vector graph_inputs_including_initializers_; + bool graph_inputs_manually_set_ = false; // Graph inputs excluding initializers. std::vector graph_inputs_excluding_initializers_; // Graph outputs. std::vector graph_outputs_; + bool graph_outputs_manually_set_ = false; // Graph value_info. std::vector value_info_; @@ -975,12 +965,6 @@ class Graph { // NodeArgs that come from outer scope. Used when building a graph so that // these don't get recorded as graph inputs in the GraphProto. std::unordered_set outer_scope_node_arg_names_; - - // Explicit graph input order to be used when constructing a Graph manually. - std::vector graph_input_order_; - - // Explicit graph output order to be used when constructing a Graph manually. - std::vector graph_output_order_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index e8db4acbda..7f5437e1fb 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -940,11 +940,8 @@ Status Graph::BuildConnections(std::vector& outer_scope_node_args_c // and not explicitly listed in the ordered graph outputs (as that implies we should leave it as an output). // If the Graph was loaded from a GraphProto, honor the explicit graph outputs and leave as is. if (!loaded_from_model_file) { - auto in_ordered_graph_outputs = find(graph_output_order_.cbegin(), graph_output_order_.cend(), node_arg); - if (in_ordered_graph_outputs == graph_output_order_.cend()) { - graph_outputs_.erase(std::remove(graph_outputs_.begin(), graph_outputs_.end(), node_arg), - graph_outputs_.end()); - } + graph_outputs_.erase(std::remove(graph_outputs_.begin(), graph_outputs_.end(), node_arg), + graph_outputs_.end()); } } } @@ -2219,10 +2216,8 @@ void Graph::CleanUnusedInitializers() { GSL_SUPPRESS(es .84) // warning about ignoring return value from insert(...) Status Graph::SetGraphInputsOutputs() { - // Reset graph inputs/outputs/value info state. + // Reset graph inputs excluding initializers/value_info. graph_inputs_excluding_initializers_.clear(); - graph_inputs_including_initializers_.clear(); - graph_outputs_.clear(); value_info_.clear(); // Flag indicates that this graph is loaded from model file. @@ -2231,10 +2226,11 @@ Status Graph::SetGraphInputsOutputs() { // and outputs will be inferred. const bool loaded_from_model_file = GraphLoadedFromModelFile(graph_proto_); - // if something is coming from outer scope, consider it already added - std::unordered_set added_input_names{outer_scope_node_arg_names_}; - if (loaded_from_model_file) { + // Reset graph inputs/outputs. + graph_inputs_including_initializers_.clear(); + graph_outputs_.clear(); + // Name to NodeArg mapping of all graph initializers. std::unordered_map graph_initializers; @@ -2302,49 +2298,31 @@ Status Graph::SetGraphInputsOutputs() { } } else { - std::unordered_map output_name_to_node_arg; - std::vector ordered_output_names; + std::unordered_map output_name_to_node_arg_index; + std::vector output_node_args_in_order; - // add any explicitly ordered inputs - for (auto* node_arg : graph_input_order_) { - if (!node_arg || !node_arg->Exists()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid entry in explicitly ordered inputs"); - } - - added_input_names.insert(node_arg->Name()); - graph_inputs_including_initializers_.push_back(node_arg); - if (name_to_initial_tensor_.find(node_arg->Name()) == name_to_initial_tensor_.end()) { - graph_inputs_excluding_initializers_.push_back(node_arg); - } + // if something is coming from outer scope, consider it already added + std::unordered_set added_input_names{outer_scope_node_arg_names_}; + if (!graph_inputs_manually_set_) { + graph_inputs_including_initializers_.clear(); } - // add any explicitly ordered outputs - for (auto* node_arg : graph_output_order_) { - if (!node_arg || !node_arg->Exists()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid entry in explicitly ordered outputs"); - } - output_name_to_node_arg.insert({node_arg->Name(), node_arg}); - ordered_output_names.push_back(node_arg->Name()); + if (!graph_outputs_manually_set_) { + graph_outputs_.clear(); } - // add all other outputs + // Collect all nodes' outputs for (const auto& node : Nodes()) { for (const auto* output_def : node.OutputDefs()) { if (output_def->Exists()) { - auto& name = output_def->Name(); - // check it wasn't in the explicitly ordered outputs - if (output_name_to_node_arg.find(name) == output_name_to_node_arg.cend()) { - output_name_to_node_arg.insert({name, output_def}); - ordered_output_names.push_back(name); - } + output_node_args_in_order.push_back(output_def); + output_name_to_node_arg_index.insert({output_def->Name(), output_node_args_in_order.size() - 1}); } } } // Init graph output args with copy of all node output args. - auto graph_output_args = output_name_to_node_arg; - std::unordered_set inner_nodes; - + auto graph_output_args = output_name_to_node_arg_index; for (const auto& node : Nodes()) { // Go thru all node's inputs. for (const auto* input_arg : node.InputDefs()) { @@ -2353,15 +2331,28 @@ Status Graph::SetGraphInputsOutputs() { continue; } - auto output_arg_iter = output_name_to_node_arg.find(input_arg->Name()); - if (output_name_to_node_arg.end() == output_arg_iter) { + auto output_arg_iter = output_name_to_node_arg_index.find(input_arg->Name()); + if (output_name_to_node_arg_index.end() == output_arg_iter) { // This input arg should be fed when running evaluation. // it should be a graph input. const std::string& name = input_arg->Name(); if (added_input_names.end() == added_input_names.find(name)) { // This graph input has not been added into . - graph_inputs_including_initializers_.push_back(input_arg); - + if (!graph_inputs_manually_set_) { + graph_inputs_including_initializers_.push_back(input_arg); + } else { + // Validation: the must be in graph inputs or initializers when it's manually set. + auto& inputs = GetInputsIncludingInitializers(); + auto iter = std::find(inputs.begin(), inputs.end(), input_arg); + if (inputs.end() == iter) { + // it's not in graph inputs. + auto initializers = GetAllInitializedTensors(); + if (initializers.end() == initializers.find(input_arg->Name())) { + // It's not in graph initializers. + return Status(ONNXRUNTIME, FAIL, input_arg->Name() + " must be either specified in graph inputs or graph initailizers."); + } + } + } if (name_to_initial_tensor_.find(name) == name_to_initial_tensor_.end()) { graph_inputs_excluding_initializers_.push_back(input_arg); } @@ -2377,12 +2368,15 @@ Status Graph::SetGraphInputsOutputs() { } } - // Set graph outputs - auto end = graph_output_args.end(); - for (auto& name : ordered_output_names) { - auto graph_output = graph_output_args.find(name); - if (graph_output != end) { - graph_outputs_.push_back(graph_output->second); + if (!graph_outputs_manually_set_) { + // Set graph outputs in order. + std::vector graph_output_args_index; + for (auto output_arg : graph_output_args) { + graph_output_args_index.push_back(output_arg.second); + } + std::sort(graph_output_args_index.begin(), graph_output_args_index.end()); + for (auto& output_arg_index : graph_output_args_index) { + graph_outputs_.push_back(output_node_args_in_order[output_arg_index]); } } } @@ -2483,6 +2477,25 @@ Status Graph::InlineFunction(Node& node) { return Status::OK(); } +void Graph::SetInputs(const std::vector inputs) { + if (GraphLoadedFromModelFile(graph_proto_)) { + // TODO: add this support. + ORT_THROW("This API is not supported when model is loaded from proto file right now."); + } + + graph_inputs_including_initializers_ = inputs; + graph_inputs_manually_set_ = true; +} + +void Graph::SetOutputs(const std::vector outputs) { + if (GraphLoadedFromModelFile(graph_proto_)) { + // TODO: add this support. + ORT_THROW("This API is not supported when model is loaded from proto file right now."); + } + graph_outputs_ = outputs; + graph_outputs_manually_set_ = true; +} + void Graph::AddFunction(const ONNX_NAMESPACE::FunctionProto* func_proto) { this->model_functions_[func_proto->name()] = func_proto; } diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index 929517e63d..74ba5b025e 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -443,12 +443,6 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained) for (auto i = 0; i < 20; ++i) { map.insert({std::to_string(i), i}); - - std::cout << "Insert " << i << "\n"; - for (auto pair : map) { - std::cout << pair.first << ":" << pair.second << " "; - } - std::cout << "\n"; } // | | @@ -458,10 +452,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained) // | // d (Split) // / \ - // 1 .. 10 - std::unordered_map, std::vector>> - expected_node_name_to_input_output_args; - + // 1 .. 10 TypeProto tensor_int32; tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32); tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); @@ -475,37 +466,36 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained) auto& output_arg_c = graph.GetOrCreateNodeArg("node_c_out_1", &tensor_int32); std::vector split_outputs; + std::vector graph_outputs; for (int i = 0; i < 10; ++i) { - split_outputs.push_back(&graph.GetOrCreateNodeArg("node_d_out_" + std::to_string(i + 1), &tensor_int32)); + auto arg = &graph.GetOrCreateNodeArg("node_d_out_" + std::to_string(i + 1), &tensor_int32); + split_outputs.push_back(arg); + graph_outputs.push_back(arg); } - + std::reverse(graph_outputs.begin(), graph_outputs.end()); std::vector inputs; std::vector outputs; inputs.push_back(&input_arg_a); outputs.push_back(&output_arg_a); - expected_node_name_to_input_output_args["a"] = {inputs, outputs}; graph.AddNode("a", "Identity_Fake", "a", inputs, outputs); inputs.resize(2); inputs[0] = &output_arg_b; inputs[1] = &output_arg_a; outputs[0] = &output_arg_c; - expected_node_name_to_input_output_args["c"] = {inputs, outputs}; graph.AddNode("c", "Merge_Fake", "c", inputs, outputs); // deliberately add 'b' after 'c' to mix up the inputs as well inputs.resize(1); inputs[0] = &input_arg_b; outputs[0] = &output_arg_b; - expected_node_name_to_input_output_args["b"] = {inputs, outputs}; graph.AddNode("b", "Identity_Fake", "b", inputs, outputs); inputs[0] = &output_arg_c; - expected_node_name_to_input_output_args["d"] = {inputs, split_outputs}; graph.AddNode("d", "Split_Fake", "d", inputs, split_outputs); - auto validate_inputs_outputs = [&split_outputs](const Graph& graph) { + auto validate_inputs_outputs = [&graph_outputs](const Graph& graph) { auto inputs = graph.GetInputs(); auto outputs = graph.GetOutputs(); @@ -516,10 +506,11 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained) ASSERT_TRUE(outputs.size() == 10); for (int i = 0; i < 10; ++i) { - EXPECT_TRUE(split_outputs[i]->Name() == outputs[i]->Name()); + EXPECT_TRUE(graph_outputs[i]->Name() == outputs[i]->Name()); } }; - + graph.SetInputs({&input_arg_a, &input_arg_b}); + graph.SetOutputs(graph_outputs); auto status = graph.Resolve(); ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); diff --git a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc index bdbe3515b9..8ecbfb9656 100644 --- a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc @@ -268,8 +268,8 @@ static const ONNX_NAMESPACE::GraphProto CreateSubgraph(const RunOptions& options } } - graph.SetInputOrder({&iter_num_in, &cond_in, &loop_var_0_in, &loop_var_1_in}); - graph.SetOutputOrder({cond_out, loop_var_0_out, loop_var_1_out, loop_out_0}); + graph.SetInputs({&iter_num_in, &cond_in, &loop_var_0_in, &loop_var_1_in}); + graph.SetOutputs({cond_out, loop_var_0_out, loop_var_1_out, loop_out_0}); // optional input backed by an initializer to make sure that's handled too. // we expect that Graph::InferAndVerifySubgraphTypes will be able to ignore the optional input if not provided @@ -447,8 +447,8 @@ TEST(Loop, InfiniteLoopTermination) { graph.AddNode("loop_var_out", "Identity", "Forward outer_scope_0 to loop_var_0_out", inputs, outputs); } - graph.SetInputOrder({&iter_num_in, &cond_in, &outer_scope_0}); - graph.SetOutputOrder({&cond_out, &loop_var_0_out}); + graph.SetInputs({&iter_num_in, &cond_in, &outer_scope_0}); + graph.SetOutputs({&cond_out, &loop_var_0_out}); auto status = graph.Resolve(); EXPECT_EQ(status, Status::OK()); diff --git a/onnxruntime/test/providers/cpu/controlflow/scan_test.cc b/onnxruntime/test/providers/cpu/controlflow/scan_test.cc index 8d4991ff90..1e9e348e4c 100644 --- a/onnxruntime/test/providers/cpu/controlflow/scan_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/scan_test.cc @@ -1046,8 +1046,8 @@ void MixedTypeInputs(bool is_v8) { graph.AddNode("node3", "Identity", "Copy scan_in_1 to state_out_1", {&scan_in_1}, {&state_out_1}); graph.AddNode("node4", "Identity", "Copy scan_in_2 to state_out_2", {&scan_in_2}, {&state_out_2}); - graph.SetInputOrder({&state_in_1, &state_in_2, &scan_in_1, &scan_in_2}); - graph.SetOutputOrder({&state_out_1, &state_out_2, &scan_out_1, &scan_out_2}); + graph.SetInputs({&state_in_1, &state_in_2, &scan_in_1, &scan_in_2}); + graph.SetOutputs({&state_out_1, &state_out_2, &scan_out_1, &scan_out_2}); auto status = graph.Resolve(); EXPECT_EQ(status, Status::OK()); @@ -1108,8 +1108,8 @@ void UnknownDimInSubgraphOutput(bool is_v8) { graph.AddNode("node1", "Identity", "Copy state_in_1 to scan_out_1", {&state_in_1}, {&scan_out_1}); graph.AddNode("node2", "Identity", "Copy scan_in_1 to state_out_1", {&scan_in_1}, {&state_out_1}); - graph.SetInputOrder({&state_in_1, &scan_in_1}); - graph.SetOutputOrder({&state_out_1, &scan_out_1}); + graph.SetInputs({&state_in_1, &scan_in_1}); + graph.SetOutputs({&state_out_1, &scan_out_1}); auto status = graph.Resolve(); EXPECT_EQ(status, Status::OK());