Make OpKernelInfo not depend on SessionState. (#442)

This commit is contained in:
Weixing Zhang 2019-02-05 22:38:50 -08:00 committed by GitHub
parent 9faac70dae
commit 851e291f22
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 64 additions and 24 deletions

View file

@ -25,7 +25,9 @@ class KernelRegistry {
// TODO(Task:132) Make usage of unique_ptr/shared_ptr as out param consistent
Status CreateKernel(const onnxruntime::Node& node,
const IExecutionProvider& execution_provider,
const SessionState& session_state,
const std::unordered_map<int, MLValue>& initialized_tensors,
const MLValueNameIdxMap& mlvalue_name_idx_map,
const FuncManager& funcs_mgr,
std::unique_ptr<OpKernel>& op_kernel) const;
// Check if an execution provider can create kernel for a node and return

View file

@ -13,9 +13,9 @@
namespace onnxruntime {
class SessionState;
/**
class MLValueNameIdxMap;
class FuncManager;
/**
A very light-weight class, which works as an aggregated
view of all data needed for constructing a Kernel instance.
NOTE: it does not own/hold any objects.
@ -25,7 +25,9 @@ class OpKernelInfo : public OpNodeProtoHelper<ProtoHelperNodeContext> {
explicit OpKernelInfo(const onnxruntime::Node& node,
const KernelDef& kernel_def,
const IExecutionProvider& execution_provider,
const SessionState& session_state);
const std::unordered_map<int, MLValue>& initialized_tensors,
const MLValueNameIdxMap& mlvalue_name_idx_map,
const FuncManager& funcs_mgr);
OpKernelInfo(const OpKernelInfo& other);
@ -52,8 +54,10 @@ class OpKernelInfo : public OpNodeProtoHelper<ProtoHelperNodeContext> {
// For non cpu/cuda case, this pointer should be set so that function kernel
// will delegate kernel compute call to <execution_provider> compute call.
gsl::not_null<const ::onnxruntime::IExecutionProvider*> execution_provider_;
const std::unordered_map<int, MLValue>& initialized_tensors_;
const MLValueNameIdxMap& mlvalue_name_idx_map_;
const FuncManager& funcs_mgr_;
ProtoHelperNodeContext proto_helper_context_;
const SessionState& session_state_;
};
} // namespace onnxruntime

View file

@ -3,8 +3,8 @@
#include <memory>
#include <unordered_map>
#include "core/framework/kernel_registry.h"
#include "core/framework/session_state.h"
using namespace ::onnxruntime::common;
namespace onnxruntime {
@ -260,7 +260,9 @@ Status KernelRegistry::Register(KernelCreateInfo&& create_info) {
Status KernelRegistry::CreateKernel(const onnxruntime::Node& node,
const IExecutionProvider& execution_provider,
const SessionState& session_state,
const std::unordered_map<int, MLValue>& initialized_tensors,
const MLValueNameIdxMap& mlvalue_name_idx_map,
const FuncManager& funcs_mgr,
/*out*/ std::unique_ptr<OpKernel>& op_kernel) const {
const KernelCreateInfo* kernel_create_info = TryFindKernel(node, execution_provider.Type());
@ -268,7 +270,12 @@ Status KernelRegistry::CreateKernel(const onnxruntime::Node& node,
return Status(ONNXRUNTIME, FAIL, "Failed to find kernel for " + node.OpType());
}
OpKernelInfo kernel_info(node, *kernel_create_info->kernel_def, execution_provider, session_state);
OpKernelInfo kernel_info(node,
*kernel_create_info->kernel_def,
execution_provider,
initialized_tensors,
mlvalue_name_idx_map,
funcs_mgr);
op_kernel.reset(kernel_create_info->kernel_create_func(kernel_info));
return Status::OK();
}

View file

@ -5,6 +5,7 @@
#include "core/framework/kernel_registry.h"
#include "core/framework/customregistry.h"
#include "core/framework/execution_providers.h"
#include "core/framework/session_state.h"
using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
@ -20,7 +21,12 @@ Status KernelRegistryManager::CreateKernel(const onnxruntime::Node& node,
Status status;
for (auto& registry : kernel_registries_) {
status = registry->CreateKernel(node, execution_provider, session_state, op_kernel);
status = registry->CreateKernel(node,
execution_provider,
session_state.GetInitializedTensors(),
session_state.GetMLValueNameIdxMap(),
session_state.GetFuncMgr(),
op_kernel);
if (status.IsOK()) {
return status;
}

View file

@ -1,25 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/mlvalue_name_idx_map.h"
#include "core/framework/fuse_nodes_funcs.h"
#include "core/framework/op_kernel.h"
#include "core/framework/op_kernel_info.h"
#include "core/framework/session_state.h"
namespace onnxruntime {
OpKernelInfo::OpKernelInfo(const onnxruntime::Node& node,
const KernelDef& kernel_def,
const IExecutionProvider& execution_provider,
const SessionState& session_state)
const std::unordered_map<int, MLValue>& initialized_tensors,
const MLValueNameIdxMap& mlvalue_name_idx_map,
const FuncManager& funcs_mgr)
: OpNodeProtoHelper(&proto_helper_context_),
node_(node),
kernel_def_(kernel_def),
execution_provider_(&execution_provider),
proto_helper_context_(node),
session_state_(session_state) {}
initialized_tensors_(initialized_tensors),
mlvalue_name_idx_map_(mlvalue_name_idx_map),
funcs_mgr_(funcs_mgr),
proto_helper_context_(node) {}
OpKernelInfo::OpKernelInfo(const OpKernelInfo& other)
: OpKernelInfo(other.node_, other.kernel_def_, *other.execution_provider_, other.session_state_) {}
: OpKernelInfo(other.node_,
other.kernel_def_,
*other.execution_provider_,
other.initialized_tensors_,
other.mlvalue_name_idx_map_,
other.funcs_mgr_) {}
const OrtAllocatorInfo& OpKernelInfo::GetAllocatorInfo(int device_id, OrtMemType mem_type) const {
AllocatorPtr alloc = GetAllocator(device_id, mem_type);
@ -49,13 +59,12 @@ bool OpKernelInfo::TryGetConstantInput(int input_index, const Tensor** constant_
}
auto& input_arg_name = node_.InputDefs()[input_index]->Name();
int input_arg_index = -1;
if (!session_state_.GetMLValueNameIdxMap().GetIdx(input_arg_name, input_arg_index).IsOK()) {
if (!mlvalue_name_idx_map_.GetIdx(input_arg_name, input_arg_index).IsOK()) {
return false;
}
auto& initializers = session_state_.GetInitializedTensors();
auto iter = initializers.find(input_arg_index);
if (initializers.end() == iter) {
auto iter = initialized_tensors_.find(input_arg_index);
if (initialized_tensors_.end() == iter) {
return false;
}
if (!iter->second.IsTensor()) {
@ -67,7 +76,6 @@ bool OpKernelInfo::TryGetConstantInput(int input_index, const Tensor** constant_
}
common::Status OpKernelInfo::GetFusedFuncs(ComputeFunc* compute, CreateFunctionStateFunc* create, DestroyFunctionStateFunc* release) const {
const auto& funcs_mgr = session_state_.GetFuncMgr();
return funcs_mgr.GetFuncs(node_.Name(), compute, create, release);
return funcs_mgr_.GetFuncs(node_.Name(), compute, create, release);
}
} // namespace onnxruntime

View file

@ -371,7 +371,10 @@ static common::Status CreateOpKernelInternal(const onnxruntime::Node& node,
const SessionState& session_state,
const KernelRegistryManager& custom_registry_manager,
std::unique_ptr<OpKernel>& op_kernel) {
return custom_registry_manager.CreateKernel(node, exec_provider, session_state, op_kernel);
return custom_registry_manager.CreateKernel(node,
exec_provider,
session_state,
op_kernel);
}
static common::Status CreateOpKernel(const onnxruntime::Node& node,

View file

@ -194,7 +194,12 @@ class PlannerTest : public ::testing::Test {
}
void BindKernel(onnxruntime::Node* p_node, ::onnxruntime::KernelDef& kernel_def) {
auto info = std::make_unique<OpKernelInfo>(*p_node, kernel_def, *execution_providers_.Get(*p_node), state_);
auto info = std::make_unique<OpKernelInfo>(*p_node,
kernel_def,
*execution_providers_.Get(*p_node),
state_.GetInitializedTensors(),
state_.GetMLValueNameIdxMap(),
state_.GetFuncMgr());
auto dummy = std::make_unique<DummyOpKernel>(*info);
op_kernel_infos_.push_back(std::move(info));
state_.AddKernel(p_node->Index(), std::move(dummy));

View file

@ -53,7 +53,12 @@ TEST(SessionStateTest, AddGetKernelTest) {
KernelDef kernel_def;
CPUExecutionProvider execution_provider{CPUExecutionProviderInfo{"CPUExecutionProvider"}};
OpKernelInfo p_info(node, kernel_def, execution_provider, s);
OpKernelInfo p_info(node,
kernel_def,
execution_provider,
s.GetInitializedTensors(),
s.GetMLValueNameIdxMap(),
s.GetFuncMgr());
unique_ptr<TestOpKernel> p_kernel;
p_kernel.reset(new TestOpKernel(p_info));
size_t orig_num_outputs = p_kernel->Node().OutputDefs().size();