diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index c4c1e71356..cafbf007d1 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -298,8 +298,12 @@ class Node { /** Sets the execution ProviderType that this Node will be executed by. */ void SetExecutionProviderType(ProviderType execution_provider_type); - /** Gets the NodeProto representation of this Node. */ - void ToProto(ONNX_NAMESPACE::NodeProto& proto) const; + /** Gets the NodeProto representation of this Node. + @param update_subgraphs Update the GraphProto values for any subgraphs in the returned NodeProto. + If graph optimization has been run this is most likely required + to ensure the complete Graph is valid. + */ + void ToProto(ONNX_NAMESPACE::NodeProto& proto, bool update_subgraphs = false) const; /** Call the provided function for all explicit inputs, implicit inputs, and outputs of this Node. If the NodeArg is an explicit or implicit input, is_input will be true when func is called. diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index e70faff417..d06b1eccb1 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -399,21 +399,25 @@ void Node::SetExecutionProviderType(ProviderType execution_provider_type) { execution_provider_type_ = execution_provider_type; } -void Node::ToProto(NodeProto& proto) const { - // Set name. +void Node::ToProto(NodeProto& proto, bool update_subgraphs) const { proto.set_name(name_); - // Set op type. proto.set_op_type(op_type_); - // Set op domain; - proto.set_domain(domain_); - // Set doc string. - proto.set_doc_string(description_); + + if (!domain_.empty()) + proto.set_domain(domain_); + + if (!description_.empty()) + proto.set_doc_string(description_); // Set attributes. proto.clear_attribute(); for (const auto& attribute : attributes_) { const gsl::not_null attr{proto.add_attribute()}; - *attr = attribute.second; + *attr = attribute.second; // copy + if (update_subgraphs && attr->has_g()) { + attr->clear_g(); + *attr->mutable_g() = attr_to_subgraph_map_.find(attribute.first)->second->ToGraphProto(); + } } // Set inputs' definitions. @@ -2327,7 +2331,9 @@ void Graph::ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const for (auto& node_idx : graph_viewer.GetNodesInTopologicalOrder()) { const gsl::not_null node_proto{graph_proto.add_node()}; const gsl::not_null p_node{GetNode(node_idx)}; - p_node->ToProto(*node_proto); + // we need to update any GraphProto attributes for subgraphs so that any changes made by things + // such as the optimizers are captured. otherwise we can end up saving an invalid graph. + p_node->ToProto(*node_proto, /* update_subgraphs */ true); } }