mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Make execution order an option for GraphViewerToProto() (#20411)
**Current issue:** Once ORT gets the capability from EP's GetCapability(), it creates a graph viewer based on the capability as below: `viewers.push_back(std::make_unique<GraphViewer>(graph, *cur_capability.sub_graph));` or see the code [here](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/graph_partitioner.cc#L458). At this point, the graph viewer has the chance to generate the wrong order of `nodes_in_topological_order_` when calling [Graph::ReverseDFSFrom](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph_viewer.cc#L107), so that during EP Compile(), EP might create the "wrong nodes ordering" model proto from the graph viewer when calling [GraphViewerToProto()](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph_proto_serializer.cc#L37) because of the `nodes_in_topological_order_`. This is a problem for TRT EP to refit weights to the "weightless" engine. Since the engine is built from the model proto provided by TRT EP and the weights is in the original onnx model. The model proto and the orignal onnx model are not the same in terms of node ordering which makes TRT complain when refitting. **The original model (subgraph of ResNet50):** <img width="442" alt="image" src="https://github.com/microsoft/onnxruntime/assets/54722500/bb9a641d-f2f2-46c3-aebf-4084a08ff289"> **The serialized model proto generated by TRT EP:** (The highlighted part has the wrong node order compared to the original model.) <img width="340" alt="image" src="https://github.com/microsoft/onnxruntime/assets/54722500/bbc6bf34-f960-4753-9474-a18ebc2dc48b"> **The solution 1:** Change default comparator to `NodeCompare::operator() {return n1->Index() > n2->Index();}` The root cause of the different node order between original model and EP generated model is from graph viewer [generating ](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph_viewer.cc#L107)the different `nodes_in_topological_order_`. Modifying the `NodeCompare::operator()` for sorting can fix the problem. The `NodeCompare::operator()` will be used in [Graph::ReverseDFSFrom](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph.cc#L1760) where the input nodes of the current node will be [sorted](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph.cc#L1802) based on node index. Due to the sorted nodes will be pushed into a stack which later determines the final topological node order in a "first in, last out" approach, the larger node index should be pushed into the stack first. So that we can get a topological node order aligns with smaller index node comes first. **The solution 2 (This PR uses this solution):** Use priority-based BFS for topological sort in GraphViewerToProto().
This commit is contained in:
parent
21b3cbc3af
commit
bbc30feb63
6 changed files with 38 additions and 10 deletions
|
|
@ -8,7 +8,8 @@ namespace onnxruntime {
|
|||
void GraphViewerToProto(const GraphViewer& graph_view,
|
||||
ONNX_NAMESPACE::GraphProto& graph_proto,
|
||||
bool include_initializer,
|
||||
bool include_outer_scope_args) {
|
||||
bool include_outer_scope_args,
|
||||
ExecutionOrder order = ExecutionOrder::DEFAULT) {
|
||||
graph_proto.set_name(graph_view.Name());
|
||||
graph_proto.set_doc_string(graph_view.Description());
|
||||
|
||||
|
|
@ -34,7 +35,7 @@ void GraphViewerToProto(const GraphViewer& graph_view,
|
|||
}
|
||||
|
||||
// Nodes must be sorted in Topological Order in the GraphProto per ONNX spec.
|
||||
for (auto& node_idx : graph_view.GetNodesInTopologicalOrder()) {
|
||||
for (auto& node_idx : graph_view.GetNodesInTopologicalOrder(order)) {
|
||||
const gsl::not_null<ONNX_NAMESPACE::NodeProto*> node_proto{graph_proto.add_node()};
|
||||
const gsl::not_null<const Node*> p_node{graph_view.GetNode(node_idx)};
|
||||
// we need to update any GraphProto attributes for subgraphs so that any changes made by things
|
||||
|
|
@ -62,7 +63,7 @@ void GraphViewerToProto(const GraphViewer& graph_view,
|
|||
|
||||
// handle outer scope value which is a constant initializer
|
||||
if (include_outer_scope_args) {
|
||||
for (auto& node_idx : graph_view.GetNodesInTopologicalOrder()) {
|
||||
for (auto& node_idx : graph_view.GetNodesInTopologicalOrder(order)) {
|
||||
const auto& node = graph_view.GetNode(node_idx);
|
||||
for (const auto& input : node->InputDefs()) {
|
||||
if (current_scope_initializer_set.find(input->Name()) != current_scope_initializer_set.end()) {
|
||||
|
|
|
|||
|
|
@ -7,5 +7,9 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
void GraphViewerToProto(const GraphViewer& graph_view, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializer, bool include_outer_scope_args);
|
||||
void GraphViewerToProto(const GraphViewer& graph_view,
|
||||
ONNX_NAMESPACE::GraphProto& graph_proto,
|
||||
bool include_initializer,
|
||||
bool include_outer_scope_args,
|
||||
ExecutionOrder order);
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -851,7 +851,11 @@ struct ProviderHost {
|
|||
virtual const std::vector<NodeIndex>& GraphViewer__GetNodesInTopologicalOrder(const GraphViewer* p) = 0;
|
||||
virtual const std::vector<const NodeArg*>& GraphViewer__GetInputsIncludingInitializers(const GraphViewer* p) noexcept = 0;
|
||||
|
||||
virtual void GraphViewer__ToProto(const GraphViewer* p, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) noexcept = 0;
|
||||
virtual void GraphViewer__ToProto(const GraphViewer* p,
|
||||
ONNX_NAMESPACE::GraphProto& graph_proto,
|
||||
bool include_initializers,
|
||||
bool include_outer_scope_args,
|
||||
int execution_order) noexcept = 0;
|
||||
virtual const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const = 0;
|
||||
|
||||
// Path
|
||||
|
|
|
|||
|
|
@ -887,7 +887,12 @@ class GraphViewer final {
|
|||
const std::vector<NodeIndex>& GetNodesInTopologicalOrder() const { return g_host->GraphViewer__GetNodesInTopologicalOrder(this); }
|
||||
const std::vector<const NodeArg*>& GetInputsIncludingInitializers() const noexcept { return g_host->GraphViewer__GetInputsIncludingInitializers(this); }
|
||||
|
||||
void ToProto(ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) const { g_host->GraphViewer__ToProto(this, graph_proto, include_initializers, include_outer_scope_args); }
|
||||
void ToProto(ONNX_NAMESPACE::GraphProto& graph_proto,
|
||||
bool include_initializers,
|
||||
bool include_outer_scope_args,
|
||||
int execution_order = 0) const {
|
||||
g_host->GraphViewer__ToProto(this, graph_proto, include_initializers, include_outer_scope_args, execution_order);
|
||||
}
|
||||
const Node* GetProducerNode(const std::string& node_arg_name) const { return g_host->GraphViewer__GetProducerNode(this, node_arg_name); }
|
||||
|
||||
GraphViewer() = delete;
|
||||
|
|
|
|||
|
|
@ -2103,7 +2103,12 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
|
|||
auto graph_viewer = graph_build.CreateGraphViewer();
|
||||
auto model = graph_viewer->CreateModel(*GetLogger());
|
||||
auto model_proto = model->ToProto();
|
||||
graph_viewer->ToProto(*model_proto->mutable_graph(), true, true);
|
||||
|
||||
// ORT's default topological sort is using reversed DFS.
|
||||
// When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index.
|
||||
// The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating
|
||||
// the model proto that has different node ordering compared to original onnx model.
|
||||
graph_viewer->ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/);
|
||||
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
|
||||
|
||||
std::string string_buf;
|
||||
|
|
@ -2499,7 +2504,12 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
|
|||
// Reconstruct graph proto from fused node's function body
|
||||
auto model = graph_body_viewer.CreateModel(*GetLogger());
|
||||
auto model_proto = model->ToProto();
|
||||
graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true);
|
||||
|
||||
// ORT's default topological sort is using reversed DFS.
|
||||
// When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index.
|
||||
// The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating
|
||||
// the model proto that has different node ordering compared to original onnx model.
|
||||
graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/);
|
||||
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
|
||||
std::string string_buf;
|
||||
model_proto->SerializeToString(string_buf);
|
||||
|
|
|
|||
|
|
@ -1088,8 +1088,12 @@ struct ProviderHostImpl : ProviderHost {
|
|||
|
||||
const std::vector<NodeIndex>& GraphViewer__GetNodesInTopologicalOrder(const GraphViewer* p) override { return p->GetNodesInTopologicalOrder(); }
|
||||
const std::vector<const NodeArg*>& GraphViewer__GetInputsIncludingInitializers(const GraphViewer* p) noexcept override { return p->GetInputsIncludingInitializers(); }
|
||||
void GraphViewer__ToProto(const GraphViewer* p, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) noexcept override {
|
||||
GraphViewerToProto(*p, graph_proto, include_initializers, include_outer_scope_args);
|
||||
void GraphViewer__ToProto(const GraphViewer* p,
|
||||
ONNX_NAMESPACE::GraphProto& graph_proto,
|
||||
bool include_initializers,
|
||||
bool include_outer_scope_args,
|
||||
int execution_order) noexcept override {
|
||||
GraphViewerToProto(*p, graph_proto, include_initializers, include_outer_scope_args, static_cast<ExecutionOrder>(execution_order));
|
||||
}
|
||||
const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const override { return p->GetProducerNode(node_arg_name); }
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue