mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
allow users to set graph inputs and outputs fully. (#905)
* allow users to set graph inputs and outputs fully. * update * update the comments of the APIs * update * remove commented-out codes. * fix test failures. * fix comments. * adding more check to throw not support exception right now.
This commit is contained in:
parent
bb58806872
commit
f39a8d1f59
5 changed files with 90 additions and 102 deletions
|
|
@ -727,27 +727,15 @@ class Graph {
|
|||
ORT_IGNORE_RETURN_VALUE(outer_scope_node_arg_names_.insert(name));
|
||||
}
|
||||
|
||||
/** When programmatically constructing a Graph, explicitly set the order to use for graph inputs when the graph is
|
||||
resolved.
|
||||
This will determine the graph input order when the Graph is converted to a GraphProto by Graph::ToGraphProto.
|
||||
@param inputs NodeArgs that represent graph inputs which need to be explicitly ordered.
|
||||
Any graph inputs not in this list will be appended to the ordered graph input list, in the order that they were first
|
||||
used by Nodes (i.e. the order of Node creation implicitly determines the ordering).
|
||||
/** When programmatically constructing a Graph, explicitly set graph inputs.
|
||||
@param inputs NodeArgs that represent complete graph inputs which need to be explicitly ordered.
|
||||
@remarks If the Graph was loaded from a GraphProto this has no effect.*/
|
||||
void SetInputOrder(const std::vector<const NodeArg*> inputs) {
|
||||
graph_input_order_ = inputs;
|
||||
}
|
||||
void SetInputs(const std::vector<const NodeArg*> inputs);
|
||||
|
||||
/** When programmatically constructing a Graph, explicitly set the order to use for graph outputs when the graph is
|
||||
resolved.
|
||||
This will determine the graph output order when the Graph is converted to a GraphProto by Graph::ToGraphProto.
|
||||
@param outputs NodeArgs that represent graph outputs which need to be explicitly ordered.
|
||||
Any graph outputs not in this list will be appended to the ordered graph output list, in the order that they were first
|
||||
produced by Nodes (i.e. the order of Node creation implicitly determines the ordering).
|
||||
/** When programmatically constructing a Graph, explicitly set graph outputs.
|
||||
@param outputs NodeArgs that represent complete graph outputs which need to be explicitly ordered.
|
||||
@remarks If the Graph was loaded from a GraphProto this has no effect.*/
|
||||
void SetOutputOrder(const std::vector<const NodeArg*> outputs) {
|
||||
graph_output_order_ = outputs;
|
||||
}
|
||||
void SetOutputs(const std::vector<const NodeArg*> outputs);
|
||||
|
||||
/** Returns true if this is a subgraph or fase if it is a high-level graph. */
|
||||
bool IsSubgraph() const { return parent_graph_ != nullptr; }
|
||||
|
|
@ -945,12 +933,14 @@ class Graph {
|
|||
|
||||
// Full list of graph inputs. Matches number and order of inputs in the GraphProto.
|
||||
std::vector<const NodeArg*> graph_inputs_including_initializers_;
|
||||
bool graph_inputs_manually_set_ = false;
|
||||
|
||||
// Graph inputs excluding initializers.
|
||||
std::vector<const NodeArg*> graph_inputs_excluding_initializers_;
|
||||
|
||||
// Graph outputs.
|
||||
std::vector<const NodeArg*> graph_outputs_;
|
||||
bool graph_outputs_manually_set_ = false;
|
||||
|
||||
// Graph value_info.
|
||||
std::vector<const NodeArg*> value_info_;
|
||||
|
|
@ -975,12 +965,6 @@ class Graph {
|
|||
// NodeArgs that come from outer scope. Used when building a graph so that
|
||||
// these don't get recorded as graph inputs in the GraphProto.
|
||||
std::unordered_set<std::string> outer_scope_node_arg_names_;
|
||||
|
||||
// Explicit graph input order to be used when constructing a Graph manually.
|
||||
std::vector<const NodeArg*> graph_input_order_;
|
||||
|
||||
// Explicit graph output order to be used when constructing a Graph manually.
|
||||
std::vector<const NodeArg*> graph_output_order_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -940,11 +940,8 @@ Status Graph::BuildConnections(std::vector<std::string>& outer_scope_node_args_c
|
|||
// and not explicitly listed in the ordered graph outputs (as that implies we should leave it as an output).
|
||||
// If the Graph was loaded from a GraphProto, honor the explicit graph outputs and leave as is.
|
||||
if (!loaded_from_model_file) {
|
||||
auto in_ordered_graph_outputs = find(graph_output_order_.cbegin(), graph_output_order_.cend(), node_arg);
|
||||
if (in_ordered_graph_outputs == graph_output_order_.cend()) {
|
||||
graph_outputs_.erase(std::remove(graph_outputs_.begin(), graph_outputs_.end(), node_arg),
|
||||
graph_outputs_.end());
|
||||
}
|
||||
graph_outputs_.erase(std::remove(graph_outputs_.begin(), graph_outputs_.end(), node_arg),
|
||||
graph_outputs_.end());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2219,10 +2216,8 @@ void Graph::CleanUnusedInitializers() {
|
|||
|
||||
GSL_SUPPRESS(es .84) // warning about ignoring return value from insert(...)
|
||||
Status Graph::SetGraphInputsOutputs() {
|
||||
// Reset graph inputs/outputs/value info state.
|
||||
// Reset graph inputs excluding initializers/value_info.
|
||||
graph_inputs_excluding_initializers_.clear();
|
||||
graph_inputs_including_initializers_.clear();
|
||||
graph_outputs_.clear();
|
||||
value_info_.clear();
|
||||
|
||||
// Flag indicates that this graph is loaded from model file.
|
||||
|
|
@ -2231,10 +2226,11 @@ Status Graph::SetGraphInputsOutputs() {
|
|||
// and outputs will be inferred.
|
||||
const bool loaded_from_model_file = GraphLoadedFromModelFile(graph_proto_);
|
||||
|
||||
// if something is coming from outer scope, consider it already added
|
||||
std::unordered_set<std::string> added_input_names{outer_scope_node_arg_names_};
|
||||
|
||||
if (loaded_from_model_file) {
|
||||
// Reset graph inputs/outputs.
|
||||
graph_inputs_including_initializers_.clear();
|
||||
graph_outputs_.clear();
|
||||
|
||||
// Name to NodeArg mapping of all graph initializers.
|
||||
std::unordered_map<std::string, const NodeArg*> graph_initializers;
|
||||
|
||||
|
|
@ -2302,49 +2298,31 @@ Status Graph::SetGraphInputsOutputs() {
|
|||
}
|
||||
|
||||
} else {
|
||||
std::unordered_map<std::string, const NodeArg*> output_name_to_node_arg;
|
||||
std::vector<std::string> ordered_output_names;
|
||||
std::unordered_map<std::string, size_t> output_name_to_node_arg_index;
|
||||
std::vector<const NodeArg*> output_node_args_in_order;
|
||||
|
||||
// add any explicitly ordered inputs
|
||||
for (auto* node_arg : graph_input_order_) {
|
||||
if (!node_arg || !node_arg->Exists()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid entry in explicitly ordered inputs");
|
||||
}
|
||||
|
||||
added_input_names.insert(node_arg->Name());
|
||||
graph_inputs_including_initializers_.push_back(node_arg);
|
||||
if (name_to_initial_tensor_.find(node_arg->Name()) == name_to_initial_tensor_.end()) {
|
||||
graph_inputs_excluding_initializers_.push_back(node_arg);
|
||||
}
|
||||
// if something is coming from outer scope, consider it already added
|
||||
std::unordered_set<std::string> added_input_names{outer_scope_node_arg_names_};
|
||||
if (!graph_inputs_manually_set_) {
|
||||
graph_inputs_including_initializers_.clear();
|
||||
}
|
||||
|
||||
// add any explicitly ordered outputs
|
||||
for (auto* node_arg : graph_output_order_) {
|
||||
if (!node_arg || !node_arg->Exists()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid entry in explicitly ordered outputs");
|
||||
}
|
||||
output_name_to_node_arg.insert({node_arg->Name(), node_arg});
|
||||
ordered_output_names.push_back(node_arg->Name());
|
||||
if (!graph_outputs_manually_set_) {
|
||||
graph_outputs_.clear();
|
||||
}
|
||||
|
||||
// add all other outputs
|
||||
// Collect all nodes' outputs
|
||||
for (const auto& node : Nodes()) {
|
||||
for (const auto* output_def : node.OutputDefs()) {
|
||||
if (output_def->Exists()) {
|
||||
auto& name = output_def->Name();
|
||||
// check it wasn't in the explicitly ordered outputs
|
||||
if (output_name_to_node_arg.find(name) == output_name_to_node_arg.cend()) {
|
||||
output_name_to_node_arg.insert({name, output_def});
|
||||
ordered_output_names.push_back(name);
|
||||
}
|
||||
output_node_args_in_order.push_back(output_def);
|
||||
output_name_to_node_arg_index.insert({output_def->Name(), output_node_args_in_order.size() - 1});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Init graph output args with copy of all node output args.
|
||||
auto graph_output_args = output_name_to_node_arg;
|
||||
std::unordered_set<Node*> inner_nodes;
|
||||
|
||||
auto graph_output_args = output_name_to_node_arg_index;
|
||||
for (const auto& node : Nodes()) {
|
||||
// Go thru all node's inputs.
|
||||
for (const auto* input_arg : node.InputDefs()) {
|
||||
|
|
@ -2353,15 +2331,28 @@ Status Graph::SetGraphInputsOutputs() {
|
|||
continue;
|
||||
}
|
||||
|
||||
auto output_arg_iter = output_name_to_node_arg.find(input_arg->Name());
|
||||
if (output_name_to_node_arg.end() == output_arg_iter) {
|
||||
auto output_arg_iter = output_name_to_node_arg_index.find(input_arg->Name());
|
||||
if (output_name_to_node_arg_index.end() == output_arg_iter) {
|
||||
// This input arg should be fed when running evaluation.
|
||||
// it should be a graph input.
|
||||
const std::string& name = input_arg->Name();
|
||||
if (added_input_names.end() == added_input_names.find(name)) {
|
||||
// This graph input has not been added into <graph_inputs_>.
|
||||
graph_inputs_including_initializers_.push_back(input_arg);
|
||||
|
||||
if (!graph_inputs_manually_set_) {
|
||||
graph_inputs_including_initializers_.push_back(input_arg);
|
||||
} else {
|
||||
// Validation: the <input_arg> must be in graph inputs or initializers when it's manually set.
|
||||
auto& inputs = GetInputsIncludingInitializers();
|
||||
auto iter = std::find(inputs.begin(), inputs.end(), input_arg);
|
||||
if (inputs.end() == iter) {
|
||||
// it's not in graph inputs.
|
||||
auto initializers = GetAllInitializedTensors();
|
||||
if (initializers.end() == initializers.find(input_arg->Name())) {
|
||||
// It's not in graph initializers.
|
||||
return Status(ONNXRUNTIME, FAIL, input_arg->Name() + " must be either specified in graph inputs or graph initailizers.");
|
||||
}
|
||||
}
|
||||
}
|
||||
if (name_to_initial_tensor_.find(name) == name_to_initial_tensor_.end()) {
|
||||
graph_inputs_excluding_initializers_.push_back(input_arg);
|
||||
}
|
||||
|
|
@ -2377,12 +2368,15 @@ Status Graph::SetGraphInputsOutputs() {
|
|||
}
|
||||
}
|
||||
|
||||
// Set graph outputs
|
||||
auto end = graph_output_args.end();
|
||||
for (auto& name : ordered_output_names) {
|
||||
auto graph_output = graph_output_args.find(name);
|
||||
if (graph_output != end) {
|
||||
graph_outputs_.push_back(graph_output->second);
|
||||
if (!graph_outputs_manually_set_) {
|
||||
// Set graph outputs in order.
|
||||
std::vector<size_t> graph_output_args_index;
|
||||
for (auto output_arg : graph_output_args) {
|
||||
graph_output_args_index.push_back(output_arg.second);
|
||||
}
|
||||
std::sort(graph_output_args_index.begin(), graph_output_args_index.end());
|
||||
for (auto& output_arg_index : graph_output_args_index) {
|
||||
graph_outputs_.push_back(output_node_args_in_order[output_arg_index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2483,6 +2477,25 @@ Status Graph::InlineFunction(Node& node) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
void Graph::SetInputs(const std::vector<const NodeArg*> inputs) {
|
||||
if (GraphLoadedFromModelFile(graph_proto_)) {
|
||||
// TODO: add this support.
|
||||
ORT_THROW("This API is not supported when model is loaded from proto file right now.");
|
||||
}
|
||||
|
||||
graph_inputs_including_initializers_ = inputs;
|
||||
graph_inputs_manually_set_ = true;
|
||||
}
|
||||
|
||||
void Graph::SetOutputs(const std::vector<const NodeArg*> outputs) {
|
||||
if (GraphLoadedFromModelFile(graph_proto_)) {
|
||||
// TODO: add this support.
|
||||
ORT_THROW("This API is not supported when model is loaded from proto file right now.");
|
||||
}
|
||||
graph_outputs_ = outputs;
|
||||
graph_outputs_manually_set_ = true;
|
||||
}
|
||||
|
||||
void Graph::AddFunction(const ONNX_NAMESPACE::FunctionProto* func_proto) {
|
||||
this->model_functions_[func_proto->name()] = func_proto;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -443,12 +443,6 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained)
|
|||
|
||||
for (auto i = 0; i < 20; ++i) {
|
||||
map.insert({std::to_string(i), i});
|
||||
|
||||
std::cout << "Insert " << i << "\n";
|
||||
for (auto pair : map) {
|
||||
std::cout << pair.first << ":" << pair.second << " ";
|
||||
}
|
||||
std::cout << "\n";
|
||||
}
|
||||
|
||||
// | |
|
||||
|
|
@ -458,10 +452,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained)
|
|||
// |
|
||||
// d (Split)
|
||||
// / \
|
||||
// 1 .. 10
|
||||
std::unordered_map<std::string, std::pair<std::vector<NodeArg*>, std::vector<NodeArg*>>>
|
||||
expected_node_name_to_input_output_args;
|
||||
|
||||
// 1 .. 10
|
||||
TypeProto tensor_int32;
|
||||
tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32);
|
||||
tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
|
|
@ -475,37 +466,36 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained)
|
|||
auto& output_arg_c = graph.GetOrCreateNodeArg("node_c_out_1", &tensor_int32);
|
||||
|
||||
std::vector<NodeArg*> split_outputs;
|
||||
std::vector<const NodeArg*> graph_outputs;
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
split_outputs.push_back(&graph.GetOrCreateNodeArg("node_d_out_" + std::to_string(i + 1), &tensor_int32));
|
||||
auto arg = &graph.GetOrCreateNodeArg("node_d_out_" + std::to_string(i + 1), &tensor_int32);
|
||||
split_outputs.push_back(arg);
|
||||
graph_outputs.push_back(arg);
|
||||
}
|
||||
|
||||
std::reverse(graph_outputs.begin(), graph_outputs.end());
|
||||
std::vector<NodeArg*> inputs;
|
||||
std::vector<NodeArg*> outputs;
|
||||
|
||||
inputs.push_back(&input_arg_a);
|
||||
outputs.push_back(&output_arg_a);
|
||||
expected_node_name_to_input_output_args["a"] = {inputs, outputs};
|
||||
graph.AddNode("a", "Identity_Fake", "a", inputs, outputs);
|
||||
|
||||
inputs.resize(2);
|
||||
inputs[0] = &output_arg_b;
|
||||
inputs[1] = &output_arg_a;
|
||||
outputs[0] = &output_arg_c;
|
||||
expected_node_name_to_input_output_args["c"] = {inputs, outputs};
|
||||
graph.AddNode("c", "Merge_Fake", "c", inputs, outputs);
|
||||
|
||||
// deliberately add 'b' after 'c' to mix up the inputs as well
|
||||
inputs.resize(1);
|
||||
inputs[0] = &input_arg_b;
|
||||
outputs[0] = &output_arg_b;
|
||||
expected_node_name_to_input_output_args["b"] = {inputs, outputs};
|
||||
graph.AddNode("b", "Identity_Fake", "b", inputs, outputs);
|
||||
|
||||
inputs[0] = &output_arg_c;
|
||||
expected_node_name_to_input_output_args["d"] = {inputs, split_outputs};
|
||||
graph.AddNode("d", "Split_Fake", "d", inputs, split_outputs);
|
||||
|
||||
auto validate_inputs_outputs = [&split_outputs](const Graph& graph) {
|
||||
auto validate_inputs_outputs = [&graph_outputs](const Graph& graph) {
|
||||
auto inputs = graph.GetInputs();
|
||||
auto outputs = graph.GetOutputs();
|
||||
|
||||
|
|
@ -516,10 +506,11 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained)
|
|||
|
||||
ASSERT_TRUE(outputs.size() == 10);
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
EXPECT_TRUE(split_outputs[i]->Name() == outputs[i]->Name());
|
||||
EXPECT_TRUE(graph_outputs[i]->Name() == outputs[i]->Name());
|
||||
}
|
||||
};
|
||||
|
||||
graph.SetInputs({&input_arg_a, &input_arg_b});
|
||||
graph.SetOutputs(graph_outputs);
|
||||
auto status = graph.Resolve();
|
||||
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
||||
|
|
|
|||
|
|
@ -268,8 +268,8 @@ static const ONNX_NAMESPACE::GraphProto CreateSubgraph(const RunOptions& options
|
|||
}
|
||||
}
|
||||
|
||||
graph.SetInputOrder({&iter_num_in, &cond_in, &loop_var_0_in, &loop_var_1_in});
|
||||
graph.SetOutputOrder({cond_out, loop_var_0_out, loop_var_1_out, loop_out_0});
|
||||
graph.SetInputs({&iter_num_in, &cond_in, &loop_var_0_in, &loop_var_1_in});
|
||||
graph.SetOutputs({cond_out, loop_var_0_out, loop_var_1_out, loop_out_0});
|
||||
|
||||
// optional input backed by an initializer to make sure that's handled too.
|
||||
// we expect that Graph::InferAndVerifySubgraphTypes will be able to ignore the optional input if not provided
|
||||
|
|
@ -447,8 +447,8 @@ TEST(Loop, InfiniteLoopTermination) {
|
|||
graph.AddNode("loop_var_out", "Identity", "Forward outer_scope_0 to loop_var_0_out", inputs, outputs);
|
||||
}
|
||||
|
||||
graph.SetInputOrder({&iter_num_in, &cond_in, &outer_scope_0});
|
||||
graph.SetOutputOrder({&cond_out, &loop_var_0_out});
|
||||
graph.SetInputs({&iter_num_in, &cond_in, &outer_scope_0});
|
||||
graph.SetOutputs({&cond_out, &loop_var_0_out});
|
||||
|
||||
auto status = graph.Resolve();
|
||||
EXPECT_EQ(status, Status::OK());
|
||||
|
|
|
|||
|
|
@ -1046,8 +1046,8 @@ void MixedTypeInputs(bool is_v8) {
|
|||
graph.AddNode("node3", "Identity", "Copy scan_in_1 to state_out_1", {&scan_in_1}, {&state_out_1});
|
||||
graph.AddNode("node4", "Identity", "Copy scan_in_2 to state_out_2", {&scan_in_2}, {&state_out_2});
|
||||
|
||||
graph.SetInputOrder({&state_in_1, &state_in_2, &scan_in_1, &scan_in_2});
|
||||
graph.SetOutputOrder({&state_out_1, &state_out_2, &scan_out_1, &scan_out_2});
|
||||
graph.SetInputs({&state_in_1, &state_in_2, &scan_in_1, &scan_in_2});
|
||||
graph.SetOutputs({&state_out_1, &state_out_2, &scan_out_1, &scan_out_2});
|
||||
|
||||
auto status = graph.Resolve();
|
||||
EXPECT_EQ(status, Status::OK());
|
||||
|
|
@ -1108,8 +1108,8 @@ void UnknownDimInSubgraphOutput(bool is_v8) {
|
|||
graph.AddNode("node1", "Identity", "Copy state_in_1 to scan_out_1", {&state_in_1}, {&scan_out_1});
|
||||
graph.AddNode("node2", "Identity", "Copy scan_in_1 to state_out_1", {&scan_in_1}, {&state_out_1});
|
||||
|
||||
graph.SetInputOrder({&state_in_1, &scan_in_1});
|
||||
graph.SetOutputOrder({&state_out_1, &scan_out_1});
|
||||
graph.SetInputs({&state_in_1, &scan_in_1});
|
||||
graph.SetOutputs({&state_out_1, &scan_out_1});
|
||||
|
||||
auto status = graph.Resolve();
|
||||
EXPECT_EQ(status, Status::OK());
|
||||
|
|
|
|||
Loading…
Reference in a new issue