onnxruntime/onnxruntime/core/framework/session_state.h
Ke Zhang 3bf0e364e2
Move CopyTensor out of IExecutionProvider interface. (#1268)
* add ortdevice class

* add data transfer manager for copying tensors.

* update

* add data trasnfer for gpu

* fix constexpr build break.

* update

* remove unnecessary header files.

* remove unnecessary header files.

* add dependency

* add dependency

* add dependency

* add dependency

* fix linux build break.

* update

* fix build break

* fix build break

* fix build break

* update

* update

* update c api.

* update to not use OrtCreateAllocatorInfo

* change to all eps .

* fix linux build break

* remove useless codes.

* update

* move datatransfermanager in session state

* update

* fix cuda build break.

* fix comments

* fix windows GPU build.

* fix comments

* fix build break

* fix comments

* fix test failure

* update

* fix comments

* fix onnx runtime server.

* update

* fix test failure.

* fix comments

* fix comment
2019-07-11 14:49:20 -07:00

244 lines
9.3 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <memory>
#include <map>
#include <unordered_map>
#include <vector>
#include "gsl/gsl_util"
#include "core/platform/ort_mutex.h"
#include "core/common/common.h"
#include "core/common/logging/logging.h"
#include "core/common/profiler.h"
#include "core/framework/allocation_planner.h"
#include "core/framework/data_transfer_manager.h"
#include "core/framework/execution_providers.h"
#include "core/framework/feeds_fetches_manager.h"
#include "core/framework/kernel_registry_manager.h"
#include "core/framework/mem_pattern.h"
#include "core/framework/ml_value.h"
#include "core/common/callback.h"
#include "core/framework/ort_value_name_idx_map.h"
#include "core/framework/node_index_info.h"
#include "core/graph/graph_viewer.h"
#include "core/framework/fuse_nodes_funcs.h"
#include "core/platform/threadpool.h"
namespace onnxruntime {
class ExecutionProviders;
class KernelDef;
class OpKernel;
class NodeIndexInfo;
struct SequentialExecutionPlan;
struct MemoryPatternGroup;
/**
* SessionState should be modified by the inference session class only.
* It is supposed to be passed by const-ref only to all the executors.
* This class owns all the initializers.
*/
class SessionState {
public:
SessionState(const ExecutionProviders& execution_providers, bool enable_mem_pattern)
: execution_providers_{execution_providers}, enable_mem_pattern_(enable_mem_pattern) {}
~SessionState() {
for (auto& kvp : deleter_for_initialized_tensors_) {
kvp.second.f(kvp.second.param);
}
}
// Graph viewer.
void SetGraphViewer(std::unique_ptr<GraphViewer> graph_viewer);
const GraphViewer* GetGraphViewer() const;
// kernels
// Get kernel for specified node.
// It should called right before graph execution only.
const OpKernel* GetKernel(NodeIndex node_id) const;
void AddKernel(NodeIndex node_id, std::unique_ptr<OpKernel> p_kernel);
const ExecutionProviders& GetExecutionProviders() const noexcept { return execution_providers_; }
const OrtValueNameIdxMap& GetOrtValueNameIdxMap() const noexcept { return ort_value_name_idx_map_; }
OrtValueNameIdxMap& GetOrtValueNameIdxMap() noexcept { return ort_value_name_idx_map_; }
// initialized tensors
/**
* Adds an initialized tensor (weight) so that it can be used by the
* execution frame to setup the appropriate OrtValue vectors.
* This function will take a shallow copy of d if d is not NULL.
* If 'constant' is true the tensor value cannot be overridden by an input at runtime.
*/
Status AddInitializedTensor(int ort_value_index, const OrtValue& ort_value, const OrtCallback* d, bool constant);
/**
* Gets the map of ort_value_index to initialized tensors (weights) so that it can be used by the
* execution frame to setup the appropriate OrtValue vectors.
* The lifetime of returned OrtValues are limited by this SessionState object.
*/
const std::unordered_map<int, OrtValue>& GetInitializedTensors() const;
/**
* Gets the map of ort_value_index to initialized tensors (e.g. weights) that are constant
* and cannot be overridden at runtime.
* The lifetime of returned OrtValues are limited by this SessionState object.
*/
const std::unordered_map<int, OrtValue>& GetConstantInitializedTensors() const;
// execution plan
void SetExecutionPlan(std::unique_ptr<SequentialExecutionPlan> p_seq_exec_plan);
const SequentialExecutionPlan* GetExecutionPlan() const;
/**
Set the logger to use for this session.
*/
SessionState& SetLogger(const logging::Logger& logger);
/**
Get the logger for this session.
Falls back to returning Logging::LoggingManager::DefaultLogger if SetLogger has not been called.
*/
const logging::Logger& Logger() const;
/**
Set the profiler for this session.
*/
void SetProfiler(profiling::Profiler& profiler);
/**
Get the profiler for this session. It needs to be enabled via the InferenceSession to perform
profiling actions.
*/
profiling::Profiler& Profiler() const;
/**
Get cached memory pattern based on input shapes
*/
const MemoryPatternGroup* GetMemoryPatternGroup(const std::vector<TensorShape>& input_shapes) const;
/**
Set generated memory pattern with a given input shapes.
Const as it's an internal cache update only.
*/
Status UpdateMemoryPatternGroupCache(const std::vector<TensorShape>& input_shape,
std::unique_ptr<MemoryPatternGroup> mem_patterns) const;
/**
Get enable memory pattern flag
*/
bool GetEnableMemoryPattern() const;
struct NodeInfo {
/**
*
* \param index0
* \param p_node0 Nullable
* \param kci0 Nullable
*/
NodeInfo(size_t index0, const onnxruntime::Node* p_node0, const KernelCreateInfo* kci0)
: index(index0),
p_node(p_node0),
kci(kci0) {
}
size_t index;
// Nullable
const onnxruntime::Node* p_node = nullptr;
// Nullable
const KernelCreateInfo* kci = nullptr;
};
using NameNodeInfoMapType = std::unordered_map<std::string, std::vector<NodeInfo>>;
common::Status AddInputNameToNodeInfoMapping(const std::string& input_name, const NodeInfo& node_info);
common::Status GetInputNodeInfo(const std::string& input_name, std::vector<NodeInfo>& node_info_vec) const;
const NameNodeInfoMapType& GetInputNodeInfoMap() const;
void AddOutputNameToNodeInfoMapping(const std::string& output_name, const NodeInfo& node_info);
const NameNodeInfoMapType& GetOutputNodeInfoMap() const;
/// Add a SessionState instance for executing a subgraph in a Node
/// @param index Index of Node containing subgraph
/// @param attribute_name Name of attribute containing the subgraph GraphProto
/// @param session_state SessionState for subgraph execution
void AddSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name,
std::unique_ptr<SessionState> session_state);
/// Return SessionState for the given Node index and attribute name if found.
const SessionState* GetSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name) const;
SessionState* GetMutableSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name);
onnxruntime::concurrency::ThreadPool* GetThreadPool() const { return thread_pool_; }
void SetThreadPool(onnxruntime::concurrency::ThreadPool* p_pool) { thread_pool_ = p_pool; }
bool ExportDll() const { return export_fused_dll_; }
void SetExportDllFlag(bool flag) { export_fused_dll_ = flag; }
const FuncManager& GetFuncMgr() const { return fused_funcs_mgr_; }
FuncManager& GetMutableFuncMgr() { return fused_funcs_mgr_; }
const DataTransferManager& GetDataTransferMgr() const { return *data_transfer_mgr_; }
void SetDataTransferMgr(const DataTransferManager* data_transfer_mgr) { data_transfer_mgr_ = data_transfer_mgr; }
std::vector<BufferUniquePtr>& GetMutableWeightsBuffers() { return weights_buffers_; }
void CalculateNodeIndexInfo();
const NodeIndexInfo& GetNodeIndexInfo() const;
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SessionState);
// cache of the constructed kernels to avoid spending construction
// time per executor
std::unordered_map<NodeIndex, std::unique_ptr<OpKernel>> session_kernels_;
std::unique_ptr<GraphViewer> graph_viewer_;
const ExecutionProviders& execution_providers_; // owned by InferenceSession
OrtValueNameIdxMap ort_value_name_idx_map_;
// initialized tensors
std::unordered_map<int, OrtValue> initialized_tensors_; // key is ort_value_index
// subset of initialized_tensors_ that are constant and cannot be overridden at runtime
std::unordered_map<int, OrtValue> constant_initialized_tensors_;
// This data structure is for uninitializing string tensors and
// munmap memory region and close file descriptor
std::unordered_map<int, OrtCallback> deleter_for_initialized_tensors_;
std::vector<BufferUniquePtr> weights_buffers_;
std::unique_ptr<SequentialExecutionPlan> p_seq_exec_plan_ = nullptr;
const logging::Logger* logger_ = nullptr;
profiling::Profiler* profiler_;
// switch for enable memory pattern optimization or not.
const bool enable_mem_pattern_;
// lock for the mem_patterns_
mutable OrtMutex mem_patterns_lock_;
// cache for the generated mem_patterns. key is calculated based on input shapes.
mutable std::map<int64_t, std::unique_ptr<MemoryPatternGroup>> mem_patterns_;
NameNodeInfoMapType input_names_to_nodeinfo_mapping_;
NameNodeInfoMapType output_names_to_nodeinfo_mapping_;
// subgraph SessionState. entry for node containing subgraph, with value containing attribute:SessionState pair
// as a node may contain multiple subgraphs (e.g. 'If' has one for both the 'then' and 'else' branches).
using SubgraphSessionStateMap =
std::unordered_map<onnxruntime::NodeIndex, std::unordered_map<std::string, std::unique_ptr<SessionState>>>;
SubgraphSessionStateMap subgraph_session_states_;
onnxruntime::concurrency::ThreadPool* thread_pool_ = nullptr;
bool export_fused_dll_ = false;
FuncManager fused_funcs_mgr_;
const DataTransferManager* data_transfer_mgr_;
std::unique_ptr<NodeIndexInfo> node_index_info_;
std::multimap<int, std::unique_ptr<FeedsFetchesManager>> cached_feeds_fetches_managers_;
};
} // namespace onnxruntime