From 07a71d5bf231fce2995be2075bdda9b3ed9bcfb9 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Thu, 17 Mar 2022 12:41:34 -0700 Subject: [PATCH] Fix handling of nodes inserted by NHWC transformer. (#10904) --- onnxruntime/core/framework/session_state.cc | 13 ++++--------- onnxruntime/core/session/inference_session.cc | 3 +++ 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 139cbd6d22..7f1499e2e3 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -988,23 +988,18 @@ Status SessionState::LoadFromOrtFormat(const fbs::SessionState& fbs_session_stat // kernel hashes for model are in top level SessionState const auto& compiled_kernel_hashes = GetCompiledKernelHashes(); - const bool original_nodes_should_exist = - compiled_kernel_hashes.empty() -#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - && graph_.RuntimeOptimizationReplayCtx().num_replayed_optimizations == 0 -#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - ; - // process the nodes that existed when the model was created for (FbsSessionStateViewer::Index i = 0, end = fbs_session_state_viewer.GetNumNodeKernelInfos(); i < end; ++i) { const auto node_kernel_info = fbs_session_state_viewer.GetNodeKernelInfo(i); Node* const node = graph_.GetNode(node_kernel_info.node_index); if (node == nullptr) { - // this is OK if we have compiled kernels/replayed runtime optimizations and the original node was replaced. +#if defined(ORT_MINIMAL_BUILD) && !defined(ORT_EXTENDED_MINIMAL_BUILD) + // this is OK if we have compiled kernels and the original node was replaced. // if not the model is invalid. - ORT_RETURN_IF(original_nodes_should_exist, + ORT_RETURN_IF(compiled_kernel_hashes.empty(), "Can't find node with index ", node_kernel_info.node_index, ". Invalid ORT format model."); +#endif // defined(ORT_MINIMAL_BUILD) && !defined(ORT_EXTENDED_MINIMAL_BUILD) continue; } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index f665a44947..e0f3e6a373 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1280,6 +1280,9 @@ Status AssignNodesToEpsFromHashesImpl(Graph& graph, const fbs::SessionState& fbs for (const auto& node : graph.Nodes()) { if (node.GetExecutionProviderType().empty()) { auto kernel_hash = utils::GetHashValueFromStaticKernelHashMap(node.OpType(), node.SinceVersion()); + if (!kernel_hash.has_value()) { + kernel_hash = utils::GetInternalNhwcOpHash(node); + } if (kernel_hash.has_value()) { ORT_RETURN_IF_ERROR(set_node_ep(node.Index(), kernel_hash.value())); }