mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Fix handling of nodes inserted by NHWC transformer. (#10904)
This commit is contained in:
parent
e03b799b95
commit
07a71d5bf2
2 changed files with 7 additions and 9 deletions
|
|
@ -988,23 +988,18 @@ Status SessionState::LoadFromOrtFormat(const fbs::SessionState& fbs_session_stat
|
|||
// kernel hashes for model are in top level SessionState
|
||||
const auto& compiled_kernel_hashes = GetCompiledKernelHashes();
|
||||
|
||||
const bool original_nodes_should_exist =
|
||||
compiled_kernel_hashes.empty()
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
&& graph_.RuntimeOptimizationReplayCtx().num_replayed_optimizations == 0
|
||||
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
;
|
||||
|
||||
// process the nodes that existed when the model was created
|
||||
for (FbsSessionStateViewer::Index i = 0, end = fbs_session_state_viewer.GetNumNodeKernelInfos(); i < end; ++i) {
|
||||
const auto node_kernel_info = fbs_session_state_viewer.GetNodeKernelInfo(i);
|
||||
|
||||
Node* const node = graph_.GetNode(node_kernel_info.node_index);
|
||||
if (node == nullptr) {
|
||||
// this is OK if we have compiled kernels/replayed runtime optimizations and the original node was replaced.
|
||||
#if defined(ORT_MINIMAL_BUILD) && !defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
// this is OK if we have compiled kernels and the original node was replaced.
|
||||
// if not the model is invalid.
|
||||
ORT_RETURN_IF(original_nodes_should_exist,
|
||||
ORT_RETURN_IF(compiled_kernel_hashes.empty(),
|
||||
"Can't find node with index ", node_kernel_info.node_index, ". Invalid ORT format model.");
|
||||
#endif // defined(ORT_MINIMAL_BUILD) && !defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1280,6 +1280,9 @@ Status AssignNodesToEpsFromHashesImpl(Graph& graph, const fbs::SessionState& fbs
|
|||
for (const auto& node : graph.Nodes()) {
|
||||
if (node.GetExecutionProviderType().empty()) {
|
||||
auto kernel_hash = utils::GetHashValueFromStaticKernelHashMap(node.OpType(), node.SinceVersion());
|
||||
if (!kernel_hash.has_value()) {
|
||||
kernel_hash = utils::GetInternalNhwcOpHash(node);
|
||||
}
|
||||
if (kernel_hash.has_value()) {
|
||||
ORT_RETURN_IF_ERROR(set_node_ep(node.Index(), kernel_hash.value()));
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue