Avoid call to Node::ToProto on first Graph::Resolve to improve session creation performance. (#20296)

### Description
<!-- Describe your changes. -->
The first call to Graph::Resolve occurs when creating the Graph instance
when loading an existing model from ModelProto. As the Node instance
will exactly match the source NodeProto there's no need to call
Node::ToProto in this case.

Add a temporary reference to the original NodeProto to avoid the call on
the first Graph::Resolve.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Better alternative to #19469
This commit is contained in:
Scott McKay 2024-04-17 10:07:12 +10:00 committed by GitHub
parent c11941289b
commit 5c8034cc20
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 50 additions and 17 deletions

View file

@ -621,6 +621,22 @@ class Node {
// Reference to the function template defined in the model.
const FunctionTemplate* func_template_ = nullptr;
// set/clear NodeProto that the Node was created from.
// Set by Graph ctor when loading a model from file.
// Cleared after first call to onnx::check_node in VerifyNodeAndOpMatch when the first Graph::Resolve runs.
void SetOriginalNodeProto(const ONNX_NAMESPACE::NodeProto* node_proto) {
original_node_proto_ = node_proto;
}
const ONNX_NAMESPACE::NodeProto* GetOriginalNodeProto() const {
return original_node_proto_;
}
// NodeProto that the Node was created from. We temporarily set this as a performance optimization to avoid calling
// Node::ToProto when running onnx::check_node in the first Graph::Resolve. At that point we know all the nodes are
// unchanged from the original model.
const ONNX_NAMESPACE::NodeProto* original_node_proto_ = nullptr;
#endif
// Execution priority, lower value for higher priority

View file

@ -2583,9 +2583,17 @@ Status Graph::VerifyNodeAndOpMatch(const ResolveOptions& options) {
{
auto status = Status::OK();
ORT_TRY {
NodeProto node_proto;
node.ToProto(node_proto);
checker::check_node(node_proto, ctx, lsc);
// if this is first Graph::Resolve call, we may have a NodeProto that was set on the Node so we can skip
// the ToProto call.
if (const NodeProto* orig_node_proto = node.GetOriginalNodeProto(); orig_node_proto) {
checker::check_node(*orig_node_proto, ctx, lsc);
// clear original as we don't know if the node will be modified once the Graph::Resolve completes.
node.SetOriginalNodeProto(nullptr);
} else {
NodeProto node_proto;
node.ToProto(node_proto);
checker::check_node(node_proto, ctx, lsc);
}
}
ORT_CATCH(const std::exception& ex) {
ORT_HANDLE_EXCEPTION([&]() {
@ -3123,13 +3131,25 @@ Node& Graph::AddNode(const NodeProto& node_proto,
attributes[attr.name()] = attr;
}
return AddNode(node_proto.name(),
node_proto.op_type(),
node_proto.doc_string(),
input_defs,
output_defs,
&attributes,
node_proto.domain());
Node& new_node = AddNode(node_proto.name(),
node_proto.op_type(),
node_proto.doc_string(),
input_defs,
output_defs,
&attributes,
node_proto.domain());
// Perf optimization: temporarily set NodeProto in Node so we don't need to call Node::ToProto prior to
// calling onnx::check_node
// NOTE: We don't handle a node with kOnnxDomainAlias. The entry in schema_registry_ uses kOnnxDomain,
// and that's what onnx::check_node uses during validation.
// The Node ctor automatically converts kOnnxDomainAlias to kOnnxDomain to handle this.
// node_proto is const so we can't do the same here.
if (node_proto.domain() != kOnnxDomainAlias) {
new_node.SetOriginalNodeProto(&node_proto);
}
return new_node;
}
static flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>

View file

@ -1956,12 +1956,9 @@ TEST_F(PlannerTest, TestCpuIf) {
sess_opt.graph_optimization_level = TransformerLevel::Default;
InferenceSession sess(sess_opt, GetEnvironment(), ORT_TSTR("./testdata/multi_stream_models/cpu_if.onnx"));
auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider());
ASSERT_TRUE(status.IsOK());
status = sess.Load();
ASSERT_TRUE(status.IsOK());
status = sess.Initialize();
ASSERT_TRUE(status.IsOK());
ASSERT_STATUS_OK(sess.RegisterExecutionProvider(DefaultCudaExecutionProvider()));
ASSERT_STATUS_OK(sess.Load());
ASSERT_STATUS_OK(sess.Initialize());
auto& sess_state = const_cast<onnxruntime::SessionState&>(sess.GetSessionState());
const auto& exe_plan = sess_state.GetExecutionPlan()->execution_plan;
@ -1971,7 +1968,7 @@ TEST_F(PlannerTest, TestCpuIf) {
exe_plan[1]->steps_[7]->GetNodeIndex() == 7) {
// there must be a wait before cpu If node
static const std::string WaitOnEPStep = "WaitOnEPStep";
ASSERT_TRUE(exe_plan[1]->steps_[6]->ToString().substr(0, WaitOnEPStep.size()) == WaitOnEPStep);
ASSERT_EQ(exe_plan[1]->steps_[6]->ToString().substr(0, WaitOnEPStep.size()), WaitOnEPStep);
}
}