From fdb132643d8e07d383d6cd9af1c19f547fe1718f Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 15 Sep 2023 12:13:37 -0700 Subject: [PATCH] Remove redundant Resolve() after each inlined function (#17556) ### Description Remove `Resolve()` on the entire graph as each function is resolved. We retain `Resolve()` after each inlining iteration. ### Motivation and Context Poor performance for inlining the model and session initialization. Original model before Resolve() removal FunctionTest.Profiling (**65953 ms**) After Resolve() Removal FunctionTest.Profiling (**2911 ms**) RelWithDebInfo pre-inlined model. Presumably because it runs Level1 optimizers Non-inlined model consists of functions and Level1 optimizers have no effect. FunctionTest.Profiling (**9851 ms**) --- include/onnxruntime/core/graph/graph.h | 1 + onnxruntime/core/graph/graph.cc | 2 -- .../orttraining/test/training_ops/function_op_test_utils.cc | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 19caa69d94..f153e88909 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1135,6 +1135,7 @@ class Graph { /** Directly insert the nodes in the function Node provided into this Graph. + The Graph needs to be Resolve()d after this call. @param node Node with Node::Type of Node::Type::Fused @returns Status indicating success or providing an error message. */ diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index d4164681f2..383c1d689d 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -4145,8 +4145,6 @@ Status Graph::InlineFunction(Node& callnode) { // std::cout << "Graph after inlining\n\n" << *this << std::endl << std::flush; - ORT_RETURN_IF_ERROR(this->Resolve()); - return Status::OK(); } diff --git a/orttraining/orttraining/test/training_ops/function_op_test_utils.cc b/orttraining/orttraining/test/training_ops/function_op_test_utils.cc index 5eed4765ab..9504ba2c1e 100644 --- a/orttraining/orttraining/test/training_ops/function_op_test_utils.cc +++ b/orttraining/orttraining/test/training_ops/function_op_test_utils.cc @@ -25,8 +25,8 @@ void OpFunctionTester::RunFunctionBodyGraphOnCPU(TwoDArray& results) { auto& node = *graph.Nodes().begin(); ASSERT_EQ(node.OpType(), op); - // Inline function will call Resolve itself ASSERT_STATUS_OK(graph.InlineFunction(node)); + ASSERT_STATUS_OK(graph.Resolve()); // Hookup the inputs and outputs std::unordered_map feeds;