Update the GraphProto for subgraphs when saving the Graph. (#647)

* Update the GraphProto for subgraphs when saving the Graph. This is required to produce a valid overall Graph if the Graph has been optimized.
This commit is contained in:
Scott McKay 2019-10-23 15:14:06 -07:00 committed by GitHub
parent 6fca8b0a94
commit 41d55ea274
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 11 deletions

View file

@ -298,8 +298,12 @@ class Node {
/** Sets the execution ProviderType that this Node will be executed by. */
void SetExecutionProviderType(ProviderType execution_provider_type);
/** Gets the NodeProto representation of this Node. */
void ToProto(ONNX_NAMESPACE::NodeProto& proto) const;
/** Gets the NodeProto representation of this Node.
@param update_subgraphs Update the GraphProto values for any subgraphs in the returned NodeProto.
If graph optimization has been run this is most likely required
to ensure the complete Graph is valid.
*/
void ToProto(ONNX_NAMESPACE::NodeProto& proto, bool update_subgraphs = false) const;
/** Call the provided function for all explicit inputs, implicit inputs, and outputs of this Node.
If the NodeArg is an explicit or implicit input, is_input will be true when func is called.

View file

@ -399,21 +399,25 @@ void Node::SetExecutionProviderType(ProviderType execution_provider_type) {
execution_provider_type_ = execution_provider_type;
}
void Node::ToProto(NodeProto& proto) const {
// Set name.
void Node::ToProto(NodeProto& proto, bool update_subgraphs) const {
proto.set_name(name_);
// Set op type.
proto.set_op_type(op_type_);
// Set op domain;
proto.set_domain(domain_);
// Set doc string.
proto.set_doc_string(description_);
if (!domain_.empty())
proto.set_domain(domain_);
if (!description_.empty())
proto.set_doc_string(description_);
// Set attributes.
proto.clear_attribute();
for (const auto& attribute : attributes_) {
const gsl::not_null<AttributeProto*> attr{proto.add_attribute()};
*attr = attribute.second;
*attr = attribute.second; // copy
if (update_subgraphs && attr->has_g()) {
attr->clear_g();
*attr->mutable_g() = attr_to_subgraph_map_.find(attribute.first)->second->ToGraphProto();
}
}
// Set inputs' definitions.
@ -2327,7 +2331,9 @@ void Graph::ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const
for (auto& node_idx : graph_viewer.GetNodesInTopologicalOrder()) {
const gsl::not_null<NodeProto*> node_proto{graph_proto.add_node()};
const gsl::not_null<const Node*> p_node{GetNode(node_idx)};
p_node->ToProto(*node_proto);
// we need to update any GraphProto attributes for subgraphs so that any changes made by things
// such as the optimizers are captured. otherwise we can end up saving an invalid graph.
p_node->ToProto(*node_proto, /* update_subgraphs */ true);
}
}