allow users to set graph inputs and outputs fully. (#905)

* allow users to set graph inputs and outputs fully.

* update

* update the comments of the APIs

* update

* remove commented-out codes.

* fix test failures.

* fix comments.

* adding more check to throw not support exception right now.
This commit is contained in:
Ke Zhang 2019-04-29 15:58:39 +08:00 committed by GitHub
parent bb58806872
commit f39a8d1f59
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 90 additions and 102 deletions

View file

@ -727,27 +727,15 @@ class Graph {
ORT_IGNORE_RETURN_VALUE(outer_scope_node_arg_names_.insert(name));
}
/** When programmatically constructing a Graph, explicitly set the order to use for graph inputs when the graph is
resolved.
This will determine the graph input order when the Graph is converted to a GraphProto by Graph::ToGraphProto.
@param inputs NodeArgs that represent graph inputs which need to be explicitly ordered.
Any graph inputs not in this list will be appended to the ordered graph input list, in the order that they were first
used by Nodes (i.e. the order of Node creation implicitly determines the ordering).
/** When programmatically constructing a Graph, explicitly set graph inputs.
@param inputs NodeArgs that represent complete graph inputs which need to be explicitly ordered.
@remarks If the Graph was loaded from a GraphProto this has no effect.*/
void SetInputOrder(const std::vector<const NodeArg*> inputs) {
graph_input_order_ = inputs;
}
void SetInputs(const std::vector<const NodeArg*> inputs);
/** When programmatically constructing a Graph, explicitly set the order to use for graph outputs when the graph is
resolved.
This will determine the graph output order when the Graph is converted to a GraphProto by Graph::ToGraphProto.
@param outputs NodeArgs that represent graph outputs which need to be explicitly ordered.
Any graph outputs not in this list will be appended to the ordered graph output list, in the order that they were first
produced by Nodes (i.e. the order of Node creation implicitly determines the ordering).
/** When programmatically constructing a Graph, explicitly set graph outputs.
@param outputs NodeArgs that represent complete graph outputs which need to be explicitly ordered.
@remarks If the Graph was loaded from a GraphProto this has no effect.*/
void SetOutputOrder(const std::vector<const NodeArg*> outputs) {
graph_output_order_ = outputs;
}
void SetOutputs(const std::vector<const NodeArg*> outputs);
/** Returns true if this is a subgraph or fase if it is a high-level graph. */
bool IsSubgraph() const { return parent_graph_ != nullptr; }
@ -945,12 +933,14 @@ class Graph {
// Full list of graph inputs. Matches number and order of inputs in the GraphProto.
std::vector<const NodeArg*> graph_inputs_including_initializers_;
bool graph_inputs_manually_set_ = false;
// Graph inputs excluding initializers.
std::vector<const NodeArg*> graph_inputs_excluding_initializers_;
// Graph outputs.
std::vector<const NodeArg*> graph_outputs_;
bool graph_outputs_manually_set_ = false;
// Graph value_info.
std::vector<const NodeArg*> value_info_;
@ -975,12 +965,6 @@ class Graph {
// NodeArgs that come from outer scope. Used when building a graph so that
// these don't get recorded as graph inputs in the GraphProto.
std::unordered_set<std::string> outer_scope_node_arg_names_;
// Explicit graph input order to be used when constructing a Graph manually.
std::vector<const NodeArg*> graph_input_order_;
// Explicit graph output order to be used when constructing a Graph manually.
std::vector<const NodeArg*> graph_output_order_;
};
} // namespace onnxruntime

View file

@ -940,11 +940,8 @@ Status Graph::BuildConnections(std::vector<std::string>& outer_scope_node_args_c
// and not explicitly listed in the ordered graph outputs (as that implies we should leave it as an output).
// If the Graph was loaded from a GraphProto, honor the explicit graph outputs and leave as is.
if (!loaded_from_model_file) {
auto in_ordered_graph_outputs = find(graph_output_order_.cbegin(), graph_output_order_.cend(), node_arg);
if (in_ordered_graph_outputs == graph_output_order_.cend()) {
graph_outputs_.erase(std::remove(graph_outputs_.begin(), graph_outputs_.end(), node_arg),
graph_outputs_.end());
}
graph_outputs_.erase(std::remove(graph_outputs_.begin(), graph_outputs_.end(), node_arg),
graph_outputs_.end());
}
}
}
@ -2219,10 +2216,8 @@ void Graph::CleanUnusedInitializers() {
GSL_SUPPRESS(es .84) // warning about ignoring return value from insert(...)
Status Graph::SetGraphInputsOutputs() {
// Reset graph inputs/outputs/value info state.
// Reset graph inputs excluding initializers/value_info.
graph_inputs_excluding_initializers_.clear();
graph_inputs_including_initializers_.clear();
graph_outputs_.clear();
value_info_.clear();
// Flag indicates that this graph is loaded from model file.
@ -2231,10 +2226,11 @@ Status Graph::SetGraphInputsOutputs() {
// and outputs will be inferred.
const bool loaded_from_model_file = GraphLoadedFromModelFile(graph_proto_);
// if something is coming from outer scope, consider it already added
std::unordered_set<std::string> added_input_names{outer_scope_node_arg_names_};
if (loaded_from_model_file) {
// Reset graph inputs/outputs.
graph_inputs_including_initializers_.clear();
graph_outputs_.clear();
// Name to NodeArg mapping of all graph initializers.
std::unordered_map<std::string, const NodeArg*> graph_initializers;
@ -2302,49 +2298,31 @@ Status Graph::SetGraphInputsOutputs() {
}
} else {
std::unordered_map<std::string, const NodeArg*> output_name_to_node_arg;
std::vector<std::string> ordered_output_names;
std::unordered_map<std::string, size_t> output_name_to_node_arg_index;
std::vector<const NodeArg*> output_node_args_in_order;
// add any explicitly ordered inputs
for (auto* node_arg : graph_input_order_) {
if (!node_arg || !node_arg->Exists()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid entry in explicitly ordered inputs");
}
added_input_names.insert(node_arg->Name());
graph_inputs_including_initializers_.push_back(node_arg);
if (name_to_initial_tensor_.find(node_arg->Name()) == name_to_initial_tensor_.end()) {
graph_inputs_excluding_initializers_.push_back(node_arg);
}
// if something is coming from outer scope, consider it already added
std::unordered_set<std::string> added_input_names{outer_scope_node_arg_names_};
if (!graph_inputs_manually_set_) {
graph_inputs_including_initializers_.clear();
}
// add any explicitly ordered outputs
for (auto* node_arg : graph_output_order_) {
if (!node_arg || !node_arg->Exists()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid entry in explicitly ordered outputs");
}
output_name_to_node_arg.insert({node_arg->Name(), node_arg});
ordered_output_names.push_back(node_arg->Name());
if (!graph_outputs_manually_set_) {
graph_outputs_.clear();
}
// add all other outputs
// Collect all nodes' outputs
for (const auto& node : Nodes()) {
for (const auto* output_def : node.OutputDefs()) {
if (output_def->Exists()) {
auto& name = output_def->Name();
// check it wasn't in the explicitly ordered outputs
if (output_name_to_node_arg.find(name) == output_name_to_node_arg.cend()) {
output_name_to_node_arg.insert({name, output_def});
ordered_output_names.push_back(name);
}
output_node_args_in_order.push_back(output_def);
output_name_to_node_arg_index.insert({output_def->Name(), output_node_args_in_order.size() - 1});
}
}
}
// Init graph output args with copy of all node output args.
auto graph_output_args = output_name_to_node_arg;
std::unordered_set<Node*> inner_nodes;
auto graph_output_args = output_name_to_node_arg_index;
for (const auto& node : Nodes()) {
// Go thru all node's inputs.
for (const auto* input_arg : node.InputDefs()) {
@ -2353,15 +2331,28 @@ Status Graph::SetGraphInputsOutputs() {
continue;
}
auto output_arg_iter = output_name_to_node_arg.find(input_arg->Name());
if (output_name_to_node_arg.end() == output_arg_iter) {
auto output_arg_iter = output_name_to_node_arg_index.find(input_arg->Name());
if (output_name_to_node_arg_index.end() == output_arg_iter) {
// This input arg should be fed when running evaluation.
// it should be a graph input.
const std::string& name = input_arg->Name();
if (added_input_names.end() == added_input_names.find(name)) {
// This graph input has not been added into <graph_inputs_>.
graph_inputs_including_initializers_.push_back(input_arg);
if (!graph_inputs_manually_set_) {
graph_inputs_including_initializers_.push_back(input_arg);
} else {
// Validation: the <input_arg> must be in graph inputs or initializers when it's manually set.
auto& inputs = GetInputsIncludingInitializers();
auto iter = std::find(inputs.begin(), inputs.end(), input_arg);
if (inputs.end() == iter) {
// it's not in graph inputs.
auto initializers = GetAllInitializedTensors();
if (initializers.end() == initializers.find(input_arg->Name())) {
// It's not in graph initializers.
return Status(ONNXRUNTIME, FAIL, input_arg->Name() + " must be either specified in graph inputs or graph initailizers.");
}
}
}
if (name_to_initial_tensor_.find(name) == name_to_initial_tensor_.end()) {
graph_inputs_excluding_initializers_.push_back(input_arg);
}
@ -2377,12 +2368,15 @@ Status Graph::SetGraphInputsOutputs() {
}
}
// Set graph outputs
auto end = graph_output_args.end();
for (auto& name : ordered_output_names) {
auto graph_output = graph_output_args.find(name);
if (graph_output != end) {
graph_outputs_.push_back(graph_output->second);
if (!graph_outputs_manually_set_) {
// Set graph outputs in order.
std::vector<size_t> graph_output_args_index;
for (auto output_arg : graph_output_args) {
graph_output_args_index.push_back(output_arg.second);
}
std::sort(graph_output_args_index.begin(), graph_output_args_index.end());
for (auto& output_arg_index : graph_output_args_index) {
graph_outputs_.push_back(output_node_args_in_order[output_arg_index]);
}
}
}
@ -2483,6 +2477,25 @@ Status Graph::InlineFunction(Node& node) {
return Status::OK();
}
void Graph::SetInputs(const std::vector<const NodeArg*> inputs) {
if (GraphLoadedFromModelFile(graph_proto_)) {
// TODO: add this support.
ORT_THROW("This API is not supported when model is loaded from proto file right now.");
}
graph_inputs_including_initializers_ = inputs;
graph_inputs_manually_set_ = true;
}
void Graph::SetOutputs(const std::vector<const NodeArg*> outputs) {
if (GraphLoadedFromModelFile(graph_proto_)) {
// TODO: add this support.
ORT_THROW("This API is not supported when model is loaded from proto file right now.");
}
graph_outputs_ = outputs;
graph_outputs_manually_set_ = true;
}
void Graph::AddFunction(const ONNX_NAMESPACE::FunctionProto* func_proto) {
this->model_functions_[func_proto->name()] = func_proto;
}

View file

@ -443,12 +443,6 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained)
for (auto i = 0; i < 20; ++i) {
map.insert({std::to_string(i), i});
std::cout << "Insert " << i << "\n";
for (auto pair : map) {
std::cout << pair.first << ":" << pair.second << " ";
}
std::cout << "\n";
}
// | |
@ -458,10 +452,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained)
// |
// d (Split)
// / \
// 1 .. 10
std::unordered_map<std::string, std::pair<std::vector<NodeArg*>, std::vector<NodeArg*>>>
expected_node_name_to_input_output_args;
// 1 .. 10
TypeProto tensor_int32;
tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32);
tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
@ -475,37 +466,36 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained)
auto& output_arg_c = graph.GetOrCreateNodeArg("node_c_out_1", &tensor_int32);
std::vector<NodeArg*> split_outputs;
std::vector<const NodeArg*> graph_outputs;
for (int i = 0; i < 10; ++i) {
split_outputs.push_back(&graph.GetOrCreateNodeArg("node_d_out_" + std::to_string(i + 1), &tensor_int32));
auto arg = &graph.GetOrCreateNodeArg("node_d_out_" + std::to_string(i + 1), &tensor_int32);
split_outputs.push_back(arg);
graph_outputs.push_back(arg);
}
std::reverse(graph_outputs.begin(), graph_outputs.end());
std::vector<NodeArg*> inputs;
std::vector<NodeArg*> outputs;
inputs.push_back(&input_arg_a);
outputs.push_back(&output_arg_a);
expected_node_name_to_input_output_args["a"] = {inputs, outputs};
graph.AddNode("a", "Identity_Fake", "a", inputs, outputs);
inputs.resize(2);
inputs[0] = &output_arg_b;
inputs[1] = &output_arg_a;
outputs[0] = &output_arg_c;
expected_node_name_to_input_output_args["c"] = {inputs, outputs};
graph.AddNode("c", "Merge_Fake", "c", inputs, outputs);
// deliberately add 'b' after 'c' to mix up the inputs as well
inputs.resize(1);
inputs[0] = &input_arg_b;
outputs[0] = &output_arg_b;
expected_node_name_to_input_output_args["b"] = {inputs, outputs};
graph.AddNode("b", "Identity_Fake", "b", inputs, outputs);
inputs[0] = &output_arg_c;
expected_node_name_to_input_output_args["d"] = {inputs, split_outputs};
graph.AddNode("d", "Split_Fake", "d", inputs, split_outputs);
auto validate_inputs_outputs = [&split_outputs](const Graph& graph) {
auto validate_inputs_outputs = [&graph_outputs](const Graph& graph) {
auto inputs = graph.GetInputs();
auto outputs = graph.GetOutputs();
@ -516,10 +506,11 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained)
ASSERT_TRUE(outputs.size() == 10);
for (int i = 0; i < 10; ++i) {
EXPECT_TRUE(split_outputs[i]->Name() == outputs[i]->Name());
EXPECT_TRUE(graph_outputs[i]->Name() == outputs[i]->Name());
}
};
graph.SetInputs({&input_arg_a, &input_arg_b});
graph.SetOutputs(graph_outputs);
auto status = graph.Resolve();
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();

View file

@ -268,8 +268,8 @@ static const ONNX_NAMESPACE::GraphProto CreateSubgraph(const RunOptions& options
}
}
graph.SetInputOrder({&iter_num_in, &cond_in, &loop_var_0_in, &loop_var_1_in});
graph.SetOutputOrder({cond_out, loop_var_0_out, loop_var_1_out, loop_out_0});
graph.SetInputs({&iter_num_in, &cond_in, &loop_var_0_in, &loop_var_1_in});
graph.SetOutputs({cond_out, loop_var_0_out, loop_var_1_out, loop_out_0});
// optional input backed by an initializer to make sure that's handled too.
// we expect that Graph::InferAndVerifySubgraphTypes will be able to ignore the optional input if not provided
@ -447,8 +447,8 @@ TEST(Loop, InfiniteLoopTermination) {
graph.AddNode("loop_var_out", "Identity", "Forward outer_scope_0 to loop_var_0_out", inputs, outputs);
}
graph.SetInputOrder({&iter_num_in, &cond_in, &outer_scope_0});
graph.SetOutputOrder({&cond_out, &loop_var_0_out});
graph.SetInputs({&iter_num_in, &cond_in, &outer_scope_0});
graph.SetOutputs({&cond_out, &loop_var_0_out});
auto status = graph.Resolve();
EXPECT_EQ(status, Status::OK());

View file

@ -1046,8 +1046,8 @@ void MixedTypeInputs(bool is_v8) {
graph.AddNode("node3", "Identity", "Copy scan_in_1 to state_out_1", {&scan_in_1}, {&state_out_1});
graph.AddNode("node4", "Identity", "Copy scan_in_2 to state_out_2", {&scan_in_2}, {&state_out_2});
graph.SetInputOrder({&state_in_1, &state_in_2, &scan_in_1, &scan_in_2});
graph.SetOutputOrder({&state_out_1, &state_out_2, &scan_out_1, &scan_out_2});
graph.SetInputs({&state_in_1, &state_in_2, &scan_in_1, &scan_in_2});
graph.SetOutputs({&state_out_1, &state_out_2, &scan_out_1, &scan_out_2});
auto status = graph.Resolve();
EXPECT_EQ(status, Status::OK());
@ -1108,8 +1108,8 @@ void UnknownDimInSubgraphOutput(bool is_v8) {
graph.AddNode("node1", "Identity", "Copy state_in_1 to scan_out_1", {&state_in_1}, {&scan_out_1});
graph.AddNode("node2", "Identity", "Copy scan_in_1 to state_out_1", {&scan_in_1}, {&state_out_1});
graph.SetInputOrder({&state_in_1, &scan_in_1});
graph.SetOutputOrder({&state_out_1, &scan_out_1});
graph.SetInputs({&state_in_1, &scan_in_1});
graph.SetOutputs({&state_out_1, &scan_out_1});
auto status = graph.Resolve();
EXPECT_EQ(status, Status::OK());