mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Avoid call to Node::ToProto on first Graph::Resolve to improve session creation performance. (#20296)
### Description <!-- Describe your changes. --> The first call to Graph::Resolve occurs when creating the Graph instance when loading an existing model from ModelProto. As the Node instance will exactly match the source NodeProto there's no need to call Node::ToProto in this case. Add a temporary reference to the original NodeProto to avoid the call on the first Graph::Resolve. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Better alternative to #19469
This commit is contained in:
parent
c11941289b
commit
5c8034cc20
4 changed files with 50 additions and 17 deletions
|
|
@ -621,6 +621,22 @@ class Node {
|
|||
|
||||
// Reference to the function template defined in the model.
|
||||
const FunctionTemplate* func_template_ = nullptr;
|
||||
|
||||
// set/clear NodeProto that the Node was created from.
|
||||
// Set by Graph ctor when loading a model from file.
|
||||
// Cleared after first call to onnx::check_node in VerifyNodeAndOpMatch when the first Graph::Resolve runs.
|
||||
void SetOriginalNodeProto(const ONNX_NAMESPACE::NodeProto* node_proto) {
|
||||
original_node_proto_ = node_proto;
|
||||
}
|
||||
|
||||
const ONNX_NAMESPACE::NodeProto* GetOriginalNodeProto() const {
|
||||
return original_node_proto_;
|
||||
}
|
||||
|
||||
// NodeProto that the Node was created from. We temporarily set this as a performance optimization to avoid calling
|
||||
// Node::ToProto when running onnx::check_node in the first Graph::Resolve. At that point we know all the nodes are
|
||||
// unchanged from the original model.
|
||||
const ONNX_NAMESPACE::NodeProto* original_node_proto_ = nullptr;
|
||||
#endif
|
||||
|
||||
// Execution priority, lower value for higher priority
|
||||
|
|
|
|||
|
|
@ -2583,9 +2583,17 @@ Status Graph::VerifyNodeAndOpMatch(const ResolveOptions& options) {
|
|||
{
|
||||
auto status = Status::OK();
|
||||
ORT_TRY {
|
||||
NodeProto node_proto;
|
||||
node.ToProto(node_proto);
|
||||
checker::check_node(node_proto, ctx, lsc);
|
||||
// if this is first Graph::Resolve call, we may have a NodeProto that was set on the Node so we can skip
|
||||
// the ToProto call.
|
||||
if (const NodeProto* orig_node_proto = node.GetOriginalNodeProto(); orig_node_proto) {
|
||||
checker::check_node(*orig_node_proto, ctx, lsc);
|
||||
// clear original as we don't know if the node will be modified once the Graph::Resolve completes.
|
||||
node.SetOriginalNodeProto(nullptr);
|
||||
} else {
|
||||
NodeProto node_proto;
|
||||
node.ToProto(node_proto);
|
||||
checker::check_node(node_proto, ctx, lsc);
|
||||
}
|
||||
}
|
||||
ORT_CATCH(const std::exception& ex) {
|
||||
ORT_HANDLE_EXCEPTION([&]() {
|
||||
|
|
@ -3123,13 +3131,25 @@ Node& Graph::AddNode(const NodeProto& node_proto,
|
|||
attributes[attr.name()] = attr;
|
||||
}
|
||||
|
||||
return AddNode(node_proto.name(),
|
||||
node_proto.op_type(),
|
||||
node_proto.doc_string(),
|
||||
input_defs,
|
||||
output_defs,
|
||||
&attributes,
|
||||
node_proto.domain());
|
||||
Node& new_node = AddNode(node_proto.name(),
|
||||
node_proto.op_type(),
|
||||
node_proto.doc_string(),
|
||||
input_defs,
|
||||
output_defs,
|
||||
&attributes,
|
||||
node_proto.domain());
|
||||
|
||||
// Perf optimization: temporarily set NodeProto in Node so we don't need to call Node::ToProto prior to
|
||||
// calling onnx::check_node
|
||||
// NOTE: We don't handle a node with kOnnxDomainAlias. The entry in schema_registry_ uses kOnnxDomain,
|
||||
// and that's what onnx::check_node uses during validation.
|
||||
// The Node ctor automatically converts kOnnxDomainAlias to kOnnxDomain to handle this.
|
||||
// node_proto is const so we can't do the same here.
|
||||
if (node_proto.domain() != kOnnxDomainAlias) {
|
||||
new_node.SetOriginalNodeProto(&node_proto);
|
||||
}
|
||||
|
||||
return new_node;
|
||||
}
|
||||
|
||||
static flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
|
||||
|
|
|
|||
|
|
@ -1956,12 +1956,9 @@ TEST_F(PlannerTest, TestCpuIf) {
|
|||
sess_opt.graph_optimization_level = TransformerLevel::Default;
|
||||
|
||||
InferenceSession sess(sess_opt, GetEnvironment(), ORT_TSTR("./testdata/multi_stream_models/cpu_if.onnx"));
|
||||
auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider());
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
status = sess.Load();
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
status = sess.Initialize();
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
ASSERT_STATUS_OK(sess.RegisterExecutionProvider(DefaultCudaExecutionProvider()));
|
||||
ASSERT_STATUS_OK(sess.Load());
|
||||
ASSERT_STATUS_OK(sess.Initialize());
|
||||
|
||||
auto& sess_state = const_cast<onnxruntime::SessionState&>(sess.GetSessionState());
|
||||
const auto& exe_plan = sess_state.GetExecutionPlan()->execution_plan;
|
||||
|
|
@ -1971,7 +1968,7 @@ TEST_F(PlannerTest, TestCpuIf) {
|
|||
exe_plan[1]->steps_[7]->GetNodeIndex() == 7) {
|
||||
// there must be a wait before cpu If node
|
||||
static const std::string WaitOnEPStep = "WaitOnEPStep";
|
||||
ASSERT_TRUE(exe_plan[1]->steps_[6]->ToString().substr(0, WaitOnEPStep.size()) == WaitOnEPStep);
|
||||
ASSERT_EQ(exe_plan[1]->steps_[6]->ToString().substr(0, WaitOnEPStep.size()), WaitOnEPStep);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Binary file not shown.
Loading…
Reference in a new issue