diff --git a/include/onnxruntime/core/framework/kernel_registry.h b/include/onnxruntime/core/framework/kernel_registry.h index b9f94e5fab..b77c5d8d5c 100644 --- a/include/onnxruntime/core/framework/kernel_registry.h +++ b/include/onnxruntime/core/framework/kernel_registry.h @@ -25,7 +25,9 @@ class KernelRegistry { // TODO(Task:132) Make usage of unique_ptr/shared_ptr as out param consistent Status CreateKernel(const onnxruntime::Node& node, const IExecutionProvider& execution_provider, - const SessionState& session_state, + const std::unordered_map& initialized_tensors, + const MLValueNameIdxMap& mlvalue_name_idx_map, + const FuncManager& funcs_mgr, std::unique_ptr& op_kernel) const; // Check if an execution provider can create kernel for a node and return diff --git a/include/onnxruntime/core/framework/op_kernel_info.h b/include/onnxruntime/core/framework/op_kernel_info.h index 47bb0e9bde..6f6db5047a 100644 --- a/include/onnxruntime/core/framework/op_kernel_info.h +++ b/include/onnxruntime/core/framework/op_kernel_info.h @@ -13,9 +13,9 @@ namespace onnxruntime { -class SessionState; - -/** +class MLValueNameIdxMap; +class FuncManager; + /** A very light-weight class, which works as an aggregated view of all data needed for constructing a Kernel instance. NOTE: it does not own/hold any objects. @@ -25,7 +25,9 @@ class OpKernelInfo : public OpNodeProtoHelper { explicit OpKernelInfo(const onnxruntime::Node& node, const KernelDef& kernel_def, const IExecutionProvider& execution_provider, - const SessionState& session_state); + const std::unordered_map& initialized_tensors, + const MLValueNameIdxMap& mlvalue_name_idx_map, + const FuncManager& funcs_mgr); OpKernelInfo(const OpKernelInfo& other); @@ -52,8 +54,10 @@ class OpKernelInfo : public OpNodeProtoHelper { // For non cpu/cuda case, this pointer should be set so that function kernel // will delegate kernel compute call to compute call. gsl::not_null execution_provider_; + const std::unordered_map& initialized_tensors_; + const MLValueNameIdxMap& mlvalue_name_idx_map_; + const FuncManager& funcs_mgr_; ProtoHelperNodeContext proto_helper_context_; - const SessionState& session_state_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/kernel_registry.cc b/onnxruntime/core/framework/kernel_registry.cc index 3495665c42..28c73d744a 100644 --- a/onnxruntime/core/framework/kernel_registry.cc +++ b/onnxruntime/core/framework/kernel_registry.cc @@ -3,8 +3,8 @@ #include #include - #include "core/framework/kernel_registry.h" +#include "core/framework/session_state.h" using namespace ::onnxruntime::common; namespace onnxruntime { @@ -260,7 +260,9 @@ Status KernelRegistry::Register(KernelCreateInfo&& create_info) { Status KernelRegistry::CreateKernel(const onnxruntime::Node& node, const IExecutionProvider& execution_provider, - const SessionState& session_state, + const std::unordered_map& initialized_tensors, + const MLValueNameIdxMap& mlvalue_name_idx_map, + const FuncManager& funcs_mgr, /*out*/ std::unique_ptr& op_kernel) const { const KernelCreateInfo* kernel_create_info = TryFindKernel(node, execution_provider.Type()); @@ -268,7 +270,12 @@ Status KernelRegistry::CreateKernel(const onnxruntime::Node& node, return Status(ONNXRUNTIME, FAIL, "Failed to find kernel for " + node.OpType()); } - OpKernelInfo kernel_info(node, *kernel_create_info->kernel_def, execution_provider, session_state); + OpKernelInfo kernel_info(node, + *kernel_create_info->kernel_def, + execution_provider, + initialized_tensors, + mlvalue_name_idx_map, + funcs_mgr); op_kernel.reset(kernel_create_info->kernel_create_func(kernel_info)); return Status::OK(); } diff --git a/onnxruntime/core/framework/kernel_registry_manager.cc b/onnxruntime/core/framework/kernel_registry_manager.cc index 9ae5f335af..2f9199c806 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.cc +++ b/onnxruntime/core/framework/kernel_registry_manager.cc @@ -5,6 +5,7 @@ #include "core/framework/kernel_registry.h" #include "core/framework/customregistry.h" #include "core/framework/execution_providers.h" +#include "core/framework/session_state.h" using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; @@ -20,7 +21,12 @@ Status KernelRegistryManager::CreateKernel(const onnxruntime::Node& node, Status status; for (auto& registry : kernel_registries_) { - status = registry->CreateKernel(node, execution_provider, session_state, op_kernel); + status = registry->CreateKernel(node, + execution_provider, + session_state.GetInitializedTensors(), + session_state.GetMLValueNameIdxMap(), + session_state.GetFuncMgr(), + op_kernel); if (status.IsOK()) { return status; } diff --git a/onnxruntime/core/framework/op_kernel_info.cc b/onnxruntime/core/framework/op_kernel_info.cc index 09a4359c37..b1ad1b72de 100644 --- a/onnxruntime/core/framework/op_kernel_info.cc +++ b/onnxruntime/core/framework/op_kernel_info.cc @@ -1,25 +1,35 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/framework/mlvalue_name_idx_map.h" +#include "core/framework/fuse_nodes_funcs.h" #include "core/framework/op_kernel.h" #include "core/framework/op_kernel_info.h" -#include "core/framework/session_state.h" namespace onnxruntime { OpKernelInfo::OpKernelInfo(const onnxruntime::Node& node, const KernelDef& kernel_def, const IExecutionProvider& execution_provider, - const SessionState& session_state) + const std::unordered_map& initialized_tensors, + const MLValueNameIdxMap& mlvalue_name_idx_map, + const FuncManager& funcs_mgr) : OpNodeProtoHelper(&proto_helper_context_), node_(node), kernel_def_(kernel_def), execution_provider_(&execution_provider), - proto_helper_context_(node), - session_state_(session_state) {} + initialized_tensors_(initialized_tensors), + mlvalue_name_idx_map_(mlvalue_name_idx_map), + funcs_mgr_(funcs_mgr), + proto_helper_context_(node) {} OpKernelInfo::OpKernelInfo(const OpKernelInfo& other) - : OpKernelInfo(other.node_, other.kernel_def_, *other.execution_provider_, other.session_state_) {} + : OpKernelInfo(other.node_, + other.kernel_def_, + *other.execution_provider_, + other.initialized_tensors_, + other.mlvalue_name_idx_map_, + other.funcs_mgr_) {} const OrtAllocatorInfo& OpKernelInfo::GetAllocatorInfo(int device_id, OrtMemType mem_type) const { AllocatorPtr alloc = GetAllocator(device_id, mem_type); @@ -49,13 +59,12 @@ bool OpKernelInfo::TryGetConstantInput(int input_index, const Tensor** constant_ } auto& input_arg_name = node_.InputDefs()[input_index]->Name(); int input_arg_index = -1; - if (!session_state_.GetMLValueNameIdxMap().GetIdx(input_arg_name, input_arg_index).IsOK()) { + if (!mlvalue_name_idx_map_.GetIdx(input_arg_name, input_arg_index).IsOK()) { return false; } - auto& initializers = session_state_.GetInitializedTensors(); - auto iter = initializers.find(input_arg_index); - if (initializers.end() == iter) { + auto iter = initialized_tensors_.find(input_arg_index); + if (initialized_tensors_.end() == iter) { return false; } if (!iter->second.IsTensor()) { @@ -67,7 +76,6 @@ bool OpKernelInfo::TryGetConstantInput(int input_index, const Tensor** constant_ } common::Status OpKernelInfo::GetFusedFuncs(ComputeFunc* compute, CreateFunctionStateFunc* create, DestroyFunctionStateFunc* release) const { - const auto& funcs_mgr = session_state_.GetFuncMgr(); - return funcs_mgr.GetFuncs(node_.Name(), compute, create, release); + return funcs_mgr_.GetFuncs(node_.Name(), compute, create, release); } } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state_initializer.cc b/onnxruntime/core/framework/session_state_initializer.cc index 84e95ec302..0ffa1e22ab 100644 --- a/onnxruntime/core/framework/session_state_initializer.cc +++ b/onnxruntime/core/framework/session_state_initializer.cc @@ -371,7 +371,10 @@ static common::Status CreateOpKernelInternal(const onnxruntime::Node& node, const SessionState& session_state, const KernelRegistryManager& custom_registry_manager, std::unique_ptr& op_kernel) { - return custom_registry_manager.CreateKernel(node, exec_provider, session_state, op_kernel); + return custom_registry_manager.CreateKernel(node, + exec_provider, + session_state, + op_kernel); } static common::Status CreateOpKernel(const onnxruntime::Node& node, diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index fdd6956f17..60a95891f7 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -194,7 +194,12 @@ class PlannerTest : public ::testing::Test { } void BindKernel(onnxruntime::Node* p_node, ::onnxruntime::KernelDef& kernel_def) { - auto info = std::make_unique(*p_node, kernel_def, *execution_providers_.Get(*p_node), state_); + auto info = std::make_unique(*p_node, + kernel_def, + *execution_providers_.Get(*p_node), + state_.GetInitializedTensors(), + state_.GetMLValueNameIdxMap(), + state_.GetFuncMgr()); auto dummy = std::make_unique(*info); op_kernel_infos_.push_back(std::move(info)); state_.AddKernel(p_node->Index(), std::move(dummy)); diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index b586749f28..62c6932c58 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -53,7 +53,12 @@ TEST(SessionStateTest, AddGetKernelTest) { KernelDef kernel_def; CPUExecutionProvider execution_provider{CPUExecutionProviderInfo{"CPUExecutionProvider"}}; - OpKernelInfo p_info(node, kernel_def, execution_provider, s); + OpKernelInfo p_info(node, + kernel_def, + execution_provider, + s.GetInitializedTensors(), + s.GetMLValueNameIdxMap(), + s.GetFuncMgr()); unique_ptr p_kernel; p_kernel.reset(new TestOpKernel(p_info)); size_t orig_num_outputs = p_kernel->Node().OutputDefs().size();