mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
Make OpKernelInfo not depend on SessionState. (#442)
This commit is contained in:
parent
9faac70dae
commit
851e291f22
8 changed files with 64 additions and 24 deletions
|
|
@ -25,7 +25,9 @@ class KernelRegistry {
|
|||
// TODO(Task:132) Make usage of unique_ptr/shared_ptr as out param consistent
|
||||
Status CreateKernel(const onnxruntime::Node& node,
|
||||
const IExecutionProvider& execution_provider,
|
||||
const SessionState& session_state,
|
||||
const std::unordered_map<int, MLValue>& initialized_tensors,
|
||||
const MLValueNameIdxMap& mlvalue_name_idx_map,
|
||||
const FuncManager& funcs_mgr,
|
||||
std::unique_ptr<OpKernel>& op_kernel) const;
|
||||
|
||||
// Check if an execution provider can create kernel for a node and return
|
||||
|
|
|
|||
|
|
@ -13,9 +13,9 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
class SessionState;
|
||||
|
||||
/**
|
||||
class MLValueNameIdxMap;
|
||||
class FuncManager;
|
||||
/**
|
||||
A very light-weight class, which works as an aggregated
|
||||
view of all data needed for constructing a Kernel instance.
|
||||
NOTE: it does not own/hold any objects.
|
||||
|
|
@ -25,7 +25,9 @@ class OpKernelInfo : public OpNodeProtoHelper<ProtoHelperNodeContext> {
|
|||
explicit OpKernelInfo(const onnxruntime::Node& node,
|
||||
const KernelDef& kernel_def,
|
||||
const IExecutionProvider& execution_provider,
|
||||
const SessionState& session_state);
|
||||
const std::unordered_map<int, MLValue>& initialized_tensors,
|
||||
const MLValueNameIdxMap& mlvalue_name_idx_map,
|
||||
const FuncManager& funcs_mgr);
|
||||
|
||||
OpKernelInfo(const OpKernelInfo& other);
|
||||
|
||||
|
|
@ -52,8 +54,10 @@ class OpKernelInfo : public OpNodeProtoHelper<ProtoHelperNodeContext> {
|
|||
// For non cpu/cuda case, this pointer should be set so that function kernel
|
||||
// will delegate kernel compute call to <execution_provider> compute call.
|
||||
gsl::not_null<const ::onnxruntime::IExecutionProvider*> execution_provider_;
|
||||
const std::unordered_map<int, MLValue>& initialized_tensors_;
|
||||
const MLValueNameIdxMap& mlvalue_name_idx_map_;
|
||||
const FuncManager& funcs_mgr_;
|
||||
ProtoHelperNodeContext proto_helper_context_;
|
||||
const SessionState& session_state_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@
|
|||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "core/framework/kernel_registry.h"
|
||||
#include "core/framework/session_state.h"
|
||||
|
||||
using namespace ::onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
|
@ -260,7 +260,9 @@ Status KernelRegistry::Register(KernelCreateInfo&& create_info) {
|
|||
|
||||
Status KernelRegistry::CreateKernel(const onnxruntime::Node& node,
|
||||
const IExecutionProvider& execution_provider,
|
||||
const SessionState& session_state,
|
||||
const std::unordered_map<int, MLValue>& initialized_tensors,
|
||||
const MLValueNameIdxMap& mlvalue_name_idx_map,
|
||||
const FuncManager& funcs_mgr,
|
||||
/*out*/ std::unique_ptr<OpKernel>& op_kernel) const {
|
||||
const KernelCreateInfo* kernel_create_info = TryFindKernel(node, execution_provider.Type());
|
||||
|
||||
|
|
@ -268,7 +270,12 @@ Status KernelRegistry::CreateKernel(const onnxruntime::Node& node,
|
|||
return Status(ONNXRUNTIME, FAIL, "Failed to find kernel for " + node.OpType());
|
||||
}
|
||||
|
||||
OpKernelInfo kernel_info(node, *kernel_create_info->kernel_def, execution_provider, session_state);
|
||||
OpKernelInfo kernel_info(node,
|
||||
*kernel_create_info->kernel_def,
|
||||
execution_provider,
|
||||
initialized_tensors,
|
||||
mlvalue_name_idx_map,
|
||||
funcs_mgr);
|
||||
op_kernel.reset(kernel_create_info->kernel_create_func(kernel_info));
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#include "core/framework/kernel_registry.h"
|
||||
#include "core/framework/customregistry.h"
|
||||
#include "core/framework/execution_providers.h"
|
||||
#include "core/framework/session_state.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace ::onnxruntime::common;
|
||||
|
|
@ -20,7 +21,12 @@ Status KernelRegistryManager::CreateKernel(const onnxruntime::Node& node,
|
|||
|
||||
Status status;
|
||||
for (auto& registry : kernel_registries_) {
|
||||
status = registry->CreateKernel(node, execution_provider, session_state, op_kernel);
|
||||
status = registry->CreateKernel(node,
|
||||
execution_provider,
|
||||
session_state.GetInitializedTensors(),
|
||||
session_state.GetMLValueNameIdxMap(),
|
||||
session_state.GetFuncMgr(),
|
||||
op_kernel);
|
||||
if (status.IsOK()) {
|
||||
return status;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,25 +1,35 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/framework/mlvalue_name_idx_map.h"
|
||||
#include "core/framework/fuse_nodes_funcs.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/framework/op_kernel_info.h"
|
||||
#include "core/framework/session_state.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
OpKernelInfo::OpKernelInfo(const onnxruntime::Node& node,
|
||||
const KernelDef& kernel_def,
|
||||
const IExecutionProvider& execution_provider,
|
||||
const SessionState& session_state)
|
||||
const std::unordered_map<int, MLValue>& initialized_tensors,
|
||||
const MLValueNameIdxMap& mlvalue_name_idx_map,
|
||||
const FuncManager& funcs_mgr)
|
||||
: OpNodeProtoHelper(&proto_helper_context_),
|
||||
node_(node),
|
||||
kernel_def_(kernel_def),
|
||||
execution_provider_(&execution_provider),
|
||||
proto_helper_context_(node),
|
||||
session_state_(session_state) {}
|
||||
initialized_tensors_(initialized_tensors),
|
||||
mlvalue_name_idx_map_(mlvalue_name_idx_map),
|
||||
funcs_mgr_(funcs_mgr),
|
||||
proto_helper_context_(node) {}
|
||||
|
||||
OpKernelInfo::OpKernelInfo(const OpKernelInfo& other)
|
||||
: OpKernelInfo(other.node_, other.kernel_def_, *other.execution_provider_, other.session_state_) {}
|
||||
: OpKernelInfo(other.node_,
|
||||
other.kernel_def_,
|
||||
*other.execution_provider_,
|
||||
other.initialized_tensors_,
|
||||
other.mlvalue_name_idx_map_,
|
||||
other.funcs_mgr_) {}
|
||||
|
||||
const OrtAllocatorInfo& OpKernelInfo::GetAllocatorInfo(int device_id, OrtMemType mem_type) const {
|
||||
AllocatorPtr alloc = GetAllocator(device_id, mem_type);
|
||||
|
|
@ -49,13 +59,12 @@ bool OpKernelInfo::TryGetConstantInput(int input_index, const Tensor** constant_
|
|||
}
|
||||
auto& input_arg_name = node_.InputDefs()[input_index]->Name();
|
||||
int input_arg_index = -1;
|
||||
if (!session_state_.GetMLValueNameIdxMap().GetIdx(input_arg_name, input_arg_index).IsOK()) {
|
||||
if (!mlvalue_name_idx_map_.GetIdx(input_arg_name, input_arg_index).IsOK()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto& initializers = session_state_.GetInitializedTensors();
|
||||
auto iter = initializers.find(input_arg_index);
|
||||
if (initializers.end() == iter) {
|
||||
auto iter = initialized_tensors_.find(input_arg_index);
|
||||
if (initialized_tensors_.end() == iter) {
|
||||
return false;
|
||||
}
|
||||
if (!iter->second.IsTensor()) {
|
||||
|
|
@ -67,7 +76,6 @@ bool OpKernelInfo::TryGetConstantInput(int input_index, const Tensor** constant_
|
|||
}
|
||||
|
||||
common::Status OpKernelInfo::GetFusedFuncs(ComputeFunc* compute, CreateFunctionStateFunc* create, DestroyFunctionStateFunc* release) const {
|
||||
const auto& funcs_mgr = session_state_.GetFuncMgr();
|
||||
return funcs_mgr.GetFuncs(node_.Name(), compute, create, release);
|
||||
return funcs_mgr_.GetFuncs(node_.Name(), compute, create, release);
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -371,7 +371,10 @@ static common::Status CreateOpKernelInternal(const onnxruntime::Node& node,
|
|||
const SessionState& session_state,
|
||||
const KernelRegistryManager& custom_registry_manager,
|
||||
std::unique_ptr<OpKernel>& op_kernel) {
|
||||
return custom_registry_manager.CreateKernel(node, exec_provider, session_state, op_kernel);
|
||||
return custom_registry_manager.CreateKernel(node,
|
||||
exec_provider,
|
||||
session_state,
|
||||
op_kernel);
|
||||
}
|
||||
|
||||
static common::Status CreateOpKernel(const onnxruntime::Node& node,
|
||||
|
|
|
|||
|
|
@ -194,7 +194,12 @@ class PlannerTest : public ::testing::Test {
|
|||
}
|
||||
|
||||
void BindKernel(onnxruntime::Node* p_node, ::onnxruntime::KernelDef& kernel_def) {
|
||||
auto info = std::make_unique<OpKernelInfo>(*p_node, kernel_def, *execution_providers_.Get(*p_node), state_);
|
||||
auto info = std::make_unique<OpKernelInfo>(*p_node,
|
||||
kernel_def,
|
||||
*execution_providers_.Get(*p_node),
|
||||
state_.GetInitializedTensors(),
|
||||
state_.GetMLValueNameIdxMap(),
|
||||
state_.GetFuncMgr());
|
||||
auto dummy = std::make_unique<DummyOpKernel>(*info);
|
||||
op_kernel_infos_.push_back(std::move(info));
|
||||
state_.AddKernel(p_node->Index(), std::move(dummy));
|
||||
|
|
|
|||
|
|
@ -53,7 +53,12 @@ TEST(SessionStateTest, AddGetKernelTest) {
|
|||
KernelDef kernel_def;
|
||||
CPUExecutionProvider execution_provider{CPUExecutionProviderInfo{"CPUExecutionProvider"}};
|
||||
|
||||
OpKernelInfo p_info(node, kernel_def, execution_provider, s);
|
||||
OpKernelInfo p_info(node,
|
||||
kernel_def,
|
||||
execution_provider,
|
||||
s.GetInitializedTensors(),
|
||||
s.GetMLValueNameIdxMap(),
|
||||
s.GetFuncMgr());
|
||||
unique_ptr<TestOpKernel> p_kernel;
|
||||
p_kernel.reset(new TestOpKernel(p_info));
|
||||
size_t orig_num_outputs = p_kernel->Node().OutputDefs().size();
|
||||
|
|
|
|||
Loading…
Reference in a new issue