mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-16 01:33:39 +00:00
Refactor session state finalization and kernel lookup usage (#4763)
* Refactor SessionState to support coming de/serialization changes - move more parts into SessionState to simplify usage - do the kernel lookup once instead of multiple times from different places - rename finalize_session_state.* to session_state_utils.* as the finalization logic is now inside SessionState * Fix some build issues * Move subgraph session state creation into SessionState. It's not needed by GraphPartitioner any more so we can delay the creation until later. Fixes issue where EP may have removed the subgraph during partitioning when taking a control flow node, and SessionState thought the subgraph was still valid. * Address PR comments * Clarify a comment
This commit is contained in:
parent
d7233c7c97
commit
dc50aa42d5
15 changed files with 459 additions and 470 deletions
|
|
@ -100,11 +100,22 @@ std::ostream& operator<<(std::ostream& out, std::pair<const SequentialExecutionP
|
|||
return out;
|
||||
}
|
||||
|
||||
static const KernelCreateInfo& GetKernelCreateInfo(
|
||||
const std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>>& kernel_create_info_map,
|
||||
NodeIndex node_index) {
|
||||
auto entry = kernel_create_info_map.find(node_index);
|
||||
ORT_ENFORCE(entry != kernel_create_info_map.cend(),
|
||||
"SessionState should have saved the KernelCreateInfo prior to this running. NodeIndex:", node_index);
|
||||
|
||||
return *entry->second;
|
||||
}
|
||||
|
||||
class PlannerImpl {
|
||||
public:
|
||||
PlannerImpl(const Node* parent_node, const onnxruntime::GraphViewer& graph_viewer,
|
||||
const std::vector<const NodeArg*>& outer_scope_node_args, const ExecutionProviders& providers,
|
||||
const KernelRegistryManager& kernel_registry, const OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
const std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>>& kernel_create_info_map,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
const ISequentialPlannerContext& context, SequentialExecutionPlan& plan)
|
||||
: context_(context),
|
||||
plan_(plan),
|
||||
|
|
@ -112,7 +123,7 @@ class PlannerImpl {
|
|||
graph_viewer_(graph_viewer),
|
||||
outer_scope_node_args_(outer_scope_node_args),
|
||||
execution_providers_(providers),
|
||||
kernel_registry_(kernel_registry),
|
||||
kernel_create_info_map_(kernel_create_info_map),
|
||||
ort_value_name_idx_map_(ort_value_name_idx_map) {}
|
||||
|
||||
Status CreatePlan();
|
||||
|
|
@ -126,7 +137,7 @@ class PlannerImpl {
|
|||
const std::vector<const NodeArg*>& outer_scope_node_args_;
|
||||
const ExecutionProviders& execution_providers_;
|
||||
|
||||
const KernelRegistryManager& kernel_registry_;
|
||||
const std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>>& kernel_create_info_map_;
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map_;
|
||||
|
||||
// OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation:
|
||||
|
|
@ -206,13 +217,13 @@ class PlannerImpl {
|
|||
// Find if there exists some input tensor that we can use in-place for output_arg_num-th input in the node.
|
||||
bool FindReusableInput(const onnxruntime::Node& node, int output_arg_num, OrtValueIndex* reusable_input) {
|
||||
auto p_output_arg = node.OutputDefs()[output_arg_num];
|
||||
const KernelCreateInfo* ci;
|
||||
Status st = kernel_registry_.SearchKernelRegistry(node, &ci);
|
||||
if (!st.IsOK() || ci == nullptr || ci->kernel_def == nullptr) {
|
||||
const KernelCreateInfo& ci = GetKernelCreateInfo(kernel_create_info_map_, node.Index());
|
||||
|
||||
if (ci.kernel_def == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const std::vector<std::pair<int, int>>& alias_map = ci->kernel_def->Alias();
|
||||
const std::vector<std::pair<int, int>>& alias_map = ci.kernel_def->Alias();
|
||||
auto input_args = node.InputDefs();
|
||||
for (auto pair : alias_map) {
|
||||
if (pair.second == output_arg_num) {
|
||||
|
|
@ -227,7 +238,7 @@ class PlannerImpl {
|
|||
}
|
||||
}
|
||||
|
||||
const std::vector<std::pair<int, int>>& inplace_map = ci->kernel_def->MayInplace();
|
||||
const std::vector<std::pair<int, int>>& inplace_map = ci.kernel_def->MayInplace();
|
||||
for (auto pair : inplace_map) {
|
||||
if (pair.second == output_arg_num) {
|
||||
if ((0 <= pair.first) && (static_cast<size_t>(pair.first) < input_args.size())) {
|
||||
|
|
@ -392,22 +403,17 @@ class PlannerImpl {
|
|||
|
||||
for (SequentialExecutionPlan::NodeExecutionPlan& step : plan_.execution_plan) {
|
||||
auto pnode = graph_viewer_.GetNode(step.node_index);
|
||||
if (pnode == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Can not find the node ", step.node_index);
|
||||
if (pnode == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Can not find the node ", step.node_index);
|
||||
}
|
||||
|
||||
// Identify where each output of this node should be allocated.
|
||||
// This is determined by the opkernel bound to the node.
|
||||
const KernelCreateInfo* kernel_create_info = nullptr;
|
||||
ORT_RETURN_IF_ERROR(kernel_registry_.SearchKernelRegistry(*pnode, &kernel_create_info));
|
||||
auto p_kernel_def = kernel_create_info->kernel_def.get();
|
||||
if (nullptr == p_kernel_def) {
|
||||
std::ostringstream errormsg;
|
||||
errormsg << "No suitable kernel definition found for op " << pnode->OpType();
|
||||
if (pnode->Op() != nullptr) {
|
||||
errormsg << "(" << pnode->Op()->since_version() << ")";
|
||||
}
|
||||
if (!pnode->Name().empty()) errormsg << " (node " << pnode->Name() << ")";
|
||||
return Status(ONNXRUNTIME, FAIL, errormsg.str());
|
||||
}
|
||||
// This is determined by the OpKernel bound to the node.
|
||||
const KernelCreateInfo& kernel_create_info = GetKernelCreateInfo(kernel_create_info_map_, pnode->Index());
|
||||
|
||||
const auto* p_kernel_def = kernel_create_info.kernel_def.get();
|
||||
|
||||
ORT_ENFORCE(p_kernel_def, "Should not have entry in kernel create info with nullptr for kernel_def");
|
||||
|
||||
auto exec_provider = execution_providers_.Get(*pnode);
|
||||
if (exec_provider == nullptr) {
|
||||
|
|
@ -484,11 +490,9 @@ class PlannerImpl {
|
|||
auto* p_provider = execution_providers_.Get(node);
|
||||
ORT_ENFORCE(p_provider);
|
||||
|
||||
const KernelCreateInfo* kernel_create_info;
|
||||
auto st = kernel_registry_.SearchKernelRegistry(node, &kernel_create_info);
|
||||
ORT_ENFORCE(st.IsOK(), st.ErrorMessage());
|
||||
ORT_ENFORCE(kernel_create_info != nullptr && kernel_create_info->kernel_def != nullptr);
|
||||
if (kernel_create_info->kernel_def->IsInputOnCpu(input_index))
|
||||
const KernelCreateInfo& kernel_create_info = GetKernelCreateInfo(kernel_create_info_map_, node.Index());
|
||||
|
||||
if (kernel_create_info.kernel_def->IsInputOnCpu(input_index))
|
||||
// weights are not output from any node, so it's OK to put its location on CPU provider
|
||||
return execution_providers_.GetDefaultCpuMemoryInfo();
|
||||
return p_provider->GetAllocator(0, OrtMemTypeDefault)->Info();
|
||||
|
|
@ -769,17 +773,20 @@ Status PlannerImpl::CreatePlan() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SequentialPlanner::CreatePlan(const Node* parent_node, const onnxruntime::GraphViewer& graph_viewer,
|
||||
const std::vector<const NodeArg*>& outer_scope_node_args,
|
||||
const ExecutionProviders& providers, const KernelRegistryManager& kernel_registry,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
const ISequentialPlannerContext& context,
|
||||
std::unique_ptr<SequentialExecutionPlan>& plan) {
|
||||
Status SequentialPlanner::CreatePlan(
|
||||
const Node* parent_node,
|
||||
const onnxruntime::GraphViewer& graph_viewer,
|
||||
const std::vector<const NodeArg*>& outer_scope_node_args,
|
||||
const ExecutionProviders& providers,
|
||||
const std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>>& kernel_create_info_map,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
const ISequentialPlannerContext& context,
|
||||
std::unique_ptr<SequentialExecutionPlan>& plan) {
|
||||
// allocate/reset here so we know it's clean
|
||||
plan = onnxruntime::make_unique<SequentialExecutionPlan>();
|
||||
|
||||
PlannerImpl planner(parent_node, graph_viewer, outer_scope_node_args, providers, kernel_registry,
|
||||
ort_value_name_idx_map, context, *plan);
|
||||
PlannerImpl planner(parent_node, graph_viewer, outer_scope_node_args, providers,
|
||||
kernel_create_info_map, ort_value_name_idx_map, context, *plan);
|
||||
|
||||
return planner.CreatePlan();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ class TensorShapeProto;
|
|||
namespace onnxruntime {
|
||||
|
||||
class ExecutionProviders;
|
||||
struct KernelCreateInfo;
|
||||
class KernelRegistryManager;
|
||||
class OrtValueNameIdxMap;
|
||||
|
||||
|
|
@ -48,11 +49,14 @@ class SequentialPlannerContext : public ISequentialPlannerContext {
|
|||
class SequentialPlanner {
|
||||
public:
|
||||
// This API allows user to provide a custom planner context.
|
||||
static Status CreatePlan(const Node* parent_node, const onnxruntime::GraphViewer& graph,
|
||||
const std::vector<const NodeArg*>& outer_scope_node_args,
|
||||
const ExecutionProviders& providers, const KernelRegistryManager& kernel_registry,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map, const ISequentialPlannerContext& context,
|
||||
std::unique_ptr<SequentialExecutionPlan>& plan);
|
||||
static Status CreatePlan(
|
||||
const Node* parent_node, const onnxruntime::GraphViewer& graph,
|
||||
const std::vector<const NodeArg*>& outer_scope_node_args,
|
||||
const ExecutionProviders& providers,
|
||||
const std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>>& kernel_create_info_map,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
const ISequentialPlannerContext& context,
|
||||
std::unique_ptr<SequentialExecutionPlan>& plan);
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,25 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include <map>
|
||||
|
||||
#include "core/common/const_pointer_container.h"
|
||||
#include "core/framework/allocator.h"
|
||||
#include "core/framework/tensor.h"
|
||||
#include "core/framework/tensor_allocator.h"
|
||||
#include "core/framework/session_options.h"
|
||||
#include "core/platform/path_lib.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
class KernelRegistryManager;
|
||||
class Node;
|
||||
class SessionState;
|
||||
|
||||
Status FinalizeSessionState(SessionState& session_state,
|
||||
const std::basic_string<PATH_CHAR_TYPE>& graph_loc,
|
||||
KernelRegistryManager& kernel_registry_manager,
|
||||
_In_opt_ const Node* parent_node,
|
||||
const SessionOptions& session_options);
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -10,48 +10,18 @@
|
|||
using namespace ONNX_NAMESPACE;
|
||||
using namespace ::onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
Status KernelRegistryManager::CreateKernel(const onnxruntime::Node& node,
|
||||
const IExecutionProvider& execution_provider,
|
||||
const SessionState& session_state,
|
||||
/*out*/ std::unique_ptr<OpKernel>& op_kernel) const {
|
||||
auto create_error_message = [&node](const std::string& error) {
|
||||
std::ostringstream errormsg;
|
||||
errormsg << error << node.OpType();
|
||||
if (node.Op() != nullptr) errormsg << "(" << node.Op()->since_version() << ")";
|
||||
if (!node.Name().empty()) errormsg << " (node " << node.Name() << ")";
|
||||
return errormsg.str();
|
||||
};
|
||||
std::unique_ptr<OpKernel> KernelRegistryManager::CreateKernel(const onnxruntime::Node& node,
|
||||
const IExecutionProvider& execution_provider,
|
||||
const SessionState& session_state,
|
||||
const KernelCreateInfo& kernel_create_info) const {
|
||||
OpKernelInfo kernel_info(node, *kernel_create_info.kernel_def, execution_provider,
|
||||
session_state.GetConstantInitializedTensors(),
|
||||
session_state.GetOrtValueNameIdxMap(),
|
||||
session_state.GetFuncMgr(),
|
||||
session_state.GetDataTransferMgr());
|
||||
|
||||
const std::string& ptype = node.GetExecutionProviderType();
|
||||
if (ptype.empty()) {
|
||||
return Status(ONNXRUNTIME, FAIL,
|
||||
create_error_message("The node is not placed on any Execution Provider, "
|
||||
"therefore, can't find a suitable kernel for "));
|
||||
}
|
||||
|
||||
Status status;
|
||||
{
|
||||
for (auto& registry : custom_kernel_registries_) {
|
||||
status = registry->TryCreateKernel(node, execution_provider, session_state.GetConstantInitializedTensors(),
|
||||
session_state.GetOrtValueNameIdxMap(), session_state.GetFuncMgr(), session_state.GetDataTransferMgr(), op_kernel);
|
||||
if (status.IsOK()) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
KernelRegistry* p = nullptr;
|
||||
auto iter = provider_type_to_registry_.find(ptype);
|
||||
if (iter != provider_type_to_registry_.end()) p = iter->second.get();
|
||||
if (p != nullptr) {
|
||||
status = p->TryCreateKernel(node, execution_provider, session_state.GetConstantInitializedTensors(),
|
||||
session_state.GetOrtValueNameIdxMap(), session_state.GetFuncMgr(), session_state.GetDataTransferMgr(), op_kernel);
|
||||
if (status.IsOK()) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
return Status(ONNXRUNTIME, NOT_IMPLEMENTED, create_error_message("Failed to find kernel for "));
|
||||
// OpKernel is abstract base class so can't use make_unique
|
||||
return std::unique_ptr<OpKernel>(kernel_create_info.kernel_create_func(kernel_info));
|
||||
}
|
||||
|
||||
Status KernelRegistryManager::RegisterKernels(const ExecutionProviders& execution_providers) {
|
||||
|
|
@ -88,32 +58,40 @@ bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r,
|
|||
|
||||
Status KernelRegistryManager::SearchKernelRegistry(const onnxruntime::Node& node,
|
||||
/*out*/ const KernelCreateInfo** kernel_create_info) const {
|
||||
Status status;
|
||||
|
||||
auto create_error_message = [&node, &status](const std::string& prefix) {
|
||||
std::ostringstream errormsg;
|
||||
errormsg << prefix << node.OpType();
|
||||
if (node.Op() != nullptr) errormsg << "(" << node.Op()->since_version() << ")";
|
||||
if (!node.Name().empty()) errormsg << " (node " << node.Name() << "). ";
|
||||
if (!status.IsOK()) errormsg << status.ErrorMessage();
|
||||
|
||||
return errormsg.str();
|
||||
};
|
||||
|
||||
const std::string& ptype = node.GetExecutionProviderType();
|
||||
if (ptype.empty()) {
|
||||
return Status(ONNXRUNTIME, FAIL, "The node is not placed on any Execution Provider");
|
||||
return Status(ONNXRUNTIME, FAIL, create_error_message("The node is not placed on any Execution Provider. "));
|
||||
}
|
||||
Status status;
|
||||
{
|
||||
for (auto& registry : custom_kernel_registries_) {
|
||||
status = registry->TryFindKernel(node, std::string(), kernel_create_info);
|
||||
if (status.IsOK()) return status;
|
||||
}
|
||||
|
||||
for (auto& registry : custom_kernel_registries_) {
|
||||
status = registry->TryFindKernel(node, std::string(), kernel_create_info);
|
||||
if (status.IsOK()) return status;
|
||||
}
|
||||
|
||||
KernelRegistry* p = nullptr;
|
||||
auto iter = provider_type_to_registry_.find(ptype);
|
||||
if (iter != provider_type_to_registry_.end()) p = iter->second.get();
|
||||
if (iter != provider_type_to_registry_.end()) {
|
||||
p = iter->second.get();
|
||||
}
|
||||
|
||||
if (p != nullptr) {
|
||||
status = p->TryFindKernel(node, std::string(), kernel_create_info);
|
||||
if (status.IsOK()) return status;
|
||||
}
|
||||
|
||||
std::ostringstream errormsg;
|
||||
errormsg << "Failed to find kernel for " << node.OpType();
|
||||
if (node.Op() != nullptr) errormsg << "(" << node.Op()->since_version() << ")";
|
||||
if (!node.Name().empty()) errormsg << " (node " << node.Name() << ").";
|
||||
if (!status.IsOK()) errormsg << status.ErrorMessage();
|
||||
return Status(ONNXRUNTIME, NOT_IMPLEMENTED, errormsg.str());
|
||||
return Status(ONNXRUNTIME, NOT_IMPLEMENTED, create_error_message("Failed to find kernel for "));
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -42,17 +42,15 @@ class KernelRegistryManager {
|
|||
// Then B > A > providers
|
||||
void RegisterKernelRegistry(std::shared_ptr<KernelRegistry> kernel_registry);
|
||||
|
||||
// This function assumes the node is already assigned to an execution provider
|
||||
// Don't call this function before graph partition is done
|
||||
Status CreateKernel(const onnxruntime::Node& node, const IExecutionProvider& execution_provider,
|
||||
const SessionState& session_state,
|
||||
/*out*/ std::unique_ptr<OpKernel>& op_kernel) const ORT_MUST_USE_RESULT;
|
||||
|
||||
// This function assumes the node is already assigned to an execution provider
|
||||
// Don't call this function before graph partition is done
|
||||
Status SearchKernelRegistry(const onnxruntime::Node& node,
|
||||
/*out*/ const KernelCreateInfo** kernel_create_info) const;
|
||||
|
||||
std::unique_ptr<OpKernel> CreateKernel(const onnxruntime::Node& node, const IExecutionProvider& execution_provider,
|
||||
const SessionState& session_state,
|
||||
const KernelCreateInfo& kernel_create_info) const ORT_MUST_USE_RESULT;
|
||||
|
||||
/**
|
||||
* Whether this node can be run on this provider
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -7,11 +7,13 @@
|
|||
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/framework/allocator.h"
|
||||
#include "core/framework/node_index_info.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/framework/utils.h"
|
||||
#include "core/framework/ort_value_pattern_planner.h"
|
||||
#include "core/framework/allocator.h"
|
||||
#include "core/framework/session_state_utils.h"
|
||||
#include "core/framework/utils.h"
|
||||
#include "core/providers/cpu/controlflow/utils.h"
|
||||
|
||||
using namespace ::onnxruntime::common;
|
||||
|
||||
|
|
@ -116,7 +118,33 @@ void SessionState::CreateGraphInfo() {
|
|||
LOGS(logger_, VERBOSE) << "Done saving OrtValue mappings.";
|
||||
}
|
||||
|
||||
Status SessionState::CreateKernels(const KernelRegistryManager& custom_registry_manager) {
|
||||
Status SessionState::PopulateKernelCreateInfo(KernelRegistryManager& kernel_registry_manager) {
|
||||
for (auto& node : graph_.Nodes()) {
|
||||
const KernelCreateInfo* kci = nullptr;
|
||||
ORT_RETURN_IF_ERROR(kernel_registry_manager.SearchKernelRegistry(node, &kci));
|
||||
ORT_IGNORE_RETURN_VALUE(
|
||||
kernel_create_info_map_.insert({node.Index(), gsl::not_null<const KernelCreateInfo*>(kci)}));
|
||||
}
|
||||
|
||||
for (const auto& entry : subgraph_session_states_) {
|
||||
for (const auto& name_to_subgraph_session_state : entry.second) {
|
||||
SessionState& subgraph_session_state = *name_to_subgraph_session_state.second;
|
||||
ORT_RETURN_IF_ERROR(subgraph_session_state.PopulateKernelCreateInfo(kernel_registry_manager));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const KernelCreateInfo& SessionState::GetNodeKernelCreateInfo(NodeIndex node_index) const {
|
||||
auto entry = kernel_create_info_map_.find(node_index);
|
||||
// invalid node index or FinalizeSessionState should have been called. Either way it's an internal logic error
|
||||
ORT_ENFORCE(entry != kernel_create_info_map_.cend());
|
||||
|
||||
return *entry->second;
|
||||
}
|
||||
|
||||
Status SessionState::CreateKernels(const KernelRegistryManager& kernel_registry_manager) {
|
||||
const GraphNodes& nodes = graph_viewer_->Nodes();
|
||||
if (!nodes.empty()) {
|
||||
size_t max_nodeid = 0;
|
||||
|
|
@ -127,22 +155,14 @@ Status SessionState::CreateKernels(const KernelRegistryManager& custom_registry_
|
|||
session_kernels_.resize(max_nodeid + 1, nullptr);
|
||||
for (auto& node : graph_viewer_->Nodes()) {
|
||||
// construct and save the kernels
|
||||
std::unique_ptr<OpKernel> op_kernel;
|
||||
const KernelCreateInfo& kci = GetNodeKernelCreateInfo(node.Index());
|
||||
|
||||
// the execution provider was required to be valid to find the KernelCreateInfo so we don't need to check it here
|
||||
onnxruntime::ProviderType exec_provider_name = node.GetExecutionProviderType();
|
||||
const IExecutionProvider& exec_provider = *execution_providers_.Get(exec_provider_name);
|
||||
|
||||
const IExecutionProvider* exec_provider = nullptr;
|
||||
if (exec_provider_name.empty() || (exec_provider = execution_providers_.Get(exec_provider_name)) == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Could not create kernel for node: ", node.Name(),
|
||||
" as there's no execution provider allocated.");
|
||||
}
|
||||
auto op_kernel = kernel_registry_manager.CreateKernel(node, exec_provider, *this, kci);
|
||||
|
||||
common::Status status = custom_registry_manager.CreateKernel(node, *exec_provider, *this, op_kernel);
|
||||
if (!status.IsOK()) {
|
||||
return common::Status(
|
||||
status.Category(), status.Code(),
|
||||
MakeString("Kernel creation failed for node: ", node.Name(), " with error: ", status.ErrorMessage()));
|
||||
}
|
||||
assert(session_kernels_[node.Index()] == nullptr);
|
||||
// assumes vector is already resize()'ed to the number of nodes in the graph
|
||||
session_kernels_[node.Index()] = op_kernel.release();
|
||||
}
|
||||
|
|
@ -151,10 +171,6 @@ Status SessionState::CreateKernels(const KernelRegistryManager& custom_registry_
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
void SessionState::SetExecutionPlan(std::unique_ptr<SequentialExecutionPlan> p_seq_exec_plan) {
|
||||
p_seq_exec_plan_ = std::move(p_seq_exec_plan);
|
||||
}
|
||||
|
||||
const SequentialExecutionPlan* SessionState::GetExecutionPlan() const { return p_seq_exec_plan_.get(); }
|
||||
|
||||
Status SessionState::AddInitializedTensor(int ort_value_index, const OrtValue& ort_value, const OrtCallback* d,
|
||||
|
|
@ -294,7 +310,7 @@ Status ResolveDimParams(const GraphViewer& graph,
|
|||
Status ResolveSizeAndShape(
|
||||
const NodeArg* arg,
|
||||
const std::unordered_map<std::string, int64_t>& symbolic_dimensions,
|
||||
size_t& size, // total number of elements. It's 0 if shape is unknown.
|
||||
size_t& size, // total number of elements. It's 0 if shape is unknown.
|
||||
std::vector<int64_t>& resolved_shape) {
|
||||
if (!arg->Shape()) {
|
||||
// 0 means no shape information.
|
||||
|
|
@ -576,10 +592,6 @@ const SessionState* SessionState::GetSubgraphSessionState(onnxruntime::NodeIndex
|
|||
return const_cast<SessionState*>(this)->GetMutableSubgraphSessionState(index, attribute_name);
|
||||
}
|
||||
|
||||
void SessionState::RemoveSubgraphSessionState(onnxruntime::NodeIndex index) {
|
||||
subgraph_session_states_.erase(index);
|
||||
}
|
||||
|
||||
const NodeIndexInfo& SessionState::GetNodeIndexInfo() const {
|
||||
ORT_ENFORCE(node_index_info_, "SetGraphAndCreateKernels must be called prior to GetExecutionInfo.");
|
||||
return *node_index_info_;
|
||||
|
|
@ -618,4 +630,154 @@ const std::unordered_set<NodeIndex>* SessionState::GetToBeExecutedNodes(
|
|||
return (it != to_be_executed_nodes_.end()) ? &it->second : nullptr;
|
||||
}
|
||||
|
||||
Status SessionState::CreateSubgraphSessionState() {
|
||||
for (auto& node : graph_.Nodes()) {
|
||||
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
|
||||
const auto& ep = node.GetExecutionProviderType();
|
||||
if (ep != kCpuExecutionProvider && ep != kCudaExecutionProvider) {
|
||||
// SessionState is only used when ORT is executing the subgraph. If a non-ORT EP has taken the control flow
|
||||
// node containing the subgraph it will create whatever state it needs internally.
|
||||
continue;
|
||||
}
|
||||
|
||||
auto& attr_name = entry.first;
|
||||
Graph* subgraph = entry.second;
|
||||
ORT_ENFORCE(subgraph, "Main Graph instance should have populated all subgraphs when being resolved.");
|
||||
|
||||
auto subgraph_session_state =
|
||||
onnxruntime::make_unique<SessionState>(*subgraph, execution_providers_, enable_mem_pattern_,
|
||||
thread_pool_, inter_op_thread_pool_, data_transfer_mgr_,
|
||||
logger_, profiler_);
|
||||
|
||||
// Pass fused function manager to subgraph
|
||||
subgraph_session_state->fused_funcs_mgr_.SetFusedFuncs(fused_funcs_mgr_);
|
||||
|
||||
// recurse
|
||||
ORT_RETURN_IF_ERROR(subgraph_session_state->CreateSubgraphSessionState());
|
||||
|
||||
// add the subgraph SessionState instance to the parent graph SessionState so it can be retrieved
|
||||
// by Compute() via OpKernelContextInternal.
|
||||
AddSubgraphSessionState(node.Index(), attr_name, std::move(subgraph_session_state));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SessionState::FinalizeSessionState(const std::basic_string<PATH_CHAR_TYPE>& graph_location,
|
||||
KernelRegistryManager& kernel_registry_manager,
|
||||
const SessionOptions& session_options,
|
||||
bool remove_initializers) {
|
||||
// recursively create the subgraph session state instances and populate the kernel create info in them.
|
||||
// it's simpler to handle the kernel create info recursively when deserializing,
|
||||
// so also do it recursively when calling PopulateKernelCreateInfo for consistency.
|
||||
ORT_RETURN_IF_ERROR(CreateSubgraphSessionState());
|
||||
|
||||
ORT_RETURN_IF_ERROR(PopulateKernelCreateInfo(kernel_registry_manager));
|
||||
|
||||
return FinalizeSessionStateImpl(graph_location, kernel_registry_manager, nullptr, session_options,
|
||||
remove_initializers);
|
||||
}
|
||||
|
||||
Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_TYPE>& graph_location,
|
||||
KernelRegistryManager& kernel_registry_manager,
|
||||
_In_opt_ const Node* parent_node,
|
||||
const SessionOptions& session_options,
|
||||
bool remove_initializers) {
|
||||
CreateGraphInfo();
|
||||
|
||||
// ignore any outer scope args we don't know about. this can happen if a node contains multiple subgraphs.
|
||||
std::vector<const NodeArg*> valid_outer_scope_node_args;
|
||||
if (parent_node) {
|
||||
auto outer_scope_node_args = parent_node->ImplicitInputDefs();
|
||||
valid_outer_scope_node_args.reserve(outer_scope_node_args.size());
|
||||
|
||||
std::for_each(outer_scope_node_args.cbegin(), outer_scope_node_args.cend(),
|
||||
[this, &valid_outer_scope_node_args](const NodeArg* node_arg) {
|
||||
int idx;
|
||||
if (ort_value_name_idx_map_.GetIdx(node_arg->Name(), idx).IsOK()) {
|
||||
valid_outer_scope_node_args.push_back(node_arg);
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
SequentialPlannerContext context(session_options.execution_mode);
|
||||
ORT_RETURN_IF_ERROR(SequentialPlanner::CreatePlan(parent_node, *graph_viewer_, valid_outer_scope_node_args,
|
||||
execution_providers_, kernel_create_info_map_,
|
||||
ort_value_name_idx_map_, context, p_seq_exec_plan_));
|
||||
|
||||
// Uncomment the below to dump the allocation plan to std::cout
|
||||
// LOGS(logger_, VERBOSE) << std::make_pair(p_seq_exec_plan_.get(), this);
|
||||
|
||||
std::unique_ptr<ITensorAllocator> tensor_allocator_(
|
||||
ITensorAllocator::Create(enable_mem_pattern_, *p_seq_exec_plan_, *this, weights_buffers_));
|
||||
|
||||
// move initializers from TensorProto instances in Graph to OrtValue instances in SessionState
|
||||
ORT_RETURN_IF_ERROR(
|
||||
session_state_utils::SaveInitializedTensors(
|
||||
Env::Default(), graph_location, *graph_viewer_,
|
||||
execution_providers_.GetDefaultCpuMemoryInfo(),
|
||||
ort_value_name_idx_map_, *tensor_allocator_,
|
||||
[this](int idx, const OrtValue& value, const OrtCallback& d, bool constant) -> Status {
|
||||
return AddInitializedTensor(idx, value, &d, constant);
|
||||
},
|
||||
logger_, data_transfer_mgr_));
|
||||
|
||||
// remove weights from the graph now to save memory but in many cases it won't save memory, if the tensor was
|
||||
// preallocated with the some other tensors in a single 'allocate' call, which is very common.
|
||||
// TODO: make it better
|
||||
if (remove_initializers) {
|
||||
CleanInitializedTensorsFromGraph();
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(CreateKernels(kernel_registry_manager));
|
||||
|
||||
const auto disable_prepacking =
|
||||
GetSessionConfigOrDefault(session_options, ORT_SESSION_OPTIONS_CONFIG_DISABLEPREPACKING, "0");
|
||||
|
||||
if (disable_prepacking != "1") {
|
||||
ORT_RETURN_IF_ERROR(PrepackInitializedConstantTensors());
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(
|
||||
session_state_utils::SaveInputOutputNamesToNodeMapping(*graph_viewer_, *this, valid_outer_scope_node_args));
|
||||
|
||||
// Need to recurse into subgraph session state instances to finalize them and add the execution info
|
||||
|
||||
// Currently all subgraphs need to be executed using the sequential EP due to potential deadlock with the current
|
||||
// parallel executor implementation
|
||||
SessionOptions subgraph_session_options(session_options);
|
||||
subgraph_session_options.execution_mode = ExecutionMode::ORT_SEQUENTIAL;
|
||||
|
||||
for (const auto& node_to_subgraph_ss : subgraph_session_states_) {
|
||||
Node& node = *graph_.GetNode(node_to_subgraph_ss.first);
|
||||
|
||||
for (const auto& attr_subgraph_pair : node.GetAttributeNameToMutableSubgraphMap()) {
|
||||
auto& attr_name = attr_subgraph_pair.first;
|
||||
auto entry = node_to_subgraph_ss.second.find(attr_name);
|
||||
// CreateSubgraphSessionState should ensure all these entries are created
|
||||
ORT_ENFORCE(entry != node_to_subgraph_ss.second.cend(),
|
||||
"Missing session state for subgraph. Node:'", node.Name(),
|
||||
"' OpType:", node.OpType(), " Index:", node.Index(), " Attribute:", attr_name);
|
||||
|
||||
SessionState& subgraph_session_state = *entry->second;
|
||||
|
||||
// recurse
|
||||
ORT_RETURN_IF_ERROR(subgraph_session_state.FinalizeSessionStateImpl(
|
||||
graph_location, kernel_registry_manager, &node, subgraph_session_options, remove_initializers));
|
||||
|
||||
// setup all the info for handling the feeds and fetches used in subgraph execution
|
||||
auto* p_op_kernel = GetMutableKernel(node.Index());
|
||||
ORT_ENFORCE(p_op_kernel);
|
||||
|
||||
// Downcast is safe, since only control flow nodes have subgraphs
|
||||
// (node.GetAttributeNameToMutableSubgraphMap() is non-empty)
|
||||
auto& control_flow_kernel = static_cast<controlflow::IControlFlowKernel&>(*p_op_kernel);
|
||||
ORT_RETURN_IF_ERROR(control_flow_kernel.SetupSubgraphExecutionInfo(*this, attr_name, subgraph_session_state));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
//// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
|
@ -7,26 +7,30 @@
|
|||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "gsl/gsl"
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/common/profiler.h"
|
||||
#include "core/framework/allocation_planner.h"
|
||||
#include "core/framework/callback.h"
|
||||
#include "core/framework/data_transfer_manager.h"
|
||||
#include "core/framework/execution_providers.h"
|
||||
#include "core/framework/feeds_fetches_manager.h"
|
||||
#include "core/framework/framework_common.h"
|
||||
#include "core/framework/fuse_nodes_funcs.h"
|
||||
#include "core/framework/kernel_registry_manager.h"
|
||||
#include "core/framework/mem_pattern.h"
|
||||
#include "core/framework/ml_value.h"
|
||||
#include "core/framework/callback.h"
|
||||
#include "core/framework/ort_value_name_idx_map.h"
|
||||
#include "core/framework/node_index_info.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/framework/ort_value_name_idx_map.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/framework/fuse_nodes_funcs.h"
|
||||
#include "core/platform/threadpool.h"
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
#include "core/platform/ort_mutex.h"
|
||||
#include "core/platform/path_lib.h"
|
||||
#include "core/platform/threadpool.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -77,13 +81,6 @@ class SessionState {
|
|||
SetupAllocators();
|
||||
}
|
||||
|
||||
// Populate OrtValueNameIdxMap and create the graph viewer.
|
||||
// Call once all graph modifications like transforms are completed.
|
||||
void CreateGraphInfo();
|
||||
|
||||
// Call CreateKernels after CreateGraphInfo
|
||||
Status CreateKernels(const KernelRegistryManager& custom_registry_manager);
|
||||
|
||||
~SessionState() {
|
||||
for (auto* p : session_kernels_) {
|
||||
delete p;
|
||||
|
|
@ -141,18 +138,6 @@ class SessionState {
|
|||
*/
|
||||
const std::unordered_map<int, OrtValue>& GetConstantInitializedTensors() const;
|
||||
|
||||
/**
|
||||
Cleans the initialized tensors that have been added to SessionState as OrtValue instances from the Graph instance
|
||||
where they are present as TensorProto instances and will not be used when executing the model.
|
||||
*/
|
||||
void CleanInitializedTensorsFromGraph();
|
||||
|
||||
/**
|
||||
* Prepack the constant initialized tensors for better performance.
|
||||
* The original constant initialized tensors will be removed to save memory.
|
||||
*/
|
||||
Status PrepackInitializedConstantTensors();
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
/**
|
||||
Get some initialized tensors (weights).
|
||||
|
|
@ -173,8 +158,7 @@ class SessionState {
|
|||
NameMLValMap GetInitializedTensors(const std::unordered_set<std::string>& interested_weights) const;
|
||||
#endif
|
||||
|
||||
// execution plan
|
||||
void SetExecutionPlan(std::unique_ptr<SequentialExecutionPlan> p_seq_exec_plan);
|
||||
// execution plan. nullptr until FinalizeSessionState is called
|
||||
const SequentialExecutionPlan* GetExecutionPlan() const;
|
||||
/**
|
||||
Get the logger for this session.
|
||||
|
|
@ -235,6 +219,7 @@ class SessionState {
|
|||
};
|
||||
|
||||
using NameNodeInfoMapType = std::unordered_map<std::string, std::vector<NodeInfo>>;
|
||||
|
||||
common::Status AddInputNameToNodeInfoMapping(const std::string& input_name, const NodeInfo& node_info);
|
||||
common::Status GetInputNodeInfo(const std::string& input_name, std::vector<NodeInfo>& node_info_vec) const;
|
||||
const NameNodeInfoMapType& GetInputNodeInfoMap() const;
|
||||
|
|
@ -243,22 +228,12 @@ class SessionState {
|
|||
common::Status GetOutputNodeInfo(const std::string& output_name, std::vector<NodeInfo>& node_info_vec) const;
|
||||
const NameNodeInfoMapType& GetOutputNodeInfoMap() const;
|
||||
|
||||
/// Add a SessionState instance for executing a subgraph in a Node
|
||||
/// @param index Index of Node containing subgraph
|
||||
/// @param attribute_name Name of attribute containing the subgraph GraphProto
|
||||
/// @param session_state SessionState for subgraph execution
|
||||
void AddSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name,
|
||||
std::unique_ptr<SessionState> session_state);
|
||||
// Get the KernelCreateInfo entry for a node. SessionState must be finalized before calling.
|
||||
const KernelCreateInfo& GetNodeKernelCreateInfo(NodeIndex node_index) const;
|
||||
|
||||
/// Return SessionState for the given Node index and attribute name if found.
|
||||
const SessionState* GetSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name) const;
|
||||
|
||||
SessionState* GetMutableSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name);
|
||||
|
||||
// Remove the SessionState for a node containing a subgraph.
|
||||
// If the node isn't going to be executed by the CPU provider we don't need it.
|
||||
void RemoveSubgraphSessionState(onnxruntime::NodeIndex index);
|
||||
|
||||
concurrency::ThreadPool* GetThreadPool() const noexcept { return thread_pool_; }
|
||||
concurrency::ThreadPool* GetInterOpThreadPool() const noexcept { return inter_op_thread_pool_; }
|
||||
|
||||
|
|
@ -276,11 +251,47 @@ class SessionState {
|
|||
void UpdateToBeExecutedNodes(const std::vector<int>& fetch_mlvalue_idxs);
|
||||
const std::unordered_set<NodeIndex>* GetToBeExecutedNodes(const std::vector<int>& fetch_mlvalue_idxs) const;
|
||||
|
||||
Status FinalizeSessionState(const std::basic_string<PATH_CHAR_TYPE>& graph_loc,
|
||||
KernelRegistryManager& kernel_registry_manager,
|
||||
const SessionOptions& session_options = {},
|
||||
bool remove_initializers = true);
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SessionState);
|
||||
|
||||
void SetupAllocators();
|
||||
|
||||
// Populate OrtValueNameIdxMap and create the graph viewer.
|
||||
void CreateGraphInfo();
|
||||
|
||||
// create kernels using info in kernel_create_info_map_
|
||||
Status CreateKernels(const KernelRegistryManager& custom_registry_manager);
|
||||
|
||||
// remove TensorProto versions of initializers from Graph instance
|
||||
// (replaced byOrtValue instances in initialized_tensors_)
|
||||
void CleanInitializedTensorsFromGraph();
|
||||
|
||||
/**
|
||||
* Prepack the constant initialized tensors for better performance.
|
||||
* The original constant initialized tensors will be removed to save memory.
|
||||
*/
|
||||
Status PrepackInitializedConstantTensors();
|
||||
|
||||
SessionState* GetMutableSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name);
|
||||
|
||||
Status CreateSubgraphSessionState();
|
||||
|
||||
void AddSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name,
|
||||
std::unique_ptr<SessionState> session_state);
|
||||
|
||||
Status PopulateKernelCreateInfo(KernelRegistryManager& kernel_registry_manager);
|
||||
|
||||
Status FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_TYPE>& graph_loc,
|
||||
KernelRegistryManager& kernel_registry_manager,
|
||||
_In_opt_ const Node* parent_node,
|
||||
const SessionOptions& session_options,
|
||||
bool remove_initializers);
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
Status GeneratePatternGroupCache(
|
||||
const std::vector<std::reference_wrapper<const TensorShape>>& input_shape,
|
||||
|
|
@ -289,6 +300,9 @@ class SessionState {
|
|||
std::unordered_map<int, TensorShape>& inferred_shapes) const;
|
||||
#endif
|
||||
|
||||
// KernelCreateInfo for each node so we do kernel lookup once
|
||||
std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>> kernel_create_info_map_;
|
||||
|
||||
// cache of the constructed kernels to avoid spending construction time per executor
|
||||
std::vector<OpKernel*> session_kernels_;
|
||||
Graph& graph_;
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
#include "core/framework/finalize_session_state.h"
|
||||
#include "core/framework/session_state_utils.h"
|
||||
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
|
|
@ -26,89 +26,7 @@
|
|||
#include "core/framework/tensor_allocator.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
// T should have signature of '(int idx, const OrtValue& value, const OrtCallback& d) -> Status'
|
||||
template <typename T>
|
||||
static common::Status SaveInitializedTensors(const Env& env, const std::basic_string<PATH_CHAR_TYPE>& graph_loc,
|
||||
const GraphViewer& graph, const OrtMemoryInfo& default_cpu_memory_info,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
ITensorAllocator& planner, const T& save_tensor_func,
|
||||
const logging::Logger& logger,
|
||||
const DataTransferManager& data_transfer_mgr);
|
||||
|
||||
static common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::GraphViewer& graph,
|
||||
const KernelRegistryManager& custom_registry_manager,
|
||||
SessionState& session_state,
|
||||
const std::vector<const NodeArg*>& implicit_inputs);
|
||||
|
||||
Status FinalizeSessionState(SessionState& session_state,
|
||||
const std::basic_string<PATH_CHAR_TYPE>& graph_location,
|
||||
KernelRegistryManager& kernel_registry_manager,
|
||||
_In_opt_ const Node* parent_node,
|
||||
const SessionOptions& session_options) {
|
||||
session_state.CreateGraphInfo();
|
||||
|
||||
const GraphViewer& graph_viewer = session_state.GetGraphViewer();
|
||||
const auto& logger = session_state.Logger();
|
||||
|
||||
// populate the SessionState OrtValueNameIdxMap
|
||||
const auto& ort_value_name_idx_map = session_state.GetOrtValueNameIdxMap();
|
||||
|
||||
// ignore any outer scope args we don't know about. this can happen if a node contains multiple subgraphs.
|
||||
std::vector<const NodeArg*> valid_outer_scope_node_args;
|
||||
if (parent_node) {
|
||||
auto outer_scope_node_args = parent_node->ImplicitInputDefs();
|
||||
valid_outer_scope_node_args.reserve(outer_scope_node_args.size());
|
||||
|
||||
std::for_each(outer_scope_node_args.cbegin(), outer_scope_node_args.cend(),
|
||||
[&ort_value_name_idx_map, &valid_outer_scope_node_args](const NodeArg* node_arg) {
|
||||
int idx;
|
||||
if (ort_value_name_idx_map.GetIdx(node_arg->Name(), idx).IsOK()) {
|
||||
valid_outer_scope_node_args.push_back(node_arg);
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
std::unique_ptr<SequentialExecutionPlan> exec_plan;
|
||||
SequentialPlannerContext context(session_options.execution_mode);
|
||||
ORT_RETURN_IF_ERROR(SequentialPlanner::CreatePlan(parent_node, graph_viewer, valid_outer_scope_node_args,
|
||||
session_state.GetExecutionProviders(), kernel_registry_manager,
|
||||
ort_value_name_idx_map, context, exec_plan));
|
||||
|
||||
const auto* exec_plan_ptr = exec_plan.get();
|
||||
session_state.SetExecutionPlan(std::move(exec_plan));
|
||||
|
||||
std::unique_ptr<ITensorAllocator> tensor_allocator_(
|
||||
ITensorAllocator::Create(session_state.GetEnableMemoryPattern(), *exec_plan_ptr, session_state,
|
||||
session_state.GetMutableWeightsBuffers()));
|
||||
|
||||
// lambda to save initialized tensors into SessionState directly
|
||||
const Env& env = Env::Default();
|
||||
ORT_RETURN_IF_ERROR(SaveInitializedTensors(
|
||||
env, graph_location, graph_viewer,
|
||||
session_state.GetExecutionProviders().GetDefaultCpuMemoryInfo(),
|
||||
ort_value_name_idx_map, *tensor_allocator_,
|
||||
[&session_state](int idx, const OrtValue& value, const OrtCallback& d, bool constant) -> Status {
|
||||
return session_state.AddInitializedTensor(idx, value, &d, constant);
|
||||
},
|
||||
logger, session_state.GetDataTransferMgr()));
|
||||
|
||||
// remove weights from the graph now to save memory but in many cases it won't save memory, if the tensor was
|
||||
// preallocated with the some other tensors in a single 'allocate' call, which is very common.
|
||||
// TODO: make it better
|
||||
session_state.CleanInitializedTensorsFromGraph();
|
||||
|
||||
ORT_RETURN_IF_ERROR(session_state.CreateKernels(kernel_registry_manager));
|
||||
|
||||
const auto disable_prepacking =
|
||||
GetSessionConfigOrDefault(session_options, ORT_SESSION_OPTIONS_CONFIG_DISABLEPREPACKING, "0");
|
||||
if (disable_prepacking != "1")
|
||||
ORT_RETURN_IF_ERROR(session_state.PrepackInitializedConstantTensors());
|
||||
|
||||
ORT_RETURN_IF_ERROR(SaveInputOutputNamesToNodeMapping(graph_viewer, kernel_registry_manager, session_state,
|
||||
valid_outer_scope_node_args));
|
||||
return Status::OK();
|
||||
}
|
||||
namespace session_state_utils {
|
||||
|
||||
static common::Status DeserializeTensorProto(const Env& env, const std::basic_string<PATH_CHAR_TYPE>& proto_path,
|
||||
const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer& m,
|
||||
|
|
@ -170,12 +88,12 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st
|
|||
return common::Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
common::Status SaveInitializedTensors(const Env& env, const std::basic_string<PATH_CHAR_TYPE>& graph_loc,
|
||||
const GraphViewer& graph, const OrtMemoryInfo& default_cpu_memory_info,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map, ITensorAllocator& planner,
|
||||
const T& save_tensor_func, const logging::Logger& logger,
|
||||
const DataTransferManager& data_transfer_mgr) {
|
||||
common::Status SaveInitializedTensors(
|
||||
const Env& env, const std::basic_string<PATH_CHAR_TYPE>& graph_loc,
|
||||
const GraphViewer& graph, const OrtMemoryInfo& default_cpu_memory_info,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map, ITensorAllocator& planner,
|
||||
const std::function<Status(int idx, const OrtValue& value, const OrtCallback& d, bool constant)>& save_tensor_func,
|
||||
const logging::Logger& logger, const DataTransferManager& data_transfer_mgr) {
|
||||
LOGS(logger, INFO) << "Saving initialized tensors.";
|
||||
ORT_ENFORCE(ort_value_name_idx_map.MaxIdx() > -1, "OrtValue indexes should have been populated.");
|
||||
|
||||
|
|
@ -248,10 +166,9 @@ static bool IsArgNameInInputsOutputs(const std::string& name,
|
|||
return it != graph_args.cend();
|
||||
}
|
||||
|
||||
static common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::GraphViewer& graph,
|
||||
const KernelRegistryManager& custom_registry_manager,
|
||||
SessionState& session_state,
|
||||
const std::vector<const NodeArg*>& implicit_inputs) {
|
||||
common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::GraphViewer& graph,
|
||||
SessionState& session_state,
|
||||
const std::vector<const NodeArg*>& implicit_inputs) {
|
||||
auto& graph_inputs = graph.GetInputsIncludingInitializers();
|
||||
auto& graph_outputs = graph.GetOutputs();
|
||||
|
||||
|
|
@ -259,9 +176,7 @@ static common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph
|
|||
const auto& name_to_id = session_state.GetOrtValueNameIdxMap();
|
||||
|
||||
for (auto& node : graph.Nodes()) {
|
||||
// note that KernelCreateInfo may not exist for custom kernel
|
||||
const KernelCreateInfo* kci = nullptr;
|
||||
custom_registry_manager.SearchKernelRegistry(node, &kci);
|
||||
const KernelCreateInfo& kci = session_state.GetNodeKernelCreateInfo(node.Index());
|
||||
|
||||
ORT_RETURN_IF_ERROR(
|
||||
onnxruntime::Node::ForEachWithIndex(
|
||||
|
|
@ -275,7 +190,7 @@ static common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph
|
|||
ORT_RETURN_IF_ERROR(name_to_id.GetIdx(arg.Name(), arg_index));
|
||||
const auto& device = exec_plan->GetLocation(arg_index).device;
|
||||
|
||||
SessionState::NodeInfo node_info(index, &node, kci, device);
|
||||
SessionState::NodeInfo node_info(index, &node, &kci, device);
|
||||
|
||||
if (IsArgNameInInputsOutputs(arg.Name(), graph_inputs)) {
|
||||
ORT_RETURN_IF_ERROR(session_state.AddInputNameToNodeInfoMapping(arg.Name(), node_info));
|
||||
|
|
@ -306,7 +221,7 @@ static common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph
|
|||
int arg_index;
|
||||
ORT_RETURN_IF_ERROR(name_to_id.GetIdx(input_def->Name(), arg_index));
|
||||
auto& device = exec_plan->GetLocation(arg_index).device;
|
||||
SessionState::NodeInfo node_info(std::numeric_limits<size_t>::max(), &node, kci, device);
|
||||
SessionState::NodeInfo node_info(std::numeric_limits<size_t>::max(), &node, &kci, device);
|
||||
ORT_RETURN_IF_ERROR(session_state.AddInputNameToNodeInfoMapping(input_def->Name(), node_info));
|
||||
}
|
||||
}
|
||||
|
|
@ -323,7 +238,7 @@ static common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph
|
|||
ORT_RETURN_IF_ERROR(name_to_id.GetIdx(arg.Name(), arg_index));
|
||||
const auto& device = exec_plan->GetLocation(arg_index).device;
|
||||
|
||||
SessionState::NodeInfo node_info(index, &node, kci, device);
|
||||
SessionState::NodeInfo node_info(index, &node, &kci, device);
|
||||
|
||||
if (IsArgNameInInputsOutputs(arg.Name(), graph_outputs)) {
|
||||
session_state.AddOutputNameToNodeInfoMapping(arg.Name(), node_info);
|
||||
|
|
@ -361,4 +276,6 @@ static common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace session_state_utils
|
||||
} // namespace onnxruntime
|
||||
42
onnxruntime/core/framework/session_state_utils.h
Normal file
42
onnxruntime/core/framework/session_state_utils.h
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include <map>
|
||||
|
||||
#include "core/common/const_pointer_container.h"
|
||||
#include "core/framework/allocator.h"
|
||||
#include "core/framework/tensor.h"
|
||||
#include "core/framework/tensor_allocator.h"
|
||||
#include "core/framework/session_options.h"
|
||||
#include "core/platform/path_lib.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
class Env;
|
||||
class KernelRegistryManager;
|
||||
class Node;
|
||||
class SessionState;
|
||||
class GraphViewer;
|
||||
class OrtValueNameIdxMap;
|
||||
class DataTransferManager;
|
||||
class NodeArg;
|
||||
|
||||
namespace logging {
|
||||
class Logger;
|
||||
}
|
||||
|
||||
namespace session_state_utils {
|
||||
common::Status SaveInitializedTensors(
|
||||
const Env& env, const std::basic_string<PATH_CHAR_TYPE>& graph_loc,
|
||||
const GraphViewer& graph, const OrtMemoryInfo& default_cpu_memory_info,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
ITensorAllocator& planner,
|
||||
const std::function<Status(int idx, const OrtValue& value, const OrtCallback& d, bool constant)>& save_tensor_func,
|
||||
const logging::Logger& logger,
|
||||
const DataTransferManager& data_transfer_mgr);
|
||||
|
||||
common::Status SaveInputOutputNamesToNodeMapping(const GraphViewer& graph,
|
||||
SessionState& session_state,
|
||||
const std::vector<const NodeArg*>& implicit_inputs);
|
||||
} // namespace session_state_utils
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -28,7 +28,6 @@
|
|||
#include "core/framework/ort_value_pattern_planner.h"
|
||||
#include "core/framework/mldata_type_utils.h"
|
||||
#include "core/framework/op_kernel_context_internal.h"
|
||||
#include "core/framework/finalize_session_state.h"
|
||||
#include "core/framework/TensorSeq.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/framework/tensor_type_and_shape.h"
|
||||
|
|
@ -701,91 +700,6 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
|
|||
return common::Status::OK();
|
||||
}
|
||||
|
||||
/// Create SessionState instance for each subgraph as we need that for the GraphPartitioner
|
||||
/// This will be initialized by InitializeSubgraphSessions.
|
||||
common::Status InferenceSession::CreateSubgraphSessionState(Graph& graph, SessionState& session_state) {
|
||||
for (auto& node : graph.Nodes()) {
|
||||
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
|
||||
auto& name = entry.first;
|
||||
Graph* subgraph = entry.second;
|
||||
ORT_ENFORCE(subgraph, "Main Graph instance should have populated all subgraphs when being resolved.");
|
||||
|
||||
auto subgraph_session_state = onnxruntime::make_unique<SessionState>(
|
||||
*subgraph,
|
||||
execution_providers_,
|
||||
session_state.GetEnableMemoryPattern(),
|
||||
session_state.GetThreadPool(),
|
||||
session_state.GetInterOpThreadPool(),
|
||||
session_state.GetDataTransferMgr(),
|
||||
*session_logger_,
|
||||
session_profiler_);
|
||||
|
||||
// Pass fused function manager to subgraph
|
||||
subgraph_session_state->GetMutableFuncMgr().SetFusedFuncs(session_state.GetFuncMgr());
|
||||
|
||||
// recurse
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(CreateSubgraphSessionState(*subgraph, *subgraph_session_state));
|
||||
|
||||
// add the subgraph SessionState instance to the parent graph SessionState so it can be retrieved
|
||||
// by Compute() via OpKernelContextInternal.
|
||||
session_state.AddSubgraphSessionState(node.Index(), name, std::move(subgraph_session_state));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// iterate nodes in graph looking for ones with graph attribute/s
|
||||
/// @param graph The graph to iterate
|
||||
/// @param session_state The SessionState instance for 'graph'.
|
||||
/// @remarks We pass in graph and session_state so we can handled nested subgraphs in the future
|
||||
common::Status InferenceSession::InitializeSubgraphSessions(Graph& graph, SessionState& session_state) {
|
||||
for (auto& node : graph.Nodes()) {
|
||||
// We only need subgraph session state for control flow nodes being handled by our CPU or CUDA execution provider.
|
||||
// Remove it if it's not needed.
|
||||
if (node.ContainsSubgraph()) {
|
||||
const auto ep = node.GetExecutionProviderType();
|
||||
if (ep != kCpuExecutionProvider && ep != kCudaExecutionProvider) {
|
||||
session_state.RemoveSubgraphSessionState(node.Index());
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
// not a control flow node
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
|
||||
auto& name = entry.first;
|
||||
Graph& subgraph = *entry.second;
|
||||
|
||||
SessionState* subgraph_session_state = session_state.GetMutableSubgraphSessionState(node.Index(), name);
|
||||
ORT_ENFORCE(subgraph_session_state, "CreateSubgraphSessionState should have created an entry earlier.");
|
||||
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(FinalizeSessionState(*subgraph_session_state,
|
||||
model_location_,
|
||||
kernel_registry_manager_,
|
||||
&node,
|
||||
session_options_));
|
||||
|
||||
// LOGS(*session_logger_, VERBOSE) << std::make_pair(subgraph_info.session_state->GetExecutionPlan(),
|
||||
// &*subgraph_info.session_state);
|
||||
|
||||
// setup all the info for handling the feeds and fetches used in subgraph execution
|
||||
auto* p_op_kernel = session_state.GetMutableKernel(node.Index());
|
||||
ORT_ENFORCE(p_op_kernel);
|
||||
// Downcast is safe, since only control flow nodes have subgraphs (node.GetAttributeNameToMutableSubgraphMap() is non-empty)
|
||||
auto& control_flow_kernel = static_cast<controlflow::IControlFlowKernel&>(*p_op_kernel);
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(
|
||||
control_flow_kernel.SetupSubgraphExecutionInfo(session_state, name, *subgraph_session_state));
|
||||
|
||||
// recurse
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(InitializeSubgraphSessions(subgraph, *subgraph_session_state));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool InferenceSession::IsInitialized() const {
|
||||
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
|
||||
return is_inited_;
|
||||
|
|
@ -892,17 +806,12 @@ common::Status InferenceSession::Initialize() {
|
|||
|
||||
if (session_options_.execution_mode == ExecutionMode::ORT_PARALLEL &&
|
||||
execution_providers_.Get(onnxruntime::kCudaExecutionProvider)) {
|
||||
LOGS(*session_logger_, ERROR) << "Parallel execution mode doesn't support "
|
||||
"CUDA Execution Provider currently.";
|
||||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
|
||||
"Parallel execution mode doesn't support "
|
||||
"CUDA Execution Provider currently.");
|
||||
status = common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
|
||||
"Parallel execution mode doesn't support CUDA Execution Provider currently.");
|
||||
LOGS(*session_logger_, ERROR) << status.ErrorMessage();
|
||||
return status;
|
||||
}
|
||||
|
||||
// add predefined transformers
|
||||
AddPredefinedTransformers(graph_transformation_mgr_, session_options_.graph_optimization_level,
|
||||
transformers_to_enable_);
|
||||
|
||||
onnxruntime::Graph& graph = model_->MainGraph();
|
||||
|
||||
// Collect the kernel registries from execution provider instances;
|
||||
|
|
@ -915,8 +824,8 @@ common::Status InferenceSession::Initialize() {
|
|||
// Register 2nd registries into KernelRegistryManager.
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(kernel_registry_manager_.RegisterKernels(execution_providers_));
|
||||
|
||||
// create SessionState for subgraphs as it's needed by the transformers
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(CreateSubgraphSessionState(graph, *session_state_));
|
||||
AddPredefinedTransformers(graph_transformation_mgr_, session_options_.graph_optimization_level,
|
||||
transformers_to_enable_);
|
||||
|
||||
// apply any transformations to the main graph and any subgraphs
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(TransformGraph(graph, graph_transformation_mgr_,
|
||||
|
|
@ -927,6 +836,13 @@ common::Status InferenceSession::Initialize() {
|
|||
// now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs.
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve());
|
||||
|
||||
// need to keep the initializers if we're going to save the optimized model
|
||||
bool keep_initializers = !session_options_.optimized_model_filepath.empty();
|
||||
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(session_state_->FinalizeSessionState(model_location_, kernel_registry_manager_,
|
||||
session_options_,
|
||||
!keep_initializers));
|
||||
|
||||
if (!session_options_.optimized_model_filepath.empty()) {
|
||||
// Serialize optimized ONNX model.
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(Model::Save(*model_, session_options_.optimized_model_filepath));
|
||||
|
|
@ -939,14 +855,6 @@ common::Status InferenceSession::Initialize() {
|
|||
}
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(FinalizeSessionState(*session_state_,
|
||||
model_location_,
|
||||
kernel_registry_manager_,
|
||||
nullptr /*parent_node*/,
|
||||
session_options_));
|
||||
|
||||
// handle any subgraphs
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(InitializeSubgraphSessions(graph, *session_state_));
|
||||
session_state_->ResolveMemoryPatternFlag();
|
||||
is_inited_ = true;
|
||||
|
||||
|
|
@ -1404,7 +1312,7 @@ std::string InferenceSession::EndProfiling() {
|
|||
|
||||
AllocatorPtr InferenceSession::GetAllocator(const OrtMemoryInfo& mem_info) const {
|
||||
return session_state_->GetAllocator(mem_info);
|
||||
}
|
||||
}
|
||||
|
||||
// assumes model has already been loaded before
|
||||
common::Status InferenceSession::DoPostLoadProcessing(onnxruntime::Model& model) {
|
||||
|
|
|
|||
|
|
@ -252,7 +252,7 @@ class InferenceSession {
|
|||
* @return OK if success.
|
||||
*/
|
||||
common::Status Run(const NameMLValMap& feeds, const std::vector<std::string>& output_names,
|
||||
std::vector<OrtValue>* p_fetches) ORT_MUST_USE_RESULT;
|
||||
std::vector<OrtValue>* p_fetches) ORT_MUST_USE_RESULT;
|
||||
|
||||
/**
|
||||
* See Run(const NameMLValMap& feeds, const std::vector<std::string>& output_names, std::vector<OrtValue>* p_fetches)
|
||||
|
|
@ -319,18 +319,16 @@ class InferenceSession {
|
|||
*/
|
||||
const SessionOptions& GetSessionOptions() const;
|
||||
|
||||
|
||||
/*
|
||||
* Get the DataTransferManager associated with this session
|
||||
*/
|
||||
const DataTransferManager& GetDataTransferManager() const;
|
||||
|
||||
|
||||
/*
|
||||
* Get all the providers' options this session was initialized with.
|
||||
*/
|
||||
const ProviderOptionsMap& GetAllProviderOptions() const;
|
||||
|
||||
|
||||
/**
|
||||
* Start profiling on this inference session. This simply turns on profiling events to be
|
||||
* recorded. A corresponding EndProfiling has to follow to write profiling data to a file.
|
||||
|
|
@ -360,8 +358,8 @@ class InferenceSession {
|
|||
* @return a ptr to the allocator or nullptr if not available
|
||||
*/
|
||||
AllocatorPtr GetAllocator(const OrtMemoryInfo& mem_info) const;
|
||||
|
||||
/**
|
||||
|
||||
/**
|
||||
*Get InferenceSession logger.
|
||||
*/
|
||||
const logging::Logger* GetLogger() const { return session_logger_; };
|
||||
|
|
@ -422,20 +420,16 @@ class InferenceSession {
|
|||
common::Status Load(std::function<common::Status(std::shared_ptr<Model>&)> loader,
|
||||
const std::string& event_name) ORT_MUST_USE_RESULT;
|
||||
|
||||
virtual void AddPredefinedTransformers(GraphTransformerManager& transformer_manager,
|
||||
TransformerLevel graph_optimization_level,
|
||||
const std::vector<std::string>& custom_list);
|
||||
|
||||
common::Status TransformGraph(onnxruntime::Graph& graph,
|
||||
const onnxruntime::GraphTransformerManager& graph_transformer_mgr,
|
||||
const ExecutionProviders& providers, KernelRegistryManager& kernel_registry_manager,
|
||||
const InsertCastTransformer& insert_cast_transformer,
|
||||
SessionState& session_state) ORT_MUST_USE_RESULT;
|
||||
|
||||
common::Status CreateSubgraphSessionState(Graph& graph, SessionState& session_state) ORT_MUST_USE_RESULT;
|
||||
|
||||
common::Status InitializeSubgraphSessions(Graph& graph, SessionState& session_state) ORT_MUST_USE_RESULT;
|
||||
|
||||
virtual void AddPredefinedTransformers(GraphTransformerManager& transformer_manager,
|
||||
TransformerLevel graph_optimization_level,
|
||||
const std::vector<std::string>& custom_list);
|
||||
|
||||
void InitLogger(logging::LoggingManager* logging_manager);
|
||||
|
||||
common::Status CheckShapes(const std::string& input_name, const TensorShape& input_shape,
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@
|
|||
#include "core/util/thread_utils.h"
|
||||
#include "core/providers/cpu/cpu_execution_provider.h"
|
||||
#include "test/test_environment.h"
|
||||
#include "test/util/include/asserts.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -140,7 +142,7 @@ class SequentialPlannerTestContext : public ISequentialPlannerContext {
|
|||
class PlannerTest : public ::testing::Test {
|
||||
private:
|
||||
void index(const std::string& name, int& out) {
|
||||
ASSERT_TRUE(state_.GetOrtValueNameIdxMap().GetIdx(name, out).IsOK());
|
||||
ASSERT_TRUE(state_->GetOrtValueNameIdxMap().GetIdx(name, out).IsOK());
|
||||
}
|
||||
|
||||
onnxruntime::Model model_;
|
||||
|
|
@ -160,7 +162,7 @@ class PlannerTest : public ::testing::Test {
|
|||
std::unique_ptr<concurrency::ThreadPool> tp_;
|
||||
DataTransferManager dtm_;
|
||||
profiling::Profiler profiler_;
|
||||
SessionState state_;
|
||||
std::unique_ptr<SessionState> state_;
|
||||
ShapeMap shape_map_;
|
||||
std::unique_ptr<SequentialExecutionPlan> plan_;
|
||||
|
||||
|
|
@ -169,15 +171,16 @@ class PlannerTest : public ::testing::Test {
|
|||
: model_("test", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 10}}, {}, DefaultLoggingManager().DefaultLogger()),
|
||||
graph_(model_.MainGraph()),
|
||||
tp_(concurrency::CreateThreadPool(&onnxruntime::Env::Default(), OrtThreadPoolParams(),
|
||||
concurrency::ThreadPoolType::INTRA_OP)),
|
||||
state_(graph_, execution_providers_, false, tp_.get(), nullptr, dtm_, DefaultLoggingManager().DefaultLogger(),
|
||||
profiler_) {
|
||||
concurrency::ThreadPoolType::INTRA_OP)) {
|
||||
std_kernel_ = KernelDefBuilder().SetName("Transpose").Provider(kCpuExecutionProvider).SinceVersion(1, 10).Build();
|
||||
in_place_kernel_ =
|
||||
KernelDefBuilder().SetName("Relu").Provider(kCpuExecutionProvider).SinceVersion(1, 10).MayInplace(0, 0).Build();
|
||||
CPUExecutionProviderInfo epi;
|
||||
auto execution_provider = onnxruntime::make_unique<CPUExecutionProvider>(epi);
|
||||
execution_providers_.Add("CPUExecutionProvider", std::move(execution_provider));
|
||||
|
||||
state_.reset(new SessionState(graph_, execution_providers_, false, tp_.get(), nullptr, dtm_,
|
||||
DefaultLoggingManager().DefaultLogger(), profiler_));
|
||||
}
|
||||
|
||||
onnxruntime::NodeArg* Arg(const std::string& name) {
|
||||
|
|
@ -203,12 +206,14 @@ class PlannerTest : public ::testing::Test {
|
|||
return AddNode(*in_place_kernel_, input, output);
|
||||
}
|
||||
|
||||
void BindKernel(onnxruntime::Node* p_node, ::onnxruntime::KernelDef& kernel_def, KernelRegistry* reg) {
|
||||
void BindKernel(onnxruntime::Node* p_node, ::onnxruntime::KernelDef& kernel_def, KernelRegistry* reg,
|
||||
std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>>& kernel_create_info_map) {
|
||||
const IExecutionProvider* ep = execution_providers_.Get(*p_node);
|
||||
ASSERT_NE(ep, nullptr);
|
||||
auto info = onnxruntime::make_unique<OpKernelInfo>(*p_node, kernel_def, *ep,
|
||||
state_.GetInitializedTensors(), state_.GetOrtValueNameIdxMap(),
|
||||
state_.GetFuncMgr(), state_.GetDataTransferMgr());
|
||||
auto info = onnxruntime::make_unique<OpKernelInfo>(
|
||||
*p_node, kernel_def, *ep, state_->GetInitializedTensors(), state_->GetOrtValueNameIdxMap(),
|
||||
state_->GetFuncMgr(), state_->GetDataTransferMgr());
|
||||
|
||||
op_kernel_infos_.push_back(std::move(info));
|
||||
if (!KernelRegistry::HasImplementationOf(*reg, *p_node, onnxruntime::kCpuExecutionProvider)) {
|
||||
auto st = reg->Register(
|
||||
|
|
@ -216,6 +221,10 @@ class PlannerTest : public ::testing::Test {
|
|||
[](const OpKernelInfo& info) -> OpKernel* { return new DummyOpKernel(info); }));
|
||||
ORT_ENFORCE(st.IsOK(), st.ErrorMessage());
|
||||
}
|
||||
|
||||
const KernelCreateInfo* kci;
|
||||
ASSERT_STATUS_OK(reg->TryFindKernel(*p_node, "", &kci));
|
||||
kernel_create_info_map.insert({p_node->Index(), gsl::not_null<const KernelCreateInfo*>(kci)});
|
||||
}
|
||||
|
||||
void SetShape(std::string& name, TensorShapeProto* shape) { shape_map_[Arg(name)] = shape; }
|
||||
|
|
@ -229,26 +238,31 @@ class PlannerTest : public ::testing::Test {
|
|||
void CreatePlan(const std::vector<const NodeArg*>& outer_scope_node_args = {}) {
|
||||
EXPECT_EQ(graph_.Resolve(), Status::OK());
|
||||
|
||||
state_.CreateGraphInfo();
|
||||
|
||||
std::shared_ptr<KernelRegistry> reg = std::make_shared<KernelRegistry>();
|
||||
std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>> kernel_create_info_map;
|
||||
|
||||
for (auto& binding : kernel_bindings_) {
|
||||
BindKernel(binding.first, binding.second, reg.get());
|
||||
BindKernel(binding.first, binding.second, reg.get(), kernel_create_info_map);
|
||||
}
|
||||
|
||||
auto cpu_execution_provider = onnxruntime::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
KernelRegistryManager kernel_registry_manager;
|
||||
kernel_registry_manager.RegisterKernelRegistry(reg);
|
||||
ExecutionProviders execution_providers;
|
||||
execution_providers.Add(onnxruntime::kCpuExecutionProvider, std::move(cpu_execution_provider));
|
||||
auto status = kernel_registry_manager.RegisterKernels(execution_providers);
|
||||
auto status = kernel_registry_manager.RegisterKernels(execution_providers_);
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
status = state_.CreateKernels(kernel_registry_manager);
|
||||
|
||||
// CreatePlan is called inside FinalizeSessionState and usually the initializers are removed following that.
|
||||
// Leave initializers so we can duplicate the call to CreatePlan from here to validate.
|
||||
const bool remove_initializers = false;
|
||||
status = state_->FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager, {},
|
||||
remove_initializers);
|
||||
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
SequentialPlannerTestContext test_context(&shape_map_);
|
||||
status = SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph_), outer_scope_node_args, execution_providers,
|
||||
kernel_registry_manager, state_.GetOrtValueNameIdxMap(), test_context, plan_);
|
||||
|
||||
status = SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph_), outer_scope_node_args, execution_providers_,
|
||||
kernel_create_info_map, state_->GetOrtValueNameIdxMap(), test_context,
|
||||
plan_);
|
||||
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
AllocationPlanTestUtility::BasicIntegrityCheck(*plan_, name_to_arg_.size());
|
||||
|
|
@ -278,7 +292,7 @@ class PlannerTest : public ::testing::Test {
|
|||
protected:
|
||||
Graph& GetGraph() { return graph_; }
|
||||
const SequentialExecutionPlan& GetPlan() const { return *plan_; }
|
||||
const SessionState& GetState() const { return state_; }
|
||||
const SessionState& GetState() const { return *state_; }
|
||||
};
|
||||
|
||||
TEST_F(PlannerTest, ChainTest) {
|
||||
|
|
|
|||
|
|
@ -57,19 +57,9 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) {
|
|||
SessionState state(graph, execution_providers, true, &tp_, nullptr, dtm,
|
||||
DefaultLoggingManager().DefaultLogger(), profiler);
|
||||
|
||||
state.CreateGraphInfo();
|
||||
|
||||
ASSERT_STATUS_OK(state.CreateKernels(kernel_registry_manager));
|
||||
|
||||
node->SetExecutionProviderType(xp_typ);
|
||||
|
||||
std::unique_ptr<SequentialExecutionPlan> p_seq_exec_plan;
|
||||
// TODO below line is for testing only. In production use SequentialPlanner::CreatePlan()
|
||||
SequentialPlannerContext context(ExecutionMode::ORT_SEQUENTIAL);
|
||||
ASSERT_STATUS_OK(SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph), {}, execution_providers,
|
||||
kernel_registry_manager, state.GetOrtValueNameIdxMap(), context,
|
||||
p_seq_exec_plan));
|
||||
state.SetExecutionPlan(std::move(p_seq_exec_plan));
|
||||
ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager));
|
||||
|
||||
vector<OrtValue> outputs;
|
||||
ExecutionFrame frame({}, {}, {}, outputs, {}, state);
|
||||
|
|
@ -146,10 +136,8 @@ TEST_F(ExecutionFrameTest, FeedInDataTest) {
|
|||
SessionState state(graph, execution_providers, true, &tp_, nullptr, dtm,
|
||||
DefaultLoggingManager().DefaultLogger(), profiler);
|
||||
|
||||
state.CreateGraphInfo();
|
||||
|
||||
ASSERT_STATUS_OK(state.CreateKernels(kernel_registry_manager));
|
||||
|
||||
ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager));
|
||||
|
||||
const OrtValueNameIdxMap& mlvalue_name_idx_map = state.GetOrtValueNameIdxMap();
|
||||
int x_idx = -1, y_idx = -1;
|
||||
ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("X", x_idx).IsOK());
|
||||
|
|
@ -206,9 +194,7 @@ TEST_F(ExecutionFrameTest, MemPatternTest) {
|
|||
SessionState state(graph, execution_providers, true, &tp_, nullptr, dtm,
|
||||
DefaultLoggingManager().DefaultLogger(), profiler);
|
||||
|
||||
state.CreateGraphInfo();
|
||||
|
||||
ASSERT_STATUS_OK(state.CreateKernels(kernel_registry_manager));
|
||||
ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager));
|
||||
|
||||
const OrtValueNameIdxMap& mlvalue_name_idx_map(state.GetOrtValueNameIdxMap());
|
||||
|
||||
|
|
@ -235,14 +221,6 @@ TEST_F(ExecutionFrameTest, MemPatternTest) {
|
|||
std::vector<int64_t>{2, 3},
|
||||
std::vector<float>(6, 1.0f), &v3);
|
||||
|
||||
std::unique_ptr<SequentialExecutionPlan> p_seq_exec_plan = onnxruntime::make_unique<SequentialExecutionPlan>();
|
||||
SequentialPlannerContext context(ExecutionMode::ORT_SEQUENTIAL);
|
||||
ASSERT_STATUS_OK(SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph), {}, execution_providers,
|
||||
kernel_registry_manager, mlvalue_name_idx_map, context,
|
||||
p_seq_exec_plan));
|
||||
|
||||
state.SetExecutionPlan(std::move(p_seq_exec_plan));
|
||||
|
||||
vector<OrtValue> outputs;
|
||||
ExecutionFrame frame({x1_idx, x2_idx, x3_idx}, {v1, v2, v3}, {t3_idx}, outputs, {}, state);
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
#include "core/framework/graph_partitioner.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/framework/session_state.h"
|
||||
#include "core/framework/finalize_session_state.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/graph/model.h"
|
||||
|
|
@ -51,6 +50,10 @@ TEST_P(SessionStateAddGetKernelTest, AddGetKernelTest) {
|
|||
auto& graph = model.MainGraph();
|
||||
|
||||
ExecutionProviders execution_providers;
|
||||
auto tmp_cpu_execution_provider = onnxruntime::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo(false));
|
||||
auto* cpu_execution_provider = tmp_cpu_execution_provider.get();
|
||||
ASSERT_STATUS_OK(execution_providers.Add(kCpuExecutionProvider, std::move(tmp_cpu_execution_provider)));
|
||||
|
||||
DataTransferManager dtm;
|
||||
profiling::Profiler profiler;
|
||||
SessionState s(graph, execution_providers, true, tp.get(), nullptr, dtm,
|
||||
|
|
@ -67,7 +70,6 @@ TEST_P(SessionStateAddGetKernelTest, AddGetKernelTest) {
|
|||
auto status = graph.Resolve();
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
auto kernel_def = KernelDefBuilder().SetName("Variable").Provider(kCpuExecutionProvider).SinceVersion(1, 10).Build();
|
||||
auto cpu_execution_provider = onnxruntime::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo(false));
|
||||
|
||||
OpKernelInfo p_info(node, *kernel_def, *cpu_execution_provider, s.GetConstantInitializedTensors(),
|
||||
s.GetOrtValueNameIdxMap(), s.GetFuncMgr(), s.GetDataTransferMgr());
|
||||
|
|
@ -76,7 +78,6 @@ TEST_P(SessionStateAddGetKernelTest, AddGetKernelTest) {
|
|||
size_t orig_num_outputs = p_kernel->Node().OutputDefs().size();
|
||||
std::cout << "node_idx: " << node.Index() << std::endl;
|
||||
|
||||
ASSERT_STATUS_OK(execution_providers.Add(kCpuExecutionProvider, std::move(cpu_execution_provider)));
|
||||
KernelRegistryManager kernel_registry_manager;
|
||||
status = kernel_registry_manager.RegisterKernels(execution_providers);
|
||||
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
|
@ -85,8 +86,8 @@ TEST_P(SessionStateAddGetKernelTest, AddGetKernelTest) {
|
|||
ASSERT_STATUS_OK(kernel_registry->Register(KernelCreateInfo(
|
||||
std::move(kernel_def), [](const OpKernelInfo& info) -> OpKernel* { return new TestOpKernel(info); })));
|
||||
kernel_registry_manager.RegisterKernelRegistry(kernel_registry);
|
||||
s.CreateGraphInfo();
|
||||
ASSERT_STATUS_OK(s.CreateKernels(kernel_registry_manager));
|
||||
ASSERT_STATUS_OK(s.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager));
|
||||
|
||||
auto test_kernel = s.GetKernel(node.Index());
|
||||
std::cout << "orig: " << orig_num_outputs << " new: " << test_kernel->Node().OutputDefs().size() << std::endl;
|
||||
EXPECT_EQ(orig_num_outputs, test_kernel->Node().OutputDefs().size());
|
||||
|
|
@ -140,8 +141,7 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) {
|
|||
status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr());
|
||||
ASSERT_TRUE(status.IsOK()) << status;
|
||||
|
||||
session_state.CreateGraphInfo();
|
||||
ASSERT_STATUS_OK(FinalizeSessionState(session_state, oss.str(), krm, nullptr, SessionOptions()));
|
||||
ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm));
|
||||
|
||||
const auto& initialized_tensors = session_state.GetInitializedTensors();
|
||||
const auto& const_initialized_tensors = session_state.GetConstantInitializedTensors();
|
||||
|
|
@ -259,11 +259,10 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) {
|
|||
SessionOptions sess_options;
|
||||
bool use_prepacking = GetParam();
|
||||
sess_options.session_configurations[ORT_SESSION_OPTIONS_CONFIG_DISABLEPREPACKING] = use_prepacking ? "0" : "1";
|
||||
ASSERT_STATUS_OK(FinalizeSessionState(session_state,
|
||||
std::basic_string<PATH_CHAR_TYPE>() /*graph_loc*/,
|
||||
kernel_registry_manager,
|
||||
nullptr /*parent_node*/,
|
||||
sess_options));
|
||||
ASSERT_STATUS_OK(session_state.FinalizeSessionState(std::basic_string<PATH_CHAR_TYPE>(),
|
||||
kernel_registry_manager,
|
||||
sess_options));
|
||||
|
||||
const auto& const_initialized_tensors = session_state.GetConstantInitializedTensors();
|
||||
// check prepacking
|
||||
ASSERT_EQ(const_initialized_tensors.size(), size_t(use_prepacking ? 0 : 1));
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
#include "../framework/test_utils.h"
|
||||
#include "core/graph/model.h"
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
#include <core/framework/finalize_session_state.h>
|
||||
#include "core/framework/execution_providers.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/framework/session_state.h"
|
||||
|
|
@ -52,8 +51,8 @@ TEST(MemcpyTest, copy1) {
|
|||
SessionState s(model.MainGraph(), execution_providers, true, &tp, nullptr, dtm,
|
||||
DefaultLoggingManager().DefaultLogger(), profiler);
|
||||
|
||||
s.CreateGraphInfo();
|
||||
ASSERT_STATUS_OK(FinalizeSessionState(s, ORT_TSTR(""), kernel_registry_manager, nullptr, SessionOptions()));
|
||||
SessionOptions so;
|
||||
ASSERT_STATUS_OK(s.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager, so));
|
||||
|
||||
AllocatorPtr allocator =
|
||||
execution_providers.Get(onnxruntime::kCpuExecutionProvider)->GetAllocator(0, OrtMemTypeDefault);
|
||||
|
|
|
|||
Loading…
Reference in a new issue