diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index f392725ebb..084c9a1ddc 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -100,11 +100,22 @@ std::ostream& operator<<(std::ostream& out, std::pair>& kernel_create_info_map, + NodeIndex node_index) { + auto entry = kernel_create_info_map.find(node_index); + ORT_ENFORCE(entry != kernel_create_info_map.cend(), + "SessionState should have saved the KernelCreateInfo prior to this running. NodeIndex:", node_index); + + return *entry->second; +} + class PlannerImpl { public: PlannerImpl(const Node* parent_node, const onnxruntime::GraphViewer& graph_viewer, const std::vector& outer_scope_node_args, const ExecutionProviders& providers, - const KernelRegistryManager& kernel_registry, const OrtValueNameIdxMap& ort_value_name_idx_map, + const std::unordered_map>& kernel_create_info_map, + const OrtValueNameIdxMap& ort_value_name_idx_map, const ISequentialPlannerContext& context, SequentialExecutionPlan& plan) : context_(context), plan_(plan), @@ -112,7 +123,7 @@ class PlannerImpl { graph_viewer_(graph_viewer), outer_scope_node_args_(outer_scope_node_args), execution_providers_(providers), - kernel_registry_(kernel_registry), + kernel_create_info_map_(kernel_create_info_map), ort_value_name_idx_map_(ort_value_name_idx_map) {} Status CreatePlan(); @@ -126,7 +137,7 @@ class PlannerImpl { const std::vector& outer_scope_node_args_; const ExecutionProviders& execution_providers_; - const KernelRegistryManager& kernel_registry_; + const std::unordered_map>& kernel_create_info_map_; const OrtValueNameIdxMap& ort_value_name_idx_map_; // OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation: @@ -206,13 +217,13 @@ class PlannerImpl { // Find if there exists some input tensor that we can use in-place for output_arg_num-th input in the node. bool FindReusableInput(const onnxruntime::Node& node, int output_arg_num, OrtValueIndex* reusable_input) { auto p_output_arg = node.OutputDefs()[output_arg_num]; - const KernelCreateInfo* ci; - Status st = kernel_registry_.SearchKernelRegistry(node, &ci); - if (!st.IsOK() || ci == nullptr || ci->kernel_def == nullptr) { + const KernelCreateInfo& ci = GetKernelCreateInfo(kernel_create_info_map_, node.Index()); + + if (ci.kernel_def == nullptr) { return false; } - const std::vector>& alias_map = ci->kernel_def->Alias(); + const std::vector>& alias_map = ci.kernel_def->Alias(); auto input_args = node.InputDefs(); for (auto pair : alias_map) { if (pair.second == output_arg_num) { @@ -227,7 +238,7 @@ class PlannerImpl { } } - const std::vector>& inplace_map = ci->kernel_def->MayInplace(); + const std::vector>& inplace_map = ci.kernel_def->MayInplace(); for (auto pair : inplace_map) { if (pair.second == output_arg_num) { if ((0 <= pair.first) && (static_cast(pair.first) < input_args.size())) { @@ -392,22 +403,17 @@ class PlannerImpl { for (SequentialExecutionPlan::NodeExecutionPlan& step : plan_.execution_plan) { auto pnode = graph_viewer_.GetNode(step.node_index); - if (pnode == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Can not find the node ", step.node_index); + if (pnode == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Can not find the node ", step.node_index); + } // Identify where each output of this node should be allocated. - // This is determined by the opkernel bound to the node. - const KernelCreateInfo* kernel_create_info = nullptr; - ORT_RETURN_IF_ERROR(kernel_registry_.SearchKernelRegistry(*pnode, &kernel_create_info)); - auto p_kernel_def = kernel_create_info->kernel_def.get(); - if (nullptr == p_kernel_def) { - std::ostringstream errormsg; - errormsg << "No suitable kernel definition found for op " << pnode->OpType(); - if (pnode->Op() != nullptr) { - errormsg << "(" << pnode->Op()->since_version() << ")"; - } - if (!pnode->Name().empty()) errormsg << " (node " << pnode->Name() << ")"; - return Status(ONNXRUNTIME, FAIL, errormsg.str()); - } + // This is determined by the OpKernel bound to the node. + const KernelCreateInfo& kernel_create_info = GetKernelCreateInfo(kernel_create_info_map_, pnode->Index()); + + const auto* p_kernel_def = kernel_create_info.kernel_def.get(); + + ORT_ENFORCE(p_kernel_def, "Should not have entry in kernel create info with nullptr for kernel_def"); auto exec_provider = execution_providers_.Get(*pnode); if (exec_provider == nullptr) { @@ -484,11 +490,9 @@ class PlannerImpl { auto* p_provider = execution_providers_.Get(node); ORT_ENFORCE(p_provider); - const KernelCreateInfo* kernel_create_info; - auto st = kernel_registry_.SearchKernelRegistry(node, &kernel_create_info); - ORT_ENFORCE(st.IsOK(), st.ErrorMessage()); - ORT_ENFORCE(kernel_create_info != nullptr && kernel_create_info->kernel_def != nullptr); - if (kernel_create_info->kernel_def->IsInputOnCpu(input_index)) + const KernelCreateInfo& kernel_create_info = GetKernelCreateInfo(kernel_create_info_map_, node.Index()); + + if (kernel_create_info.kernel_def->IsInputOnCpu(input_index)) // weights are not output from any node, so it's OK to put its location on CPU provider return execution_providers_.GetDefaultCpuMemoryInfo(); return p_provider->GetAllocator(0, OrtMemTypeDefault)->Info(); @@ -769,17 +773,20 @@ Status PlannerImpl::CreatePlan() { return Status::OK(); } -Status SequentialPlanner::CreatePlan(const Node* parent_node, const onnxruntime::GraphViewer& graph_viewer, - const std::vector& outer_scope_node_args, - const ExecutionProviders& providers, const KernelRegistryManager& kernel_registry, - const OrtValueNameIdxMap& ort_value_name_idx_map, - const ISequentialPlannerContext& context, - std::unique_ptr& plan) { +Status SequentialPlanner::CreatePlan( + const Node* parent_node, + const onnxruntime::GraphViewer& graph_viewer, + const std::vector& outer_scope_node_args, + const ExecutionProviders& providers, + const std::unordered_map>& kernel_create_info_map, + const OrtValueNameIdxMap& ort_value_name_idx_map, + const ISequentialPlannerContext& context, + std::unique_ptr& plan) { // allocate/reset here so we know it's clean plan = onnxruntime::make_unique(); - PlannerImpl planner(parent_node, graph_viewer, outer_scope_node_args, providers, kernel_registry, - ort_value_name_idx_map, context, *plan); + PlannerImpl planner(parent_node, graph_viewer, outer_scope_node_args, providers, + kernel_create_info_map, ort_value_name_idx_map, context, *plan); return planner.CreatePlan(); } diff --git a/onnxruntime/core/framework/allocation_planner.h b/onnxruntime/core/framework/allocation_planner.h index 35f3951577..4130bf7481 100644 --- a/onnxruntime/core/framework/allocation_planner.h +++ b/onnxruntime/core/framework/allocation_planner.h @@ -16,6 +16,7 @@ class TensorShapeProto; namespace onnxruntime { class ExecutionProviders; +struct KernelCreateInfo; class KernelRegistryManager; class OrtValueNameIdxMap; @@ -48,11 +49,14 @@ class SequentialPlannerContext : public ISequentialPlannerContext { class SequentialPlanner { public: // This API allows user to provide a custom planner context. - static Status CreatePlan(const Node* parent_node, const onnxruntime::GraphViewer& graph, - const std::vector& outer_scope_node_args, - const ExecutionProviders& providers, const KernelRegistryManager& kernel_registry, - const OrtValueNameIdxMap& ort_value_name_idx_map, const ISequentialPlannerContext& context, - std::unique_ptr& plan); + static Status CreatePlan( + const Node* parent_node, const onnxruntime::GraphViewer& graph, + const std::vector& outer_scope_node_args, + const ExecutionProviders& providers, + const std::unordered_map>& kernel_create_info_map, + const OrtValueNameIdxMap& ort_value_name_idx_map, + const ISequentialPlannerContext& context, + std::unique_ptr& plan); }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/finalize_session_state.h b/onnxruntime/core/framework/finalize_session_state.h deleted file mode 100644 index 07e4fb27c8..0000000000 --- a/onnxruntime/core/framework/finalize_session_state.h +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include - -#include "core/common/const_pointer_container.h" -#include "core/framework/allocator.h" -#include "core/framework/tensor.h" -#include "core/framework/tensor_allocator.h" -#include "core/framework/session_options.h" -#include "core/platform/path_lib.h" - -namespace onnxruntime { -class KernelRegistryManager; -class Node; -class SessionState; - -Status FinalizeSessionState(SessionState& session_state, - const std::basic_string& graph_loc, - KernelRegistryManager& kernel_registry_manager, - _In_opt_ const Node* parent_node, - const SessionOptions& session_options); - -} // namespace onnxruntime diff --git a/onnxruntime/core/framework/kernel_registry_manager.cc b/onnxruntime/core/framework/kernel_registry_manager.cc index 0d9fdb962e..f7136f5b6e 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.cc +++ b/onnxruntime/core/framework/kernel_registry_manager.cc @@ -10,48 +10,18 @@ using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; namespace onnxruntime { -Status KernelRegistryManager::CreateKernel(const onnxruntime::Node& node, - const IExecutionProvider& execution_provider, - const SessionState& session_state, - /*out*/ std::unique_ptr& op_kernel) const { - auto create_error_message = [&node](const std::string& error) { - std::ostringstream errormsg; - errormsg << error << node.OpType(); - if (node.Op() != nullptr) errormsg << "(" << node.Op()->since_version() << ")"; - if (!node.Name().empty()) errormsg << " (node " << node.Name() << ")"; - return errormsg.str(); - }; +std::unique_ptr KernelRegistryManager::CreateKernel(const onnxruntime::Node& node, + const IExecutionProvider& execution_provider, + const SessionState& session_state, + const KernelCreateInfo& kernel_create_info) const { + OpKernelInfo kernel_info(node, *kernel_create_info.kernel_def, execution_provider, + session_state.GetConstantInitializedTensors(), + session_state.GetOrtValueNameIdxMap(), + session_state.GetFuncMgr(), + session_state.GetDataTransferMgr()); - const std::string& ptype = node.GetExecutionProviderType(); - if (ptype.empty()) { - return Status(ONNXRUNTIME, FAIL, - create_error_message("The node is not placed on any Execution Provider, " - "therefore, can't find a suitable kernel for ")); - } - - Status status; - { - for (auto& registry : custom_kernel_registries_) { - status = registry->TryCreateKernel(node, execution_provider, session_state.GetConstantInitializedTensors(), - session_state.GetOrtValueNameIdxMap(), session_state.GetFuncMgr(), session_state.GetDataTransferMgr(), op_kernel); - if (status.IsOK()) { - return status; - } - } - } - - KernelRegistry* p = nullptr; - auto iter = provider_type_to_registry_.find(ptype); - if (iter != provider_type_to_registry_.end()) p = iter->second.get(); - if (p != nullptr) { - status = p->TryCreateKernel(node, execution_provider, session_state.GetConstantInitializedTensors(), - session_state.GetOrtValueNameIdxMap(), session_state.GetFuncMgr(), session_state.GetDataTransferMgr(), op_kernel); - if (status.IsOK()) { - return status; - } - } - - return Status(ONNXRUNTIME, NOT_IMPLEMENTED, create_error_message("Failed to find kernel for ")); + // OpKernel is abstract base class so can't use make_unique + return std::unique_ptr(kernel_create_info.kernel_create_func(kernel_info)); } Status KernelRegistryManager::RegisterKernels(const ExecutionProviders& execution_providers) { @@ -88,32 +58,40 @@ bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r, Status KernelRegistryManager::SearchKernelRegistry(const onnxruntime::Node& node, /*out*/ const KernelCreateInfo** kernel_create_info) const { + Status status; + + auto create_error_message = [&node, &status](const std::string& prefix) { + std::ostringstream errormsg; + errormsg << prefix << node.OpType(); + if (node.Op() != nullptr) errormsg << "(" << node.Op()->since_version() << ")"; + if (!node.Name().empty()) errormsg << " (node " << node.Name() << "). "; + if (!status.IsOK()) errormsg << status.ErrorMessage(); + + return errormsg.str(); + }; + const std::string& ptype = node.GetExecutionProviderType(); if (ptype.empty()) { - return Status(ONNXRUNTIME, FAIL, "The node is not placed on any Execution Provider"); + return Status(ONNXRUNTIME, FAIL, create_error_message("The node is not placed on any Execution Provider. ")); } - Status status; - { - for (auto& registry : custom_kernel_registries_) { - status = registry->TryFindKernel(node, std::string(), kernel_create_info); - if (status.IsOK()) return status; - } + + for (auto& registry : custom_kernel_registries_) { + status = registry->TryFindKernel(node, std::string(), kernel_create_info); + if (status.IsOK()) return status; } KernelRegistry* p = nullptr; auto iter = provider_type_to_registry_.find(ptype); - if (iter != provider_type_to_registry_.end()) p = iter->second.get(); + if (iter != provider_type_to_registry_.end()) { + p = iter->second.get(); + } + if (p != nullptr) { status = p->TryFindKernel(node, std::string(), kernel_create_info); if (status.IsOK()) return status; } - std::ostringstream errormsg; - errormsg << "Failed to find kernel for " << node.OpType(); - if (node.Op() != nullptr) errormsg << "(" << node.Op()->since_version() << ")"; - if (!node.Name().empty()) errormsg << " (node " << node.Name() << ")."; - if (!status.IsOK()) errormsg << status.ErrorMessage(); - return Status(ONNXRUNTIME, NOT_IMPLEMENTED, errormsg.str()); + return Status(ONNXRUNTIME, NOT_IMPLEMENTED, create_error_message("Failed to find kernel for ")); } } // namespace onnxruntime diff --git a/onnxruntime/core/framework/kernel_registry_manager.h b/onnxruntime/core/framework/kernel_registry_manager.h index 45dccb8cd2..d17dfc6f76 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.h +++ b/onnxruntime/core/framework/kernel_registry_manager.h @@ -42,17 +42,15 @@ class KernelRegistryManager { // Then B > A > providers void RegisterKernelRegistry(std::shared_ptr kernel_registry); - // This function assumes the node is already assigned to an execution provider - // Don't call this function before graph partition is done - Status CreateKernel(const onnxruntime::Node& node, const IExecutionProvider& execution_provider, - const SessionState& session_state, - /*out*/ std::unique_ptr& op_kernel) const ORT_MUST_USE_RESULT; - // This function assumes the node is already assigned to an execution provider // Don't call this function before graph partition is done Status SearchKernelRegistry(const onnxruntime::Node& node, /*out*/ const KernelCreateInfo** kernel_create_info) const; + std::unique_ptr CreateKernel(const onnxruntime::Node& node, const IExecutionProvider& execution_provider, + const SessionState& session_state, + const KernelCreateInfo& kernel_create_info) const ORT_MUST_USE_RESULT; + /** * Whether this node can be run on this provider */ diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 803e14f2ba..8f95fc17bc 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -7,11 +7,13 @@ #include "core/common/logging/logging.h" #include "core/common/safeint.h" +#include "core/framework/allocator.h" #include "core/framework/node_index_info.h" #include "core/framework/op_kernel.h" -#include "core/framework/utils.h" #include "core/framework/ort_value_pattern_planner.h" -#include "core/framework/allocator.h" +#include "core/framework/session_state_utils.h" +#include "core/framework/utils.h" +#include "core/providers/cpu/controlflow/utils.h" using namespace ::onnxruntime::common; @@ -116,7 +118,33 @@ void SessionState::CreateGraphInfo() { LOGS(logger_, VERBOSE) << "Done saving OrtValue mappings."; } -Status SessionState::CreateKernels(const KernelRegistryManager& custom_registry_manager) { +Status SessionState::PopulateKernelCreateInfo(KernelRegistryManager& kernel_registry_manager) { + for (auto& node : graph_.Nodes()) { + const KernelCreateInfo* kci = nullptr; + ORT_RETURN_IF_ERROR(kernel_registry_manager.SearchKernelRegistry(node, &kci)); + ORT_IGNORE_RETURN_VALUE( + kernel_create_info_map_.insert({node.Index(), gsl::not_null(kci)})); + } + + for (const auto& entry : subgraph_session_states_) { + for (const auto& name_to_subgraph_session_state : entry.second) { + SessionState& subgraph_session_state = *name_to_subgraph_session_state.second; + ORT_RETURN_IF_ERROR(subgraph_session_state.PopulateKernelCreateInfo(kernel_registry_manager)); + } + } + + return Status::OK(); +} + +const KernelCreateInfo& SessionState::GetNodeKernelCreateInfo(NodeIndex node_index) const { + auto entry = kernel_create_info_map_.find(node_index); + // invalid node index or FinalizeSessionState should have been called. Either way it's an internal logic error + ORT_ENFORCE(entry != kernel_create_info_map_.cend()); + + return *entry->second; +} + +Status SessionState::CreateKernels(const KernelRegistryManager& kernel_registry_manager) { const GraphNodes& nodes = graph_viewer_->Nodes(); if (!nodes.empty()) { size_t max_nodeid = 0; @@ -127,22 +155,14 @@ Status SessionState::CreateKernels(const KernelRegistryManager& custom_registry_ session_kernels_.resize(max_nodeid + 1, nullptr); for (auto& node : graph_viewer_->Nodes()) { // construct and save the kernels - std::unique_ptr op_kernel; + const KernelCreateInfo& kci = GetNodeKernelCreateInfo(node.Index()); + + // the execution provider was required to be valid to find the KernelCreateInfo so we don't need to check it here onnxruntime::ProviderType exec_provider_name = node.GetExecutionProviderType(); + const IExecutionProvider& exec_provider = *execution_providers_.Get(exec_provider_name); - const IExecutionProvider* exec_provider = nullptr; - if (exec_provider_name.empty() || (exec_provider = execution_providers_.Get(exec_provider_name)) == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Could not create kernel for node: ", node.Name(), - " as there's no execution provider allocated."); - } + auto op_kernel = kernel_registry_manager.CreateKernel(node, exec_provider, *this, kci); - common::Status status = custom_registry_manager.CreateKernel(node, *exec_provider, *this, op_kernel); - if (!status.IsOK()) { - return common::Status( - status.Category(), status.Code(), - MakeString("Kernel creation failed for node: ", node.Name(), " with error: ", status.ErrorMessage())); - } - assert(session_kernels_[node.Index()] == nullptr); // assumes vector is already resize()'ed to the number of nodes in the graph session_kernels_[node.Index()] = op_kernel.release(); } @@ -151,10 +171,6 @@ Status SessionState::CreateKernels(const KernelRegistryManager& custom_registry_ return Status::OK(); } -void SessionState::SetExecutionPlan(std::unique_ptr p_seq_exec_plan) { - p_seq_exec_plan_ = std::move(p_seq_exec_plan); -} - const SequentialExecutionPlan* SessionState::GetExecutionPlan() const { return p_seq_exec_plan_.get(); } Status SessionState::AddInitializedTensor(int ort_value_index, const OrtValue& ort_value, const OrtCallback* d, @@ -294,7 +310,7 @@ Status ResolveDimParams(const GraphViewer& graph, Status ResolveSizeAndShape( const NodeArg* arg, const std::unordered_map& symbolic_dimensions, - size_t& size, // total number of elements. It's 0 if shape is unknown. + size_t& size, // total number of elements. It's 0 if shape is unknown. std::vector& resolved_shape) { if (!arg->Shape()) { // 0 means no shape information. @@ -576,10 +592,6 @@ const SessionState* SessionState::GetSubgraphSessionState(onnxruntime::NodeIndex return const_cast(this)->GetMutableSubgraphSessionState(index, attribute_name); } -void SessionState::RemoveSubgraphSessionState(onnxruntime::NodeIndex index) { - subgraph_session_states_.erase(index); -} - const NodeIndexInfo& SessionState::GetNodeIndexInfo() const { ORT_ENFORCE(node_index_info_, "SetGraphAndCreateKernels must be called prior to GetExecutionInfo."); return *node_index_info_; @@ -618,4 +630,154 @@ const std::unordered_set* SessionState::GetToBeExecutedNodes( return (it != to_be_executed_nodes_.end()) ? &it->second : nullptr; } +Status SessionState::CreateSubgraphSessionState() { + for (auto& node : graph_.Nodes()) { + for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { + const auto& ep = node.GetExecutionProviderType(); + if (ep != kCpuExecutionProvider && ep != kCudaExecutionProvider) { + // SessionState is only used when ORT is executing the subgraph. If a non-ORT EP has taken the control flow + // node containing the subgraph it will create whatever state it needs internally. + continue; + } + + auto& attr_name = entry.first; + Graph* subgraph = entry.second; + ORT_ENFORCE(subgraph, "Main Graph instance should have populated all subgraphs when being resolved."); + + auto subgraph_session_state = + onnxruntime::make_unique(*subgraph, execution_providers_, enable_mem_pattern_, + thread_pool_, inter_op_thread_pool_, data_transfer_mgr_, + logger_, profiler_); + + // Pass fused function manager to subgraph + subgraph_session_state->fused_funcs_mgr_.SetFusedFuncs(fused_funcs_mgr_); + + // recurse + ORT_RETURN_IF_ERROR(subgraph_session_state->CreateSubgraphSessionState()); + + // add the subgraph SessionState instance to the parent graph SessionState so it can be retrieved + // by Compute() via OpKernelContextInternal. + AddSubgraphSessionState(node.Index(), attr_name, std::move(subgraph_session_state)); + } + } + + return Status::OK(); +} + +Status SessionState::FinalizeSessionState(const std::basic_string& graph_location, + KernelRegistryManager& kernel_registry_manager, + const SessionOptions& session_options, + bool remove_initializers) { + // recursively create the subgraph session state instances and populate the kernel create info in them. + // it's simpler to handle the kernel create info recursively when deserializing, + // so also do it recursively when calling PopulateKernelCreateInfo for consistency. + ORT_RETURN_IF_ERROR(CreateSubgraphSessionState()); + + ORT_RETURN_IF_ERROR(PopulateKernelCreateInfo(kernel_registry_manager)); + + return FinalizeSessionStateImpl(graph_location, kernel_registry_manager, nullptr, session_options, + remove_initializers); +} + +Status SessionState::FinalizeSessionStateImpl(const std::basic_string& graph_location, + KernelRegistryManager& kernel_registry_manager, + _In_opt_ const Node* parent_node, + const SessionOptions& session_options, + bool remove_initializers) { + CreateGraphInfo(); + + // ignore any outer scope args we don't know about. this can happen if a node contains multiple subgraphs. + std::vector valid_outer_scope_node_args; + if (parent_node) { + auto outer_scope_node_args = parent_node->ImplicitInputDefs(); + valid_outer_scope_node_args.reserve(outer_scope_node_args.size()); + + std::for_each(outer_scope_node_args.cbegin(), outer_scope_node_args.cend(), + [this, &valid_outer_scope_node_args](const NodeArg* node_arg) { + int idx; + if (ort_value_name_idx_map_.GetIdx(node_arg->Name(), idx).IsOK()) { + valid_outer_scope_node_args.push_back(node_arg); + }; + }); + } + + SequentialPlannerContext context(session_options.execution_mode); + ORT_RETURN_IF_ERROR(SequentialPlanner::CreatePlan(parent_node, *graph_viewer_, valid_outer_scope_node_args, + execution_providers_, kernel_create_info_map_, + ort_value_name_idx_map_, context, p_seq_exec_plan_)); + + // Uncomment the below to dump the allocation plan to std::cout + // LOGS(logger_, VERBOSE) << std::make_pair(p_seq_exec_plan_.get(), this); + + std::unique_ptr tensor_allocator_( + ITensorAllocator::Create(enable_mem_pattern_, *p_seq_exec_plan_, *this, weights_buffers_)); + + // move initializers from TensorProto instances in Graph to OrtValue instances in SessionState + ORT_RETURN_IF_ERROR( + session_state_utils::SaveInitializedTensors( + Env::Default(), graph_location, *graph_viewer_, + execution_providers_.GetDefaultCpuMemoryInfo(), + ort_value_name_idx_map_, *tensor_allocator_, + [this](int idx, const OrtValue& value, const OrtCallback& d, bool constant) -> Status { + return AddInitializedTensor(idx, value, &d, constant); + }, + logger_, data_transfer_mgr_)); + + // remove weights from the graph now to save memory but in many cases it won't save memory, if the tensor was + // preallocated with the some other tensors in a single 'allocate' call, which is very common. + // TODO: make it better + if (remove_initializers) { + CleanInitializedTensorsFromGraph(); + } + + ORT_RETURN_IF_ERROR(CreateKernels(kernel_registry_manager)); + + const auto disable_prepacking = + GetSessionConfigOrDefault(session_options, ORT_SESSION_OPTIONS_CONFIG_DISABLEPREPACKING, "0"); + + if (disable_prepacking != "1") { + ORT_RETURN_IF_ERROR(PrepackInitializedConstantTensors()); + } + + ORT_RETURN_IF_ERROR( + session_state_utils::SaveInputOutputNamesToNodeMapping(*graph_viewer_, *this, valid_outer_scope_node_args)); + + // Need to recurse into subgraph session state instances to finalize them and add the execution info + + // Currently all subgraphs need to be executed using the sequential EP due to potential deadlock with the current + // parallel executor implementation + SessionOptions subgraph_session_options(session_options); + subgraph_session_options.execution_mode = ExecutionMode::ORT_SEQUENTIAL; + + for (const auto& node_to_subgraph_ss : subgraph_session_states_) { + Node& node = *graph_.GetNode(node_to_subgraph_ss.first); + + for (const auto& attr_subgraph_pair : node.GetAttributeNameToMutableSubgraphMap()) { + auto& attr_name = attr_subgraph_pair.first; + auto entry = node_to_subgraph_ss.second.find(attr_name); + // CreateSubgraphSessionState should ensure all these entries are created + ORT_ENFORCE(entry != node_to_subgraph_ss.second.cend(), + "Missing session state for subgraph. Node:'", node.Name(), + "' OpType:", node.OpType(), " Index:", node.Index(), " Attribute:", attr_name); + + SessionState& subgraph_session_state = *entry->second; + + // recurse + ORT_RETURN_IF_ERROR(subgraph_session_state.FinalizeSessionStateImpl( + graph_location, kernel_registry_manager, &node, subgraph_session_options, remove_initializers)); + + // setup all the info for handling the feeds and fetches used in subgraph execution + auto* p_op_kernel = GetMutableKernel(node.Index()); + ORT_ENFORCE(p_op_kernel); + + // Downcast is safe, since only control flow nodes have subgraphs + // (node.GetAttributeNameToMutableSubgraphMap() is non-empty) + auto& control_flow_kernel = static_cast(*p_op_kernel); + ORT_RETURN_IF_ERROR(control_flow_kernel.SetupSubgraphExecutionInfo(*this, attr_name, subgraph_session_state)); + } + } + + return Status::OK(); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index 59a591bd0d..d4b4c17d46 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +//// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #pragma once @@ -7,26 +7,30 @@ #include #include #include + #include "gsl/gsl" -#include "core/graph/onnx_protobuf.h" + #include "core/common/common.h" #include "core/common/logging/logging.h" #include "core/common/profiler.h" #include "core/framework/allocation_planner.h" +#include "core/framework/callback.h" #include "core/framework/data_transfer_manager.h" #include "core/framework/execution_providers.h" #include "core/framework/feeds_fetches_manager.h" #include "core/framework/framework_common.h" +#include "core/framework/fuse_nodes_funcs.h" #include "core/framework/kernel_registry_manager.h" #include "core/framework/mem_pattern.h" #include "core/framework/ml_value.h" -#include "core/framework/callback.h" -#include "core/framework/ort_value_name_idx_map.h" #include "core/framework/node_index_info.h" +#include "core/framework/op_kernel.h" +#include "core/framework/ort_value_name_idx_map.h" #include "core/graph/graph_viewer.h" -#include "core/framework/fuse_nodes_funcs.h" -#include "core/platform/threadpool.h" +#include "core/graph/onnx_protobuf.h" #include "core/platform/ort_mutex.h" +#include "core/platform/path_lib.h" +#include "core/platform/threadpool.h" namespace onnxruntime { @@ -77,13 +81,6 @@ class SessionState { SetupAllocators(); } - // Populate OrtValueNameIdxMap and create the graph viewer. - // Call once all graph modifications like transforms are completed. - void CreateGraphInfo(); - - // Call CreateKernels after CreateGraphInfo - Status CreateKernels(const KernelRegistryManager& custom_registry_manager); - ~SessionState() { for (auto* p : session_kernels_) { delete p; @@ -141,18 +138,6 @@ class SessionState { */ const std::unordered_map& GetConstantInitializedTensors() const; - /** - Cleans the initialized tensors that have been added to SessionState as OrtValue instances from the Graph instance - where they are present as TensorProto instances and will not be used when executing the model. - */ - void CleanInitializedTensorsFromGraph(); - - /** - * Prepack the constant initialized tensors for better performance. - * The original constant initialized tensors will be removed to save memory. - */ - Status PrepackInitializedConstantTensors(); - #ifdef ENABLE_TRAINING /** Get some initialized tensors (weights). @@ -173,8 +158,7 @@ class SessionState { NameMLValMap GetInitializedTensors(const std::unordered_set& interested_weights) const; #endif - // execution plan - void SetExecutionPlan(std::unique_ptr p_seq_exec_plan); + // execution plan. nullptr until FinalizeSessionState is called const SequentialExecutionPlan* GetExecutionPlan() const; /** Get the logger for this session. @@ -235,6 +219,7 @@ class SessionState { }; using NameNodeInfoMapType = std::unordered_map>; + common::Status AddInputNameToNodeInfoMapping(const std::string& input_name, const NodeInfo& node_info); common::Status GetInputNodeInfo(const std::string& input_name, std::vector& node_info_vec) const; const NameNodeInfoMapType& GetInputNodeInfoMap() const; @@ -243,22 +228,12 @@ class SessionState { common::Status GetOutputNodeInfo(const std::string& output_name, std::vector& node_info_vec) const; const NameNodeInfoMapType& GetOutputNodeInfoMap() const; - /// Add a SessionState instance for executing a subgraph in a Node - /// @param index Index of Node containing subgraph - /// @param attribute_name Name of attribute containing the subgraph GraphProto - /// @param session_state SessionState for subgraph execution - void AddSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name, - std::unique_ptr session_state); + // Get the KernelCreateInfo entry for a node. SessionState must be finalized before calling. + const KernelCreateInfo& GetNodeKernelCreateInfo(NodeIndex node_index) const; /// Return SessionState for the given Node index and attribute name if found. const SessionState* GetSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name) const; - SessionState* GetMutableSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name); - - // Remove the SessionState for a node containing a subgraph. - // If the node isn't going to be executed by the CPU provider we don't need it. - void RemoveSubgraphSessionState(onnxruntime::NodeIndex index); - concurrency::ThreadPool* GetThreadPool() const noexcept { return thread_pool_; } concurrency::ThreadPool* GetInterOpThreadPool() const noexcept { return inter_op_thread_pool_; } @@ -276,11 +251,47 @@ class SessionState { void UpdateToBeExecutedNodes(const std::vector& fetch_mlvalue_idxs); const std::unordered_set* GetToBeExecutedNodes(const std::vector& fetch_mlvalue_idxs) const; + Status FinalizeSessionState(const std::basic_string& graph_loc, + KernelRegistryManager& kernel_registry_manager, + const SessionOptions& session_options = {}, + bool remove_initializers = true); + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SessionState); void SetupAllocators(); + // Populate OrtValueNameIdxMap and create the graph viewer. + void CreateGraphInfo(); + + // create kernels using info in kernel_create_info_map_ + Status CreateKernels(const KernelRegistryManager& custom_registry_manager); + + // remove TensorProto versions of initializers from Graph instance + // (replaced byOrtValue instances in initialized_tensors_) + void CleanInitializedTensorsFromGraph(); + + /** + * Prepack the constant initialized tensors for better performance. + * The original constant initialized tensors will be removed to save memory. + */ + Status PrepackInitializedConstantTensors(); + + SessionState* GetMutableSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name); + + Status CreateSubgraphSessionState(); + + void AddSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name, + std::unique_ptr session_state); + + Status PopulateKernelCreateInfo(KernelRegistryManager& kernel_registry_manager); + + Status FinalizeSessionStateImpl(const std::basic_string& graph_loc, + KernelRegistryManager& kernel_registry_manager, + _In_opt_ const Node* parent_node, + const SessionOptions& session_options, + bool remove_initializers); + #ifdef ENABLE_TRAINING Status GeneratePatternGroupCache( const std::vector>& input_shape, @@ -289,6 +300,9 @@ class SessionState { std::unordered_map& inferred_shapes) const; #endif + // KernelCreateInfo for each node so we do kernel lookup once + std::unordered_map> kernel_create_info_map_; + // cache of the constructed kernels to avoid spending construction time per executor std::vector session_kernels_; Graph& graph_; diff --git a/onnxruntime/core/framework/finalize_session_state.cc b/onnxruntime/core/framework/session_state_utils.cc similarity index 67% rename from onnxruntime/core/framework/finalize_session_state.cc rename to onnxruntime/core/framework/session_state_utils.cc index 35aa92c585..bed87aac22 100644 --- a/onnxruntime/core/framework/finalize_session_state.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "core/graph/onnx_protobuf.h" -#include "core/framework/finalize_session_state.h" +#include "core/framework/session_state_utils.h" #include #include @@ -26,89 +26,7 @@ #include "core/framework/tensor_allocator.h" namespace onnxruntime { - -// T should have signature of '(int idx, const OrtValue& value, const OrtCallback& d) -> Status' -template -static common::Status SaveInitializedTensors(const Env& env, const std::basic_string& graph_loc, - const GraphViewer& graph, const OrtMemoryInfo& default_cpu_memory_info, - const OrtValueNameIdxMap& ort_value_name_idx_map, - ITensorAllocator& planner, const T& save_tensor_func, - const logging::Logger& logger, - const DataTransferManager& data_transfer_mgr); - -static common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::GraphViewer& graph, - const KernelRegistryManager& custom_registry_manager, - SessionState& session_state, - const std::vector& implicit_inputs); - -Status FinalizeSessionState(SessionState& session_state, - const std::basic_string& graph_location, - KernelRegistryManager& kernel_registry_manager, - _In_opt_ const Node* parent_node, - const SessionOptions& session_options) { - session_state.CreateGraphInfo(); - - const GraphViewer& graph_viewer = session_state.GetGraphViewer(); - const auto& logger = session_state.Logger(); - - // populate the SessionState OrtValueNameIdxMap - const auto& ort_value_name_idx_map = session_state.GetOrtValueNameIdxMap(); - - // ignore any outer scope args we don't know about. this can happen if a node contains multiple subgraphs. - std::vector valid_outer_scope_node_args; - if (parent_node) { - auto outer_scope_node_args = parent_node->ImplicitInputDefs(); - valid_outer_scope_node_args.reserve(outer_scope_node_args.size()); - - std::for_each(outer_scope_node_args.cbegin(), outer_scope_node_args.cend(), - [&ort_value_name_idx_map, &valid_outer_scope_node_args](const NodeArg* node_arg) { - int idx; - if (ort_value_name_idx_map.GetIdx(node_arg->Name(), idx).IsOK()) { - valid_outer_scope_node_args.push_back(node_arg); - }; - }); - } - - std::unique_ptr exec_plan; - SequentialPlannerContext context(session_options.execution_mode); - ORT_RETURN_IF_ERROR(SequentialPlanner::CreatePlan(parent_node, graph_viewer, valid_outer_scope_node_args, - session_state.GetExecutionProviders(), kernel_registry_manager, - ort_value_name_idx_map, context, exec_plan)); - - const auto* exec_plan_ptr = exec_plan.get(); - session_state.SetExecutionPlan(std::move(exec_plan)); - - std::unique_ptr tensor_allocator_( - ITensorAllocator::Create(session_state.GetEnableMemoryPattern(), *exec_plan_ptr, session_state, - session_state.GetMutableWeightsBuffers())); - - // lambda to save initialized tensors into SessionState directly - const Env& env = Env::Default(); - ORT_RETURN_IF_ERROR(SaveInitializedTensors( - env, graph_location, graph_viewer, - session_state.GetExecutionProviders().GetDefaultCpuMemoryInfo(), - ort_value_name_idx_map, *tensor_allocator_, - [&session_state](int idx, const OrtValue& value, const OrtCallback& d, bool constant) -> Status { - return session_state.AddInitializedTensor(idx, value, &d, constant); - }, - logger, session_state.GetDataTransferMgr())); - - // remove weights from the graph now to save memory but in many cases it won't save memory, if the tensor was - // preallocated with the some other tensors in a single 'allocate' call, which is very common. - // TODO: make it better - session_state.CleanInitializedTensorsFromGraph(); - - ORT_RETURN_IF_ERROR(session_state.CreateKernels(kernel_registry_manager)); - - const auto disable_prepacking = - GetSessionConfigOrDefault(session_options, ORT_SESSION_OPTIONS_CONFIG_DISABLEPREPACKING, "0"); - if (disable_prepacking != "1") - ORT_RETURN_IF_ERROR(session_state.PrepackInitializedConstantTensors()); - - ORT_RETURN_IF_ERROR(SaveInputOutputNamesToNodeMapping(graph_viewer, kernel_registry_manager, session_state, - valid_outer_scope_node_args)); - return Status::OK(); -} +namespace session_state_utils { static common::Status DeserializeTensorProto(const Env& env, const std::basic_string& proto_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer& m, @@ -170,12 +88,12 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st return common::Status::OK(); } -template -common::Status SaveInitializedTensors(const Env& env, const std::basic_string& graph_loc, - const GraphViewer& graph, const OrtMemoryInfo& default_cpu_memory_info, - const OrtValueNameIdxMap& ort_value_name_idx_map, ITensorAllocator& planner, - const T& save_tensor_func, const logging::Logger& logger, - const DataTransferManager& data_transfer_mgr) { +common::Status SaveInitializedTensors( + const Env& env, const std::basic_string& graph_loc, + const GraphViewer& graph, const OrtMemoryInfo& default_cpu_memory_info, + const OrtValueNameIdxMap& ort_value_name_idx_map, ITensorAllocator& planner, + const std::function& save_tensor_func, + const logging::Logger& logger, const DataTransferManager& data_transfer_mgr) { LOGS(logger, INFO) << "Saving initialized tensors."; ORT_ENFORCE(ort_value_name_idx_map.MaxIdx() > -1, "OrtValue indexes should have been populated."); @@ -248,10 +166,9 @@ static bool IsArgNameInInputsOutputs(const std::string& name, return it != graph_args.cend(); } -static common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::GraphViewer& graph, - const KernelRegistryManager& custom_registry_manager, - SessionState& session_state, - const std::vector& implicit_inputs) { +common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::GraphViewer& graph, + SessionState& session_state, + const std::vector& implicit_inputs) { auto& graph_inputs = graph.GetInputsIncludingInitializers(); auto& graph_outputs = graph.GetOutputs(); @@ -259,9 +176,7 @@ static common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph const auto& name_to_id = session_state.GetOrtValueNameIdxMap(); for (auto& node : graph.Nodes()) { - // note that KernelCreateInfo may not exist for custom kernel - const KernelCreateInfo* kci = nullptr; - custom_registry_manager.SearchKernelRegistry(node, &kci); + const KernelCreateInfo& kci = session_state.GetNodeKernelCreateInfo(node.Index()); ORT_RETURN_IF_ERROR( onnxruntime::Node::ForEachWithIndex( @@ -275,7 +190,7 @@ static common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph ORT_RETURN_IF_ERROR(name_to_id.GetIdx(arg.Name(), arg_index)); const auto& device = exec_plan->GetLocation(arg_index).device; - SessionState::NodeInfo node_info(index, &node, kci, device); + SessionState::NodeInfo node_info(index, &node, &kci, device); if (IsArgNameInInputsOutputs(arg.Name(), graph_inputs)) { ORT_RETURN_IF_ERROR(session_state.AddInputNameToNodeInfoMapping(arg.Name(), node_info)); @@ -306,7 +221,7 @@ static common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph int arg_index; ORT_RETURN_IF_ERROR(name_to_id.GetIdx(input_def->Name(), arg_index)); auto& device = exec_plan->GetLocation(arg_index).device; - SessionState::NodeInfo node_info(std::numeric_limits::max(), &node, kci, device); + SessionState::NodeInfo node_info(std::numeric_limits::max(), &node, &kci, device); ORT_RETURN_IF_ERROR(session_state.AddInputNameToNodeInfoMapping(input_def->Name(), node_info)); } } @@ -323,7 +238,7 @@ static common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph ORT_RETURN_IF_ERROR(name_to_id.GetIdx(arg.Name(), arg_index)); const auto& device = exec_plan->GetLocation(arg_index).device; - SessionState::NodeInfo node_info(index, &node, kci, device); + SessionState::NodeInfo node_info(index, &node, &kci, device); if (IsArgNameInInputsOutputs(arg.Name(), graph_outputs)) { session_state.AddOutputNameToNodeInfoMapping(arg.Name(), node_info); @@ -361,4 +276,6 @@ static common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph return Status::OK(); } + +} // namespace session_state_utils } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state_utils.h b/onnxruntime/core/framework/session_state_utils.h new file mode 100644 index 0000000000..1c2110413d --- /dev/null +++ b/onnxruntime/core/framework/session_state_utils.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +#include "core/common/const_pointer_container.h" +#include "core/framework/allocator.h" +#include "core/framework/tensor.h" +#include "core/framework/tensor_allocator.h" +#include "core/framework/session_options.h" +#include "core/platform/path_lib.h" + +namespace onnxruntime { +class Env; +class KernelRegistryManager; +class Node; +class SessionState; +class GraphViewer; +class OrtValueNameIdxMap; +class DataTransferManager; +class NodeArg; + +namespace logging { +class Logger; +} + +namespace session_state_utils { +common::Status SaveInitializedTensors( + const Env& env, const std::basic_string& graph_loc, + const GraphViewer& graph, const OrtMemoryInfo& default_cpu_memory_info, + const OrtValueNameIdxMap& ort_value_name_idx_map, + ITensorAllocator& planner, + const std::function& save_tensor_func, + const logging::Logger& logger, + const DataTransferManager& data_transfer_mgr); + +common::Status SaveInputOutputNamesToNodeMapping(const GraphViewer& graph, + SessionState& session_state, + const std::vector& implicit_inputs); +} // namespace session_state_utils +} // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 4c6ed42341..5e8c7c74c9 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -28,7 +28,6 @@ #include "core/framework/ort_value_pattern_planner.h" #include "core/framework/mldata_type_utils.h" #include "core/framework/op_kernel_context_internal.h" -#include "core/framework/finalize_session_state.h" #include "core/framework/TensorSeq.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/tensor_type_and_shape.h" @@ -701,91 +700,6 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, return common::Status::OK(); } -/// Create SessionState instance for each subgraph as we need that for the GraphPartitioner -/// This will be initialized by InitializeSubgraphSessions. -common::Status InferenceSession::CreateSubgraphSessionState(Graph& graph, SessionState& session_state) { - for (auto& node : graph.Nodes()) { - for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { - auto& name = entry.first; - Graph* subgraph = entry.second; - ORT_ENFORCE(subgraph, "Main Graph instance should have populated all subgraphs when being resolved."); - - auto subgraph_session_state = onnxruntime::make_unique( - *subgraph, - execution_providers_, - session_state.GetEnableMemoryPattern(), - session_state.GetThreadPool(), - session_state.GetInterOpThreadPool(), - session_state.GetDataTransferMgr(), - *session_logger_, - session_profiler_); - - // Pass fused function manager to subgraph - subgraph_session_state->GetMutableFuncMgr().SetFusedFuncs(session_state.GetFuncMgr()); - - // recurse - ORT_RETURN_IF_ERROR_SESSIONID_(CreateSubgraphSessionState(*subgraph, *subgraph_session_state)); - - // add the subgraph SessionState instance to the parent graph SessionState so it can be retrieved - // by Compute() via OpKernelContextInternal. - session_state.AddSubgraphSessionState(node.Index(), name, std::move(subgraph_session_state)); - } - } - - return Status::OK(); -} - -/// iterate nodes in graph looking for ones with graph attribute/s -/// @param graph The graph to iterate -/// @param session_state The SessionState instance for 'graph'. -/// @remarks We pass in graph and session_state so we can handled nested subgraphs in the future -common::Status InferenceSession::InitializeSubgraphSessions(Graph& graph, SessionState& session_state) { - for (auto& node : graph.Nodes()) { - // We only need subgraph session state for control flow nodes being handled by our CPU or CUDA execution provider. - // Remove it if it's not needed. - if (node.ContainsSubgraph()) { - const auto ep = node.GetExecutionProviderType(); - if (ep != kCpuExecutionProvider && ep != kCudaExecutionProvider) { - session_state.RemoveSubgraphSessionState(node.Index()); - continue; - } - } else { - // not a control flow node - continue; - } - - for (const auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { - auto& name = entry.first; - Graph& subgraph = *entry.second; - - SessionState* subgraph_session_state = session_state.GetMutableSubgraphSessionState(node.Index(), name); - ORT_ENFORCE(subgraph_session_state, "CreateSubgraphSessionState should have created an entry earlier."); - - ORT_RETURN_IF_ERROR_SESSIONID_(FinalizeSessionState(*subgraph_session_state, - model_location_, - kernel_registry_manager_, - &node, - session_options_)); - - // LOGS(*session_logger_, VERBOSE) << std::make_pair(subgraph_info.session_state->GetExecutionPlan(), - // &*subgraph_info.session_state); - - // setup all the info for handling the feeds and fetches used in subgraph execution - auto* p_op_kernel = session_state.GetMutableKernel(node.Index()); - ORT_ENFORCE(p_op_kernel); - // Downcast is safe, since only control flow nodes have subgraphs (node.GetAttributeNameToMutableSubgraphMap() is non-empty) - auto& control_flow_kernel = static_cast(*p_op_kernel); - ORT_RETURN_IF_ERROR_SESSIONID_( - control_flow_kernel.SetupSubgraphExecutionInfo(session_state, name, *subgraph_session_state)); - - // recurse - ORT_RETURN_IF_ERROR_SESSIONID_(InitializeSubgraphSessions(subgraph, *subgraph_session_state)); - } - } - - return Status::OK(); -} - bool InferenceSession::IsInitialized() const { std::lock_guard l(session_mutex_); return is_inited_; @@ -892,17 +806,12 @@ common::Status InferenceSession::Initialize() { if (session_options_.execution_mode == ExecutionMode::ORT_PARALLEL && execution_providers_.Get(onnxruntime::kCudaExecutionProvider)) { - LOGS(*session_logger_, ERROR) << "Parallel execution mode doesn't support " - "CUDA Execution Provider currently."; - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "Parallel execution mode doesn't support " - "CUDA Execution Provider currently."); + status = common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Parallel execution mode doesn't support CUDA Execution Provider currently."); + LOGS(*session_logger_, ERROR) << status.ErrorMessage(); + return status; } - // add predefined transformers - AddPredefinedTransformers(graph_transformation_mgr_, session_options_.graph_optimization_level, - transformers_to_enable_); - onnxruntime::Graph& graph = model_->MainGraph(); // Collect the kernel registries from execution provider instances; @@ -915,8 +824,8 @@ common::Status InferenceSession::Initialize() { // Register 2nd registries into KernelRegistryManager. ORT_RETURN_IF_ERROR_SESSIONID_(kernel_registry_manager_.RegisterKernels(execution_providers_)); - // create SessionState for subgraphs as it's needed by the transformers - ORT_RETURN_IF_ERROR_SESSIONID_(CreateSubgraphSessionState(graph, *session_state_)); + AddPredefinedTransformers(graph_transformation_mgr_, session_options_.graph_optimization_level, + transformers_to_enable_); // apply any transformations to the main graph and any subgraphs ORT_RETURN_IF_ERROR_SESSIONID_(TransformGraph(graph, graph_transformation_mgr_, @@ -927,6 +836,13 @@ common::Status InferenceSession::Initialize() { // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs. ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve()); + // need to keep the initializers if we're going to save the optimized model + bool keep_initializers = !session_options_.optimized_model_filepath.empty(); + + ORT_RETURN_IF_ERROR_SESSIONID_(session_state_->FinalizeSessionState(model_location_, kernel_registry_manager_, + session_options_, + !keep_initializers)); + if (!session_options_.optimized_model_filepath.empty()) { // Serialize optimized ONNX model. ORT_RETURN_IF_ERROR_SESSIONID_(Model::Save(*model_, session_options_.optimized_model_filepath)); @@ -939,14 +855,6 @@ common::Status InferenceSession::Initialize() { } } - ORT_RETURN_IF_ERROR_SESSIONID_(FinalizeSessionState(*session_state_, - model_location_, - kernel_registry_manager_, - nullptr /*parent_node*/, - session_options_)); - - // handle any subgraphs - ORT_RETURN_IF_ERROR_SESSIONID_(InitializeSubgraphSessions(graph, *session_state_)); session_state_->ResolveMemoryPatternFlag(); is_inited_ = true; @@ -1404,7 +1312,7 @@ std::string InferenceSession::EndProfiling() { AllocatorPtr InferenceSession::GetAllocator(const OrtMemoryInfo& mem_info) const { return session_state_->GetAllocator(mem_info); - } +} // assumes model has already been loaded before common::Status InferenceSession::DoPostLoadProcessing(onnxruntime::Model& model) { diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index b32fd62dc7..aa8769ae2d 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -252,7 +252,7 @@ class InferenceSession { * @return OK if success. */ common::Status Run(const NameMLValMap& feeds, const std::vector& output_names, - std::vector* p_fetches) ORT_MUST_USE_RESULT; + std::vector* p_fetches) ORT_MUST_USE_RESULT; /** * See Run(const NameMLValMap& feeds, const std::vector& output_names, std::vector* p_fetches) @@ -319,18 +319,16 @@ class InferenceSession { */ const SessionOptions& GetSessionOptions() const; - /* * Get the DataTransferManager associated with this session */ const DataTransferManager& GetDataTransferManager() const; - + /* * Get all the providers' options this session was initialized with. */ const ProviderOptionsMap& GetAllProviderOptions() const; - /** * Start profiling on this inference session. This simply turns on profiling events to be * recorded. A corresponding EndProfiling has to follow to write profiling data to a file. @@ -360,8 +358,8 @@ class InferenceSession { * @return a ptr to the allocator or nullptr if not available */ AllocatorPtr GetAllocator(const OrtMemoryInfo& mem_info) const; - - /** + + /** *Get InferenceSession logger. */ const logging::Logger* GetLogger() const { return session_logger_; }; @@ -422,20 +420,16 @@ class InferenceSession { common::Status Load(std::function&)> loader, const std::string& event_name) ORT_MUST_USE_RESULT; + virtual void AddPredefinedTransformers(GraphTransformerManager& transformer_manager, + TransformerLevel graph_optimization_level, + const std::vector& custom_list); + common::Status TransformGraph(onnxruntime::Graph& graph, const onnxruntime::GraphTransformerManager& graph_transformer_mgr, const ExecutionProviders& providers, KernelRegistryManager& kernel_registry_manager, const InsertCastTransformer& insert_cast_transformer, SessionState& session_state) ORT_MUST_USE_RESULT; - common::Status CreateSubgraphSessionState(Graph& graph, SessionState& session_state) ORT_MUST_USE_RESULT; - - common::Status InitializeSubgraphSessions(Graph& graph, SessionState& session_state) ORT_MUST_USE_RESULT; - - virtual void AddPredefinedTransformers(GraphTransformerManager& transformer_manager, - TransformerLevel graph_optimization_level, - const std::vector& custom_list); - void InitLogger(logging::LoggingManager* logging_manager); common::Status CheckShapes(const std::string& input_name, const TensorShape& input_shape, diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index 8bfdb17237..96535ddc3c 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -14,6 +14,8 @@ #include "core/util/thread_utils.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "test/test_environment.h" +#include "test/util/include/asserts.h" + using namespace ONNX_NAMESPACE; namespace onnxruntime { @@ -140,7 +142,7 @@ class SequentialPlannerTestContext : public ISequentialPlannerContext { class PlannerTest : public ::testing::Test { private: void index(const std::string& name, int& out) { - ASSERT_TRUE(state_.GetOrtValueNameIdxMap().GetIdx(name, out).IsOK()); + ASSERT_TRUE(state_->GetOrtValueNameIdxMap().GetIdx(name, out).IsOK()); } onnxruntime::Model model_; @@ -160,7 +162,7 @@ class PlannerTest : public ::testing::Test { std::unique_ptr tp_; DataTransferManager dtm_; profiling::Profiler profiler_; - SessionState state_; + std::unique_ptr state_; ShapeMap shape_map_; std::unique_ptr plan_; @@ -169,15 +171,16 @@ class PlannerTest : public ::testing::Test { : model_("test", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 10}}, {}, DefaultLoggingManager().DefaultLogger()), graph_(model_.MainGraph()), tp_(concurrency::CreateThreadPool(&onnxruntime::Env::Default(), OrtThreadPoolParams(), - concurrency::ThreadPoolType::INTRA_OP)), - state_(graph_, execution_providers_, false, tp_.get(), nullptr, dtm_, DefaultLoggingManager().DefaultLogger(), - profiler_) { + concurrency::ThreadPoolType::INTRA_OP)) { std_kernel_ = KernelDefBuilder().SetName("Transpose").Provider(kCpuExecutionProvider).SinceVersion(1, 10).Build(); in_place_kernel_ = KernelDefBuilder().SetName("Relu").Provider(kCpuExecutionProvider).SinceVersion(1, 10).MayInplace(0, 0).Build(); CPUExecutionProviderInfo epi; auto execution_provider = onnxruntime::make_unique(epi); execution_providers_.Add("CPUExecutionProvider", std::move(execution_provider)); + + state_.reset(new SessionState(graph_, execution_providers_, false, tp_.get(), nullptr, dtm_, + DefaultLoggingManager().DefaultLogger(), profiler_)); } onnxruntime::NodeArg* Arg(const std::string& name) { @@ -203,12 +206,14 @@ class PlannerTest : public ::testing::Test { return AddNode(*in_place_kernel_, input, output); } - void BindKernel(onnxruntime::Node* p_node, ::onnxruntime::KernelDef& kernel_def, KernelRegistry* reg) { + void BindKernel(onnxruntime::Node* p_node, ::onnxruntime::KernelDef& kernel_def, KernelRegistry* reg, + std::unordered_map>& kernel_create_info_map) { const IExecutionProvider* ep = execution_providers_.Get(*p_node); ASSERT_NE(ep, nullptr); - auto info = onnxruntime::make_unique(*p_node, kernel_def, *ep, - state_.GetInitializedTensors(), state_.GetOrtValueNameIdxMap(), - state_.GetFuncMgr(), state_.GetDataTransferMgr()); + auto info = onnxruntime::make_unique( + *p_node, kernel_def, *ep, state_->GetInitializedTensors(), state_->GetOrtValueNameIdxMap(), + state_->GetFuncMgr(), state_->GetDataTransferMgr()); + op_kernel_infos_.push_back(std::move(info)); if (!KernelRegistry::HasImplementationOf(*reg, *p_node, onnxruntime::kCpuExecutionProvider)) { auto st = reg->Register( @@ -216,6 +221,10 @@ class PlannerTest : public ::testing::Test { [](const OpKernelInfo& info) -> OpKernel* { return new DummyOpKernel(info); })); ORT_ENFORCE(st.IsOK(), st.ErrorMessage()); } + + const KernelCreateInfo* kci; + ASSERT_STATUS_OK(reg->TryFindKernel(*p_node, "", &kci)); + kernel_create_info_map.insert({p_node->Index(), gsl::not_null(kci)}); } void SetShape(std::string& name, TensorShapeProto* shape) { shape_map_[Arg(name)] = shape; } @@ -229,26 +238,31 @@ class PlannerTest : public ::testing::Test { void CreatePlan(const std::vector& outer_scope_node_args = {}) { EXPECT_EQ(graph_.Resolve(), Status::OK()); - state_.CreateGraphInfo(); - std::shared_ptr reg = std::make_shared(); + std::unordered_map> kernel_create_info_map; for (auto& binding : kernel_bindings_) { - BindKernel(binding.first, binding.second, reg.get()); + BindKernel(binding.first, binding.second, reg.get(), kernel_create_info_map); } auto cpu_execution_provider = onnxruntime::make_unique(CPUExecutionProviderInfo()); KernelRegistryManager kernel_registry_manager; kernel_registry_manager.RegisterKernelRegistry(reg); - ExecutionProviders execution_providers; - execution_providers.Add(onnxruntime::kCpuExecutionProvider, std::move(cpu_execution_provider)); - auto status = kernel_registry_manager.RegisterKernels(execution_providers); + auto status = kernel_registry_manager.RegisterKernels(execution_providers_); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); - status = state_.CreateKernels(kernel_registry_manager); + + // CreatePlan is called inside FinalizeSessionState and usually the initializers are removed following that. + // Leave initializers so we can duplicate the call to CreatePlan from here to validate. + const bool remove_initializers = false; + status = state_->FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager, {}, + remove_initializers); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); SequentialPlannerTestContext test_context(&shape_map_); - status = SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph_), outer_scope_node_args, execution_providers, - kernel_registry_manager, state_.GetOrtValueNameIdxMap(), test_context, plan_); + + status = SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph_), outer_scope_node_args, execution_providers_, + kernel_create_info_map, state_->GetOrtValueNameIdxMap(), test_context, + plan_); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); AllocationPlanTestUtility::BasicIntegrityCheck(*plan_, name_to_arg_.size()); @@ -278,7 +292,7 @@ class PlannerTest : public ::testing::Test { protected: Graph& GetGraph() { return graph_; } const SequentialExecutionPlan& GetPlan() const { return *plan_; } - const SessionState& GetState() const { return state_; } + const SessionState& GetState() const { return *state_; } }; TEST_F(PlannerTest, ChainTest) { diff --git a/onnxruntime/test/framework/execution_frame_test.cc b/onnxruntime/test/framework/execution_frame_test.cc index 5e7dd7d509..5944e3ffb3 100644 --- a/onnxruntime/test/framework/execution_frame_test.cc +++ b/onnxruntime/test/framework/execution_frame_test.cc @@ -57,19 +57,9 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) { SessionState state(graph, execution_providers, true, &tp_, nullptr, dtm, DefaultLoggingManager().DefaultLogger(), profiler); - state.CreateGraphInfo(); - - ASSERT_STATUS_OK(state.CreateKernels(kernel_registry_manager)); - node->SetExecutionProviderType(xp_typ); - std::unique_ptr p_seq_exec_plan; - // TODO below line is for testing only. In production use SequentialPlanner::CreatePlan() - SequentialPlannerContext context(ExecutionMode::ORT_SEQUENTIAL); - ASSERT_STATUS_OK(SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph), {}, execution_providers, - kernel_registry_manager, state.GetOrtValueNameIdxMap(), context, - p_seq_exec_plan)); - state.SetExecutionPlan(std::move(p_seq_exec_plan)); + ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager)); vector outputs; ExecutionFrame frame({}, {}, {}, outputs, {}, state); @@ -146,10 +136,8 @@ TEST_F(ExecutionFrameTest, FeedInDataTest) { SessionState state(graph, execution_providers, true, &tp_, nullptr, dtm, DefaultLoggingManager().DefaultLogger(), profiler); - state.CreateGraphInfo(); - - ASSERT_STATUS_OK(state.CreateKernels(kernel_registry_manager)); - + ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager)); + const OrtValueNameIdxMap& mlvalue_name_idx_map = state.GetOrtValueNameIdxMap(); int x_idx = -1, y_idx = -1; ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("X", x_idx).IsOK()); @@ -206,9 +194,7 @@ TEST_F(ExecutionFrameTest, MemPatternTest) { SessionState state(graph, execution_providers, true, &tp_, nullptr, dtm, DefaultLoggingManager().DefaultLogger(), profiler); - state.CreateGraphInfo(); - - ASSERT_STATUS_OK(state.CreateKernels(kernel_registry_manager)); + ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager)); const OrtValueNameIdxMap& mlvalue_name_idx_map(state.GetOrtValueNameIdxMap()); @@ -235,14 +221,6 @@ TEST_F(ExecutionFrameTest, MemPatternTest) { std::vector{2, 3}, std::vector(6, 1.0f), &v3); - std::unique_ptr p_seq_exec_plan = onnxruntime::make_unique(); - SequentialPlannerContext context(ExecutionMode::ORT_SEQUENTIAL); - ASSERT_STATUS_OK(SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph), {}, execution_providers, - kernel_registry_manager, mlvalue_name_idx_map, context, - p_seq_exec_plan)); - - state.SetExecutionPlan(std::move(p_seq_exec_plan)); - vector outputs; ExecutionFrame frame({x1_idx, x2_idx, x3_idx}, {v1, v2, v3}, {t3_idx}, outputs, {}, state); diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index f96abdcba7..fecb8f194c 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -7,7 +7,6 @@ #include "core/framework/graph_partitioner.h" #include "core/framework/op_kernel.h" #include "core/framework/session_state.h" -#include "core/framework/finalize_session_state.h" #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" @@ -51,6 +50,10 @@ TEST_P(SessionStateAddGetKernelTest, AddGetKernelTest) { auto& graph = model.MainGraph(); ExecutionProviders execution_providers; + auto tmp_cpu_execution_provider = onnxruntime::make_unique(CPUExecutionProviderInfo(false)); + auto* cpu_execution_provider = tmp_cpu_execution_provider.get(); + ASSERT_STATUS_OK(execution_providers.Add(kCpuExecutionProvider, std::move(tmp_cpu_execution_provider))); + DataTransferManager dtm; profiling::Profiler profiler; SessionState s(graph, execution_providers, true, tp.get(), nullptr, dtm, @@ -67,7 +70,6 @@ TEST_P(SessionStateAddGetKernelTest, AddGetKernelTest) { auto status = graph.Resolve(); ASSERT_TRUE(status.IsOK()); auto kernel_def = KernelDefBuilder().SetName("Variable").Provider(kCpuExecutionProvider).SinceVersion(1, 10).Build(); - auto cpu_execution_provider = onnxruntime::make_unique(CPUExecutionProviderInfo(false)); OpKernelInfo p_info(node, *kernel_def, *cpu_execution_provider, s.GetConstantInitializedTensors(), s.GetOrtValueNameIdxMap(), s.GetFuncMgr(), s.GetDataTransferMgr()); @@ -76,7 +78,6 @@ TEST_P(SessionStateAddGetKernelTest, AddGetKernelTest) { size_t orig_num_outputs = p_kernel->Node().OutputDefs().size(); std::cout << "node_idx: " << node.Index() << std::endl; - ASSERT_STATUS_OK(execution_providers.Add(kCpuExecutionProvider, std::move(cpu_execution_provider))); KernelRegistryManager kernel_registry_manager; status = kernel_registry_manager.RegisterKernels(execution_providers); ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); @@ -85,8 +86,8 @@ TEST_P(SessionStateAddGetKernelTest, AddGetKernelTest) { ASSERT_STATUS_OK(kernel_registry->Register(KernelCreateInfo( std::move(kernel_def), [](const OpKernelInfo& info) -> OpKernel* { return new TestOpKernel(info); }))); kernel_registry_manager.RegisterKernelRegistry(kernel_registry); - s.CreateGraphInfo(); - ASSERT_STATUS_OK(s.CreateKernels(kernel_registry_manager)); + ASSERT_STATUS_OK(s.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager)); + auto test_kernel = s.GetKernel(node.Index()); std::cout << "orig: " << orig_num_outputs << " new: " << test_kernel->Node().OutputDefs().size() << std::endl; EXPECT_EQ(orig_num_outputs, test_kernel->Node().OutputDefs().size()); @@ -140,8 +141,7 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr()); ASSERT_TRUE(status.IsOK()) << status; - session_state.CreateGraphInfo(); - ASSERT_STATUS_OK(FinalizeSessionState(session_state, oss.str(), krm, nullptr, SessionOptions())); + ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm)); const auto& initialized_tensors = session_state.GetInitializedTensors(); const auto& const_initialized_tensors = session_state.GetConstantInitializedTensors(); @@ -259,11 +259,10 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) { SessionOptions sess_options; bool use_prepacking = GetParam(); sess_options.session_configurations[ORT_SESSION_OPTIONS_CONFIG_DISABLEPREPACKING] = use_prepacking ? "0" : "1"; - ASSERT_STATUS_OK(FinalizeSessionState(session_state, - std::basic_string() /*graph_loc*/, - kernel_registry_manager, - nullptr /*parent_node*/, - sess_options)); + ASSERT_STATUS_OK(session_state.FinalizeSessionState(std::basic_string(), + kernel_registry_manager, + sess_options)); + const auto& const_initialized_tensors = session_state.GetConstantInitializedTensors(); // check prepacking ASSERT_EQ(const_initialized_tensors.size(), size_t(use_prepacking ? 0 : 1)); diff --git a/onnxruntime/test/providers/memcpy_test.cc b/onnxruntime/test/providers/memcpy_test.cc index 180dbc575a..dc1812ed92 100644 --- a/onnxruntime/test/providers/memcpy_test.cc +++ b/onnxruntime/test/providers/memcpy_test.cc @@ -5,7 +5,6 @@ #include "../framework/test_utils.h" #include "core/graph/model.h" #include "core/graph/onnx_protobuf.h" -#include #include "core/framework/execution_providers.h" #include "core/framework/op_kernel.h" #include "core/framework/session_state.h" @@ -52,8 +51,8 @@ TEST(MemcpyTest, copy1) { SessionState s(model.MainGraph(), execution_providers, true, &tp, nullptr, dtm, DefaultLoggingManager().DefaultLogger(), profiler); - s.CreateGraphInfo(); - ASSERT_STATUS_OK(FinalizeSessionState(s, ORT_TSTR(""), kernel_registry_manager, nullptr, SessionOptions())); + SessionOptions so; + ASSERT_STATUS_OK(s.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager, so)); AllocatorPtr allocator = execution_providers.Get(onnxruntime::kCpuExecutionProvider)->GetAllocator(0, OrtMemTypeDefault);