Fix handling of nodes inserted by NHWC transformer. (#10904)

This commit is contained in:
Edward Chen 2022-03-17 12:41:34 -07:00 committed by GitHub
parent e03b799b95
commit 07a71d5bf2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 9 deletions

View file

@ -988,23 +988,18 @@ Status SessionState::LoadFromOrtFormat(const fbs::SessionState& fbs_session_stat
// kernel hashes for model are in top level SessionState
const auto& compiled_kernel_hashes = GetCompiledKernelHashes();
const bool original_nodes_should_exist =
compiled_kernel_hashes.empty()
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
&& graph_.RuntimeOptimizationReplayCtx().num_replayed_optimizations == 0
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
;
// process the nodes that existed when the model was created
for (FbsSessionStateViewer::Index i = 0, end = fbs_session_state_viewer.GetNumNodeKernelInfos(); i < end; ++i) {
const auto node_kernel_info = fbs_session_state_viewer.GetNodeKernelInfo(i);
Node* const node = graph_.GetNode(node_kernel_info.node_index);
if (node == nullptr) {
// this is OK if we have compiled kernels/replayed runtime optimizations and the original node was replaced.
#if defined(ORT_MINIMAL_BUILD) && !defined(ORT_EXTENDED_MINIMAL_BUILD)
// this is OK if we have compiled kernels and the original node was replaced.
// if not the model is invalid.
ORT_RETURN_IF(original_nodes_should_exist,
ORT_RETURN_IF(compiled_kernel_hashes.empty(),
"Can't find node with index ", node_kernel_info.node_index, ". Invalid ORT format model.");
#endif // defined(ORT_MINIMAL_BUILD) && !defined(ORT_EXTENDED_MINIMAL_BUILD)
continue;
}

View file

@ -1280,6 +1280,9 @@ Status AssignNodesToEpsFromHashesImpl(Graph& graph, const fbs::SessionState& fbs
for (const auto& node : graph.Nodes()) {
if (node.GetExecutionProviderType().empty()) {
auto kernel_hash = utils::GetHashValueFromStaticKernelHashMap(node.OpType(), node.SinceVersion());
if (!kernel_hash.has_value()) {
kernel_hash = utils::GetInternalNhwcOpHash(node);
}
if (kernel_hash.has_value()) {
ORT_RETURN_IF_ERROR(set_node_ep(node.Index(), kernel_hash.value()));
}