mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
Update the GraphProto for subgraphs when saving the Graph. (#647)
* Update the GraphProto for subgraphs when saving the Graph. This is required to produce a valid overall Graph if the Graph has been optimized.
This commit is contained in:
parent
6fca8b0a94
commit
41d55ea274
2 changed files with 21 additions and 11 deletions
|
|
@ -298,8 +298,12 @@ class Node {
|
|||
/** Sets the execution ProviderType that this Node will be executed by. */
|
||||
void SetExecutionProviderType(ProviderType execution_provider_type);
|
||||
|
||||
/** Gets the NodeProto representation of this Node. */
|
||||
void ToProto(ONNX_NAMESPACE::NodeProto& proto) const;
|
||||
/** Gets the NodeProto representation of this Node.
|
||||
@param update_subgraphs Update the GraphProto values for any subgraphs in the returned NodeProto.
|
||||
If graph optimization has been run this is most likely required
|
||||
to ensure the complete Graph is valid.
|
||||
*/
|
||||
void ToProto(ONNX_NAMESPACE::NodeProto& proto, bool update_subgraphs = false) const;
|
||||
|
||||
/** Call the provided function for all explicit inputs, implicit inputs, and outputs of this Node.
|
||||
If the NodeArg is an explicit or implicit input, is_input will be true when func is called.
|
||||
|
|
|
|||
|
|
@ -399,21 +399,25 @@ void Node::SetExecutionProviderType(ProviderType execution_provider_type) {
|
|||
execution_provider_type_ = execution_provider_type;
|
||||
}
|
||||
|
||||
void Node::ToProto(NodeProto& proto) const {
|
||||
// Set name.
|
||||
void Node::ToProto(NodeProto& proto, bool update_subgraphs) const {
|
||||
proto.set_name(name_);
|
||||
// Set op type.
|
||||
proto.set_op_type(op_type_);
|
||||
// Set op domain;
|
||||
proto.set_domain(domain_);
|
||||
// Set doc string.
|
||||
proto.set_doc_string(description_);
|
||||
|
||||
if (!domain_.empty())
|
||||
proto.set_domain(domain_);
|
||||
|
||||
if (!description_.empty())
|
||||
proto.set_doc_string(description_);
|
||||
|
||||
// Set attributes.
|
||||
proto.clear_attribute();
|
||||
for (const auto& attribute : attributes_) {
|
||||
const gsl::not_null<AttributeProto*> attr{proto.add_attribute()};
|
||||
*attr = attribute.second;
|
||||
*attr = attribute.second; // copy
|
||||
if (update_subgraphs && attr->has_g()) {
|
||||
attr->clear_g();
|
||||
*attr->mutable_g() = attr_to_subgraph_map_.find(attribute.first)->second->ToGraphProto();
|
||||
}
|
||||
}
|
||||
|
||||
// Set inputs' definitions.
|
||||
|
|
@ -2327,7 +2331,9 @@ void Graph::ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const
|
|||
for (auto& node_idx : graph_viewer.GetNodesInTopologicalOrder()) {
|
||||
const gsl::not_null<NodeProto*> node_proto{graph_proto.add_node()};
|
||||
const gsl::not_null<const Node*> p_node{GetNode(node_idx)};
|
||||
p_node->ToProto(*node_proto);
|
||||
// we need to update any GraphProto attributes for subgraphs so that any changes made by things
|
||||
// such as the optimizers are captured. otherwise we can end up saving an invalid graph.
|
||||
p_node->ToProto(*node_proto, /* update_subgraphs */ true);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue