mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
276 lines
11 KiB
C++
276 lines
11 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#include "core/framework/session_state.h"
|
|
|
|
#include <sstream>
|
|
|
|
#include "core/common/logging/logging.h"
|
|
#include "core/framework/node_index_info.h"
|
|
#include "core/framework/op_kernel.h"
|
|
#include "core/framework/utils.h"
|
|
|
|
using namespace ::onnxruntime::common;
|
|
namespace onnxruntime {
|
|
|
|
void SessionState::SetGraphViewer(std::unique_ptr<onnxruntime::GraphViewer> graph_viewer) {
|
|
ORT_ENFORCE(nullptr != graph_viewer);
|
|
graph_viewer_ = std::move(graph_viewer);
|
|
}
|
|
|
|
const GraphViewer* SessionState::GetGraphViewer() const { return graph_viewer_.get(); }
|
|
|
|
const OpKernel* SessionState::GetKernel(NodeIndex node_id) const {
|
|
if (session_kernels_.count(node_id) == 0) {
|
|
return nullptr;
|
|
}
|
|
|
|
return session_kernels_.find(node_id)->second.get();
|
|
}
|
|
|
|
void SessionState::AddKernel(onnxruntime::NodeIndex node_id, std::unique_ptr<OpKernel> p_kernel) {
|
|
// assumes vector is already resize()'ed to the number of nodes in the graph
|
|
session_kernels_[node_id] = std::move(p_kernel);
|
|
}
|
|
|
|
void SessionState::SetExecutionPlan(std::unique_ptr<SequentialExecutionPlan> p_seq_exec_plan) {
|
|
p_seq_exec_plan_ = std::move(p_seq_exec_plan);
|
|
}
|
|
|
|
const SequentialExecutionPlan* SessionState::GetExecutionPlan() const { return p_seq_exec_plan_.get(); }
|
|
|
|
Status SessionState::AddInitializedTensor(int mlvalue_index, const MLValue& mlvalue, const OrtCallback* d) {
|
|
ORT_ENFORCE(mlvalue_index >= 0 && mlvalue_index <= mlvalue_name_idx_map_.MaxIdx());
|
|
auto p = initialized_tensors_.insert({mlvalue_index, mlvalue});
|
|
if (!p.second)
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "duplicated mlvalue index:", mlvalue_index,
|
|
". Do you have duplicated calls to SessionState::AddInitializedTensor function?");
|
|
if (d != nullptr && d->f != nullptr) deleter_for_initialized_tensors_[mlvalue_index] = *d;
|
|
return Status::OK();
|
|
}
|
|
|
|
const std::unordered_map<int, MLValue>& SessionState::GetInitializedTensors() const { return initialized_tensors_; }
|
|
|
|
SessionState& SessionState::SetLogger(const logging::Logger& logger) {
|
|
logger_ = &logger;
|
|
return *this;
|
|
}
|
|
|
|
const logging::Logger& SessionState::Logger() const {
|
|
// DefaultLogger either throws or returns a valid logger.
|
|
const logging::Logger* logger = logger_ != nullptr ? logger_ : &logging::LoggingManager::DefaultLogger();
|
|
return *logger;
|
|
}
|
|
|
|
void SessionState::SetProfiler(profiling::Profiler& profiler) { profiler_ = &profiler; }
|
|
|
|
::onnxruntime::profiling::Profiler& SessionState::Profiler() const { return *profiler_; }
|
|
|
|
static int64_t CalculateMemoryPatternsKey(const std::vector<TensorShape>& shapes) {
|
|
int64_t key = 0;
|
|
for (auto& shape : shapes) {
|
|
for (auto dim : shape.GetDims()) key ^= dim;
|
|
}
|
|
return key;
|
|
}
|
|
|
|
const MemoryPatternGroup* SessionState::GetMemoryPatternGroup(const std::vector<TensorShape>& input_shapes) const {
|
|
std::lock_guard<OrtMutex> lock(mem_patterns_lock_);
|
|
int64_t key = CalculateMemoryPatternsKey(input_shapes);
|
|
auto it = mem_patterns_.find(key);
|
|
if (it == mem_patterns_.end()) return nullptr;
|
|
|
|
return it->second.get();
|
|
}
|
|
|
|
Status SessionState::UpdateMemoryPatternGroupCache(const std::vector<TensorShape>& input_shape,
|
|
std::unique_ptr<MemoryPatternGroup> mem_patterns) const {
|
|
int64_t key = CalculateMemoryPatternsKey(input_shape);
|
|
|
|
std::lock_guard<OrtMutex> lock(mem_patterns_lock_);
|
|
auto it = mem_patterns_.find(key);
|
|
if (it == mem_patterns_.end()) {
|
|
mem_patterns_[key] = std::move(mem_patterns);
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
|
|
common::Status SessionState::AddInputNameToNodeInfoMapping(const std::string& input_name, const NodeInfo& node_info) {
|
|
// in the future we could support multiple nodes on difference devices using an input, however right now
|
|
// the logic in utils::CopyOneInputAcrossDevices only checks the first entry.
|
|
// Instead of failing silently and adding extra entries that will be ignored, check if the required provider
|
|
// is the same for any duplicate entries. If it differs we can't run the model.
|
|
|
|
auto& entries = input_names_to_nodeinfo_mapping_[input_name];
|
|
|
|
if (entries.empty()) {
|
|
entries.push_back(node_info);
|
|
} else {
|
|
const auto& existing_entry = entries.front();
|
|
|
|
// if index == max it's an entry for an implicit input to a subgraph or unused graph input.
|
|
// we want to prefer the entry for explicit usage in this graph, as the implicit usage in a
|
|
// subgraph will be handled by the subgraph's SessionState.
|
|
if (node_info.index == std::numeric_limits<size_t>::max()) {
|
|
// ignore and preserve existing value
|
|
} else if (existing_entry.index == std::numeric_limits<size_t>::max()) {
|
|
// replace existing entry that is for an implicit input with new entry for explicit usage in this graph
|
|
entries[0] = node_info;
|
|
} else {
|
|
// if the providers match we can add the new entry for completeness (it will be ignored in
|
|
// utils::CopyOneInputAcrossDevices though).
|
|
// if they don't, we are broken.
|
|
const auto& current_provider = utils::GetNodeInputProviderType(entries[0]);
|
|
const auto& new_provider = utils::GetNodeInputProviderType(node_info);
|
|
|
|
if (current_provider == new_provider) {
|
|
entries.push_back(node_info);
|
|
} else {
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
|
|
"Using an input in multiple nodes on different devices is not supported currently. Input:",
|
|
input_name, " is used by node ", existing_entry.p_node->Name(), " (", current_provider,
|
|
") and node ", node_info.p_node->Name(), " (", new_provider, ").");
|
|
}
|
|
}
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
common::Status SessionState::GetInputNodeInfo(const std::string& input_name,
|
|
std::vector<NodeInfo>& node_info_vec) const {
|
|
if (!input_names_to_nodeinfo_mapping_.count(input_name)) {
|
|
return Status(ONNXRUNTIME, FAIL, "Failed to find input name in the mapping: " + input_name);
|
|
}
|
|
node_info_vec = input_names_to_nodeinfo_mapping_.at(input_name);
|
|
return Status::OK();
|
|
}
|
|
|
|
const SessionState::NameNodeInfoMapType& SessionState::GetInputNodeInfoMap() const {
|
|
return input_names_to_nodeinfo_mapping_;
|
|
}
|
|
|
|
void SessionState::AddOutputNameToNodeInfoMapping(const std::string& output_name, const NodeInfo& node_info) {
|
|
output_names_to_nodeinfo_mapping_[output_name].push_back(node_info);
|
|
}
|
|
|
|
const SessionState::NameNodeInfoMapType& SessionState::GetOutputNodeInfoMap() const {
|
|
return output_names_to_nodeinfo_mapping_;
|
|
}
|
|
|
|
void SessionState::AddSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name,
|
|
std::unique_ptr<SessionState> session_state) {
|
|
auto entry = subgraph_session_states_.find(index);
|
|
|
|
// make sure this is new. internal logic error if it is not so using ORT_ENFORCE.
|
|
if (entry != subgraph_session_states_.cend()) {
|
|
const auto& existing_entries = entry->second;
|
|
ORT_ENFORCE(existing_entries.find(attribute_name) == existing_entries.cend(), "Entry exists in node ", index,
|
|
" for attribute ", attribute_name);
|
|
}
|
|
|
|
subgraph_session_states_[index].insert(std::make_pair(attribute_name, std::move(session_state)));
|
|
}
|
|
|
|
SessionState* SessionState::GetMutableSubgraphSessionState(onnxruntime::NodeIndex index,
|
|
const std::string& attribute_name) {
|
|
SessionState* session_state = nullptr;
|
|
|
|
auto node_entry = subgraph_session_states_.find(index);
|
|
if (node_entry != subgraph_session_states_.cend()) {
|
|
const auto& attribute_state_map{node_entry->second};
|
|
|
|
const auto& subgraph_entry = attribute_state_map.find(attribute_name);
|
|
if (subgraph_entry != attribute_state_map.cend()) {
|
|
session_state = subgraph_entry->second.get();
|
|
}
|
|
}
|
|
|
|
return session_state;
|
|
}
|
|
|
|
const SessionState* SessionState::GetSubgraphSessionState(onnxruntime::NodeIndex index,
|
|
const std::string& attribute_name) const {
|
|
return const_cast<SessionState*>(this)->GetMutableSubgraphSessionState(index, attribute_name);
|
|
}
|
|
|
|
void SessionState::CalculateNodeIndexInfo() {
|
|
ORT_ENFORCE(graph_viewer_);
|
|
node_index_info_ = std::make_unique<NodeIndexInfo>(*graph_viewer_, mlvalue_name_idx_map_);
|
|
|
|
for (auto& node_to_map_pair : subgraph_session_states_) {
|
|
for (auto& attr_name_to_subgraph : node_to_map_pair.second) {
|
|
attr_name_to_subgraph.second->CalculateNodeIndexInfo();
|
|
}
|
|
}
|
|
}
|
|
|
|
const NodeIndexInfo& SessionState::GetNodeIndexInfo() const {
|
|
ORT_ENFORCE(node_index_info_, "CalculateNodeIndexInfo must be called prior to GetExecutionInfo.");
|
|
return *node_index_info_;
|
|
}
|
|
|
|
// use a cheap way of matching first. if we have multiple entries with this key, we will do the more expensive
|
|
// check of the individual feed/output names
|
|
static int MakeFeedsFetchesManagerCacheKey(const std::vector<std::string>& feed_names,
|
|
const std::vector<std::string>& output_names) {
|
|
return static_cast<int>(feed_names.size()) << 16 | static_cast<int>(output_names.size());
|
|
};
|
|
|
|
const FeedsFetchesManager* SessionState::GetFeedsFetchesManager(const std::vector<std::string>& feed_names,
|
|
const std::vector<std::string>& output_names) const {
|
|
int key = MakeFeedsFetchesManagerCacheKey(feed_names, output_names);
|
|
auto num_matches = cached_feeds_fetches_managers_.count(key);
|
|
|
|
const FeedsFetchesManager* manager = nullptr;
|
|
|
|
if (num_matches > 0) {
|
|
auto begin_end_pair = cached_feeds_fetches_managers_.equal_range(key);
|
|
auto iter = begin_end_pair.first;
|
|
auto end = begin_end_pair.second;
|
|
|
|
while (iter != end) {
|
|
auto& ffi = iter->second->GetFeedsFetchesInfo();
|
|
auto check = [](const std::vector<std::string>& input, const std::vector<std::string>& existing) {
|
|
for (size_t i = 0, end = input.size(); i < end; ++i) {
|
|
if (input[i] != existing[i]) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
};
|
|
|
|
if (check(feed_names, ffi.feed_names) && check(output_names, ffi.output_names)) {
|
|
manager = iter->second.get();
|
|
break;
|
|
}
|
|
|
|
++iter;
|
|
}
|
|
}
|
|
|
|
return manager;
|
|
}
|
|
|
|
Status SessionState::CacheFeedsFetchesManager(const std::vector<std::string>& feed_names,
|
|
const std::vector<std::string>& output_names,
|
|
std::unique_ptr<FeedsFetchesManager> manager) {
|
|
int key = MakeFeedsFetchesManagerCacheKey(feed_names, output_names);
|
|
|
|
auto num_matches = cached_feeds_fetches_managers_.count(key);
|
|
|
|
if (num_matches) {
|
|
// be paranoid and make sure we're not attempting to insert a duplicate entry.
|
|
// if so it would imply that there is probably a concurrency issue in InferenceSession::Run.
|
|
ORT_ENFORCE(GetFeedsFetchesManager(feed_names, output_names) == nullptr, "Existing FeedsFetchesManager found.");
|
|
}
|
|
|
|
cached_feeds_fetches_managers_.emplace(key, std::move(manager));
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
} // namespace onnxruntime
|