mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-23 02:38:28 +00:00
Optimize kernel index (#1672)
This commit is contained in:
parent
a818740d91
commit
4de0aa8049
12 changed files with 214 additions and 483 deletions
|
|
@ -526,7 +526,7 @@ install(TARGETS onnx_test_runner
|
|||
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR})
|
||||
|
||||
if(onnxruntime_BUILD_BENCHMARKS)
|
||||
add_executable(onnxruntime_benchmark ${TEST_SRC_DIR}/onnx/microbenchmark/main.cc ${TEST_SRC_DIR}/onnx/microbenchmark/modeltest.cc ${TEST_SRC_DIR}/onnx/microbenchmark/model_init.cc)
|
||||
add_executable(onnxruntime_benchmark ${TEST_SRC_DIR}/onnx/microbenchmark/main.cc ${TEST_SRC_DIR}/onnx/microbenchmark/modeltest.cc)
|
||||
target_include_directories(onnxruntime_benchmark PRIVATE ${ONNXRUNTIME_ROOT} ${onnxruntime_graph_header} benchmark)
|
||||
onnxruntime_add_include_to_target(onnxruntime_benchmark gsl)
|
||||
if(WIN32)
|
||||
|
|
|
|||
|
|
@ -11,23 +11,96 @@
|
|||
#include "core/framework/utils.h"
|
||||
|
||||
using namespace ::onnxruntime::common;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
void SessionState::SetGraphViewer(std::unique_ptr<onnxruntime::GraphViewer> graph_viewer) {
|
||||
ORT_ENFORCE(nullptr != graph_viewer);
|
||||
graph_viewer_ = std::move(graph_viewer);
|
||||
}
|
||||
|
||||
const GraphViewer* SessionState::GetGraphViewer() const { return graph_viewer_.get(); }
|
||||
Status SessionState::SetGraph(const Graph& graph) {
|
||||
graph_viewer_ = std::make_unique<onnxruntime::GraphViewer>(graph);
|
||||
auto& logger = Logger();
|
||||
// use graph_viewer_ to initialize ort_value_name_idx_map_
|
||||
LOGS(logger, INFO) << "SaveMLValueNameIndexMapping";
|
||||
int idx = 0;
|
||||
|
||||
const OpKernel* SessionState::GetKernel(NodeIndex node_id) const {
|
||||
auto kernel = session_kernels_.find(node_id);
|
||||
return (kernel != session_kernels_.cend()) ? kernel->second.get() : nullptr;
|
||||
// we keep all graph inputs (including initializers), even if they are unused, so make sure they all have an entry
|
||||
for (const auto* input_def : graph_viewer_->GetInputsIncludingInitializers()) {
|
||||
idx = ort_value_name_idx_map_.Add(input_def->Name());
|
||||
VLOGS(logger, 1) << "Added graph_viewer_ input with name: " << input_def->Name()
|
||||
<< " to OrtValueIndex with index: " << idx;
|
||||
}
|
||||
|
||||
for (auto& node : graph_viewer_->Nodes()) {
|
||||
// build the OrtValue->index map
|
||||
for (const auto* input_def : node.InputDefs()) {
|
||||
if (input_def->Exists()) {
|
||||
idx = ort_value_name_idx_map_.Add(input_def->Name());
|
||||
VLOGS(logger, 1) << "Added input argument with name: " << input_def->Name()
|
||||
<< " to OrtValueIndex with index: " << idx;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto* input_def : node.ImplicitInputDefs()) {
|
||||
if (input_def->Exists()) {
|
||||
idx = ort_value_name_idx_map_.Add(input_def->Name());
|
||||
VLOGS(logger, 1) << "Added implicit input argument with name: " << input_def->Name()
|
||||
<< " to OrtValueIndex with index: " << idx;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto* output_def : node.OutputDefs()) {
|
||||
if (output_def->Exists()) {
|
||||
ort_value_name_idx_map_.Add(output_def->Name());
|
||||
VLOGS(logger, 1) << "Added output argument with name: " << output_def->Name()
|
||||
<< " to OrtValueIndex with index: " << idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// allocate OrtValue for graph outputs when coming from initializers
|
||||
for (const auto& output : graph_viewer_->GetOutputs()) {
|
||||
if (output->Exists()) {
|
||||
idx = ort_value_name_idx_map_.Add(output->Name());
|
||||
VLOGS(logger, 1) << "Added graph output with name: " << output->Name() << " to OrtValueIndex with index: " << idx;
|
||||
}
|
||||
}
|
||||
|
||||
LOGS(logger, INFO) << "Done saving OrtValue mappings.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void SessionState::AddKernel(onnxruntime::NodeIndex node_id, std::unique_ptr<OpKernel> p_kernel) {
|
||||
// assumes vector is already resize()'ed to the number of nodes in the graph
|
||||
session_kernels_[node_id] = std::move(p_kernel);
|
||||
Status SessionState::CreateKernels(const KernelRegistryManager& custom_registry_manager) {
|
||||
const GraphNodes& nodes = graph_viewer_->Nodes();
|
||||
if (!nodes.empty()) {
|
||||
size_t max_nodeid = 0;
|
||||
for (auto& node : graph_viewer_->Nodes()) {
|
||||
max_nodeid = std::max(max_nodeid, node.Index());
|
||||
}
|
||||
session_kernels_.clear();
|
||||
session_kernels_.resize(max_nodeid + 1, nullptr);
|
||||
for (auto& node : graph_viewer_->Nodes()) {
|
||||
// construct and save the kernels
|
||||
std::unique_ptr<OpKernel> op_kernel;
|
||||
onnxruntime::ProviderType exec_provider_name = node.GetExecutionProviderType();
|
||||
|
||||
const IExecutionProvider* exec_provider = nullptr;
|
||||
if (exec_provider_name.empty() || (exec_provider = execution_providers_.Get(exec_provider_name)) == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Could not create kernel for node: ", node.Name(),
|
||||
" as there's no execution provider allocated.");
|
||||
}
|
||||
|
||||
common::Status status = custom_registry_manager.CreateKernel(node, *exec_provider, *this, op_kernel);
|
||||
if (!status.IsOK()) {
|
||||
return common::Status(
|
||||
status.Category(), status.Code(),
|
||||
MakeString("Kernel creation failed for node: ", node.Name(), " with error: ", status.ErrorMessage()));
|
||||
}
|
||||
assert(session_kernels_[node.Index()] == nullptr);
|
||||
// assumes vector is already resize()'ed to the number of nodes in the graph
|
||||
session_kernels_[node.Index()] = op_kernel.release();
|
||||
}
|
||||
}
|
||||
node_index_info_ = std::make_unique<NodeIndexInfo>(*graph_viewer_, ort_value_name_idx_map_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void SessionState::SetExecutionPlan(std::unique_ptr<SequentialExecutionPlan> p_seq_exec_plan) {
|
||||
|
|
@ -38,7 +111,6 @@ const SequentialExecutionPlan* SessionState::GetExecutionPlan() const { return p
|
|||
|
||||
Status SessionState::AddInitializedTensor(int ort_value_index, const OrtValue& ort_value, const OrtCallback* d,
|
||||
bool constant) {
|
||||
ORT_ENFORCE(ort_value_index >= 0 && ort_value_index <= ort_value_name_idx_map_.MaxIdx());
|
||||
auto p = initialized_tensors_.insert({ort_value_index, ort_value});
|
||||
if (!p.second)
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "duplicated ort_value index:", ort_value_index,
|
||||
|
|
@ -55,9 +127,7 @@ Status SessionState::AddInitializedTensor(int ort_value_index, const OrtValue& o
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
const std::unordered_map<int, OrtValue>& SessionState::GetInitializedTensors() const {
|
||||
return initialized_tensors_;
|
||||
}
|
||||
const std::unordered_map<int, OrtValue>& SessionState::GetInitializedTensors() const { return initialized_tensors_; }
|
||||
|
||||
const std::unordered_map<int, OrtValue>& SessionState::GetConstantInitializedTensors() const {
|
||||
return constant_initialized_tensors_;
|
||||
|
|
@ -86,7 +156,8 @@ static int64_t CalculateMemoryPatternsKey(const std::vector<std::reference_wrapp
|
|||
return key;
|
||||
}
|
||||
|
||||
const MemoryPatternGroup* SessionState::GetMemoryPatternGroup(const std::vector<std::reference_wrapper<const TensorShape>>& input_shapes) const {
|
||||
const MemoryPatternGroup* SessionState::GetMemoryPatternGroup(
|
||||
const std::vector<std::reference_wrapper<const TensorShape>>& input_shapes) const {
|
||||
int64_t key = CalculateMemoryPatternsKey(input_shapes);
|
||||
|
||||
std::lock_guard<OrtMutex> lock(mem_patterns_lock_);
|
||||
|
|
@ -96,8 +167,9 @@ const MemoryPatternGroup* SessionState::GetMemoryPatternGroup(const std::vector<
|
|||
return it->second.get();
|
||||
}
|
||||
|
||||
Status SessionState::UpdateMemoryPatternGroupCache(const std::vector<std::reference_wrapper<const TensorShape>>& input_shapes,
|
||||
std::unique_ptr<MemoryPatternGroup> mem_patterns) const {
|
||||
Status SessionState::UpdateMemoryPatternGroupCache(
|
||||
const std::vector<std::reference_wrapper<const TensorShape>>& input_shapes,
|
||||
std::unique_ptr<MemoryPatternGroup> mem_patterns) const {
|
||||
int64_t key = CalculateMemoryPatternsKey(input_shapes);
|
||||
|
||||
std::lock_guard<OrtMutex> lock(mem_patterns_lock_);
|
||||
|
|
@ -109,9 +181,7 @@ Status SessionState::UpdateMemoryPatternGroupCache(const std::vector<std::refere
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool SessionState::GetEnableMemoryPattern() const {
|
||||
return enable_mem_pattern_;
|
||||
}
|
||||
bool SessionState::GetEnableMemoryPattern() const { return enable_mem_pattern_; }
|
||||
|
||||
common::Status SessionState::AddInputNameToNodeInfoMapping(const std::string& input_name, const NodeInfo& node_info) {
|
||||
// in the future we could support multiple nodes on difference devices using an input, however right now
|
||||
|
|
@ -144,10 +214,11 @@ common::Status SessionState::AddInputNameToNodeInfoMapping(const std::string& in
|
|||
if (current_provider == new_provider) {
|
||||
entries.push_back(node_info);
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
|
||||
"Using an input in multiple nodes on different devices is not supported currently. Input:",
|
||||
input_name, " is used by node ", existing_entry.p_node->Name(), " (", current_provider,
|
||||
") and node ", node_info.p_node->Name(), " (", new_provider, ").");
|
||||
return ORT_MAKE_STATUS(
|
||||
ONNXRUNTIME, NOT_IMPLEMENTED,
|
||||
"Using an input in multiple nodes on different devices is not supported currently. Input:", input_name,
|
||||
" is used by node ", existing_entry.p_node->Name(), " (", current_provider, ") and node ",
|
||||
node_info.p_node->Name(), " (", new_provider, ").");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -178,16 +249,15 @@ const SessionState::NameNodeInfoMapType& SessionState::GetOutputNodeInfoMap() co
|
|||
return output_names_to_nodeinfo_mapping_;
|
||||
}
|
||||
|
||||
void SessionState::AddSubgraphSessionState(onnxruntime::NodeIndex index,
|
||||
const std::string& attribute_name,
|
||||
void SessionState::AddSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name,
|
||||
std::unique_ptr<SessionState> session_state) {
|
||||
auto entry = subgraph_session_states_.find(index);
|
||||
|
||||
// make sure this is new. internal logic error if it is not so using ORT_ENFORCE.
|
||||
if (entry != subgraph_session_states_.cend()) {
|
||||
const auto& existing_entries = entry->second;
|
||||
ORT_ENFORCE(existing_entries.find(attribute_name) == existing_entries.cend(),
|
||||
"Entry exists in node ", index, " for attribute ", attribute_name);
|
||||
ORT_ENFORCE(existing_entries.find(attribute_name) == existing_entries.cend(), "Entry exists in node ", index,
|
||||
" for attribute ", attribute_name);
|
||||
}
|
||||
|
||||
subgraph_session_states_[index].insert(std::make_pair(attribute_name, std::move(session_state)));
|
||||
|
|
@ -215,19 +285,8 @@ const SessionState* SessionState::GetSubgraphSessionState(onnxruntime::NodeIndex
|
|||
return const_cast<SessionState*>(this)->GetMutableSubgraphSessionState(index, attribute_name);
|
||||
}
|
||||
|
||||
void SessionState::CalculateNodeIndexInfo() {
|
||||
ORT_ENFORCE(graph_viewer_);
|
||||
node_index_info_ = std::make_unique<NodeIndexInfo>(*graph_viewer_, ort_value_name_idx_map_);
|
||||
|
||||
for (auto& node_to_map_pair : subgraph_session_states_) {
|
||||
for (auto& attr_name_to_subgraph : node_to_map_pair.second) {
|
||||
attr_name_to_subgraph.second->CalculateNodeIndexInfo();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const NodeIndexInfo& SessionState::GetNodeIndexInfo() const {
|
||||
ORT_ENFORCE(node_index_info_, "CalculateNodeIndexInfo must be called prior to GetExecutionInfo.");
|
||||
ORT_ENFORCE(node_index_info_, "SetGraphAndCreateKernels must be called prior to GetExecutionInfo.");
|
||||
return *node_index_info_;
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -40,33 +40,41 @@ struct MemoryPatternGroup;
|
|||
* SessionState should be modified by the inference session class only.
|
||||
* It is supposed to be passed by const-ref only to all the executors.
|
||||
* This class owns all the initializers.
|
||||
* Brief usage:
|
||||
* SessionState s(...);
|
||||
* for(...) s.AddInitializedTensor(...);
|
||||
* s.SetGraphAndCreateKernels(...);
|
||||
* Then you can use:
|
||||
* s.GetKernel(...);
|
||||
*/
|
||||
class SessionState {
|
||||
public:
|
||||
SessionState(const ExecutionProviders& execution_providers, bool enable_mem_pattern, concurrency::ThreadPool* thread_pool)
|
||||
SessionState(const ExecutionProviders& execution_providers, bool enable_mem_pattern,
|
||||
concurrency::ThreadPool* thread_pool)
|
||||
: execution_providers_{execution_providers}, enable_mem_pattern_(enable_mem_pattern), thread_pool_(thread_pool) {}
|
||||
|
||||
~SessionState() {
|
||||
for (auto* p : session_kernels_) {
|
||||
delete p;
|
||||
}
|
||||
for (auto& kvp : deleter_for_initialized_tensors_) {
|
||||
kvp.second.f(kvp.second.param);
|
||||
}
|
||||
}
|
||||
|
||||
// Graph viewer.
|
||||
void SetGraphViewer(std::unique_ptr<GraphViewer> graph_viewer);
|
||||
const GraphViewer* GetGraphViewer() const;
|
||||
|
||||
// kernels
|
||||
// Get kernel for specified node.
|
||||
// It should called right before graph execution only.
|
||||
const OpKernel* GetKernel(NodeIndex node_id) const;
|
||||
|
||||
void AddKernel(NodeIndex node_id, std::unique_ptr<OpKernel> p_kernel);
|
||||
const OpKernel* GetKernel(size_t node_id) const {
|
||||
return (node_id < session_kernels_.size()) ? session_kernels_[node_id] : nullptr;
|
||||
}
|
||||
|
||||
const ExecutionProviders& GetExecutionProviders() const noexcept { return execution_providers_; }
|
||||
|
||||
const OrtValueNameIdxMap& GetOrtValueNameIdxMap() const noexcept { return ort_value_name_idx_map_; }
|
||||
OrtValueNameIdxMap& GetOrtValueNameIdxMap() noexcept { return ort_value_name_idx_map_; }
|
||||
|
||||
// initialized tensors
|
||||
/**
|
||||
|
|
@ -77,6 +85,12 @@ class SessionState {
|
|||
*/
|
||||
Status AddInitializedTensor(int ort_value_index, const OrtValue& ort_value, const OrtCallback* d, bool constant);
|
||||
|
||||
Status SetGraph(const Graph& graph);
|
||||
Status CreateKernels(const KernelRegistryManager& custom_registry_manager);
|
||||
Status SetGraphAndCreateKernels(const Graph& graph, const KernelRegistryManager& custom_registry_manager) {
|
||||
ORT_RETURN_IF_ERROR(SetGraph(graph));
|
||||
return CreateKernels(custom_registry_manager);
|
||||
}
|
||||
/**
|
||||
* Gets the map of ort_value_index to initialized tensors (weights) so that it can be used by the
|
||||
* execution frame to setup the appropriate OrtValue vectors.
|
||||
|
|
@ -85,8 +99,8 @@ class SessionState {
|
|||
const std::unordered_map<int, OrtValue>& GetInitializedTensors() const;
|
||||
|
||||
/**
|
||||
* Gets the map of ort_value_index to initialized tensors (e.g. weights) that are constant
|
||||
* and cannot be overridden at runtime.
|
||||
* Gets the map of ort_value_index to initialized tensors (e.g. weights) that are constant
|
||||
* and cannot be overridden at runtime.
|
||||
* The lifetime of returned OrtValues are limited by this SessionState object.
|
||||
*/
|
||||
const std::unordered_map<int, OrtValue>& GetConstantInitializedTensors() const;
|
||||
|
|
@ -96,12 +110,12 @@ class SessionState {
|
|||
const SequentialExecutionPlan* GetExecutionPlan() const;
|
||||
|
||||
/**
|
||||
Set the logger to use for this session.
|
||||
Set the logger to use for this session.
|
||||
*/
|
||||
SessionState& SetLogger(const logging::Logger& logger);
|
||||
|
||||
/**
|
||||
Get the logger for this session.
|
||||
Get the logger for this session.
|
||||
Falls back to returning Logging::LoggingManager::DefaultLogger if SetLogger has not been called.
|
||||
*/
|
||||
const logging::Logger& Logger() const;
|
||||
|
|
@ -120,10 +134,11 @@ class SessionState {
|
|||
/**
|
||||
Get cached memory pattern based on input shapes
|
||||
*/
|
||||
const MemoryPatternGroup* GetMemoryPatternGroup(const std::vector<std::reference_wrapper<const TensorShape>>& input_shapes) const;
|
||||
const MemoryPatternGroup* GetMemoryPatternGroup(
|
||||
const std::vector<std::reference_wrapper<const TensorShape>>& input_shapes) const;
|
||||
|
||||
/**
|
||||
Set generated memory pattern with a given input shapes.
|
||||
Set generated memory pattern with a given input shapes.
|
||||
Const as it's an internal cache update only.
|
||||
*/
|
||||
Status UpdateMemoryPatternGroupCache(const std::vector<std::reference_wrapper<const TensorShape>>& input_shape,
|
||||
|
|
@ -142,10 +157,7 @@ class SessionState {
|
|||
* \param kci0 Nullable
|
||||
*/
|
||||
NodeInfo(size_t index0, const onnxruntime::Node* p_node0, const KernelCreateInfo* kci0, const OrtDevice& device0)
|
||||
: index(index0),
|
||||
p_node(p_node0),
|
||||
kci(kci0),
|
||||
device(&device0) {}
|
||||
: index(index0), p_node(p_node0), kci(kci0), device(&device0) {}
|
||||
|
||||
size_t index;
|
||||
// Nullable
|
||||
|
|
@ -187,7 +199,6 @@ class SessionState {
|
|||
void SetDataTransferMgr(const DataTransferManager* data_transfer_mgr) { data_transfer_mgr_ = data_transfer_mgr; }
|
||||
|
||||
std::vector<BufferUniquePtr>& GetMutableWeightsBuffers() { return weights_buffers_; }
|
||||
void CalculateNodeIndexInfo();
|
||||
const NodeIndexInfo& GetNodeIndexInfo() const;
|
||||
|
||||
private:
|
||||
|
|
@ -195,7 +206,7 @@ class SessionState {
|
|||
|
||||
// cache of the constructed kernels to avoid spending construction
|
||||
// time per executor
|
||||
std::unordered_map<NodeIndex, std::unique_ptr<OpKernel>> session_kernels_;
|
||||
std::vector<OpKernel*> session_kernels_;
|
||||
std::unique_ptr<GraphViewer> graph_viewer_;
|
||||
|
||||
const ExecutionProviders& execution_providers_; // owned by InferenceSession
|
||||
|
|
@ -231,7 +242,7 @@ class SessionState {
|
|||
std::unordered_map<onnxruntime::NodeIndex, std::unordered_map<std::string, std::unique_ptr<SessionState>>>;
|
||||
SubgraphSessionStateMap subgraph_session_states_;
|
||||
|
||||
//It could be NULL
|
||||
// It could be NULL
|
||||
concurrency::ThreadPool* const thread_pool_;
|
||||
|
||||
bool export_fused_dll_ = false;
|
||||
|
|
|
|||
|
|
@ -27,9 +27,6 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
static common::Status SaveMLValueNameIndexMapping(const GraphViewer& graph_viewer,
|
||||
OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
const logging::Logger& logger);
|
||||
|
||||
// T should have signature of '(int idx, const OrtValue& value, const OrtCallback& d) -> Status'
|
||||
template <typename T>
|
||||
|
|
@ -40,11 +37,6 @@ static common::Status SaveInitializedTensors(const Env& env, const std::basic_st
|
|||
const logging::Logger& logger,
|
||||
const DataTransferManager& data_transfer_mgr);
|
||||
|
||||
static common::Status SaveKernels(const ExecutionProviders& execution_providers,
|
||||
SessionState& session_state,
|
||||
const KernelRegistryManager& custom_registry_manager,
|
||||
const logging::Logger& logger);
|
||||
|
||||
static common::Status SaveInputOutputNamesToNodeMapping(
|
||||
const onnxruntime::Graph& graph,
|
||||
const KernelRegistryManager& custom_registry_manager,
|
||||
|
|
@ -68,11 +60,11 @@ common::Status SessionStateInitializer::CreatePlan(
|
|||
const Node* parent_node,
|
||||
const ConstPointerContainer<std::vector<NodeArg*>>* outer_scope_node_args,
|
||||
bool enable_sequential_execution) {
|
||||
auto graph_viewer = std::make_unique<onnxruntime::GraphViewer>(graph_);
|
||||
session_state_.SetGraph(graph_);
|
||||
const GraphViewer* graph_viewer = session_state_.GetGraphViewer();
|
||||
|
||||
// populate the SessionState OrtValueNameIdxMap
|
||||
auto& ort_value_name_idx_map = session_state_.GetOrtValueNameIdxMap();
|
||||
ORT_RETURN_IF_ERROR(SaveMLValueNameIndexMapping(*graph_viewer, ort_value_name_idx_map, logger_));
|
||||
const auto& ort_value_name_idx_map = session_state_.GetOrtValueNameIdxMap();
|
||||
|
||||
// ignore any outer scope args we don't know about. this can happen if a node contains multiple subgraphs.
|
||||
std::vector<const NodeArg*> valid_outer_scope_node_args;
|
||||
|
|
@ -92,17 +84,10 @@ common::Status SessionStateInitializer::CreatePlan(
|
|||
execution_providers_, kernel_registry_manager_,
|
||||
ort_value_name_idx_map, context, exec_plan));
|
||||
session_state_.SetExecutionPlan(std::move(exec_plan));
|
||||
session_state_.SetGraphViewer(std::move(graph_viewer));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
common::Status SessionStateInitializer::InitializeAndSave(
|
||||
const ConstPointerContainer<std::vector<NodeArg*>>* implicit_inputs) {
|
||||
const auto* exec_plan_ptr = session_state_.GetExecutionPlan();
|
||||
ORT_ENFORCE(exec_plan_ptr, "Execution plan was not found in SessionState. CreatePlan must be called first.");
|
||||
|
||||
const auto& ort_value_name_idx_map{session_state_.GetOrtValueNameIdxMap()};
|
||||
std::unique_ptr<ITensorAllocator> tensor_allocator_(ITensorAllocator::Create(
|
||||
enable_mem_pattern_, *exec_plan_ptr, execution_providers_, session_state_.GetMutableWeightsBuffers()));
|
||||
|
||||
|
|
@ -119,64 +104,12 @@ common::Status SessionStateInitializer::InitializeAndSave(
|
|||
// TODO: make it better
|
||||
graph_.CleanAllInitializedTensors();
|
||||
|
||||
ORT_RETURN_IF_ERROR(SaveKernels(execution_providers_, session_state_, kernel_registry_manager_, logger_));
|
||||
ORT_RETURN_IF_ERROR(SaveInputOutputNamesToNodeMapping(graph_, kernel_registry_manager_, session_state_,
|
||||
implicit_inputs));
|
||||
|
||||
ORT_RETURN_IF_ERROR(session_state_.CreateKernels(kernel_registry_manager_));
|
||||
ORT_RETURN_IF_ERROR(
|
||||
SaveInputOutputNamesToNodeMapping(graph_, kernel_registry_manager_, session_state_, outer_scope_node_args));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Build the OrtValue name->idx mapping
|
||||
common::Status SaveMLValueNameIndexMapping(const GraphViewer& graph_viewer, OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
const logging::Logger& logger) {
|
||||
LOGS(logger, INFO) << "SaveMLValueNameIndexMapping";
|
||||
int idx = 0;
|
||||
|
||||
// we keep all graph inputs (including initializers), even if they are unused, so make sure they all have an entry
|
||||
for (const auto* input_def : graph_viewer.GetInputsIncludingInitializers()) {
|
||||
idx = ort_value_name_idx_map.Add(input_def->Name());
|
||||
VLOGS(logger, 1) << "Added graph_viewer input with name: " << input_def->Name()
|
||||
<< " to OrtValueIndex with index: " << idx;
|
||||
}
|
||||
|
||||
for (auto& node : graph_viewer.Nodes()) {
|
||||
// build the OrtValue->index map
|
||||
for (const auto* input_def : node.InputDefs()) {
|
||||
if (input_def->Exists()) {
|
||||
idx = ort_value_name_idx_map.Add(input_def->Name());
|
||||
VLOGS(logger, 1) << "Added input argument with name: " << input_def->Name()
|
||||
<< " to OrtValueIndex with index: " << idx;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto* input_def : node.ImplicitInputDefs()) {
|
||||
if (input_def->Exists()) {
|
||||
idx = ort_value_name_idx_map.Add(input_def->Name());
|
||||
VLOGS(logger, 1) << "Added implicit input argument with name: " << input_def->Name()
|
||||
<< " to OrtValueIndex with index: " << idx;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto* output_def : node.OutputDefs()) {
|
||||
if (output_def->Exists()) {
|
||||
ort_value_name_idx_map.Add(output_def->Name());
|
||||
VLOGS(logger, 1) << "Added output argument with name: " << output_def->Name()
|
||||
<< " to OrtValueIndex with index: " << idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// allocate OrtValue for graph outputs when coming from initializers
|
||||
for (const auto& output : graph_viewer.GetOutputs()) {
|
||||
if (output->Exists()) {
|
||||
idx = ort_value_name_idx_map.Add(output->Name());
|
||||
VLOGS(logger, 1) << "Added graph output with name: " << output->Name() << " to OrtValueIndex with index: " << idx;
|
||||
}
|
||||
}
|
||||
|
||||
LOGS(logger, INFO) << "Done saving OrtValue mappings.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static common::Status DeserializeTensorProto(const Env& env, const std::basic_string<PATH_CHAR_TYPE>& proto_path,
|
||||
const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer& m,
|
||||
|
|
@ -292,46 +225,6 @@ common::Status SaveInitializedTensors(const Env& env, const std::basic_string<PA
|
|||
return common::Status::OK();
|
||||
}
|
||||
|
||||
static common::Status CreateOpKernel(const onnxruntime::Node& node, const ExecutionProviders& execution_providers,
|
||||
const SessionState& session_state,
|
||||
const KernelRegistryManager& custom_registry_manager,
|
||||
std::unique_ptr<OpKernel>& op_kernel) {
|
||||
onnxruntime::ProviderType exec_provider_name = node.GetExecutionProviderType();
|
||||
|
||||
const IExecutionProvider* exec_provider = nullptr;
|
||||
if (exec_provider_name.empty() || (exec_provider = execution_providers.Get(exec_provider_name)) == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Could not create kernel for node: ", node.Name(),
|
||||
" as there's no execution provider allocated.");
|
||||
}
|
||||
|
||||
common::Status status = custom_registry_manager.CreateKernel(node, *exec_provider, session_state, op_kernel);
|
||||
if (!status.IsOK()) {
|
||||
return common::Status(
|
||||
status.Category(), status.Code(),
|
||||
MakeString("Kernel creation failed for node: ", node.Name(), " with error: ", status.ErrorMessage()));
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
common::Status SaveKernels(const ExecutionProviders& execution_providers,
|
||||
SessionState& session_state,
|
||||
const KernelRegistryManager& custom_registry_manager,
|
||||
const logging::Logger& logger) {
|
||||
LOGS(logger, INFO) << "Saving kernels.";
|
||||
|
||||
for (auto& node : session_state.GetGraphViewer()->Nodes()) {
|
||||
// construct and save the kernels
|
||||
std::unique_ptr<OpKernel> op_kernel;
|
||||
ORT_RETURN_IF_ERROR(CreateOpKernel(node, execution_providers, session_state, custom_registry_manager, op_kernel));
|
||||
session_state.AddKernel(node.Index(), std::move(op_kernel));
|
||||
}
|
||||
|
||||
LOGS(logger, INFO) << "Done saving kernels.";
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T> // T is container of const NodeArg* or NodeArg*
|
||||
static bool IsArgNameInInputsOutputs(const std::string& name,
|
||||
const T& graph_args) {
|
||||
|
|
|
|||
|
|
@ -36,14 +36,11 @@ class SessionStateInitializer {
|
|||
KernelRegistryManager& kernel_registry_manager);
|
||||
|
||||
// First perform any transformations and create the execution plan
|
||||
common::Status CreatePlan(const Node* parent_node,
|
||||
const ConstPointerContainer<std::vector<NodeArg*>>* outer_scope_node_args,
|
||||
// Then initialize tensors, and save. save kernels and input/output node mappings
|
||||
common::Status CreatePlan(_In_opt_ const Node* parent_node,
|
||||
_In_opt_ const ConstPointerContainer<std::vector<NodeArg*>>* outer_scope_node_args,
|
||||
bool enable_sequential_execution);
|
||||
|
||||
// initialize tensors, and save. save kernels and input/output node mappings
|
||||
// \param implicit_inputs could be NULL
|
||||
common::Status InitializeAndSave(const ConstPointerContainer<std::vector<NodeArg*>>* implicit_inputs);
|
||||
|
||||
private:
|
||||
const std::basic_string<PATH_CHAR_TYPE>& graph_loc_;
|
||||
onnxruntime::Graph& graph_;
|
||||
|
|
|
|||
|
|
@ -435,7 +435,6 @@ common::Status InferenceSession::InitializeSubgraphSessions(Graph& graph, Sessio
|
|||
ORT_RETURN_IF_ERROR(initializer.CreatePlan(&node, &implicit_inputs,
|
||||
session_options_.enable_sequential_execution));
|
||||
|
||||
ORT_RETURN_IF_ERROR(initializer.InitializeAndSave(&implicit_inputs));
|
||||
|
||||
// LOGS(*session_logger_, VERBOSE) << std::make_pair(subgraph_info.session_state->GetExecutionPlan(),
|
||||
// &*subgraph_info.session_state);
|
||||
|
|
@ -533,13 +532,9 @@ common::Status InferenceSession::Initialize() {
|
|||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(session_initializer.CreatePlan(nullptr, nullptr, session_options_.enable_sequential_execution));
|
||||
ORT_RETURN_IF_ERROR(session_initializer.InitializeAndSave(nullptr));
|
||||
|
||||
// handle any subgraphs
|
||||
ORT_RETURN_IF_ERROR(InitializeSubgraphSessions(graph, session_state_));
|
||||
|
||||
session_state_.CalculateNodeIndexInfo();
|
||||
|
||||
is_inited_ = true;
|
||||
|
||||
LOGS(*session_logger_, INFO) << "Session successfully initialized.";
|
||||
|
|
|
|||
|
|
@ -34,8 +34,8 @@ struct UnaryNode {
|
|||
std::vector<onnxruntime::NodeArg*> output_args;
|
||||
onnxruntime::Node* p_node;
|
||||
|
||||
UnaryNode(onnxruntime::Graph& graph, const std::string& op,
|
||||
onnxruntime::NodeArg* p_input_arg, onnxruntime::NodeArg* p_output_arg)
|
||||
UnaryNode(onnxruntime::Graph& graph, const std::string& op, onnxruntime::NodeArg* p_input_arg,
|
||||
onnxruntime::NodeArg* p_output_arg)
|
||||
: input_args({p_input_arg}), output_args({p_output_arg}) {
|
||||
int num = NodeCounter::Next();
|
||||
p_node = &graph.AddNode("node" + std::to_string(num), op, "test op", input_args, output_args);
|
||||
|
|
@ -161,9 +161,11 @@ class PlannerTest : public ::testing::Test {
|
|||
std::unique_ptr<SequentialExecutionPlan> plan_;
|
||||
|
||||
public:
|
||||
PlannerTest() : model_("test"), graph_(model_.MainGraph()), tp_("test", 1), state_(execution_providers_, false, &tp_) {
|
||||
std_kernel_ = KernelDefBuilder().SetName("Transpose").Build();
|
||||
in_place_kernel_ = KernelDefBuilder().SetName("Relu").MayInplace(0, 0).Build();
|
||||
PlannerTest()
|
||||
: model_("test"), graph_(model_.MainGraph()), tp_("test", 1), state_(execution_providers_, false, &tp_) {
|
||||
std_kernel_ = KernelDefBuilder().SetName("Transpose").Provider(kCpuExecutionProvider).SinceVersion(1, 10).Build();
|
||||
in_place_kernel_ =
|
||||
KernelDefBuilder().SetName("Relu").Provider(kCpuExecutionProvider).SinceVersion(1, 10).MayInplace(0, 0).Build();
|
||||
CPUExecutionProviderInfo epi;
|
||||
auto execution_provider = std::make_unique<CPUExecutionProvider>(epi);
|
||||
execution_providers_.Add("CPUExecutionProvider", std::move(execution_provider));
|
||||
|
|
@ -194,18 +196,20 @@ class PlannerTest : public ::testing::Test {
|
|||
return AddNode(*in_place_kernel_, input, output);
|
||||
}
|
||||
|
||||
void BindKernel(onnxruntime::Node* p_node, ::onnxruntime::KernelDef& kernel_def) {
|
||||
void BindKernel(onnxruntime::Node* p_node, ::onnxruntime::KernelDef& kernel_def, KernelRegistry* reg) {
|
||||
auto info = std::make_unique<OpKernelInfo>(*p_node, kernel_def, *execution_providers_.Get(*p_node),
|
||||
state_.GetInitializedTensors(), state_.GetOrtValueNameIdxMap(),
|
||||
state_.GetFuncMgr(), state_.GetDataTransferMgr());
|
||||
auto dummy = std::make_unique<DummyOpKernel>(*info);
|
||||
op_kernel_infos_.push_back(std::move(info));
|
||||
state_.AddKernel(p_node->Index(), std::move(dummy));
|
||||
if (reg->TryFindKernel(*p_node, onnxruntime::kCpuExecutionProvider) == nullptr) {
|
||||
auto st = reg->Register(
|
||||
KernelCreateInfo(std::make_unique<KernelDef>(kernel_def),
|
||||
[](const OpKernelInfo& info) -> OpKernel* { return new DummyOpKernel(info); }));
|
||||
ORT_ENFORCE(st.IsOK(), st.ErrorMessage());
|
||||
}
|
||||
}
|
||||
|
||||
void SetShape(std::string& name, TensorShapeProto* shape) {
|
||||
shape_map_[Arg(name)] = shape;
|
||||
}
|
||||
void SetShape(std::string& name, TensorShapeProto* shape) { shape_map_[Arg(name)] = shape; }
|
||||
|
||||
void SetShape(std::initializer_list<std::pair<std::string&, TensorShapeProto*>> shapes) {
|
||||
for (auto& pair : shapes) {
|
||||
|
|
@ -215,29 +219,27 @@ class PlannerTest : public ::testing::Test {
|
|||
|
||||
void CreatePlan(const std::vector<const NodeArg*>& outer_scope_node_args = {}) {
|
||||
EXPECT_EQ(graph_.Resolve(), Status::OK());
|
||||
state_.SetGraphViewer(std::make_unique<GraphViewer>(graph_));
|
||||
|
||||
OrtValueNameIdxMap& mlvalue_name_idx_map{state_.GetOrtValueNameIdxMap()};
|
||||
state_.SetGraph(graph_);
|
||||
|
||||
int count = 0;
|
||||
for (auto& pair : name_to_arg_) {
|
||||
EXPECT_EQ(mlvalue_name_idx_map.Add(pair.first), count++);
|
||||
}
|
||||
std::shared_ptr<KernelRegistry> reg = std::make_shared<KernelRegistry>();
|
||||
|
||||
for (auto& binding : kernel_bindings_) {
|
||||
BindKernel(binding.first, binding.second);
|
||||
BindKernel(binding.first, binding.second, reg.get());
|
||||
}
|
||||
|
||||
auto cpu_execution_provider = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
KernelRegistryManager kernel_registry_manager;
|
||||
kernel_registry_manager.RegisterKernelRegistry(reg);
|
||||
ExecutionProviders execution_providers;
|
||||
execution_providers.Add(onnxruntime::kCpuExecutionProvider, std::move(cpu_execution_provider));
|
||||
auto status = kernel_registry_manager.RegisterKernels(execution_providers);
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
||||
status = state_.CreateKernels(kernel_registry_manager);
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
SequentialPlannerTestContext test_context(&shape_map_);
|
||||
status = SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph_), outer_scope_node_args, execution_providers,
|
||||
kernel_registry_manager, mlvalue_name_idx_map, test_context, plan_);
|
||||
kernel_registry_manager, state_.GetOrtValueNameIdxMap(), test_context, plan_);
|
||||
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
AllocationPlanTestUtility::BasicIntegrityCheck(*plan_, name_to_arg_.size());
|
||||
|
|
|
|||
|
|
@ -48,9 +48,8 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) {
|
|||
tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
onnxruntime::NodeArg input_def("X", &tensor_float), output_def("Y", &tensor_float);
|
||||
|
||||
graph.AddNode("node1", "Relu", "Relu operator", ArgMap{&input_def}, ArgMap{&output_def});
|
||||
onnxruntime::Node* node = graph.GetNode(graph.NumberOfNodes() - 1);
|
||||
|
||||
onnxruntime::Node* node = &graph.AddNode("node1", "Relu", "Relu operator", ArgMap{&input_def}, ArgMap{&output_def});
|
||||
node->SetExecutionProviderType(kCpuExecutionProvider);
|
||||
Status status = graph.Resolve();
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
||||
|
|
@ -63,11 +62,8 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) {
|
|||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
||||
SessionState state{execution_providers, true, &tp_};
|
||||
state.SetGraphViewer(std::make_unique<GraphViewer>(graph));
|
||||
|
||||
OrtValueNameIdxMap& mlvalue_name_idx_map{state.GetOrtValueNameIdxMap()};
|
||||
mlvalue_name_idx_map.Add("X");
|
||||
mlvalue_name_idx_map.Add("Y");
|
||||
status = state.SetGraphAndCreateKernels(graph, kernel_registry_manager);
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
||||
node->SetExecutionProviderType(xp_typ);
|
||||
|
||||
|
|
@ -75,12 +71,10 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) {
|
|||
// TODO below line is for testing only. In production use SequentialPlanner::CreatePlan()
|
||||
SequentialPlannerContext context(false);
|
||||
status = SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph), {}, execution_providers, kernel_registry_manager,
|
||||
mlvalue_name_idx_map, context, p_seq_exec_plan);
|
||||
state.GetOrtValueNameIdxMap(), context, p_seq_exec_plan);
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
state.SetExecutionPlan(std::move(p_seq_exec_plan));
|
||||
|
||||
state.CalculateNodeIndexInfo();
|
||||
|
||||
vector<OrtValue> outputs;
|
||||
ExecutionFrame frame({}, {}, {}, outputs, {}, state);
|
||||
|
||||
|
|
@ -117,21 +111,22 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) {
|
|||
}
|
||||
|
||||
TEST_F(ExecutionFrameTest, FeedInDataTest) {
|
||||
onnxruntime::Model model("test");
|
||||
onnxruntime::Model model("test", false, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(),
|
||||
std::unordered_map<std::string, int>{{"", 10}});
|
||||
onnxruntime::Graph& graph = model.MainGraph();
|
||||
TypeProto tensor_float;
|
||||
tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
onnxruntime::NodeArg input_def("X", &tensor_float), output_def("Y", &tensor_float);
|
||||
|
||||
graph.AddNode("node1", "Clip", "Clip operator", ArgMap{&input_def}, ArgMap{&output_def});
|
||||
graph.AddNode("node1", "Clip", "Clip operator", ArgMap{&input_def}, ArgMap{&output_def})
|
||||
.SetExecutionProviderType(kCpuExecutionProvider);
|
||||
graph.Resolve();
|
||||
auto cpu_allocator = TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault);
|
||||
auto element_type = DataTypeImpl::GetType<float>();
|
||||
TensorShape shape({3, 2});
|
||||
std::vector<float> fdata(static_cast<size_t>(shape.Size()));
|
||||
//create fake ml value with owned buffer.
|
||||
std::unique_ptr<Tensor> p_tensor = std::make_unique<Tensor>(element_type,
|
||||
shape,
|
||||
cpu_allocator);
|
||||
OrtAllocatorInfo cpuinfo(kCpuExecutionProvider, OrtDeviceAllocator);
|
||||
std::unique_ptr<Tensor> p_tensor = std::make_unique<Tensor>(element_type, shape, fdata.data(), cpuinfo);
|
||||
OrtValue value;
|
||||
value.Init(p_tensor.release(),
|
||||
DataTypeImpl::GetType<Tensor>(),
|
||||
|
|
@ -144,15 +139,14 @@ TEST_F(ExecutionFrameTest, FeedInDataTest) {
|
|||
ExecutionProviders execution_providers;
|
||||
execution_providers.Add(xp_typ, std::move(cpu_xp));
|
||||
EXPECT_TRUE(kernel_registry_manager.RegisterKernels(execution_providers).IsOK());
|
||||
|
||||
SessionState state{execution_providers, true, &tp_};
|
||||
state.SetGraphViewer(std::make_unique<GraphViewer>(graph));
|
||||
auto status = state.SetGraphAndCreateKernels(graph, kernel_registry_manager);
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
||||
OrtValueNameIdxMap& mlvalue_name_idx_map{state.GetOrtValueNameIdxMap()};
|
||||
auto x_idx = mlvalue_name_idx_map.Add("X");
|
||||
auto y_idx = mlvalue_name_idx_map.Add("Y");
|
||||
|
||||
state.CalculateNodeIndexInfo();
|
||||
const OrtValueNameIdxMap& mlvalue_name_idx_map = state.GetOrtValueNameIdxMap();
|
||||
int x_idx, y_idx;
|
||||
ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("X", x_idx).IsOK());
|
||||
ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("Y", y_idx).IsOK());
|
||||
|
||||
vector<OrtValue> outputs;
|
||||
ExecutionFrame frame({x_idx}, {value}, {y_idx}, outputs, {}, state);
|
||||
|
|
@ -198,16 +192,20 @@ TEST_F(ExecutionFrameTest, MemPatternTest) {
|
|||
kernel_registry_manager.RegisterKernels(execution_providers);
|
||||
//1. prepare input
|
||||
SessionState state{execution_providers, true, &tp_};
|
||||
state.SetGraphViewer(std::make_unique<GraphViewer>(graph));
|
||||
status = state.SetGraphAndCreateKernels(graph, kernel_registry_manager);
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
||||
OrtValueNameIdxMap& mlvalue_name_idx_map{state.GetOrtValueNameIdxMap()};
|
||||
const OrtValueNameIdxMap& mlvalue_name_idx_map{state.GetOrtValueNameIdxMap()};
|
||||
|
||||
auto x1_idx = mlvalue_name_idx_map.Add("X1");
|
||||
auto x2_idx = mlvalue_name_idx_map.Add("X2");
|
||||
auto x3_idx = mlvalue_name_idx_map.Add("X3");
|
||||
mlvalue_name_idx_map.Add("T1");
|
||||
mlvalue_name_idx_map.Add("T2");
|
||||
auto t3_idx = mlvalue_name_idx_map.Add("T3");
|
||||
int x1_idx, x2_idx, x3_idx;
|
||||
int t1_idx, t2_idx, t3_idx;
|
||||
ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("X1", x1_idx).IsOK());
|
||||
ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("X2", x2_idx).IsOK());
|
||||
ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("X3", x3_idx).IsOK());
|
||||
|
||||
ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("T1", t1_idx).IsOK());
|
||||
ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("T2", t2_idx).IsOK());
|
||||
ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("T3", t3_idx).IsOK());
|
||||
|
||||
auto cpu_allocator = execution_providers.Get(xp_type)->GetAllocator(0, OrtMemTypeDefault);
|
||||
|
||||
|
|
@ -230,8 +228,6 @@ TEST_F(ExecutionFrameTest, MemPatternTest) {
|
|||
|
||||
state.SetExecutionPlan(std::move(p_seq_exec_plan));
|
||||
|
||||
state.CalculateNodeIndexInfo();
|
||||
|
||||
vector<OrtValue> outputs;
|
||||
ExecutionFrame frame({x1_idx, x2_idx, x3_idx}, {v1, v2, v3}, {t3_idx}, outputs, {}, state);
|
||||
|
||||
|
|
|
|||
|
|
@ -53,19 +53,28 @@ TEST(SessionStateTest, AddGetKernelTest) {
|
|||
outputs.push_back(&output_arg);
|
||||
onnxruntime::Node& node = graph.AddNode("node_1", "Variable", "node 1.", inputs, outputs);
|
||||
auto status = graph.Resolve();
|
||||
EXPECT_TRUE(status.IsOK());
|
||||
KernelDef kernel_def;
|
||||
CPUExecutionProvider execution_provider{CPUExecutionProviderInfo{"CPUExecutionProvider"}};
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
auto kernel_def = KernelDefBuilder().SetName("Variable").Provider(kCpuExecutionProvider).SinceVersion(1, 10).Build();
|
||||
auto cpu_execution_provider = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo(false));
|
||||
|
||||
OpKernelInfo p_info(node, kernel_def, execution_provider, s.GetConstantInitializedTensors(),
|
||||
OpKernelInfo p_info(node, *kernel_def, *cpu_execution_provider.get(), s.GetConstantInitializedTensors(),
|
||||
s.GetOrtValueNameIdxMap(), s.GetFuncMgr(), s.GetDataTransferMgr());
|
||||
unique_ptr<TestOpKernel> p_kernel;
|
||||
p_kernel.reset(new TestOpKernel(p_info));
|
||||
size_t orig_num_outputs = p_kernel->Node().OutputDefs().size();
|
||||
std::cout << "node_idx: " << node.Index() << std::endl;
|
||||
|
||||
s.SetGraphViewer(std::make_unique<GraphViewer>(graph));
|
||||
s.AddKernel(node.Index(), std::move(p_kernel));
|
||||
execution_providers.Add(kCpuExecutionProvider, std::move(cpu_execution_provider));
|
||||
KernelRegistryManager kernel_registry_manager;
|
||||
status = kernel_registry_manager.RegisterKernels(execution_providers);
|
||||
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
node.SetExecutionProviderType(kCpuExecutionProvider);
|
||||
std::shared_ptr<KernelRegistry> kernel_registry = std::make_shared<KernelRegistry>();
|
||||
kernel_registry->Register(KernelCreateInfo(
|
||||
std::move(kernel_def), [](const OpKernelInfo& info) -> OpKernel* { return new TestOpKernel(info); }));
|
||||
kernel_registry_manager.RegisterKernelRegistry(kernel_registry);
|
||||
status = s.SetGraphAndCreateKernels(graph, kernel_registry_manager);
|
||||
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
auto test_kernel = s.GetKernel(node.Index());
|
||||
std::cout << "orig: " << orig_num_outputs << " new: " << test_kernel->Node().OutputDefs().size() << std::endl;
|
||||
EXPECT_EQ(orig_num_outputs, test_kernel->Node().OutputDefs().size());
|
||||
|
|
@ -79,8 +88,7 @@ class TestParam {
|
|||
};
|
||||
TestParam param_list[] = {{3, true}, {4, true}, {3, false}, {4, false}};
|
||||
} // namespace
|
||||
class SessionStateTestP : public testing::TestWithParam<TestParam> {
|
||||
};
|
||||
class SessionStateTestP : public testing::TestWithParam<TestParam> {};
|
||||
// Test that we separate out constant and non-constant initializers correctly
|
||||
TEST_P(SessionStateTestP, TestInitializerProcessing) {
|
||||
const TestParam& param = GetParam();
|
||||
|
|
@ -104,8 +112,8 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) {
|
|||
ASSERT_TRUE(status.IsOK()) << status;
|
||||
|
||||
SessionState session_state(execution_providers, param.enable_mem_pattern, &tp);
|
||||
SessionStateInitializer session_initializer(param.enable_mem_pattern, ToWideString(model_path), graph,
|
||||
session_state, execution_providers, krm);
|
||||
SessionStateInitializer session_initializer(param.enable_mem_pattern, ToWideString(model_path), graph, session_state,
|
||||
execution_providers, krm);
|
||||
|
||||
GraphPartitioner partitioner(krm, execution_providers);
|
||||
status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr());
|
||||
|
|
@ -114,9 +122,6 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) {
|
|||
status = session_initializer.CreatePlan(nullptr, nullptr, true);
|
||||
ASSERT_TRUE(status.IsOK()) << status;
|
||||
|
||||
status = session_initializer.InitializeAndSave(nullptr);
|
||||
ASSERT_TRUE(status.IsOK()) << status;
|
||||
|
||||
const auto& initialized_tensors = session_state.GetInitializedTensors();
|
||||
const auto& const_initialized_tensors = session_state.GetConstantInitializedTensors();
|
||||
|
||||
|
|
@ -144,7 +149,6 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) {
|
|||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(SessionStateTests, SessionStateTestP,
|
||||
testing::ValuesIn(param_list));
|
||||
INSTANTIATE_TEST_CASE_P(SessionStateTests, SessionStateTestP, testing::ValuesIn(param_list));
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ TEST(CApiTest, load_float_tensor_with_external_data) {
|
|||
}
|
||||
|
||||
#if defined(__amd64__) || defined(_M_X64)
|
||||
|
||||
#ifdef NDEBUG
|
||||
TEST(CApiTest, load_huge_tensor_with_external_data) {
|
||||
FILE* fp;
|
||||
std::basic_string<ORTCHAR_T> filename(ORT_TSTR("tensor_XXXXXX"));
|
||||
|
|
@ -183,5 +183,6 @@ TEST(CApiTest, load_huge_tensor_with_external_data) {
|
|||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,225 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <benchmark/benchmark.h>
|
||||
#include <core/graph/model.h>
|
||||
#include <core/framework/execution_providers.h>
|
||||
#include <core/framework/kernel_registry_manager.h>
|
||||
#include <core/framework/session_state.h>
|
||||
#include <core/framework/graph_partitioner.h>
|
||||
#include <core/providers/cpu/cpu_execution_provider.h>
|
||||
#ifdef USE_CUDA
|
||||
#include <core/providers/cuda/cuda_execution_provider.h>
|
||||
#endif
|
||||
#ifdef USE_MKLDNN
|
||||
#include <core/providers/mkldnn/mkldnn_execution_provider.h>
|
||||
#endif
|
||||
#include <core/platform/env.h>
|
||||
#include <core/graph/onnx_protobuf.h>
|
||||
#include <google/protobuf/io/zero_copy_stream_impl.h>
|
||||
#include <google/protobuf/text_format.h>
|
||||
using namespace google::protobuf::io;
|
||||
|
||||
constexpr const char* model_str =
|
||||
"ir_version: 4\n"
|
||||
"graph {\n"
|
||||
" node {\n"
|
||||
" input: \"X\"\n"
|
||||
" input: \"X\"\n"
|
||||
" output: \"Y\"\n"
|
||||
" op_type: \"MatMul\"\n"
|
||||
" }\n"
|
||||
" name: \"test-model\"\n"
|
||||
" input {\n"
|
||||
" name: \"X\"\n"
|
||||
" type {\n"
|
||||
" tensor_type {\n"
|
||||
" elem_type: 1\n"
|
||||
" shape {\n"
|
||||
" dim {\n"
|
||||
" dim_value: 2\n"
|
||||
" }\n"
|
||||
" dim {\n"
|
||||
" dim_value: 2\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
" output {\n"
|
||||
" name: \"Y\"\n"
|
||||
" type {\n"
|
||||
" tensor_type {\n"
|
||||
" elem_type: 1\n"
|
||||
" shape {\n"
|
||||
" dim {\n"
|
||||
" dim_value: 2\n"
|
||||
" }\n"
|
||||
" dim {\n"
|
||||
" dim_value: 2\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
"opset_import {\n"
|
||||
" version: 8\n"
|
||||
"}";
|
||||
|
||||
using namespace onnxruntime;
|
||||
|
||||
#define BM_BREAK_IF_ERROR(expr) \
|
||||
do { \
|
||||
auto _status = (expr); \
|
||||
if ((!_status.IsOK())) state.SkipWithError(_status.ErrorMessage().c_str()); \
|
||||
} while (0)
|
||||
|
||||
Status CreateModelFromStr(const char* str, std::unique_ptr<Model>* out) {
|
||||
ONNX_NAMESPACE::ModelProto mp;
|
||||
if (!google::protobuf::TextFormat::ParseFromString(str, &mp)) throw std::runtime_error("load model failed");
|
||||
*out = std::make_unique<Model>(mp);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CreateExecutionProviders(std::unique_ptr<ExecutionProviders>* ret) {
|
||||
std::unique_ptr<ExecutionProviders> execution_providers = std::make_unique<ExecutionProviders>();
|
||||
#ifdef USE_CUDA
|
||||
{
|
||||
CUDAExecutionProviderInfo epi;
|
||||
ORT_RETURN_IF_ERROR(
|
||||
execution_providers->Add(onnxruntime::kCudaExecutionProvider, std::make_unique<CUDAExecutionProvider>(epi)));
|
||||
}
|
||||
#endif
|
||||
#ifdef USE_MKLDNN
|
||||
{
|
||||
MKLDNNExecutionProviderInfo epi;
|
||||
ORT_RETURN_IF_ERROR(execution_providers->Add(onnxruntime::kMklDnnExecutionProvider,
|
||||
std::make_unique<MKLDNNExecutionProvider>(epi)));
|
||||
}
|
||||
#endif
|
||||
{
|
||||
CPUExecutionProviderInfo epi;
|
||||
ORT_RETURN_IF_ERROR(
|
||||
execution_providers->Add(onnxruntime::kCpuExecutionProvider, std::make_unique<CPUExecutionProvider>(epi)));
|
||||
}
|
||||
*ret = std::move(execution_providers);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CreateKernelRegistryManagerFromModel(std::unique_ptr<KernelRegistryManager>* ret, Model* model, concurrency::ThreadPool& tp) {
|
||||
std::unique_ptr<ExecutionProviders> execution_providers;
|
||||
ORT_RETURN_IF_ERROR(CreateExecutionProviders(&execution_providers));
|
||||
std::unique_ptr<KernelRegistryManager> kernel_registry_manager = std::make_unique<KernelRegistryManager>();
|
||||
ORT_RETURN_IF_ERROR(kernel_registry_manager->RegisterKernels(*execution_providers));
|
||||
SessionState s{*execution_providers, true, &tp};
|
||||
s.SetLogger(logging::LoggingManager::DefaultLogger());
|
||||
|
||||
ORT_RETURN_IF_ERROR(model->MainGraph().Resolve());
|
||||
s.SetGraphViewer(std::make_unique<GraphViewer>(model->MainGraph()));
|
||||
GraphPartitioner partitioner(*kernel_registry_manager, *execution_providers);
|
||||
ORT_RETURN_IF_ERROR(partitioner.Partition(model->MainGraph(), s.ExportDll(), s.GetMutableFuncMgr()));
|
||||
*ret = std::move(kernel_registry_manager);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static void SearchKernelRegistry_IMPL(benchmark::State& state, Model* model) {
|
||||
std::unique_ptr<KernelRegistryManager> kernel_registry_manager;
|
||||
concurrency::ThreadPool tp{"test", 1};
|
||||
auto st = CreateKernelRegistryManagerFromModel(&kernel_registry_manager, model, tp);
|
||||
if (!st.IsOK()) throw std::runtime_error("failed");
|
||||
for (auto _ : state) {
|
||||
for (const auto& n : model->MainGraph().Nodes()) {
|
||||
const KernelCreateInfo* info;
|
||||
BM_BREAK_IF_ERROR(kernel_registry_manager->SearchKernelRegistry(n, &info));
|
||||
if (info == nullptr) state.SkipWithError("Search kernel failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void BM_SearchKernelRegistry_SingleNodeModel(benchmark::State& state) {
|
||||
std::unique_ptr<Model> model;
|
||||
Status st = CreateModelFromStr(model_str, &model);
|
||||
if (!st.IsOK()) throw std::runtime_error("failed");
|
||||
SearchKernelRegistry_IMPL(state, model.get());
|
||||
}
|
||||
|
||||
BENCHMARK(BM_SearchKernelRegistry_SingleNodeModel);
|
||||
|
||||
static void BM_SearchKernelRegistry_RealModel_tiny_yolo(benchmark::State& state) {
|
||||
std::shared_ptr<onnxruntime::Model> model;
|
||||
auto st = onnxruntime::Model::Load("../models/opset8/test_tiny_yolov2/model.onnx", model);
|
||||
SearchKernelRegistry_IMPL(state, model.get());
|
||||
}
|
||||
|
||||
BENCHMARK(BM_SearchKernelRegistry_RealModel_tiny_yolo);
|
||||
|
||||
static void BM_SearchKernelRegistry_RealModel_inception_v4(benchmark::State& state) {
|
||||
std::shared_ptr<onnxruntime::Model> model;
|
||||
auto st = onnxruntime::Model::Load("../models/opset9/tf_inception_v4/model.onnx", model);
|
||||
SearchKernelRegistry_IMPL(state, model.get());
|
||||
}
|
||||
|
||||
BENCHMARK(BM_SearchKernelRegistry_RealModel_inception_v4);
|
||||
|
||||
static void BM_PartitionModel_tiny_yolo(benchmark::State& state) {
|
||||
int fd;
|
||||
Status status = Env::Default().FileOpenRd("../models/opset8/test_tiny_yolov2/model.onnx", fd);
|
||||
if (!status.IsOK()) throw std::runtime_error("open test data failed");
|
||||
auto raw_input = std::unique_ptr<ZeroCopyInputStream>(std::make_unique<FileInputStream>(fd));
|
||||
auto coded_input = std::make_unique<CodedInputStream>(raw_input.get());
|
||||
|
||||
ONNX_NAMESPACE::ModelProto model_proto;
|
||||
if (!model_proto.ParseFromCodedStream(coded_input.get())) throw std::runtime_error("open test data failed");
|
||||
std::unique_ptr<ExecutionProviders> execution_providers;
|
||||
BM_BREAK_IF_ERROR(CreateExecutionProviders(&execution_providers));
|
||||
std::unique_ptr<KernelRegistryManager> kernel_registry_manager = std::make_unique<KernelRegistryManager>();
|
||||
status = kernel_registry_manager->RegisterKernels(*execution_providers);
|
||||
if (!status.IsOK()) throw std::runtime_error("RegisterKernels failed");
|
||||
concurrency::ThreadPool tp{"test", 1};
|
||||
|
||||
for (auto _ : state) {
|
||||
state.PauseTiming();
|
||||
std::shared_ptr<onnxruntime::Model> model = std::make_shared<onnxruntime::Model>(model_proto);
|
||||
SessionState s{*execution_providers, true, &tp};
|
||||
s.SetLogger(logging::LoggingManager::DefaultLogger());
|
||||
BM_BREAK_IF_ERROR(model->MainGraph().Resolve());
|
||||
s.SetGraphViewer(std::make_unique<GraphViewer>(model->MainGraph()));
|
||||
GraphPartitioner partitioner(*kernel_registry_manager, *execution_providers);
|
||||
state.ResumeTiming();
|
||||
BM_BREAK_IF_ERROR(partitioner.Partition(model->MainGraph(), s.ExportDll(), s.GetMutableFuncMgr()));
|
||||
}
|
||||
}
|
||||
|
||||
BENCHMARK(BM_PartitionModel_tiny_yolo);
|
||||
|
||||
static void BM_PartitionModel_inception_v4(benchmark::State& state) {
|
||||
int fd;
|
||||
Status status = Env::Default().FileOpenRd("../models/opset9/tf_inception_v4/model.onnx", fd);
|
||||
if (!status.IsOK()) throw std::runtime_error("open test data failed");
|
||||
auto raw_input = std::unique_ptr<ZeroCopyInputStream>(std::make_unique<FileInputStream>(fd));
|
||||
auto coded_input = std::make_unique<CodedInputStream>(raw_input.get());
|
||||
|
||||
ONNX_NAMESPACE::ModelProto model_proto;
|
||||
if (!model_proto.ParseFromCodedStream(coded_input.get())) throw std::runtime_error("open test data failed");
|
||||
std::unique_ptr<ExecutionProviders> execution_providers;
|
||||
BM_BREAK_IF_ERROR(CreateExecutionProviders(&execution_providers));
|
||||
std::unique_ptr<KernelRegistryManager> kernel_registry_manager = std::make_unique<KernelRegistryManager>();
|
||||
status = kernel_registry_manager->RegisterKernels(*execution_providers);
|
||||
if (!status.IsOK()) throw std::runtime_error("RegisterKernels failed");
|
||||
concurrency::ThreadPool tp{"test", 1};
|
||||
|
||||
for (auto _ : state) {
|
||||
state.PauseTiming();
|
||||
std::shared_ptr<onnxruntime::Model> model = std::make_shared<onnxruntime::Model>(model_proto);
|
||||
SessionState s{*execution_providers, true, &tp};
|
||||
s.SetLogger(logging::LoggingManager::DefaultLogger());
|
||||
BM_BREAK_IF_ERROR(model->MainGraph().Resolve());
|
||||
s.SetGraphViewer(std::make_unique<GraphViewer>(model->MainGraph()));
|
||||
GraphPartitioner partitioner(*kernel_registry_manager, *execution_providers);
|
||||
state.ResumeTiming();
|
||||
BM_BREAK_IF_ERROR(partitioner.Partition(model->MainGraph(), s.ExportDll(), s.GetMutableFuncMgr()));
|
||||
}
|
||||
}
|
||||
|
||||
BENCHMARK(BM_PartitionModel_inception_v4);
|
||||
|
|
@ -42,14 +42,12 @@ TEST(MemcpyTest, copy1) {
|
|||
Model model(mp);
|
||||
st = model.MainGraph().Resolve();
|
||||
ASSERT_TRUE(st.IsOK()) << st.ErrorMessage();
|
||||
s.SetGraphViewer(std::make_unique<GraphViewer>(model.MainGraph()));
|
||||
PutAllNodesOnOneProvider(model.MainGraph(), onnxruntime::kCpuExecutionProvider);
|
||||
SessionStateInitializer session_initializer{true, ORT_TSTR(""), model.MainGraph(),
|
||||
s, execution_providers, kernel_registry_manager};
|
||||
st = session_initializer.CreatePlan(nullptr, {}, true);
|
||||
ASSERT_TRUE(st.IsOK()) << st.ErrorMessage();
|
||||
st = session_initializer.InitializeAndSave(nullptr);
|
||||
ASSERT_TRUE(st.IsOK()) << st.ErrorMessage();
|
||||
|
||||
AllocatorPtr allocator =
|
||||
execution_providers.Get(onnxruntime::kCpuExecutionProvider)->GetAllocator(0, OrtMemTypeDefault);
|
||||
auto* data_type = DataTypeImpl::GetType<float>();
|
||||
|
|
|
|||
Loading…
Reference in a new issue