From dc50aa42d5b1cf591b8d63ec37bdf108c389da39 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Thu, 20 Aug 2020 12:19:38 +1000 Subject: [PATCH] Refactor session state finalization and kernel lookup usage (#4763) * Refactor SessionState to support coming de/serialization changes - move more parts into SessionState to simplify usage - do the kernel lookup once instead of multiple times from different places - rename finalize_session_state.* to session_state_utils.* as the finalization logic is now inside SessionState * Fix some build issues * Move subgraph session state creation into SessionState. It's not needed by GraphPartitioner any more so we can delay the creation until later. Fixes issue where EP may have removed the subgraph during partitioning when taking a control flow node, and SessionState thought the subgraph was still valid. * Address PR comments * Clarify a comment --- .../core/framework/allocation_planner.cc | 77 ++++--- .../core/framework/allocation_planner.h | 14 +- .../core/framework/finalize_session_state.h | 25 --- .../core/framework/kernel_registry_manager.cc | 88 +++----- .../core/framework/kernel_registry_manager.h | 10 +- onnxruntime/core/framework/session_state.cc | 212 +++++++++++++++--- onnxruntime/core/framework/session_state.h | 92 ++++---- ...ession_state.cc => session_state_utils.cc} | 117 ++-------- .../core/framework/session_state_utils.h | 42 ++++ onnxruntime/core/session/inference_session.cc | 120 ++-------- onnxruntime/core/session/inference_session.h | 22 +- .../test/framework/allocation_planner_test.cc | 52 +++-- .../test/framework/execution_frame_test.cc | 30 +-- .../test/framework/session_state_test.cc | 23 +- onnxruntime/test/providers/memcpy_test.cc | 5 +- 15 files changed, 459 insertions(+), 470 deletions(-) delete mode 100644 onnxruntime/core/framework/finalize_session_state.h rename onnxruntime/core/framework/{finalize_session_state.cc => session_state_utils.cc} (67%) create mode 100644 onnxruntime/core/framework/session_state_utils.h diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index f392725ebb..084c9a1ddc 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -100,11 +100,22 @@ std::ostream& operator<<(std::ostream& out, std::pair>& kernel_create_info_map, + NodeIndex node_index) { + auto entry = kernel_create_info_map.find(node_index); + ORT_ENFORCE(entry != kernel_create_info_map.cend(), + "SessionState should have saved the KernelCreateInfo prior to this running. NodeIndex:", node_index); + + return *entry->second; +} + class PlannerImpl { public: PlannerImpl(const Node* parent_node, const onnxruntime::GraphViewer& graph_viewer, const std::vector& outer_scope_node_args, const ExecutionProviders& providers, - const KernelRegistryManager& kernel_registry, const OrtValueNameIdxMap& ort_value_name_idx_map, + const std::unordered_map>& kernel_create_info_map, + const OrtValueNameIdxMap& ort_value_name_idx_map, const ISequentialPlannerContext& context, SequentialExecutionPlan& plan) : context_(context), plan_(plan), @@ -112,7 +123,7 @@ class PlannerImpl { graph_viewer_(graph_viewer), outer_scope_node_args_(outer_scope_node_args), execution_providers_(providers), - kernel_registry_(kernel_registry), + kernel_create_info_map_(kernel_create_info_map), ort_value_name_idx_map_(ort_value_name_idx_map) {} Status CreatePlan(); @@ -126,7 +137,7 @@ class PlannerImpl { const std::vector& outer_scope_node_args_; const ExecutionProviders& execution_providers_; - const KernelRegistryManager& kernel_registry_; + const std::unordered_map>& kernel_create_info_map_; const OrtValueNameIdxMap& ort_value_name_idx_map_; // OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation: @@ -206,13 +217,13 @@ class PlannerImpl { // Find if there exists some input tensor that we can use in-place for output_arg_num-th input in the node. bool FindReusableInput(const onnxruntime::Node& node, int output_arg_num, OrtValueIndex* reusable_input) { auto p_output_arg = node.OutputDefs()[output_arg_num]; - const KernelCreateInfo* ci; - Status st = kernel_registry_.SearchKernelRegistry(node, &ci); - if (!st.IsOK() || ci == nullptr || ci->kernel_def == nullptr) { + const KernelCreateInfo& ci = GetKernelCreateInfo(kernel_create_info_map_, node.Index()); + + if (ci.kernel_def == nullptr) { return false; } - const std::vector>& alias_map = ci->kernel_def->Alias(); + const std::vector>& alias_map = ci.kernel_def->Alias(); auto input_args = node.InputDefs(); for (auto pair : alias_map) { if (pair.second == output_arg_num) { @@ -227,7 +238,7 @@ class PlannerImpl { } } - const std::vector>& inplace_map = ci->kernel_def->MayInplace(); + const std::vector>& inplace_map = ci.kernel_def->MayInplace(); for (auto pair : inplace_map) { if (pair.second == output_arg_num) { if ((0 <= pair.first) && (static_cast(pair.first) < input_args.size())) { @@ -392,22 +403,17 @@ class PlannerImpl { for (SequentialExecutionPlan::NodeExecutionPlan& step : plan_.execution_plan) { auto pnode = graph_viewer_.GetNode(step.node_index); - if (pnode == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Can not find the node ", step.node_index); + if (pnode == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Can not find the node ", step.node_index); + } // Identify where each output of this node should be allocated. - // This is determined by the opkernel bound to the node. - const KernelCreateInfo* kernel_create_info = nullptr; - ORT_RETURN_IF_ERROR(kernel_registry_.SearchKernelRegistry(*pnode, &kernel_create_info)); - auto p_kernel_def = kernel_create_info->kernel_def.get(); - if (nullptr == p_kernel_def) { - std::ostringstream errormsg; - errormsg << "No suitable kernel definition found for op " << pnode->OpType(); - if (pnode->Op() != nullptr) { - errormsg << "(" << pnode->Op()->since_version() << ")"; - } - if (!pnode->Name().empty()) errormsg << " (node " << pnode->Name() << ")"; - return Status(ONNXRUNTIME, FAIL, errormsg.str()); - } + // This is determined by the OpKernel bound to the node. + const KernelCreateInfo& kernel_create_info = GetKernelCreateInfo(kernel_create_info_map_, pnode->Index()); + + const auto* p_kernel_def = kernel_create_info.kernel_def.get(); + + ORT_ENFORCE(p_kernel_def, "Should not have entry in kernel create info with nullptr for kernel_def"); auto exec_provider = execution_providers_.Get(*pnode); if (exec_provider == nullptr) { @@ -484,11 +490,9 @@ class PlannerImpl { auto* p_provider = execution_providers_.Get(node); ORT_ENFORCE(p_provider); - const KernelCreateInfo* kernel_create_info; - auto st = kernel_registry_.SearchKernelRegistry(node, &kernel_create_info); - ORT_ENFORCE(st.IsOK(), st.ErrorMessage()); - ORT_ENFORCE(kernel_create_info != nullptr && kernel_create_info->kernel_def != nullptr); - if (kernel_create_info->kernel_def->IsInputOnCpu(input_index)) + const KernelCreateInfo& kernel_create_info = GetKernelCreateInfo(kernel_create_info_map_, node.Index()); + + if (kernel_create_info.kernel_def->IsInputOnCpu(input_index)) // weights are not output from any node, so it's OK to put its location on CPU provider return execution_providers_.GetDefaultCpuMemoryInfo(); return p_provider->GetAllocator(0, OrtMemTypeDefault)->Info(); @@ -769,17 +773,20 @@ Status PlannerImpl::CreatePlan() { return Status::OK(); } -Status SequentialPlanner::CreatePlan(const Node* parent_node, const onnxruntime::GraphViewer& graph_viewer, - const std::vector& outer_scope_node_args, - const ExecutionProviders& providers, const KernelRegistryManager& kernel_registry, - const OrtValueNameIdxMap& ort_value_name_idx_map, - const ISequentialPlannerContext& context, - std::unique_ptr& plan) { +Status SequentialPlanner::CreatePlan( + const Node* parent_node, + const onnxruntime::GraphViewer& graph_viewer, + const std::vector& outer_scope_node_args, + const ExecutionProviders& providers, + const std::unordered_map>& kernel_create_info_map, + const OrtValueNameIdxMap& ort_value_name_idx_map, + const ISequentialPlannerContext& context, + std::unique_ptr& plan) { // allocate/reset here so we know it's clean plan = onnxruntime::make_unique(); - PlannerImpl planner(parent_node, graph_viewer, outer_scope_node_args, providers, kernel_registry, - ort_value_name_idx_map, context, *plan); + PlannerImpl planner(parent_node, graph_viewer, outer_scope_node_args, providers, + kernel_create_info_map, ort_value_name_idx_map, context, *plan); return planner.CreatePlan(); } diff --git a/onnxruntime/core/framework/allocation_planner.h b/onnxruntime/core/framework/allocation_planner.h index 35f3951577..4130bf7481 100644 --- a/onnxruntime/core/framework/allocation_planner.h +++ b/onnxruntime/core/framework/allocation_planner.h @@ -16,6 +16,7 @@ class TensorShapeProto; namespace onnxruntime { class ExecutionProviders; +struct KernelCreateInfo; class KernelRegistryManager; class OrtValueNameIdxMap; @@ -48,11 +49,14 @@ class SequentialPlannerContext : public ISequentialPlannerContext { class SequentialPlanner { public: // This API allows user to provide a custom planner context. - static Status CreatePlan(const Node* parent_node, const onnxruntime::GraphViewer& graph, - const std::vector& outer_scope_node_args, - const ExecutionProviders& providers, const KernelRegistryManager& kernel_registry, - const OrtValueNameIdxMap& ort_value_name_idx_map, const ISequentialPlannerContext& context, - std::unique_ptr& plan); + static Status CreatePlan( + const Node* parent_node, const onnxruntime::GraphViewer& graph, + const std::vector& outer_scope_node_args, + const ExecutionProviders& providers, + const std::unordered_map>& kernel_create_info_map, + const OrtValueNameIdxMap& ort_value_name_idx_map, + const ISequentialPlannerContext& context, + std::unique_ptr& plan); }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/finalize_session_state.h b/onnxruntime/core/framework/finalize_session_state.h deleted file mode 100644 index 07e4fb27c8..0000000000 --- a/onnxruntime/core/framework/finalize_session_state.h +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include - -#include "core/common/const_pointer_container.h" -#include "core/framework/allocator.h" -#include "core/framework/tensor.h" -#include "core/framework/tensor_allocator.h" -#include "core/framework/session_options.h" -#include "core/platform/path_lib.h" - -namespace onnxruntime { -class KernelRegistryManager; -class Node; -class SessionState; - -Status FinalizeSessionState(SessionState& session_state, - const std::basic_string& graph_loc, - KernelRegistryManager& kernel_registry_manager, - _In_opt_ const Node* parent_node, - const SessionOptions& session_options); - -} // namespace onnxruntime diff --git a/onnxruntime/core/framework/kernel_registry_manager.cc b/onnxruntime/core/framework/kernel_registry_manager.cc index 0d9fdb962e..f7136f5b6e 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.cc +++ b/onnxruntime/core/framework/kernel_registry_manager.cc @@ -10,48 +10,18 @@ using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; namespace onnxruntime { -Status KernelRegistryManager::CreateKernel(const onnxruntime::Node& node, - const IExecutionProvider& execution_provider, - const SessionState& session_state, - /*out*/ std::unique_ptr& op_kernel) const { - auto create_error_message = [&node](const std::string& error) { - std::ostringstream errormsg; - errormsg << error << node.OpType(); - if (node.Op() != nullptr) errormsg << "(" << node.Op()->since_version() << ")"; - if (!node.Name().empty()) errormsg << " (node " << node.Name() << ")"; - return errormsg.str(); - }; +std::unique_ptr KernelRegistryManager::CreateKernel(const onnxruntime::Node& node, + const IExecutionProvider& execution_provider, + const SessionState& session_state, + const KernelCreateInfo& kernel_create_info) const { + OpKernelInfo kernel_info(node, *kernel_create_info.kernel_def, execution_provider, + session_state.GetConstantInitializedTensors(), + session_state.GetOrtValueNameIdxMap(), + session_state.GetFuncMgr(), + session_state.GetDataTransferMgr()); - const std::string& ptype = node.GetExecutionProviderType(); - if (ptype.empty()) { - return Status(ONNXRUNTIME, FAIL, - create_error_message("The node is not placed on any Execution Provider, " - "therefore, can't find a suitable kernel for ")); - } - - Status status; - { - for (auto& registry : custom_kernel_registries_) { - status = registry->TryCreateKernel(node, execution_provider, session_state.GetConstantInitializedTensors(), - session_state.GetOrtValueNameIdxMap(), session_state.GetFuncMgr(), session_state.GetDataTransferMgr(), op_kernel); - if (status.IsOK()) { - return status; - } - } - } - - KernelRegistry* p = nullptr; - auto iter = provider_type_to_registry_.find(ptype); - if (iter != provider_type_to_registry_.end()) p = iter->second.get(); - if (p != nullptr) { - status = p->TryCreateKernel(node, execution_provider, session_state.GetConstantInitializedTensors(), - session_state.GetOrtValueNameIdxMap(), session_state.GetFuncMgr(), session_state.GetDataTransferMgr(), op_kernel); - if (status.IsOK()) { - return status; - } - } - - return Status(ONNXRUNTIME, NOT_IMPLEMENTED, create_error_message("Failed to find kernel for ")); + // OpKernel is abstract base class so can't use make_unique + return std::unique_ptr(kernel_create_info.kernel_create_func(kernel_info)); } Status KernelRegistryManager::RegisterKernels(const ExecutionProviders& execution_providers) { @@ -88,32 +58,40 @@ bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r, Status KernelRegistryManager::SearchKernelRegistry(const onnxruntime::Node& node, /*out*/ const KernelCreateInfo** kernel_create_info) const { + Status status; + + auto create_error_message = [&node, &status](const std::string& prefix) { + std::ostringstream errormsg; + errormsg << prefix << node.OpType(); + if (node.Op() != nullptr) errormsg << "(" << node.Op()->since_version() << ")"; + if (!node.Name().empty()) errormsg << " (node " << node.Name() << "). "; + if (!status.IsOK()) errormsg << status.ErrorMessage(); + + return errormsg.str(); + }; + const std::string& ptype = node.GetExecutionProviderType(); if (ptype.empty()) { - return Status(ONNXRUNTIME, FAIL, "The node is not placed on any Execution Provider"); + return Status(ONNXRUNTIME, FAIL, create_error_message("The node is not placed on any Execution Provider. ")); } - Status status; - { - for (auto& registry : custom_kernel_registries_) { - status = registry->TryFindKernel(node, std::string(), kernel_create_info); - if (status.IsOK()) return status; - } + + for (auto& registry : custom_kernel_registries_) { + status = registry->TryFindKernel(node, std::string(), kernel_create_info); + if (status.IsOK()) return status; } KernelRegistry* p = nullptr; auto iter = provider_type_to_registry_.find(ptype); - if (iter != provider_type_to_registry_.end()) p = iter->second.get(); + if (iter != provider_type_to_registry_.end()) { + p = iter->second.get(); + } + if (p != nullptr) { status = p->TryFindKernel(node, std::string(), kernel_create_info); if (status.IsOK()) return status; } - std::ostringstream errormsg; - errormsg << "Failed to find kernel for " << node.OpType(); - if (node.Op() != nullptr) errormsg << "(" << node.Op()->since_version() << ")"; - if (!node.Name().empty()) errormsg << " (node " << node.Name() << ")."; - if (!status.IsOK()) errormsg << status.ErrorMessage(); - return Status(ONNXRUNTIME, NOT_IMPLEMENTED, errormsg.str()); + return Status(ONNXRUNTIME, NOT_IMPLEMENTED, create_error_message("Failed to find kernel for ")); } } // namespace onnxruntime diff --git a/onnxruntime/core/framework/kernel_registry_manager.h b/onnxruntime/core/framework/kernel_registry_manager.h index 45dccb8cd2..d17dfc6f76 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.h +++ b/onnxruntime/core/framework/kernel_registry_manager.h @@ -42,17 +42,15 @@ class KernelRegistryManager { // Then B > A > providers void RegisterKernelRegistry(std::shared_ptr kernel_registry); - // This function assumes the node is already assigned to an execution provider - // Don't call this function before graph partition is done - Status CreateKernel(const onnxruntime::Node& node, const IExecutionProvider& execution_provider, - const SessionState& session_state, - /*out*/ std::unique_ptr& op_kernel) const ORT_MUST_USE_RESULT; - // This function assumes the node is already assigned to an execution provider // Don't call this function before graph partition is done Status SearchKernelRegistry(const onnxruntime::Node& node, /*out*/ const KernelCreateInfo** kernel_create_info) const; + std::unique_ptr CreateKernel(const onnxruntime::Node& node, const IExecutionProvider& execution_provider, + const SessionState& session_state, + const KernelCreateInfo& kernel_create_info) const ORT_MUST_USE_RESULT; + /** * Whether this node can be run on this provider */ diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 803e14f2ba..8f95fc17bc 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -7,11 +7,13 @@ #include "core/common/logging/logging.h" #include "core/common/safeint.h" +#include "core/framework/allocator.h" #include "core/framework/node_index_info.h" #include "core/framework/op_kernel.h" -#include "core/framework/utils.h" #include "core/framework/ort_value_pattern_planner.h" -#include "core/framework/allocator.h" +#include "core/framework/session_state_utils.h" +#include "core/framework/utils.h" +#include "core/providers/cpu/controlflow/utils.h" using namespace ::onnxruntime::common; @@ -116,7 +118,33 @@ void SessionState::CreateGraphInfo() { LOGS(logger_, VERBOSE) << "Done saving OrtValue mappings."; } -Status SessionState::CreateKernels(const KernelRegistryManager& custom_registry_manager) { +Status SessionState::PopulateKernelCreateInfo(KernelRegistryManager& kernel_registry_manager) { + for (auto& node : graph_.Nodes()) { + const KernelCreateInfo* kci = nullptr; + ORT_RETURN_IF_ERROR(kernel_registry_manager.SearchKernelRegistry(node, &kci)); + ORT_IGNORE_RETURN_VALUE( + kernel_create_info_map_.insert({node.Index(), gsl::not_null(kci)})); + } + + for (const auto& entry : subgraph_session_states_) { + for (const auto& name_to_subgraph_session_state : entry.second) { + SessionState& subgraph_session_state = *name_to_subgraph_session_state.second; + ORT_RETURN_IF_ERROR(subgraph_session_state.PopulateKernelCreateInfo(kernel_registry_manager)); + } + } + + return Status::OK(); +} + +const KernelCreateInfo& SessionState::GetNodeKernelCreateInfo(NodeIndex node_index) const { + auto entry = kernel_create_info_map_.find(node_index); + // invalid node index or FinalizeSessionState should have been called. Either way it's an internal logic error + ORT_ENFORCE(entry != kernel_create_info_map_.cend()); + + return *entry->second; +} + +Status SessionState::CreateKernels(const KernelRegistryManager& kernel_registry_manager) { const GraphNodes& nodes = graph_viewer_->Nodes(); if (!nodes.empty()) { size_t max_nodeid = 0; @@ -127,22 +155,14 @@ Status SessionState::CreateKernels(const KernelRegistryManager& custom_registry_ session_kernels_.resize(max_nodeid + 1, nullptr); for (auto& node : graph_viewer_->Nodes()) { // construct and save the kernels - std::unique_ptr op_kernel; + const KernelCreateInfo& kci = GetNodeKernelCreateInfo(node.Index()); + + // the execution provider was required to be valid to find the KernelCreateInfo so we don't need to check it here onnxruntime::ProviderType exec_provider_name = node.GetExecutionProviderType(); + const IExecutionProvider& exec_provider = *execution_providers_.Get(exec_provider_name); - const IExecutionProvider* exec_provider = nullptr; - if (exec_provider_name.empty() || (exec_provider = execution_providers_.Get(exec_provider_name)) == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Could not create kernel for node: ", node.Name(), - " as there's no execution provider allocated."); - } + auto op_kernel = kernel_registry_manager.CreateKernel(node, exec_provider, *this, kci); - common::Status status = custom_registry_manager.CreateKernel(node, *exec_provider, *this, op_kernel); - if (!status.IsOK()) { - return common::Status( - status.Category(), status.Code(), - MakeString("Kernel creation failed for node: ", node.Name(), " with error: ", status.ErrorMessage())); - } - assert(session_kernels_[node.Index()] == nullptr); // assumes vector is already resize()'ed to the number of nodes in the graph session_kernels_[node.Index()] = op_kernel.release(); } @@ -151,10 +171,6 @@ Status SessionState::CreateKernels(const KernelRegistryManager& custom_registry_ return Status::OK(); } -void SessionState::SetExecutionPlan(std::unique_ptr p_seq_exec_plan) { - p_seq_exec_plan_ = std::move(p_seq_exec_plan); -} - const SequentialExecutionPlan* SessionState::GetExecutionPlan() const { return p_seq_exec_plan_.get(); } Status SessionState::AddInitializedTensor(int ort_value_index, const OrtValue& ort_value, const OrtCallback* d, @@ -294,7 +310,7 @@ Status ResolveDimParams(const GraphViewer& graph, Status ResolveSizeAndShape( const NodeArg* arg, const std::unordered_map& symbolic_dimensions, - size_t& size, // total number of elements. It's 0 if shape is unknown. + size_t& size, // total number of elements. It's 0 if shape is unknown. std::vector& resolved_shape) { if (!arg->Shape()) { // 0 means no shape information. @@ -576,10 +592,6 @@ const SessionState* SessionState::GetSubgraphSessionState(onnxruntime::NodeIndex return const_cast(this)->GetMutableSubgraphSessionState(index, attribute_name); } -void SessionState::RemoveSubgraphSessionState(onnxruntime::NodeIndex index) { - subgraph_session_states_.erase(index); -} - const NodeIndexInfo& SessionState::GetNodeIndexInfo() const { ORT_ENFORCE(node_index_info_, "SetGraphAndCreateKernels must be called prior to GetExecutionInfo."); return *node_index_info_; @@ -618,4 +630,154 @@ const std::unordered_set* SessionState::GetToBeExecutedNodes( return (it != to_be_executed_nodes_.end()) ? &it->second : nullptr; } +Status SessionState::CreateSubgraphSessionState() { + for (auto& node : graph_.Nodes()) { + for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { + const auto& ep = node.GetExecutionProviderType(); + if (ep != kCpuExecutionProvider && ep != kCudaExecutionProvider) { + // SessionState is only used when ORT is executing the subgraph. If a non-ORT EP has taken the control flow + // node containing the subgraph it will create whatever state it needs internally. + continue; + } + + auto& attr_name = entry.first; + Graph* subgraph = entry.second; + ORT_ENFORCE(subgraph, "Main Graph instance should have populated all subgraphs when being resolved."); + + auto subgraph_session_state = + onnxruntime::make_unique(*subgraph, execution_providers_, enable_mem_pattern_, + thread_pool_, inter_op_thread_pool_, data_transfer_mgr_, + logger_, profiler_); + + // Pass fused function manager to subgraph + subgraph_session_state->fused_funcs_mgr_.SetFusedFuncs(fused_funcs_mgr_); + + // recurse + ORT_RETURN_IF_ERROR(subgraph_session_state->CreateSubgraphSessionState()); + + // add the subgraph SessionState instance to the parent graph SessionState so it can be retrieved + // by Compute() via OpKernelContextInternal. + AddSubgraphSessionState(node.Index(), attr_name, std::move(subgraph_session_state)); + } + } + + return Status::OK(); +} + +Status SessionState::FinalizeSessionState(const std::basic_string& graph_location, + KernelRegistryManager& kernel_registry_manager, + const SessionOptions& session_options, + bool remove_initializers) { + // recursively create the subgraph session state instances and populate the kernel create info in them. + // it's simpler to handle the kernel create info recursively when deserializing, + // so also do it recursively when calling PopulateKernelCreateInfo for consistency. + ORT_RETURN_IF_ERROR(CreateSubgraphSessionState()); + + ORT_RETURN_IF_ERROR(PopulateKernelCreateInfo(kernel_registry_manager)); + + return FinalizeSessionStateImpl(graph_location, kernel_registry_manager, nullptr, session_options, + remove_initializers); +} + +Status SessionState::FinalizeSessionStateImpl(const std::basic_string& graph_location, + KernelRegistryManager& kernel_registry_manager, + _In_opt_ const Node* parent_node, + const SessionOptions& session_options, + bool remove_initializers) { + CreateGraphInfo(); + + // ignore any outer scope args we don't know about. this can happen if a node contains multiple subgraphs. + std::vector valid_outer_scope_node_args; + if (parent_node) { + auto outer_scope_node_args = parent_node->ImplicitInputDefs(); + valid_outer_scope_node_args.reserve(outer_scope_node_args.size()); + + std::for_each(outer_scope_node_args.cbegin(), outer_scope_node_args.cend(), + [this, &valid_outer_scope_node_args](const NodeArg* node_arg) { + int idx; + if (ort_value_name_idx_map_.GetIdx(node_arg->Name(), idx).IsOK()) { + valid_outer_scope_node_args.push_back(node_arg); + }; + }); + } + + SequentialPlannerContext context(session_options.execution_mode); + ORT_RETURN_IF_ERROR(SequentialPlanner::CreatePlan(parent_node, *graph_viewer_, valid_outer_scope_node_args, + execution_providers_, kernel_create_info_map_, + ort_value_name_idx_map_, context, p_seq_exec_plan_)); + + // Uncomment the below to dump the allocation plan to std::cout + // LOGS(logger_, VERBOSE) << std::make_pair(p_seq_exec_plan_.get(), this); + + std::unique_ptr tensor_allocator_( + ITensorAllocator::Create(enable_mem_pattern_, *p_seq_exec_plan_, *this, weights_buffers_)); + + // move initializers from TensorProto instances in Graph to OrtValue instances in SessionState + ORT_RETURN_IF_ERROR( + session_state_utils::SaveInitializedTensors( + Env::Default(), graph_location, *graph_viewer_, + execution_providers_.GetDefaultCpuMemoryInfo(), + ort_value_name_idx_map_, *tensor_allocator_, + [this](int idx, const OrtValue& value, const OrtCallback& d, bool constant) -> Status { + return AddInitializedTensor(idx, value, &d, constant); + }, + logger_, data_transfer_mgr_)); + + // remove weights from the graph now to save memory but in many cases it won't save memory, if the tensor was + // preallocated with the some other tensors in a single 'allocate' call, which is very common. + // TODO: make it better + if (remove_initializers) { + CleanInitializedTensorsFromGraph(); + } + + ORT_RETURN_IF_ERROR(CreateKernels(kernel_registry_manager)); + + const auto disable_prepacking = + GetSessionConfigOrDefault(session_options, ORT_SESSION_OPTIONS_CONFIG_DISABLEPREPACKING, "0"); + + if (disable_prepacking != "1") { + ORT_RETURN_IF_ERROR(PrepackInitializedConstantTensors()); + } + + ORT_RETURN_IF_ERROR( + session_state_utils::SaveInputOutputNamesToNodeMapping(*graph_viewer_, *this, valid_outer_scope_node_args)); + + // Need to recurse into subgraph session state instances to finalize them and add the execution info + + // Currently all subgraphs need to be executed using the sequential EP due to potential deadlock with the current + // parallel executor implementation + SessionOptions subgraph_session_options(session_options); + subgraph_session_options.execution_mode = ExecutionMode::ORT_SEQUENTIAL; + + for (const auto& node_to_subgraph_ss : subgraph_session_states_) { + Node& node = *graph_.GetNode(node_to_subgraph_ss.first); + + for (const auto& attr_subgraph_pair : node.GetAttributeNameToMutableSubgraphMap()) { + auto& attr_name = attr_subgraph_pair.first; + auto entry = node_to_subgraph_ss.second.find(attr_name); + // CreateSubgraphSessionState should ensure all these entries are created + ORT_ENFORCE(entry != node_to_subgraph_ss.second.cend(), + "Missing session state for subgraph. Node:'", node.Name(), + "' OpType:", node.OpType(), " Index:", node.Index(), " Attribute:", attr_name); + + SessionState& subgraph_session_state = *entry->second; + + // recurse + ORT_RETURN_IF_ERROR(subgraph_session_state.FinalizeSessionStateImpl( + graph_location, kernel_registry_manager, &node, subgraph_session_options, remove_initializers)); + + // setup all the info for handling the feeds and fetches used in subgraph execution + auto* p_op_kernel = GetMutableKernel(node.Index()); + ORT_ENFORCE(p_op_kernel); + + // Downcast is safe, since only control flow nodes have subgraphs + // (node.GetAttributeNameToMutableSubgraphMap() is non-empty) + auto& control_flow_kernel = static_cast(*p_op_kernel); + ORT_RETURN_IF_ERROR(control_flow_kernel.SetupSubgraphExecutionInfo(*this, attr_name, subgraph_session_state)); + } + } + + return Status::OK(); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index 59a591bd0d..d4b4c17d46 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +//// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #pragma once @@ -7,26 +7,30 @@ #include #include #include + #include "gsl/gsl" -#include "core/graph/onnx_protobuf.h" + #include "core/common/common.h" #include "core/common/logging/logging.h" #include "core/common/profiler.h" #include "core/framework/allocation_planner.h" +#include "core/framework/callback.h" #include "core/framework/data_transfer_manager.h" #include "core/framework/execution_providers.h" #include "core/framework/feeds_fetches_manager.h" #include "core/framework/framework_common.h" +#include "core/framework/fuse_nodes_funcs.h" #include "core/framework/kernel_registry_manager.h" #include "core/framework/mem_pattern.h" #include "core/framework/ml_value.h" -#include "core/framework/callback.h" -#include "core/framework/ort_value_name_idx_map.h" #include "core/framework/node_index_info.h" +#include "core/framework/op_kernel.h" +#include "core/framework/ort_value_name_idx_map.h" #include "core/graph/graph_viewer.h" -#include "core/framework/fuse_nodes_funcs.h" -#include "core/platform/threadpool.h" +#include "core/graph/onnx_protobuf.h" #include "core/platform/ort_mutex.h" +#include "core/platform/path_lib.h" +#include "core/platform/threadpool.h" namespace onnxruntime { @@ -77,13 +81,6 @@ class SessionState { SetupAllocators(); } - // Populate OrtValueNameIdxMap and create the graph viewer. - // Call once all graph modifications like transforms are completed. - void CreateGraphInfo(); - - // Call CreateKernels after CreateGraphInfo - Status CreateKernels(const KernelRegistryManager& custom_registry_manager); - ~SessionState() { for (auto* p : session_kernels_) { delete p; @@ -141,18 +138,6 @@ class SessionState { */ const std::unordered_map& GetConstantInitializedTensors() const; - /** - Cleans the initialized tensors that have been added to SessionState as OrtValue instances from the Graph instance - where they are present as TensorProto instances and will not be used when executing the model. - */ - void CleanInitializedTensorsFromGraph(); - - /** - * Prepack the constant initialized tensors for better performance. - * The original constant initialized tensors will be removed to save memory. - */ - Status PrepackInitializedConstantTensors(); - #ifdef ENABLE_TRAINING /** Get some initialized tensors (weights). @@ -173,8 +158,7 @@ class SessionState { NameMLValMap GetInitializedTensors(const std::unordered_set& interested_weights) const; #endif - // execution plan - void SetExecutionPlan(std::unique_ptr p_seq_exec_plan); + // execution plan. nullptr until FinalizeSessionState is called const SequentialExecutionPlan* GetExecutionPlan() const; /** Get the logger for this session. @@ -235,6 +219,7 @@ class SessionState { }; using NameNodeInfoMapType = std::unordered_map>; + common::Status AddInputNameToNodeInfoMapping(const std::string& input_name, const NodeInfo& node_info); common::Status GetInputNodeInfo(const std::string& input_name, std::vector& node_info_vec) const; const NameNodeInfoMapType& GetInputNodeInfoMap() const; @@ -243,22 +228,12 @@ class SessionState { common::Status GetOutputNodeInfo(const std::string& output_name, std::vector& node_info_vec) const; const NameNodeInfoMapType& GetOutputNodeInfoMap() const; - /// Add a SessionState instance for executing a subgraph in a Node - /// @param index Index of Node containing subgraph - /// @param attribute_name Name of attribute containing the subgraph GraphProto - /// @param session_state SessionState for subgraph execution - void AddSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name, - std::unique_ptr session_state); + // Get the KernelCreateInfo entry for a node. SessionState must be finalized before calling. + const KernelCreateInfo& GetNodeKernelCreateInfo(NodeIndex node_index) const; /// Return SessionState for the given Node index and attribute name if found. const SessionState* GetSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name) const; - SessionState* GetMutableSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name); - - // Remove the SessionState for a node containing a subgraph. - // If the node isn't going to be executed by the CPU provider we don't need it. - void RemoveSubgraphSessionState(onnxruntime::NodeIndex index); - concurrency::ThreadPool* GetThreadPool() const noexcept { return thread_pool_; } concurrency::ThreadPool* GetInterOpThreadPool() const noexcept { return inter_op_thread_pool_; } @@ -276,11 +251,47 @@ class SessionState { void UpdateToBeExecutedNodes(const std::vector& fetch_mlvalue_idxs); const std::unordered_set* GetToBeExecutedNodes(const std::vector& fetch_mlvalue_idxs) const; + Status FinalizeSessionState(const std::basic_string& graph_loc, + KernelRegistryManager& kernel_registry_manager, + const SessionOptions& session_options = {}, + bool remove_initializers = true); + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SessionState); void SetupAllocators(); + // Populate OrtValueNameIdxMap and create the graph viewer. + void CreateGraphInfo(); + + // create kernels using info in kernel_create_info_map_ + Status CreateKernels(const KernelRegistryManager& custom_registry_manager); + + // remove TensorProto versions of initializers from Graph instance + // (replaced byOrtValue instances in initialized_tensors_) + void CleanInitializedTensorsFromGraph(); + + /** + * Prepack the constant initialized tensors for better performance. + * The original constant initialized tensors will be removed to save memory. + */ + Status PrepackInitializedConstantTensors(); + + SessionState* GetMutableSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name); + + Status CreateSubgraphSessionState(); + + void AddSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name, + std::unique_ptr session_state); + + Status PopulateKernelCreateInfo(KernelRegistryManager& kernel_registry_manager); + + Status FinalizeSessionStateImpl(const std::basic_string& graph_loc, + KernelRegistryManager& kernel_registry_manager, + _In_opt_ const Node* parent_node, + const SessionOptions& session_options, + bool remove_initializers); + #ifdef ENABLE_TRAINING Status GeneratePatternGroupCache( const std::vector>& input_shape, @@ -289,6 +300,9 @@ class SessionState { std::unordered_map& inferred_shapes) const; #endif + // KernelCreateInfo for each node so we do kernel lookup once + std::unordered_map> kernel_create_info_map_; + // cache of the constructed kernels to avoid spending construction time per executor std::vector session_kernels_; Graph& graph_; diff --git a/onnxruntime/core/framework/finalize_session_state.cc b/onnxruntime/core/framework/session_state_utils.cc similarity index 67% rename from onnxruntime/core/framework/finalize_session_state.cc rename to onnxruntime/core/framework/session_state_utils.cc index 35aa92c585..bed87aac22 100644 --- a/onnxruntime/core/framework/finalize_session_state.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "core/graph/onnx_protobuf.h" -#include "core/framework/finalize_session_state.h" +#include "core/framework/session_state_utils.h" #include #include @@ -26,89 +26,7 @@ #include "core/framework/tensor_allocator.h" namespace onnxruntime { - -// T should have signature of '(int idx, const OrtValue& value, const OrtCallback& d) -> Status' -template -static common::Status SaveInitializedTensors(const Env& env, const std::basic_string& graph_loc, - const GraphViewer& graph, const OrtMemoryInfo& default_cpu_memory_info, - const OrtValueNameIdxMap& ort_value_name_idx_map, - ITensorAllocator& planner, const T& save_tensor_func, - const logging::Logger& logger, - const DataTransferManager& data_transfer_mgr); - -static common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::GraphViewer& graph, - const KernelRegistryManager& custom_registry_manager, - SessionState& session_state, - const std::vector& implicit_inputs); - -Status FinalizeSessionState(SessionState& session_state, - const std::basic_string& graph_location, - KernelRegistryManager& kernel_registry_manager, - _In_opt_ const Node* parent_node, - const SessionOptions& session_options) { - session_state.CreateGraphInfo(); - - const GraphViewer& graph_viewer = session_state.GetGraphViewer(); - const auto& logger = session_state.Logger(); - - // populate the SessionState OrtValueNameIdxMap - const auto& ort_value_name_idx_map = session_state.GetOrtValueNameIdxMap(); - - // ignore any outer scope args we don't know about. this can happen if a node contains multiple subgraphs. - std::vector valid_outer_scope_node_args; - if (parent_node) { - auto outer_scope_node_args = parent_node->ImplicitInputDefs(); - valid_outer_scope_node_args.reserve(outer_scope_node_args.size()); - - std::for_each(outer_scope_node_args.cbegin(), outer_scope_node_args.cend(), - [&ort_value_name_idx_map, &valid_outer_scope_node_args](const NodeArg* node_arg) { - int idx; - if (ort_value_name_idx_map.GetIdx(node_arg->Name(), idx).IsOK()) { - valid_outer_scope_node_args.push_back(node_arg); - }; - }); - } - - std::unique_ptr exec_plan; - SequentialPlannerContext context(session_options.execution_mode); - ORT_RETURN_IF_ERROR(SequentialPlanner::CreatePlan(parent_node, graph_viewer, valid_outer_scope_node_args, - session_state.GetExecutionProviders(), kernel_registry_manager, - ort_value_name_idx_map, context, exec_plan)); - - const auto* exec_plan_ptr = exec_plan.get(); - session_state.SetExecutionPlan(std::move(exec_plan)); - - std::unique_ptr tensor_allocator_( - ITensorAllocator::Create(session_state.GetEnableMemoryPattern(), *exec_plan_ptr, session_state, - session_state.GetMutableWeightsBuffers())); - - // lambda to save initialized tensors into SessionState directly - const Env& env = Env::Default(); - ORT_RETURN_IF_ERROR(SaveInitializedTensors( - env, graph_location, graph_viewer, - session_state.GetExecutionProviders().GetDefaultCpuMemoryInfo(), - ort_value_name_idx_map, *tensor_allocator_, - [&session_state](int idx, const OrtValue& value, const OrtCallback& d, bool constant) -> Status { - return session_state.AddInitializedTensor(idx, value, &d, constant); - }, - logger, session_state.GetDataTransferMgr())); - - // remove weights from the graph now to save memory but in many cases it won't save memory, if the tensor was - // preallocated with the some other tensors in a single 'allocate' call, which is very common. - // TODO: make it better - session_state.CleanInitializedTensorsFromGraph(); - - ORT_RETURN_IF_ERROR(session_state.CreateKernels(kernel_registry_manager)); - - const auto disable_prepacking = - GetSessionConfigOrDefault(session_options, ORT_SESSION_OPTIONS_CONFIG_DISABLEPREPACKING, "0"); - if (disable_prepacking != "1") - ORT_RETURN_IF_ERROR(session_state.PrepackInitializedConstantTensors()); - - ORT_RETURN_IF_ERROR(SaveInputOutputNamesToNodeMapping(graph_viewer, kernel_registry_manager, session_state, - valid_outer_scope_node_args)); - return Status::OK(); -} +namespace session_state_utils { static common::Status DeserializeTensorProto(const Env& env, const std::basic_string& proto_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer& m, @@ -170,12 +88,12 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st return common::Status::OK(); } -template -common::Status SaveInitializedTensors(const Env& env, const std::basic_string& graph_loc, - const GraphViewer& graph, const OrtMemoryInfo& default_cpu_memory_info, - const OrtValueNameIdxMap& ort_value_name_idx_map, ITensorAllocator& planner, - const T& save_tensor_func, const logging::Logger& logger, - const DataTransferManager& data_transfer_mgr) { +common::Status SaveInitializedTensors( + const Env& env, const std::basic_string& graph_loc, + const GraphViewer& graph, const OrtMemoryInfo& default_cpu_memory_info, + const OrtValueNameIdxMap& ort_value_name_idx_map, ITensorAllocator& planner, + const std::function& save_tensor_func, + const logging::Logger& logger, const DataTransferManager& data_transfer_mgr) { LOGS(logger, INFO) << "Saving initialized tensors."; ORT_ENFORCE(ort_value_name_idx_map.MaxIdx() > -1, "OrtValue indexes should have been populated."); @@ -248,10 +166,9 @@ static bool IsArgNameInInputsOutputs(const std::string& name, return it != graph_args.cend(); } -static common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::GraphViewer& graph, - const KernelRegistryManager& custom_registry_manager, - SessionState& session_state, - const std::vector& implicit_inputs) { +common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::GraphViewer& graph, + SessionState& session_state, + const std::vector& implicit_inputs) { auto& graph_inputs = graph.GetInputsIncludingInitializers(); auto& graph_outputs = graph.GetOutputs(); @@ -259,9 +176,7 @@ static common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph const auto& name_to_id = session_state.GetOrtValueNameIdxMap(); for (auto& node : graph.Nodes()) { - // note that KernelCreateInfo may not exist for custom kernel - const KernelCreateInfo* kci = nullptr; - custom_registry_manager.SearchKernelRegistry(node, &kci); + const KernelCreateInfo& kci = session_state.GetNodeKernelCreateInfo(node.Index()); ORT_RETURN_IF_ERROR( onnxruntime::Node::ForEachWithIndex( @@ -275,7 +190,7 @@ static common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph ORT_RETURN_IF_ERROR(name_to_id.GetIdx(arg.Name(), arg_index)); const auto& device = exec_plan->GetLocation(arg_index).device; - SessionState::NodeInfo node_info(index, &node, kci, device); + SessionState::NodeInfo node_info(index, &node, &kci, device); if (IsArgNameInInputsOutputs(arg.Name(), graph_inputs)) { ORT_RETURN_IF_ERROR(session_state.AddInputNameToNodeInfoMapping(arg.Name(), node_info)); @@ -306,7 +221,7 @@ static common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph int arg_index; ORT_RETURN_IF_ERROR(name_to_id.GetIdx(input_def->Name(), arg_index)); auto& device = exec_plan->GetLocation(arg_index).device; - SessionState::NodeInfo node_info(std::numeric_limits::max(), &node, kci, device); + SessionState::NodeInfo node_info(std::numeric_limits::max(), &node, &kci, device); ORT_RETURN_IF_ERROR(session_state.AddInputNameToNodeInfoMapping(input_def->Name(), node_info)); } } @@ -323,7 +238,7 @@ static common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph ORT_RETURN_IF_ERROR(name_to_id.GetIdx(arg.Name(), arg_index)); const auto& device = exec_plan->GetLocation(arg_index).device; - SessionState::NodeInfo node_info(index, &node, kci, device); + SessionState::NodeInfo node_info(index, &node, &kci, device); if (IsArgNameInInputsOutputs(arg.Name(), graph_outputs)) { session_state.AddOutputNameToNodeInfoMapping(arg.Name(), node_info); @@ -361,4 +276,6 @@ static common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph return Status::OK(); } + +} // namespace session_state_utils } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state_utils.h b/onnxruntime/core/framework/session_state_utils.h new file mode 100644 index 0000000000..1c2110413d --- /dev/null +++ b/onnxruntime/core/framework/session_state_utils.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +#include "core/common/const_pointer_container.h" +#include "core/framework/allocator.h" +#include "core/framework/tensor.h" +#include "core/framework/tensor_allocator.h" +#include "core/framework/session_options.h" +#include "core/platform/path_lib.h" + +namespace onnxruntime { +class Env; +class KernelRegistryManager; +class Node; +class SessionState; +class GraphViewer; +class OrtValueNameIdxMap; +class DataTransferManager; +class NodeArg; + +namespace logging { +class Logger; +} + +namespace session_state_utils { +common::Status SaveInitializedTensors( + const Env& env, const std::basic_string& graph_loc, + const GraphViewer& graph, const OrtMemoryInfo& default_cpu_memory_info, + const OrtValueNameIdxMap& ort_value_name_idx_map, + ITensorAllocator& planner, + const std::function& save_tensor_func, + const logging::Logger& logger, + const DataTransferManager& data_transfer_mgr); + +common::Status SaveInputOutputNamesToNodeMapping(const GraphViewer& graph, + SessionState& session_state, + const std::vector& implicit_inputs); +} // namespace session_state_utils +} // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 4c6ed42341..5e8c7c74c9 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -28,7 +28,6 @@ #include "core/framework/ort_value_pattern_planner.h" #include "core/framework/mldata_type_utils.h" #include "core/framework/op_kernel_context_internal.h" -#include "core/framework/finalize_session_state.h" #include "core/framework/TensorSeq.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/tensor_type_and_shape.h" @@ -701,91 +700,6 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, return common::Status::OK(); } -/// Create SessionState instance for each subgraph as we need that for the GraphPartitioner -/// This will be initialized by InitializeSubgraphSessions. -common::Status InferenceSession::CreateSubgraphSessionState(Graph& graph, SessionState& session_state) { - for (auto& node : graph.Nodes()) { - for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { - auto& name = entry.first; - Graph* subgraph = entry.second; - ORT_ENFORCE(subgraph, "Main Graph instance should have populated all subgraphs when being resolved."); - - auto subgraph_session_state = onnxruntime::make_unique( - *subgraph, - execution_providers_, - session_state.GetEnableMemoryPattern(), - session_state.GetThreadPool(), - session_state.GetInterOpThreadPool(), - session_state.GetDataTransferMgr(), - *session_logger_, - session_profiler_); - - // Pass fused function manager to subgraph - subgraph_session_state->GetMutableFuncMgr().SetFusedFuncs(session_state.GetFuncMgr()); - - // recurse - ORT_RETURN_IF_ERROR_SESSIONID_(CreateSubgraphSessionState(*subgraph, *subgraph_session_state)); - - // add the subgraph SessionState instance to the parent graph SessionState so it can be retrieved - // by Compute() via OpKernelContextInternal. - session_state.AddSubgraphSessionState(node.Index(), name, std::move(subgraph_session_state)); - } - } - - return Status::OK(); -} - -/// iterate nodes in graph looking for ones with graph attribute/s -/// @param graph The graph to iterate -/// @param session_state The SessionState instance for 'graph'. -/// @remarks We pass in graph and session_state so we can handled nested subgraphs in the future -common::Status InferenceSession::InitializeSubgraphSessions(Graph& graph, SessionState& session_state) { - for (auto& node : graph.Nodes()) { - // We only need subgraph session state for control flow nodes being handled by our CPU or CUDA execution provider. - // Remove it if it's not needed. - if (node.ContainsSubgraph()) { - const auto ep = node.GetExecutionProviderType(); - if (ep != kCpuExecutionProvider && ep != kCudaExecutionProvider) { - session_state.RemoveSubgraphSessionState(node.Index()); - continue; - } - } else { - // not a control flow node - continue; - } - - for (const auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { - auto& name = entry.first; - Graph& subgraph = *entry.second; - - SessionState* subgraph_session_state = session_state.GetMutableSubgraphSessionState(node.Index(), name); - ORT_ENFORCE(subgraph_session_state, "CreateSubgraphSessionState should have created an entry earlier."); - - ORT_RETURN_IF_ERROR_SESSIONID_(FinalizeSessionState(*subgraph_session_state, - model_location_, - kernel_registry_manager_, - &node, - session_options_)); - - // LOGS(*session_logger_, VERBOSE) << std::make_pair(subgraph_info.session_state->GetExecutionPlan(), - // &*subgraph_info.session_state); - - // setup all the info for handling the feeds and fetches used in subgraph execution - auto* p_op_kernel = session_state.GetMutableKernel(node.Index()); - ORT_ENFORCE(p_op_kernel); - // Downcast is safe, since only control flow nodes have subgraphs (node.GetAttributeNameToMutableSubgraphMap() is non-empty) - auto& control_flow_kernel = static_cast(*p_op_kernel); - ORT_RETURN_IF_ERROR_SESSIONID_( - control_flow_kernel.SetupSubgraphExecutionInfo(session_state, name, *subgraph_session_state)); - - // recurse - ORT_RETURN_IF_ERROR_SESSIONID_(InitializeSubgraphSessions(subgraph, *subgraph_session_state)); - } - } - - return Status::OK(); -} - bool InferenceSession::IsInitialized() const { std::lock_guard l(session_mutex_); return is_inited_; @@ -892,17 +806,12 @@ common::Status InferenceSession::Initialize() { if (session_options_.execution_mode == ExecutionMode::ORT_PARALLEL && execution_providers_.Get(onnxruntime::kCudaExecutionProvider)) { - LOGS(*session_logger_, ERROR) << "Parallel execution mode doesn't support " - "CUDA Execution Provider currently."; - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "Parallel execution mode doesn't support " - "CUDA Execution Provider currently."); + status = common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Parallel execution mode doesn't support CUDA Execution Provider currently."); + LOGS(*session_logger_, ERROR) << status.ErrorMessage(); + return status; } - // add predefined transformers - AddPredefinedTransformers(graph_transformation_mgr_, session_options_.graph_optimization_level, - transformers_to_enable_); - onnxruntime::Graph& graph = model_->MainGraph(); // Collect the kernel registries from execution provider instances; @@ -915,8 +824,8 @@ common::Status InferenceSession::Initialize() { // Register 2nd registries into KernelRegistryManager. ORT_RETURN_IF_ERROR_SESSIONID_(kernel_registry_manager_.RegisterKernels(execution_providers_)); - // create SessionState for subgraphs as it's needed by the transformers - ORT_RETURN_IF_ERROR_SESSIONID_(CreateSubgraphSessionState(graph, *session_state_)); + AddPredefinedTransformers(graph_transformation_mgr_, session_options_.graph_optimization_level, + transformers_to_enable_); // apply any transformations to the main graph and any subgraphs ORT_RETURN_IF_ERROR_SESSIONID_(TransformGraph(graph, graph_transformation_mgr_, @@ -927,6 +836,13 @@ common::Status InferenceSession::Initialize() { // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs. ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve()); + // need to keep the initializers if we're going to save the optimized model + bool keep_initializers = !session_options_.optimized_model_filepath.empty(); + + ORT_RETURN_IF_ERROR_SESSIONID_(session_state_->FinalizeSessionState(model_location_, kernel_registry_manager_, + session_options_, + !keep_initializers)); + if (!session_options_.optimized_model_filepath.empty()) { // Serialize optimized ONNX model. ORT_RETURN_IF_ERROR_SESSIONID_(Model::Save(*model_, session_options_.optimized_model_filepath)); @@ -939,14 +855,6 @@ common::Status InferenceSession::Initialize() { } } - ORT_RETURN_IF_ERROR_SESSIONID_(FinalizeSessionState(*session_state_, - model_location_, - kernel_registry_manager_, - nullptr /*parent_node*/, - session_options_)); - - // handle any subgraphs - ORT_RETURN_IF_ERROR_SESSIONID_(InitializeSubgraphSessions(graph, *session_state_)); session_state_->ResolveMemoryPatternFlag(); is_inited_ = true; @@ -1404,7 +1312,7 @@ std::string InferenceSession::EndProfiling() { AllocatorPtr InferenceSession::GetAllocator(const OrtMemoryInfo& mem_info) const { return session_state_->GetAllocator(mem_info); - } +} // assumes model has already been loaded before common::Status InferenceSession::DoPostLoadProcessing(onnxruntime::Model& model) { diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index b32fd62dc7..aa8769ae2d 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -252,7 +252,7 @@ class InferenceSession { * @return OK if success. */ common::Status Run(const NameMLValMap& feeds, const std::vector& output_names, - std::vector* p_fetches) ORT_MUST_USE_RESULT; + std::vector* p_fetches) ORT_MUST_USE_RESULT; /** * See Run(const NameMLValMap& feeds, const std::vector& output_names, std::vector* p_fetches) @@ -319,18 +319,16 @@ class InferenceSession { */ const SessionOptions& GetSessionOptions() const; - /* * Get the DataTransferManager associated with this session */ const DataTransferManager& GetDataTransferManager() const; - + /* * Get all the providers' options this session was initialized with. */ const ProviderOptionsMap& GetAllProviderOptions() const; - /** * Start profiling on this inference session. This simply turns on profiling events to be * recorded. A corresponding EndProfiling has to follow to write profiling data to a file. @@ -360,8 +358,8 @@ class InferenceSession { * @return a ptr to the allocator or nullptr if not available */ AllocatorPtr GetAllocator(const OrtMemoryInfo& mem_info) const; - - /** + + /** *Get InferenceSession logger. */ const logging::Logger* GetLogger() const { return session_logger_; }; @@ -422,20 +420,16 @@ class InferenceSession { common::Status Load(std::function&)> loader, const std::string& event_name) ORT_MUST_USE_RESULT; + virtual void AddPredefinedTransformers(GraphTransformerManager& transformer_manager, + TransformerLevel graph_optimization_level, + const std::vector& custom_list); + common::Status TransformGraph(onnxruntime::Graph& graph, const onnxruntime::GraphTransformerManager& graph_transformer_mgr, const ExecutionProviders& providers, KernelRegistryManager& kernel_registry_manager, const InsertCastTransformer& insert_cast_transformer, SessionState& session_state) ORT_MUST_USE_RESULT; - common::Status CreateSubgraphSessionState(Graph& graph, SessionState& session_state) ORT_MUST_USE_RESULT; - - common::Status InitializeSubgraphSessions(Graph& graph, SessionState& session_state) ORT_MUST_USE_RESULT; - - virtual void AddPredefinedTransformers(GraphTransformerManager& transformer_manager, - TransformerLevel graph_optimization_level, - const std::vector& custom_list); - void InitLogger(logging::LoggingManager* logging_manager); common::Status CheckShapes(const std::string& input_name, const TensorShape& input_shape, diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index 8bfdb17237..96535ddc3c 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -14,6 +14,8 @@ #include "core/util/thread_utils.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "test/test_environment.h" +#include "test/util/include/asserts.h" + using namespace ONNX_NAMESPACE; namespace onnxruntime { @@ -140,7 +142,7 @@ class SequentialPlannerTestContext : public ISequentialPlannerContext { class PlannerTest : public ::testing::Test { private: void index(const std::string& name, int& out) { - ASSERT_TRUE(state_.GetOrtValueNameIdxMap().GetIdx(name, out).IsOK()); + ASSERT_TRUE(state_->GetOrtValueNameIdxMap().GetIdx(name, out).IsOK()); } onnxruntime::Model model_; @@ -160,7 +162,7 @@ class PlannerTest : public ::testing::Test { std::unique_ptr tp_; DataTransferManager dtm_; profiling::Profiler profiler_; - SessionState state_; + std::unique_ptr state_; ShapeMap shape_map_; std::unique_ptr plan_; @@ -169,15 +171,16 @@ class PlannerTest : public ::testing::Test { : model_("test", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 10}}, {}, DefaultLoggingManager().DefaultLogger()), graph_(model_.MainGraph()), tp_(concurrency::CreateThreadPool(&onnxruntime::Env::Default(), OrtThreadPoolParams(), - concurrency::ThreadPoolType::INTRA_OP)), - state_(graph_, execution_providers_, false, tp_.get(), nullptr, dtm_, DefaultLoggingManager().DefaultLogger(), - profiler_) { + concurrency::ThreadPoolType::INTRA_OP)) { std_kernel_ = KernelDefBuilder().SetName("Transpose").Provider(kCpuExecutionProvider).SinceVersion(1, 10).Build(); in_place_kernel_ = KernelDefBuilder().SetName("Relu").Provider(kCpuExecutionProvider).SinceVersion(1, 10).MayInplace(0, 0).Build(); CPUExecutionProviderInfo epi; auto execution_provider = onnxruntime::make_unique(epi); execution_providers_.Add("CPUExecutionProvider", std::move(execution_provider)); + + state_.reset(new SessionState(graph_, execution_providers_, false, tp_.get(), nullptr, dtm_, + DefaultLoggingManager().DefaultLogger(), profiler_)); } onnxruntime::NodeArg* Arg(const std::string& name) { @@ -203,12 +206,14 @@ class PlannerTest : public ::testing::Test { return AddNode(*in_place_kernel_, input, output); } - void BindKernel(onnxruntime::Node* p_node, ::onnxruntime::KernelDef& kernel_def, KernelRegistry* reg) { + void BindKernel(onnxruntime::Node* p_node, ::onnxruntime::KernelDef& kernel_def, KernelRegistry* reg, + std::unordered_map>& kernel_create_info_map) { const IExecutionProvider* ep = execution_providers_.Get(*p_node); ASSERT_NE(ep, nullptr); - auto info = onnxruntime::make_unique(*p_node, kernel_def, *ep, - state_.GetInitializedTensors(), state_.GetOrtValueNameIdxMap(), - state_.GetFuncMgr(), state_.GetDataTransferMgr()); + auto info = onnxruntime::make_unique( + *p_node, kernel_def, *ep, state_->GetInitializedTensors(), state_->GetOrtValueNameIdxMap(), + state_->GetFuncMgr(), state_->GetDataTransferMgr()); + op_kernel_infos_.push_back(std::move(info)); if (!KernelRegistry::HasImplementationOf(*reg, *p_node, onnxruntime::kCpuExecutionProvider)) { auto st = reg->Register( @@ -216,6 +221,10 @@ class PlannerTest : public ::testing::Test { [](const OpKernelInfo& info) -> OpKernel* { return new DummyOpKernel(info); })); ORT_ENFORCE(st.IsOK(), st.ErrorMessage()); } + + const KernelCreateInfo* kci; + ASSERT_STATUS_OK(reg->TryFindKernel(*p_node, "", &kci)); + kernel_create_info_map.insert({p_node->Index(), gsl::not_null(kci)}); } void SetShape(std::string& name, TensorShapeProto* shape) { shape_map_[Arg(name)] = shape; } @@ -229,26 +238,31 @@ class PlannerTest : public ::testing::Test { void CreatePlan(const std::vector& outer_scope_node_args = {}) { EXPECT_EQ(graph_.Resolve(), Status::OK()); - state_.CreateGraphInfo(); - std::shared_ptr reg = std::make_shared(); + std::unordered_map> kernel_create_info_map; for (auto& binding : kernel_bindings_) { - BindKernel(binding.first, binding.second, reg.get()); + BindKernel(binding.first, binding.second, reg.get(), kernel_create_info_map); } auto cpu_execution_provider = onnxruntime::make_unique(CPUExecutionProviderInfo()); KernelRegistryManager kernel_registry_manager; kernel_registry_manager.RegisterKernelRegistry(reg); - ExecutionProviders execution_providers; - execution_providers.Add(onnxruntime::kCpuExecutionProvider, std::move(cpu_execution_provider)); - auto status = kernel_registry_manager.RegisterKernels(execution_providers); + auto status = kernel_registry_manager.RegisterKernels(execution_providers_); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); - status = state_.CreateKernels(kernel_registry_manager); + + // CreatePlan is called inside FinalizeSessionState and usually the initializers are removed following that. + // Leave initializers so we can duplicate the call to CreatePlan from here to validate. + const bool remove_initializers = false; + status = state_->FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager, {}, + remove_initializers); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); SequentialPlannerTestContext test_context(&shape_map_); - status = SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph_), outer_scope_node_args, execution_providers, - kernel_registry_manager, state_.GetOrtValueNameIdxMap(), test_context, plan_); + + status = SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph_), outer_scope_node_args, execution_providers_, + kernel_create_info_map, state_->GetOrtValueNameIdxMap(), test_context, + plan_); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); AllocationPlanTestUtility::BasicIntegrityCheck(*plan_, name_to_arg_.size()); @@ -278,7 +292,7 @@ class PlannerTest : public ::testing::Test { protected: Graph& GetGraph() { return graph_; } const SequentialExecutionPlan& GetPlan() const { return *plan_; } - const SessionState& GetState() const { return state_; } + const SessionState& GetState() const { return *state_; } }; TEST_F(PlannerTest, ChainTest) { diff --git a/onnxruntime/test/framework/execution_frame_test.cc b/onnxruntime/test/framework/execution_frame_test.cc index 5e7dd7d509..5944e3ffb3 100644 --- a/onnxruntime/test/framework/execution_frame_test.cc +++ b/onnxruntime/test/framework/execution_frame_test.cc @@ -57,19 +57,9 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) { SessionState state(graph, execution_providers, true, &tp_, nullptr, dtm, DefaultLoggingManager().DefaultLogger(), profiler); - state.CreateGraphInfo(); - - ASSERT_STATUS_OK(state.CreateKernels(kernel_registry_manager)); - node->SetExecutionProviderType(xp_typ); - std::unique_ptr p_seq_exec_plan; - // TODO below line is for testing only. In production use SequentialPlanner::CreatePlan() - SequentialPlannerContext context(ExecutionMode::ORT_SEQUENTIAL); - ASSERT_STATUS_OK(SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph), {}, execution_providers, - kernel_registry_manager, state.GetOrtValueNameIdxMap(), context, - p_seq_exec_plan)); - state.SetExecutionPlan(std::move(p_seq_exec_plan)); + ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager)); vector outputs; ExecutionFrame frame({}, {}, {}, outputs, {}, state); @@ -146,10 +136,8 @@ TEST_F(ExecutionFrameTest, FeedInDataTest) { SessionState state(graph, execution_providers, true, &tp_, nullptr, dtm, DefaultLoggingManager().DefaultLogger(), profiler); - state.CreateGraphInfo(); - - ASSERT_STATUS_OK(state.CreateKernels(kernel_registry_manager)); - + ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager)); + const OrtValueNameIdxMap& mlvalue_name_idx_map = state.GetOrtValueNameIdxMap(); int x_idx = -1, y_idx = -1; ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("X", x_idx).IsOK()); @@ -206,9 +194,7 @@ TEST_F(ExecutionFrameTest, MemPatternTest) { SessionState state(graph, execution_providers, true, &tp_, nullptr, dtm, DefaultLoggingManager().DefaultLogger(), profiler); - state.CreateGraphInfo(); - - ASSERT_STATUS_OK(state.CreateKernels(kernel_registry_manager)); + ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager)); const OrtValueNameIdxMap& mlvalue_name_idx_map(state.GetOrtValueNameIdxMap()); @@ -235,14 +221,6 @@ TEST_F(ExecutionFrameTest, MemPatternTest) { std::vector{2, 3}, std::vector(6, 1.0f), &v3); - std::unique_ptr p_seq_exec_plan = onnxruntime::make_unique(); - SequentialPlannerContext context(ExecutionMode::ORT_SEQUENTIAL); - ASSERT_STATUS_OK(SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph), {}, execution_providers, - kernel_registry_manager, mlvalue_name_idx_map, context, - p_seq_exec_plan)); - - state.SetExecutionPlan(std::move(p_seq_exec_plan)); - vector outputs; ExecutionFrame frame({x1_idx, x2_idx, x3_idx}, {v1, v2, v3}, {t3_idx}, outputs, {}, state); diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index f96abdcba7..fecb8f194c 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -7,7 +7,6 @@ #include "core/framework/graph_partitioner.h" #include "core/framework/op_kernel.h" #include "core/framework/session_state.h" -#include "core/framework/finalize_session_state.h" #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" @@ -51,6 +50,10 @@ TEST_P(SessionStateAddGetKernelTest, AddGetKernelTest) { auto& graph = model.MainGraph(); ExecutionProviders execution_providers; + auto tmp_cpu_execution_provider = onnxruntime::make_unique(CPUExecutionProviderInfo(false)); + auto* cpu_execution_provider = tmp_cpu_execution_provider.get(); + ASSERT_STATUS_OK(execution_providers.Add(kCpuExecutionProvider, std::move(tmp_cpu_execution_provider))); + DataTransferManager dtm; profiling::Profiler profiler; SessionState s(graph, execution_providers, true, tp.get(), nullptr, dtm, @@ -67,7 +70,6 @@ TEST_P(SessionStateAddGetKernelTest, AddGetKernelTest) { auto status = graph.Resolve(); ASSERT_TRUE(status.IsOK()); auto kernel_def = KernelDefBuilder().SetName("Variable").Provider(kCpuExecutionProvider).SinceVersion(1, 10).Build(); - auto cpu_execution_provider = onnxruntime::make_unique(CPUExecutionProviderInfo(false)); OpKernelInfo p_info(node, *kernel_def, *cpu_execution_provider, s.GetConstantInitializedTensors(), s.GetOrtValueNameIdxMap(), s.GetFuncMgr(), s.GetDataTransferMgr()); @@ -76,7 +78,6 @@ TEST_P(SessionStateAddGetKernelTest, AddGetKernelTest) { size_t orig_num_outputs = p_kernel->Node().OutputDefs().size(); std::cout << "node_idx: " << node.Index() << std::endl; - ASSERT_STATUS_OK(execution_providers.Add(kCpuExecutionProvider, std::move(cpu_execution_provider))); KernelRegistryManager kernel_registry_manager; status = kernel_registry_manager.RegisterKernels(execution_providers); ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); @@ -85,8 +86,8 @@ TEST_P(SessionStateAddGetKernelTest, AddGetKernelTest) { ASSERT_STATUS_OK(kernel_registry->Register(KernelCreateInfo( std::move(kernel_def), [](const OpKernelInfo& info) -> OpKernel* { return new TestOpKernel(info); }))); kernel_registry_manager.RegisterKernelRegistry(kernel_registry); - s.CreateGraphInfo(); - ASSERT_STATUS_OK(s.CreateKernels(kernel_registry_manager)); + ASSERT_STATUS_OK(s.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager)); + auto test_kernel = s.GetKernel(node.Index()); std::cout << "orig: " << orig_num_outputs << " new: " << test_kernel->Node().OutputDefs().size() << std::endl; EXPECT_EQ(orig_num_outputs, test_kernel->Node().OutputDefs().size()); @@ -140,8 +141,7 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr()); ASSERT_TRUE(status.IsOK()) << status; - session_state.CreateGraphInfo(); - ASSERT_STATUS_OK(FinalizeSessionState(session_state, oss.str(), krm, nullptr, SessionOptions())); + ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm)); const auto& initialized_tensors = session_state.GetInitializedTensors(); const auto& const_initialized_tensors = session_state.GetConstantInitializedTensors(); @@ -259,11 +259,10 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) { SessionOptions sess_options; bool use_prepacking = GetParam(); sess_options.session_configurations[ORT_SESSION_OPTIONS_CONFIG_DISABLEPREPACKING] = use_prepacking ? "0" : "1"; - ASSERT_STATUS_OK(FinalizeSessionState(session_state, - std::basic_string() /*graph_loc*/, - kernel_registry_manager, - nullptr /*parent_node*/, - sess_options)); + ASSERT_STATUS_OK(session_state.FinalizeSessionState(std::basic_string(), + kernel_registry_manager, + sess_options)); + const auto& const_initialized_tensors = session_state.GetConstantInitializedTensors(); // check prepacking ASSERT_EQ(const_initialized_tensors.size(), size_t(use_prepacking ? 0 : 1)); diff --git a/onnxruntime/test/providers/memcpy_test.cc b/onnxruntime/test/providers/memcpy_test.cc index 180dbc575a..dc1812ed92 100644 --- a/onnxruntime/test/providers/memcpy_test.cc +++ b/onnxruntime/test/providers/memcpy_test.cc @@ -5,7 +5,6 @@ #include "../framework/test_utils.h" #include "core/graph/model.h" #include "core/graph/onnx_protobuf.h" -#include #include "core/framework/execution_providers.h" #include "core/framework/op_kernel.h" #include "core/framework/session_state.h" @@ -52,8 +51,8 @@ TEST(MemcpyTest, copy1) { SessionState s(model.MainGraph(), execution_providers, true, &tp, nullptr, dtm, DefaultLoggingManager().DefaultLogger(), profiler); - s.CreateGraphInfo(); - ASSERT_STATUS_OK(FinalizeSessionState(s, ORT_TSTR(""), kernel_registry_manager, nullptr, SessionOptions())); + SessionOptions so; + ASSERT_STATUS_OK(s.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager, so)); AllocatorPtr allocator = execution_providers.Get(onnxruntime::kCpuExecutionProvider)->GetAllocator(0, OrtMemTypeDefault);