Optimize kernel index (#1672)

This commit is contained in:
Changming Sun 2019-08-22 10:26:35 -07:00 committed by GitHub
parent a818740d91
commit 4de0aa8049
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 214 additions and 483 deletions

View file

@ -526,7 +526,7 @@ install(TARGETS onnx_test_runner
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR})
if(onnxruntime_BUILD_BENCHMARKS)
add_executable(onnxruntime_benchmark ${TEST_SRC_DIR}/onnx/microbenchmark/main.cc ${TEST_SRC_DIR}/onnx/microbenchmark/modeltest.cc ${TEST_SRC_DIR}/onnx/microbenchmark/model_init.cc)
add_executable(onnxruntime_benchmark ${TEST_SRC_DIR}/onnx/microbenchmark/main.cc ${TEST_SRC_DIR}/onnx/microbenchmark/modeltest.cc)
target_include_directories(onnxruntime_benchmark PRIVATE ${ONNXRUNTIME_ROOT} ${onnxruntime_graph_header} benchmark)
onnxruntime_add_include_to_target(onnxruntime_benchmark gsl)
if(WIN32)

View file

@ -11,23 +11,96 @@
#include "core/framework/utils.h"
using namespace ::onnxruntime::common;
namespace onnxruntime {
void SessionState::SetGraphViewer(std::unique_ptr<onnxruntime::GraphViewer> graph_viewer) {
ORT_ENFORCE(nullptr != graph_viewer);
graph_viewer_ = std::move(graph_viewer);
}
const GraphViewer* SessionState::GetGraphViewer() const { return graph_viewer_.get(); }
Status SessionState::SetGraph(const Graph& graph) {
graph_viewer_ = std::make_unique<onnxruntime::GraphViewer>(graph);
auto& logger = Logger();
// use graph_viewer_ to initialize ort_value_name_idx_map_
LOGS(logger, INFO) << "SaveMLValueNameIndexMapping";
int idx = 0;
const OpKernel* SessionState::GetKernel(NodeIndex node_id) const {
auto kernel = session_kernels_.find(node_id);
return (kernel != session_kernels_.cend()) ? kernel->second.get() : nullptr;
// we keep all graph inputs (including initializers), even if they are unused, so make sure they all have an entry
for (const auto* input_def : graph_viewer_->GetInputsIncludingInitializers()) {
idx = ort_value_name_idx_map_.Add(input_def->Name());
VLOGS(logger, 1) << "Added graph_viewer_ input with name: " << input_def->Name()
<< " to OrtValueIndex with index: " << idx;
}
for (auto& node : graph_viewer_->Nodes()) {
// build the OrtValue->index map
for (const auto* input_def : node.InputDefs()) {
if (input_def->Exists()) {
idx = ort_value_name_idx_map_.Add(input_def->Name());
VLOGS(logger, 1) << "Added input argument with name: " << input_def->Name()
<< " to OrtValueIndex with index: " << idx;
}
}
for (const auto* input_def : node.ImplicitInputDefs()) {
if (input_def->Exists()) {
idx = ort_value_name_idx_map_.Add(input_def->Name());
VLOGS(logger, 1) << "Added implicit input argument with name: " << input_def->Name()
<< " to OrtValueIndex with index: " << idx;
}
}
for (const auto* output_def : node.OutputDefs()) {
if (output_def->Exists()) {
ort_value_name_idx_map_.Add(output_def->Name());
VLOGS(logger, 1) << "Added output argument with name: " << output_def->Name()
<< " to OrtValueIndex with index: " << idx;
}
}
}
// allocate OrtValue for graph outputs when coming from initializers
for (const auto& output : graph_viewer_->GetOutputs()) {
if (output->Exists()) {
idx = ort_value_name_idx_map_.Add(output->Name());
VLOGS(logger, 1) << "Added graph output with name: " << output->Name() << " to OrtValueIndex with index: " << idx;
}
}
LOGS(logger, INFO) << "Done saving OrtValue mappings.";
return Status::OK();
}
void SessionState::AddKernel(onnxruntime::NodeIndex node_id, std::unique_ptr<OpKernel> p_kernel) {
// assumes vector is already resize()'ed to the number of nodes in the graph
session_kernels_[node_id] = std::move(p_kernel);
Status SessionState::CreateKernels(const KernelRegistryManager& custom_registry_manager) {
const GraphNodes& nodes = graph_viewer_->Nodes();
if (!nodes.empty()) {
size_t max_nodeid = 0;
for (auto& node : graph_viewer_->Nodes()) {
max_nodeid = std::max(max_nodeid, node.Index());
}
session_kernels_.clear();
session_kernels_.resize(max_nodeid + 1, nullptr);
for (auto& node : graph_viewer_->Nodes()) {
// construct and save the kernels
std::unique_ptr<OpKernel> op_kernel;
onnxruntime::ProviderType exec_provider_name = node.GetExecutionProviderType();
const IExecutionProvider* exec_provider = nullptr;
if (exec_provider_name.empty() || (exec_provider = execution_providers_.Get(exec_provider_name)) == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Could not create kernel for node: ", node.Name(),
" as there's no execution provider allocated.");
}
common::Status status = custom_registry_manager.CreateKernel(node, *exec_provider, *this, op_kernel);
if (!status.IsOK()) {
return common::Status(
status.Category(), status.Code(),
MakeString("Kernel creation failed for node: ", node.Name(), " with error: ", status.ErrorMessage()));
}
assert(session_kernels_[node.Index()] == nullptr);
// assumes vector is already resize()'ed to the number of nodes in the graph
session_kernels_[node.Index()] = op_kernel.release();
}
}
node_index_info_ = std::make_unique<NodeIndexInfo>(*graph_viewer_, ort_value_name_idx_map_);
return Status::OK();
}
void SessionState::SetExecutionPlan(std::unique_ptr<SequentialExecutionPlan> p_seq_exec_plan) {
@ -38,7 +111,6 @@ const SequentialExecutionPlan* SessionState::GetExecutionPlan() const { return p
Status SessionState::AddInitializedTensor(int ort_value_index, const OrtValue& ort_value, const OrtCallback* d,
bool constant) {
ORT_ENFORCE(ort_value_index >= 0 && ort_value_index <= ort_value_name_idx_map_.MaxIdx());
auto p = initialized_tensors_.insert({ort_value_index, ort_value});
if (!p.second)
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "duplicated ort_value index:", ort_value_index,
@ -55,9 +127,7 @@ Status SessionState::AddInitializedTensor(int ort_value_index, const OrtValue& o
return Status::OK();
}
const std::unordered_map<int, OrtValue>& SessionState::GetInitializedTensors() const {
return initialized_tensors_;
}
const std::unordered_map<int, OrtValue>& SessionState::GetInitializedTensors() const { return initialized_tensors_; }
const std::unordered_map<int, OrtValue>& SessionState::GetConstantInitializedTensors() const {
return constant_initialized_tensors_;
@ -86,7 +156,8 @@ static int64_t CalculateMemoryPatternsKey(const std::vector<std::reference_wrapp
return key;
}
const MemoryPatternGroup* SessionState::GetMemoryPatternGroup(const std::vector<std::reference_wrapper<const TensorShape>>& input_shapes) const {
const MemoryPatternGroup* SessionState::GetMemoryPatternGroup(
const std::vector<std::reference_wrapper<const TensorShape>>& input_shapes) const {
int64_t key = CalculateMemoryPatternsKey(input_shapes);
std::lock_guard<OrtMutex> lock(mem_patterns_lock_);
@ -96,8 +167,9 @@ const MemoryPatternGroup* SessionState::GetMemoryPatternGroup(const std::vector<
return it->second.get();
}
Status SessionState::UpdateMemoryPatternGroupCache(const std::vector<std::reference_wrapper<const TensorShape>>& input_shapes,
std::unique_ptr<MemoryPatternGroup> mem_patterns) const {
Status SessionState::UpdateMemoryPatternGroupCache(
const std::vector<std::reference_wrapper<const TensorShape>>& input_shapes,
std::unique_ptr<MemoryPatternGroup> mem_patterns) const {
int64_t key = CalculateMemoryPatternsKey(input_shapes);
std::lock_guard<OrtMutex> lock(mem_patterns_lock_);
@ -109,9 +181,7 @@ Status SessionState::UpdateMemoryPatternGroupCache(const std::vector<std::refere
return Status::OK();
}
bool SessionState::GetEnableMemoryPattern() const {
return enable_mem_pattern_;
}
bool SessionState::GetEnableMemoryPattern() const { return enable_mem_pattern_; }
common::Status SessionState::AddInputNameToNodeInfoMapping(const std::string& input_name, const NodeInfo& node_info) {
// in the future we could support multiple nodes on difference devices using an input, however right now
@ -144,10 +214,11 @@ common::Status SessionState::AddInputNameToNodeInfoMapping(const std::string& in
if (current_provider == new_provider) {
entries.push_back(node_info);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"Using an input in multiple nodes on different devices is not supported currently. Input:",
input_name, " is used by node ", existing_entry.p_node->Name(), " (", current_provider,
") and node ", node_info.p_node->Name(), " (", new_provider, ").");
return ORT_MAKE_STATUS(
ONNXRUNTIME, NOT_IMPLEMENTED,
"Using an input in multiple nodes on different devices is not supported currently. Input:", input_name,
" is used by node ", existing_entry.p_node->Name(), " (", current_provider, ") and node ",
node_info.p_node->Name(), " (", new_provider, ").");
}
}
}
@ -178,16 +249,15 @@ const SessionState::NameNodeInfoMapType& SessionState::GetOutputNodeInfoMap() co
return output_names_to_nodeinfo_mapping_;
}
void SessionState::AddSubgraphSessionState(onnxruntime::NodeIndex index,
const std::string& attribute_name,
void SessionState::AddSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name,
std::unique_ptr<SessionState> session_state) {
auto entry = subgraph_session_states_.find(index);
// make sure this is new. internal logic error if it is not so using ORT_ENFORCE.
if (entry != subgraph_session_states_.cend()) {
const auto& existing_entries = entry->second;
ORT_ENFORCE(existing_entries.find(attribute_name) == existing_entries.cend(),
"Entry exists in node ", index, " for attribute ", attribute_name);
ORT_ENFORCE(existing_entries.find(attribute_name) == existing_entries.cend(), "Entry exists in node ", index,
" for attribute ", attribute_name);
}
subgraph_session_states_[index].insert(std::make_pair(attribute_name, std::move(session_state)));
@ -215,19 +285,8 @@ const SessionState* SessionState::GetSubgraphSessionState(onnxruntime::NodeIndex
return const_cast<SessionState*>(this)->GetMutableSubgraphSessionState(index, attribute_name);
}
void SessionState::CalculateNodeIndexInfo() {
ORT_ENFORCE(graph_viewer_);
node_index_info_ = std::make_unique<NodeIndexInfo>(*graph_viewer_, ort_value_name_idx_map_);
for (auto& node_to_map_pair : subgraph_session_states_) {
for (auto& attr_name_to_subgraph : node_to_map_pair.second) {
attr_name_to_subgraph.second->CalculateNodeIndexInfo();
}
}
}
const NodeIndexInfo& SessionState::GetNodeIndexInfo() const {
ORT_ENFORCE(node_index_info_, "CalculateNodeIndexInfo must be called prior to GetExecutionInfo.");
ORT_ENFORCE(node_index_info_, "SetGraphAndCreateKernels must be called prior to GetExecutionInfo.");
return *node_index_info_;
}
} // namespace onnxruntime

View file

@ -40,33 +40,41 @@ struct MemoryPatternGroup;
* SessionState should be modified by the inference session class only.
* It is supposed to be passed by const-ref only to all the executors.
* This class owns all the initializers.
* Brief usage:
* SessionState s(...);
* for(...) s.AddInitializedTensor(...);
* s.SetGraphAndCreateKernels(...);
* Then you can use:
* s.GetKernel(...);
*/
class SessionState {
public:
SessionState(const ExecutionProviders& execution_providers, bool enable_mem_pattern, concurrency::ThreadPool* thread_pool)
SessionState(const ExecutionProviders& execution_providers, bool enable_mem_pattern,
concurrency::ThreadPool* thread_pool)
: execution_providers_{execution_providers}, enable_mem_pattern_(enable_mem_pattern), thread_pool_(thread_pool) {}
~SessionState() {
for (auto* p : session_kernels_) {
delete p;
}
for (auto& kvp : deleter_for_initialized_tensors_) {
kvp.second.f(kvp.second.param);
}
}
// Graph viewer.
void SetGraphViewer(std::unique_ptr<GraphViewer> graph_viewer);
const GraphViewer* GetGraphViewer() const;
// kernels
// Get kernel for specified node.
// It should called right before graph execution only.
const OpKernel* GetKernel(NodeIndex node_id) const;
void AddKernel(NodeIndex node_id, std::unique_ptr<OpKernel> p_kernel);
const OpKernel* GetKernel(size_t node_id) const {
return (node_id < session_kernels_.size()) ? session_kernels_[node_id] : nullptr;
}
const ExecutionProviders& GetExecutionProviders() const noexcept { return execution_providers_; }
const OrtValueNameIdxMap& GetOrtValueNameIdxMap() const noexcept { return ort_value_name_idx_map_; }
OrtValueNameIdxMap& GetOrtValueNameIdxMap() noexcept { return ort_value_name_idx_map_; }
// initialized tensors
/**
@ -77,6 +85,12 @@ class SessionState {
*/
Status AddInitializedTensor(int ort_value_index, const OrtValue& ort_value, const OrtCallback* d, bool constant);
Status SetGraph(const Graph& graph);
Status CreateKernels(const KernelRegistryManager& custom_registry_manager);
Status SetGraphAndCreateKernels(const Graph& graph, const KernelRegistryManager& custom_registry_manager) {
ORT_RETURN_IF_ERROR(SetGraph(graph));
return CreateKernels(custom_registry_manager);
}
/**
* Gets the map of ort_value_index to initialized tensors (weights) so that it can be used by the
* execution frame to setup the appropriate OrtValue vectors.
@ -85,8 +99,8 @@ class SessionState {
const std::unordered_map<int, OrtValue>& GetInitializedTensors() const;
/**
* Gets the map of ort_value_index to initialized tensors (e.g. weights) that are constant
* and cannot be overridden at runtime.
* Gets the map of ort_value_index to initialized tensors (e.g. weights) that are constant
* and cannot be overridden at runtime.
* The lifetime of returned OrtValues are limited by this SessionState object.
*/
const std::unordered_map<int, OrtValue>& GetConstantInitializedTensors() const;
@ -96,12 +110,12 @@ class SessionState {
const SequentialExecutionPlan* GetExecutionPlan() const;
/**
Set the logger to use for this session.
Set the logger to use for this session.
*/
SessionState& SetLogger(const logging::Logger& logger);
/**
Get the logger for this session.
Get the logger for this session.
Falls back to returning Logging::LoggingManager::DefaultLogger if SetLogger has not been called.
*/
const logging::Logger& Logger() const;
@ -120,10 +134,11 @@ class SessionState {
/**
Get cached memory pattern based on input shapes
*/
const MemoryPatternGroup* GetMemoryPatternGroup(const std::vector<std::reference_wrapper<const TensorShape>>& input_shapes) const;
const MemoryPatternGroup* GetMemoryPatternGroup(
const std::vector<std::reference_wrapper<const TensorShape>>& input_shapes) const;
/**
Set generated memory pattern with a given input shapes.
Set generated memory pattern with a given input shapes.
Const as it's an internal cache update only.
*/
Status UpdateMemoryPatternGroupCache(const std::vector<std::reference_wrapper<const TensorShape>>& input_shape,
@ -142,10 +157,7 @@ class SessionState {
* \param kci0 Nullable
*/
NodeInfo(size_t index0, const onnxruntime::Node* p_node0, const KernelCreateInfo* kci0, const OrtDevice& device0)
: index(index0),
p_node(p_node0),
kci(kci0),
device(&device0) {}
: index(index0), p_node(p_node0), kci(kci0), device(&device0) {}
size_t index;
// Nullable
@ -187,7 +199,6 @@ class SessionState {
void SetDataTransferMgr(const DataTransferManager* data_transfer_mgr) { data_transfer_mgr_ = data_transfer_mgr; }
std::vector<BufferUniquePtr>& GetMutableWeightsBuffers() { return weights_buffers_; }
void CalculateNodeIndexInfo();
const NodeIndexInfo& GetNodeIndexInfo() const;
private:
@ -195,7 +206,7 @@ class SessionState {
// cache of the constructed kernels to avoid spending construction
// time per executor
std::unordered_map<NodeIndex, std::unique_ptr<OpKernel>> session_kernels_;
std::vector<OpKernel*> session_kernels_;
std::unique_ptr<GraphViewer> graph_viewer_;
const ExecutionProviders& execution_providers_; // owned by InferenceSession
@ -231,7 +242,7 @@ class SessionState {
std::unordered_map<onnxruntime::NodeIndex, std::unordered_map<std::string, std::unique_ptr<SessionState>>>;
SubgraphSessionStateMap subgraph_session_states_;
//It could be NULL
// It could be NULL
concurrency::ThreadPool* const thread_pool_;
bool export_fused_dll_ = false;

View file

@ -27,9 +27,6 @@
namespace onnxruntime {
static common::Status SaveMLValueNameIndexMapping(const GraphViewer& graph_viewer,
OrtValueNameIdxMap& ort_value_name_idx_map,
const logging::Logger& logger);
// T should have signature of '(int idx, const OrtValue& value, const OrtCallback& d) -> Status'
template <typename T>
@ -40,11 +37,6 @@ static common::Status SaveInitializedTensors(const Env& env, const std::basic_st
const logging::Logger& logger,
const DataTransferManager& data_transfer_mgr);
static common::Status SaveKernels(const ExecutionProviders& execution_providers,
SessionState& session_state,
const KernelRegistryManager& custom_registry_manager,
const logging::Logger& logger);
static common::Status SaveInputOutputNamesToNodeMapping(
const onnxruntime::Graph& graph,
const KernelRegistryManager& custom_registry_manager,
@ -68,11 +60,11 @@ common::Status SessionStateInitializer::CreatePlan(
const Node* parent_node,
const ConstPointerContainer<std::vector<NodeArg*>>* outer_scope_node_args,
bool enable_sequential_execution) {
auto graph_viewer = std::make_unique<onnxruntime::GraphViewer>(graph_);
session_state_.SetGraph(graph_);
const GraphViewer* graph_viewer = session_state_.GetGraphViewer();
// populate the SessionState OrtValueNameIdxMap
auto& ort_value_name_idx_map = session_state_.GetOrtValueNameIdxMap();
ORT_RETURN_IF_ERROR(SaveMLValueNameIndexMapping(*graph_viewer, ort_value_name_idx_map, logger_));
const auto& ort_value_name_idx_map = session_state_.GetOrtValueNameIdxMap();
// ignore any outer scope args we don't know about. this can happen if a node contains multiple subgraphs.
std::vector<const NodeArg*> valid_outer_scope_node_args;
@ -92,17 +84,10 @@ common::Status SessionStateInitializer::CreatePlan(
execution_providers_, kernel_registry_manager_,
ort_value_name_idx_map, context, exec_plan));
session_state_.SetExecutionPlan(std::move(exec_plan));
session_state_.SetGraphViewer(std::move(graph_viewer));
return Status::OK();
}
common::Status SessionStateInitializer::InitializeAndSave(
const ConstPointerContainer<std::vector<NodeArg*>>* implicit_inputs) {
const auto* exec_plan_ptr = session_state_.GetExecutionPlan();
ORT_ENFORCE(exec_plan_ptr, "Execution plan was not found in SessionState. CreatePlan must be called first.");
const auto& ort_value_name_idx_map{session_state_.GetOrtValueNameIdxMap()};
std::unique_ptr<ITensorAllocator> tensor_allocator_(ITensorAllocator::Create(
enable_mem_pattern_, *exec_plan_ptr, execution_providers_, session_state_.GetMutableWeightsBuffers()));
@ -119,64 +104,12 @@ common::Status SessionStateInitializer::InitializeAndSave(
// TODO: make it better
graph_.CleanAllInitializedTensors();
ORT_RETURN_IF_ERROR(SaveKernels(execution_providers_, session_state_, kernel_registry_manager_, logger_));
ORT_RETURN_IF_ERROR(SaveInputOutputNamesToNodeMapping(graph_, kernel_registry_manager_, session_state_,
implicit_inputs));
ORT_RETURN_IF_ERROR(session_state_.CreateKernels(kernel_registry_manager_));
ORT_RETURN_IF_ERROR(
SaveInputOutputNamesToNodeMapping(graph_, kernel_registry_manager_, session_state_, outer_scope_node_args));
return Status::OK();
}
// Build the OrtValue name->idx mapping
common::Status SaveMLValueNameIndexMapping(const GraphViewer& graph_viewer, OrtValueNameIdxMap& ort_value_name_idx_map,
const logging::Logger& logger) {
LOGS(logger, INFO) << "SaveMLValueNameIndexMapping";
int idx = 0;
// we keep all graph inputs (including initializers), even if they are unused, so make sure they all have an entry
for (const auto* input_def : graph_viewer.GetInputsIncludingInitializers()) {
idx = ort_value_name_idx_map.Add(input_def->Name());
VLOGS(logger, 1) << "Added graph_viewer input with name: " << input_def->Name()
<< " to OrtValueIndex with index: " << idx;
}
for (auto& node : graph_viewer.Nodes()) {
// build the OrtValue->index map
for (const auto* input_def : node.InputDefs()) {
if (input_def->Exists()) {
idx = ort_value_name_idx_map.Add(input_def->Name());
VLOGS(logger, 1) << "Added input argument with name: " << input_def->Name()
<< " to OrtValueIndex with index: " << idx;
}
}
for (const auto* input_def : node.ImplicitInputDefs()) {
if (input_def->Exists()) {
idx = ort_value_name_idx_map.Add(input_def->Name());
VLOGS(logger, 1) << "Added implicit input argument with name: " << input_def->Name()
<< " to OrtValueIndex with index: " << idx;
}
}
for (const auto* output_def : node.OutputDefs()) {
if (output_def->Exists()) {
ort_value_name_idx_map.Add(output_def->Name());
VLOGS(logger, 1) << "Added output argument with name: " << output_def->Name()
<< " to OrtValueIndex with index: " << idx;
}
}
}
// allocate OrtValue for graph outputs when coming from initializers
for (const auto& output : graph_viewer.GetOutputs()) {
if (output->Exists()) {
idx = ort_value_name_idx_map.Add(output->Name());
VLOGS(logger, 1) << "Added graph output with name: " << output->Name() << " to OrtValueIndex with index: " << idx;
}
}
LOGS(logger, INFO) << "Done saving OrtValue mappings.";
return Status::OK();
}
static common::Status DeserializeTensorProto(const Env& env, const std::basic_string<PATH_CHAR_TYPE>& proto_path,
const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer& m,
@ -292,46 +225,6 @@ common::Status SaveInitializedTensors(const Env& env, const std::basic_string<PA
return common::Status::OK();
}
static common::Status CreateOpKernel(const onnxruntime::Node& node, const ExecutionProviders& execution_providers,
const SessionState& session_state,
const KernelRegistryManager& custom_registry_manager,
std::unique_ptr<OpKernel>& op_kernel) {
onnxruntime::ProviderType exec_provider_name = node.GetExecutionProviderType();
const IExecutionProvider* exec_provider = nullptr;
if (exec_provider_name.empty() || (exec_provider = execution_providers.Get(exec_provider_name)) == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Could not create kernel for node: ", node.Name(),
" as there's no execution provider allocated.");
}
common::Status status = custom_registry_manager.CreateKernel(node, *exec_provider, session_state, op_kernel);
if (!status.IsOK()) {
return common::Status(
status.Category(), status.Code(),
MakeString("Kernel creation failed for node: ", node.Name(), " with error: ", status.ErrorMessage()));
}
return status;
}
common::Status SaveKernels(const ExecutionProviders& execution_providers,
SessionState& session_state,
const KernelRegistryManager& custom_registry_manager,
const logging::Logger& logger) {
LOGS(logger, INFO) << "Saving kernels.";
for (auto& node : session_state.GetGraphViewer()->Nodes()) {
// construct and save the kernels
std::unique_ptr<OpKernel> op_kernel;
ORT_RETURN_IF_ERROR(CreateOpKernel(node, execution_providers, session_state, custom_registry_manager, op_kernel));
session_state.AddKernel(node.Index(), std::move(op_kernel));
}
LOGS(logger, INFO) << "Done saving kernels.";
return Status::OK();
}
template <typename T> // T is container of const NodeArg* or NodeArg*
static bool IsArgNameInInputsOutputs(const std::string& name,
const T& graph_args) {

View file

@ -36,14 +36,11 @@ class SessionStateInitializer {
KernelRegistryManager& kernel_registry_manager);
// First perform any transformations and create the execution plan
common::Status CreatePlan(const Node* parent_node,
const ConstPointerContainer<std::vector<NodeArg*>>* outer_scope_node_args,
// Then initialize tensors, and save. save kernels and input/output node mappings
common::Status CreatePlan(_In_opt_ const Node* parent_node,
_In_opt_ const ConstPointerContainer<std::vector<NodeArg*>>* outer_scope_node_args,
bool enable_sequential_execution);
// initialize tensors, and save. save kernels and input/output node mappings
// \param implicit_inputs could be NULL
common::Status InitializeAndSave(const ConstPointerContainer<std::vector<NodeArg*>>* implicit_inputs);
private:
const std::basic_string<PATH_CHAR_TYPE>& graph_loc_;
onnxruntime::Graph& graph_;

View file

@ -435,7 +435,6 @@ common::Status InferenceSession::InitializeSubgraphSessions(Graph& graph, Sessio
ORT_RETURN_IF_ERROR(initializer.CreatePlan(&node, &implicit_inputs,
session_options_.enable_sequential_execution));
ORT_RETURN_IF_ERROR(initializer.InitializeAndSave(&implicit_inputs));
// LOGS(*session_logger_, VERBOSE) << std::make_pair(subgraph_info.session_state->GetExecutionPlan(),
// &*subgraph_info.session_state);
@ -533,13 +532,9 @@ common::Status InferenceSession::Initialize() {
}
ORT_RETURN_IF_ERROR(session_initializer.CreatePlan(nullptr, nullptr, session_options_.enable_sequential_execution));
ORT_RETURN_IF_ERROR(session_initializer.InitializeAndSave(nullptr));
// handle any subgraphs
ORT_RETURN_IF_ERROR(InitializeSubgraphSessions(graph, session_state_));
session_state_.CalculateNodeIndexInfo();
is_inited_ = true;
LOGS(*session_logger_, INFO) << "Session successfully initialized.";

View file

@ -34,8 +34,8 @@ struct UnaryNode {
std::vector<onnxruntime::NodeArg*> output_args;
onnxruntime::Node* p_node;
UnaryNode(onnxruntime::Graph& graph, const std::string& op,
onnxruntime::NodeArg* p_input_arg, onnxruntime::NodeArg* p_output_arg)
UnaryNode(onnxruntime::Graph& graph, const std::string& op, onnxruntime::NodeArg* p_input_arg,
onnxruntime::NodeArg* p_output_arg)
: input_args({p_input_arg}), output_args({p_output_arg}) {
int num = NodeCounter::Next();
p_node = &graph.AddNode("node" + std::to_string(num), op, "test op", input_args, output_args);
@ -161,9 +161,11 @@ class PlannerTest : public ::testing::Test {
std::unique_ptr<SequentialExecutionPlan> plan_;
public:
PlannerTest() : model_("test"), graph_(model_.MainGraph()), tp_("test", 1), state_(execution_providers_, false, &tp_) {
std_kernel_ = KernelDefBuilder().SetName("Transpose").Build();
in_place_kernel_ = KernelDefBuilder().SetName("Relu").MayInplace(0, 0).Build();
PlannerTest()
: model_("test"), graph_(model_.MainGraph()), tp_("test", 1), state_(execution_providers_, false, &tp_) {
std_kernel_ = KernelDefBuilder().SetName("Transpose").Provider(kCpuExecutionProvider).SinceVersion(1, 10).Build();
in_place_kernel_ =
KernelDefBuilder().SetName("Relu").Provider(kCpuExecutionProvider).SinceVersion(1, 10).MayInplace(0, 0).Build();
CPUExecutionProviderInfo epi;
auto execution_provider = std::make_unique<CPUExecutionProvider>(epi);
execution_providers_.Add("CPUExecutionProvider", std::move(execution_provider));
@ -194,18 +196,20 @@ class PlannerTest : public ::testing::Test {
return AddNode(*in_place_kernel_, input, output);
}
void BindKernel(onnxruntime::Node* p_node, ::onnxruntime::KernelDef& kernel_def) {
void BindKernel(onnxruntime::Node* p_node, ::onnxruntime::KernelDef& kernel_def, KernelRegistry* reg) {
auto info = std::make_unique<OpKernelInfo>(*p_node, kernel_def, *execution_providers_.Get(*p_node),
state_.GetInitializedTensors(), state_.GetOrtValueNameIdxMap(),
state_.GetFuncMgr(), state_.GetDataTransferMgr());
auto dummy = std::make_unique<DummyOpKernel>(*info);
op_kernel_infos_.push_back(std::move(info));
state_.AddKernel(p_node->Index(), std::move(dummy));
if (reg->TryFindKernel(*p_node, onnxruntime::kCpuExecutionProvider) == nullptr) {
auto st = reg->Register(
KernelCreateInfo(std::make_unique<KernelDef>(kernel_def),
[](const OpKernelInfo& info) -> OpKernel* { return new DummyOpKernel(info); }));
ORT_ENFORCE(st.IsOK(), st.ErrorMessage());
}
}
void SetShape(std::string& name, TensorShapeProto* shape) {
shape_map_[Arg(name)] = shape;
}
void SetShape(std::string& name, TensorShapeProto* shape) { shape_map_[Arg(name)] = shape; }
void SetShape(std::initializer_list<std::pair<std::string&, TensorShapeProto*>> shapes) {
for (auto& pair : shapes) {
@ -215,29 +219,27 @@ class PlannerTest : public ::testing::Test {
void CreatePlan(const std::vector<const NodeArg*>& outer_scope_node_args = {}) {
EXPECT_EQ(graph_.Resolve(), Status::OK());
state_.SetGraphViewer(std::make_unique<GraphViewer>(graph_));
OrtValueNameIdxMap& mlvalue_name_idx_map{state_.GetOrtValueNameIdxMap()};
state_.SetGraph(graph_);
int count = 0;
for (auto& pair : name_to_arg_) {
EXPECT_EQ(mlvalue_name_idx_map.Add(pair.first), count++);
}
std::shared_ptr<KernelRegistry> reg = std::make_shared<KernelRegistry>();
for (auto& binding : kernel_bindings_) {
BindKernel(binding.first, binding.second);
BindKernel(binding.first, binding.second, reg.get());
}
auto cpu_execution_provider = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
KernelRegistryManager kernel_registry_manager;
kernel_registry_manager.RegisterKernelRegistry(reg);
ExecutionProviders execution_providers;
execution_providers.Add(onnxruntime::kCpuExecutionProvider, std::move(cpu_execution_provider));
auto status = kernel_registry_manager.RegisterKernels(execution_providers);
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
status = state_.CreateKernels(kernel_registry_manager);
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
SequentialPlannerTestContext test_context(&shape_map_);
status = SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph_), outer_scope_node_args, execution_providers,
kernel_registry_manager, mlvalue_name_idx_map, test_context, plan_);
kernel_registry_manager, state_.GetOrtValueNameIdxMap(), test_context, plan_);
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
AllocationPlanTestUtility::BasicIntegrityCheck(*plan_, name_to_arg_.size());

View file

@ -48,9 +48,8 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) {
tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
onnxruntime::NodeArg input_def("X", &tensor_float), output_def("Y", &tensor_float);
graph.AddNode("node1", "Relu", "Relu operator", ArgMap{&input_def}, ArgMap{&output_def});
onnxruntime::Node* node = graph.GetNode(graph.NumberOfNodes() - 1);
onnxruntime::Node* node = &graph.AddNode("node1", "Relu", "Relu operator", ArgMap{&input_def}, ArgMap{&output_def});
node->SetExecutionProviderType(kCpuExecutionProvider);
Status status = graph.Resolve();
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
@ -63,11 +62,8 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) {
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
SessionState state{execution_providers, true, &tp_};
state.SetGraphViewer(std::make_unique<GraphViewer>(graph));
OrtValueNameIdxMap& mlvalue_name_idx_map{state.GetOrtValueNameIdxMap()};
mlvalue_name_idx_map.Add("X");
mlvalue_name_idx_map.Add("Y");
status = state.SetGraphAndCreateKernels(graph, kernel_registry_manager);
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
node->SetExecutionProviderType(xp_typ);
@ -75,12 +71,10 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) {
// TODO below line is for testing only. In production use SequentialPlanner::CreatePlan()
SequentialPlannerContext context(false);
status = SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph), {}, execution_providers, kernel_registry_manager,
mlvalue_name_idx_map, context, p_seq_exec_plan);
state.GetOrtValueNameIdxMap(), context, p_seq_exec_plan);
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
state.SetExecutionPlan(std::move(p_seq_exec_plan));
state.CalculateNodeIndexInfo();
vector<OrtValue> outputs;
ExecutionFrame frame({}, {}, {}, outputs, {}, state);
@ -117,21 +111,22 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) {
}
TEST_F(ExecutionFrameTest, FeedInDataTest) {
onnxruntime::Model model("test");
onnxruntime::Model model("test", false, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(),
std::unordered_map<std::string, int>{{"", 10}});
onnxruntime::Graph& graph = model.MainGraph();
TypeProto tensor_float;
tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
onnxruntime::NodeArg input_def("X", &tensor_float), output_def("Y", &tensor_float);
graph.AddNode("node1", "Clip", "Clip operator", ArgMap{&input_def}, ArgMap{&output_def});
graph.AddNode("node1", "Clip", "Clip operator", ArgMap{&input_def}, ArgMap{&output_def})
.SetExecutionProviderType(kCpuExecutionProvider);
graph.Resolve();
auto cpu_allocator = TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault);
auto element_type = DataTypeImpl::GetType<float>();
TensorShape shape({3, 2});
std::vector<float> fdata(static_cast<size_t>(shape.Size()));
//create fake ml value with owned buffer.
std::unique_ptr<Tensor> p_tensor = std::make_unique<Tensor>(element_type,
shape,
cpu_allocator);
OrtAllocatorInfo cpuinfo(kCpuExecutionProvider, OrtDeviceAllocator);
std::unique_ptr<Tensor> p_tensor = std::make_unique<Tensor>(element_type, shape, fdata.data(), cpuinfo);
OrtValue value;
value.Init(p_tensor.release(),
DataTypeImpl::GetType<Tensor>(),
@ -144,15 +139,14 @@ TEST_F(ExecutionFrameTest, FeedInDataTest) {
ExecutionProviders execution_providers;
execution_providers.Add(xp_typ, std::move(cpu_xp));
EXPECT_TRUE(kernel_registry_manager.RegisterKernels(execution_providers).IsOK());
SessionState state{execution_providers, true, &tp_};
state.SetGraphViewer(std::make_unique<GraphViewer>(graph));
auto status = state.SetGraphAndCreateKernels(graph, kernel_registry_manager);
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
OrtValueNameIdxMap& mlvalue_name_idx_map{state.GetOrtValueNameIdxMap()};
auto x_idx = mlvalue_name_idx_map.Add("X");
auto y_idx = mlvalue_name_idx_map.Add("Y");
state.CalculateNodeIndexInfo();
const OrtValueNameIdxMap& mlvalue_name_idx_map = state.GetOrtValueNameIdxMap();
int x_idx, y_idx;
ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("X", x_idx).IsOK());
ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("Y", y_idx).IsOK());
vector<OrtValue> outputs;
ExecutionFrame frame({x_idx}, {value}, {y_idx}, outputs, {}, state);
@ -198,16 +192,20 @@ TEST_F(ExecutionFrameTest, MemPatternTest) {
kernel_registry_manager.RegisterKernels(execution_providers);
//1. prepare input
SessionState state{execution_providers, true, &tp_};
state.SetGraphViewer(std::make_unique<GraphViewer>(graph));
status = state.SetGraphAndCreateKernels(graph, kernel_registry_manager);
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
OrtValueNameIdxMap& mlvalue_name_idx_map{state.GetOrtValueNameIdxMap()};
const OrtValueNameIdxMap& mlvalue_name_idx_map{state.GetOrtValueNameIdxMap()};
auto x1_idx = mlvalue_name_idx_map.Add("X1");
auto x2_idx = mlvalue_name_idx_map.Add("X2");
auto x3_idx = mlvalue_name_idx_map.Add("X3");
mlvalue_name_idx_map.Add("T1");
mlvalue_name_idx_map.Add("T2");
auto t3_idx = mlvalue_name_idx_map.Add("T3");
int x1_idx, x2_idx, x3_idx;
int t1_idx, t2_idx, t3_idx;
ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("X1", x1_idx).IsOK());
ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("X2", x2_idx).IsOK());
ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("X3", x3_idx).IsOK());
ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("T1", t1_idx).IsOK());
ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("T2", t2_idx).IsOK());
ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("T3", t3_idx).IsOK());
auto cpu_allocator = execution_providers.Get(xp_type)->GetAllocator(0, OrtMemTypeDefault);
@ -230,8 +228,6 @@ TEST_F(ExecutionFrameTest, MemPatternTest) {
state.SetExecutionPlan(std::move(p_seq_exec_plan));
state.CalculateNodeIndexInfo();
vector<OrtValue> outputs;
ExecutionFrame frame({x1_idx, x2_idx, x3_idx}, {v1, v2, v3}, {t3_idx}, outputs, {}, state);

View file

@ -53,19 +53,28 @@ TEST(SessionStateTest, AddGetKernelTest) {
outputs.push_back(&output_arg);
onnxruntime::Node& node = graph.AddNode("node_1", "Variable", "node 1.", inputs, outputs);
auto status = graph.Resolve();
EXPECT_TRUE(status.IsOK());
KernelDef kernel_def;
CPUExecutionProvider execution_provider{CPUExecutionProviderInfo{"CPUExecutionProvider"}};
ASSERT_TRUE(status.IsOK());
auto kernel_def = KernelDefBuilder().SetName("Variable").Provider(kCpuExecutionProvider).SinceVersion(1, 10).Build();
auto cpu_execution_provider = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo(false));
OpKernelInfo p_info(node, kernel_def, execution_provider, s.GetConstantInitializedTensors(),
OpKernelInfo p_info(node, *kernel_def, *cpu_execution_provider.get(), s.GetConstantInitializedTensors(),
s.GetOrtValueNameIdxMap(), s.GetFuncMgr(), s.GetDataTransferMgr());
unique_ptr<TestOpKernel> p_kernel;
p_kernel.reset(new TestOpKernel(p_info));
size_t orig_num_outputs = p_kernel->Node().OutputDefs().size();
std::cout << "node_idx: " << node.Index() << std::endl;
s.SetGraphViewer(std::make_unique<GraphViewer>(graph));
s.AddKernel(node.Index(), std::move(p_kernel));
execution_providers.Add(kCpuExecutionProvider, std::move(cpu_execution_provider));
KernelRegistryManager kernel_registry_manager;
status = kernel_registry_manager.RegisterKernels(execution_providers);
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
node.SetExecutionProviderType(kCpuExecutionProvider);
std::shared_ptr<KernelRegistry> kernel_registry = std::make_shared<KernelRegistry>();
kernel_registry->Register(KernelCreateInfo(
std::move(kernel_def), [](const OpKernelInfo& info) -> OpKernel* { return new TestOpKernel(info); }));
kernel_registry_manager.RegisterKernelRegistry(kernel_registry);
status = s.SetGraphAndCreateKernels(graph, kernel_registry_manager);
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
auto test_kernel = s.GetKernel(node.Index());
std::cout << "orig: " << orig_num_outputs << " new: " << test_kernel->Node().OutputDefs().size() << std::endl;
EXPECT_EQ(orig_num_outputs, test_kernel->Node().OutputDefs().size());
@ -79,8 +88,7 @@ class TestParam {
};
TestParam param_list[] = {{3, true}, {4, true}, {3, false}, {4, false}};
} // namespace
class SessionStateTestP : public testing::TestWithParam<TestParam> {
};
class SessionStateTestP : public testing::TestWithParam<TestParam> {};
// Test that we separate out constant and non-constant initializers correctly
TEST_P(SessionStateTestP, TestInitializerProcessing) {
const TestParam& param = GetParam();
@ -104,8 +112,8 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) {
ASSERT_TRUE(status.IsOK()) << status;
SessionState session_state(execution_providers, param.enable_mem_pattern, &tp);
SessionStateInitializer session_initializer(param.enable_mem_pattern, ToWideString(model_path), graph,
session_state, execution_providers, krm);
SessionStateInitializer session_initializer(param.enable_mem_pattern, ToWideString(model_path), graph, session_state,
execution_providers, krm);
GraphPartitioner partitioner(krm, execution_providers);
status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr());
@ -114,9 +122,6 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) {
status = session_initializer.CreatePlan(nullptr, nullptr, true);
ASSERT_TRUE(status.IsOK()) << status;
status = session_initializer.InitializeAndSave(nullptr);
ASSERT_TRUE(status.IsOK()) << status;
const auto& initialized_tensors = session_state.GetInitializedTensors();
const auto& const_initialized_tensors = session_state.GetConstantInitializedTensors();
@ -144,7 +149,6 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) {
}
}
INSTANTIATE_TEST_CASE_P(SessionStateTests, SessionStateTestP,
testing::ValuesIn(param_list));
INSTANTIATE_TEST_CASE_P(SessionStateTests, SessionStateTestP, testing::ValuesIn(param_list));
} // namespace test
} // namespace onnxruntime

View file

@ -135,7 +135,7 @@ TEST(CApiTest, load_float_tensor_with_external_data) {
}
#if defined(__amd64__) || defined(_M_X64)
#ifdef NDEBUG
TEST(CApiTest, load_huge_tensor_with_external_data) {
FILE* fp;
std::basic_string<ORTCHAR_T> filename(ORT_TSTR("tensor_XXXXXX"));
@ -183,5 +183,6 @@ TEST(CApiTest, load_huge_tensor_with_external_data) {
}
}
#endif
#endif
} // namespace test
} // namespace onnxruntime

View file

@ -1,225 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <benchmark/benchmark.h>
#include <core/graph/model.h>
#include <core/framework/execution_providers.h>
#include <core/framework/kernel_registry_manager.h>
#include <core/framework/session_state.h>
#include <core/framework/graph_partitioner.h>
#include <core/providers/cpu/cpu_execution_provider.h>
#ifdef USE_CUDA
#include <core/providers/cuda/cuda_execution_provider.h>
#endif
#ifdef USE_MKLDNN
#include <core/providers/mkldnn/mkldnn_execution_provider.h>
#endif
#include <core/platform/env.h>
#include <core/graph/onnx_protobuf.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h>
using namespace google::protobuf::io;
constexpr const char* model_str =
"ir_version: 4\n"
"graph {\n"
" node {\n"
" input: \"X\"\n"
" input: \"X\"\n"
" output: \"Y\"\n"
" op_type: \"MatMul\"\n"
" }\n"
" name: \"test-model\"\n"
" input {\n"
" name: \"X\"\n"
" type {\n"
" tensor_type {\n"
" elem_type: 1\n"
" shape {\n"
" dim {\n"
" dim_value: 2\n"
" }\n"
" dim {\n"
" dim_value: 2\n"
" }\n"
" }\n"
" }\n"
" }\n"
" }\n"
" output {\n"
" name: \"Y\"\n"
" type {\n"
" tensor_type {\n"
" elem_type: 1\n"
" shape {\n"
" dim {\n"
" dim_value: 2\n"
" }\n"
" dim {\n"
" dim_value: 2\n"
" }\n"
" }\n"
" }\n"
" }\n"
" }\n"
"}\n"
"opset_import {\n"
" version: 8\n"
"}";
using namespace onnxruntime;
#define BM_BREAK_IF_ERROR(expr) \
do { \
auto _status = (expr); \
if ((!_status.IsOK())) state.SkipWithError(_status.ErrorMessage().c_str()); \
} while (0)
Status CreateModelFromStr(const char* str, std::unique_ptr<Model>* out) {
ONNX_NAMESPACE::ModelProto mp;
if (!google::protobuf::TextFormat::ParseFromString(str, &mp)) throw std::runtime_error("load model failed");
*out = std::make_unique<Model>(mp);
return Status::OK();
}
Status CreateExecutionProviders(std::unique_ptr<ExecutionProviders>* ret) {
std::unique_ptr<ExecutionProviders> execution_providers = std::make_unique<ExecutionProviders>();
#ifdef USE_CUDA
{
CUDAExecutionProviderInfo epi;
ORT_RETURN_IF_ERROR(
execution_providers->Add(onnxruntime::kCudaExecutionProvider, std::make_unique<CUDAExecutionProvider>(epi)));
}
#endif
#ifdef USE_MKLDNN
{
MKLDNNExecutionProviderInfo epi;
ORT_RETURN_IF_ERROR(execution_providers->Add(onnxruntime::kMklDnnExecutionProvider,
std::make_unique<MKLDNNExecutionProvider>(epi)));
}
#endif
{
CPUExecutionProviderInfo epi;
ORT_RETURN_IF_ERROR(
execution_providers->Add(onnxruntime::kCpuExecutionProvider, std::make_unique<CPUExecutionProvider>(epi)));
}
*ret = std::move(execution_providers);
return Status::OK();
}
Status CreateKernelRegistryManagerFromModel(std::unique_ptr<KernelRegistryManager>* ret, Model* model, concurrency::ThreadPool& tp) {
std::unique_ptr<ExecutionProviders> execution_providers;
ORT_RETURN_IF_ERROR(CreateExecutionProviders(&execution_providers));
std::unique_ptr<KernelRegistryManager> kernel_registry_manager = std::make_unique<KernelRegistryManager>();
ORT_RETURN_IF_ERROR(kernel_registry_manager->RegisterKernels(*execution_providers));
SessionState s{*execution_providers, true, &tp};
s.SetLogger(logging::LoggingManager::DefaultLogger());
ORT_RETURN_IF_ERROR(model->MainGraph().Resolve());
s.SetGraphViewer(std::make_unique<GraphViewer>(model->MainGraph()));
GraphPartitioner partitioner(*kernel_registry_manager, *execution_providers);
ORT_RETURN_IF_ERROR(partitioner.Partition(model->MainGraph(), s.ExportDll(), s.GetMutableFuncMgr()));
*ret = std::move(kernel_registry_manager);
return Status::OK();
}
static void SearchKernelRegistry_IMPL(benchmark::State& state, Model* model) {
std::unique_ptr<KernelRegistryManager> kernel_registry_manager;
concurrency::ThreadPool tp{"test", 1};
auto st = CreateKernelRegistryManagerFromModel(&kernel_registry_manager, model, tp);
if (!st.IsOK()) throw std::runtime_error("failed");
for (auto _ : state) {
for (const auto& n : model->MainGraph().Nodes()) {
const KernelCreateInfo* info;
BM_BREAK_IF_ERROR(kernel_registry_manager->SearchKernelRegistry(n, &info));
if (info == nullptr) state.SkipWithError("Search kernel failed");
}
}
}
static void BM_SearchKernelRegistry_SingleNodeModel(benchmark::State& state) {
std::unique_ptr<Model> model;
Status st = CreateModelFromStr(model_str, &model);
if (!st.IsOK()) throw std::runtime_error("failed");
SearchKernelRegistry_IMPL(state, model.get());
}
BENCHMARK(BM_SearchKernelRegistry_SingleNodeModel);
static void BM_SearchKernelRegistry_RealModel_tiny_yolo(benchmark::State& state) {
std::shared_ptr<onnxruntime::Model> model;
auto st = onnxruntime::Model::Load("../models/opset8/test_tiny_yolov2/model.onnx", model);
SearchKernelRegistry_IMPL(state, model.get());
}
BENCHMARK(BM_SearchKernelRegistry_RealModel_tiny_yolo);
static void BM_SearchKernelRegistry_RealModel_inception_v4(benchmark::State& state) {
std::shared_ptr<onnxruntime::Model> model;
auto st = onnxruntime::Model::Load("../models/opset9/tf_inception_v4/model.onnx", model);
SearchKernelRegistry_IMPL(state, model.get());
}
BENCHMARK(BM_SearchKernelRegistry_RealModel_inception_v4);
static void BM_PartitionModel_tiny_yolo(benchmark::State& state) {
int fd;
Status status = Env::Default().FileOpenRd("../models/opset8/test_tiny_yolov2/model.onnx", fd);
if (!status.IsOK()) throw std::runtime_error("open test data failed");
auto raw_input = std::unique_ptr<ZeroCopyInputStream>(std::make_unique<FileInputStream>(fd));
auto coded_input = std::make_unique<CodedInputStream>(raw_input.get());
ONNX_NAMESPACE::ModelProto model_proto;
if (!model_proto.ParseFromCodedStream(coded_input.get())) throw std::runtime_error("open test data failed");
std::unique_ptr<ExecutionProviders> execution_providers;
BM_BREAK_IF_ERROR(CreateExecutionProviders(&execution_providers));
std::unique_ptr<KernelRegistryManager> kernel_registry_manager = std::make_unique<KernelRegistryManager>();
status = kernel_registry_manager->RegisterKernels(*execution_providers);
if (!status.IsOK()) throw std::runtime_error("RegisterKernels failed");
concurrency::ThreadPool tp{"test", 1};
for (auto _ : state) {
state.PauseTiming();
std::shared_ptr<onnxruntime::Model> model = std::make_shared<onnxruntime::Model>(model_proto);
SessionState s{*execution_providers, true, &tp};
s.SetLogger(logging::LoggingManager::DefaultLogger());
BM_BREAK_IF_ERROR(model->MainGraph().Resolve());
s.SetGraphViewer(std::make_unique<GraphViewer>(model->MainGraph()));
GraphPartitioner partitioner(*kernel_registry_manager, *execution_providers);
state.ResumeTiming();
BM_BREAK_IF_ERROR(partitioner.Partition(model->MainGraph(), s.ExportDll(), s.GetMutableFuncMgr()));
}
}
BENCHMARK(BM_PartitionModel_tiny_yolo);
static void BM_PartitionModel_inception_v4(benchmark::State& state) {
int fd;
Status status = Env::Default().FileOpenRd("../models/opset9/tf_inception_v4/model.onnx", fd);
if (!status.IsOK()) throw std::runtime_error("open test data failed");
auto raw_input = std::unique_ptr<ZeroCopyInputStream>(std::make_unique<FileInputStream>(fd));
auto coded_input = std::make_unique<CodedInputStream>(raw_input.get());
ONNX_NAMESPACE::ModelProto model_proto;
if (!model_proto.ParseFromCodedStream(coded_input.get())) throw std::runtime_error("open test data failed");
std::unique_ptr<ExecutionProviders> execution_providers;
BM_BREAK_IF_ERROR(CreateExecutionProviders(&execution_providers));
std::unique_ptr<KernelRegistryManager> kernel_registry_manager = std::make_unique<KernelRegistryManager>();
status = kernel_registry_manager->RegisterKernels(*execution_providers);
if (!status.IsOK()) throw std::runtime_error("RegisterKernels failed");
concurrency::ThreadPool tp{"test", 1};
for (auto _ : state) {
state.PauseTiming();
std::shared_ptr<onnxruntime::Model> model = std::make_shared<onnxruntime::Model>(model_proto);
SessionState s{*execution_providers, true, &tp};
s.SetLogger(logging::LoggingManager::DefaultLogger());
BM_BREAK_IF_ERROR(model->MainGraph().Resolve());
s.SetGraphViewer(std::make_unique<GraphViewer>(model->MainGraph()));
GraphPartitioner partitioner(*kernel_registry_manager, *execution_providers);
state.ResumeTiming();
BM_BREAK_IF_ERROR(partitioner.Partition(model->MainGraph(), s.ExportDll(), s.GetMutableFuncMgr()));
}
}
BENCHMARK(BM_PartitionModel_inception_v4);

View file

@ -42,14 +42,12 @@ TEST(MemcpyTest, copy1) {
Model model(mp);
st = model.MainGraph().Resolve();
ASSERT_TRUE(st.IsOK()) << st.ErrorMessage();
s.SetGraphViewer(std::make_unique<GraphViewer>(model.MainGraph()));
PutAllNodesOnOneProvider(model.MainGraph(), onnxruntime::kCpuExecutionProvider);
SessionStateInitializer session_initializer{true, ORT_TSTR(""), model.MainGraph(),
s, execution_providers, kernel_registry_manager};
st = session_initializer.CreatePlan(nullptr, {}, true);
ASSERT_TRUE(st.IsOK()) << st.ErrorMessage();
st = session_initializer.InitializeAndSave(nullptr);
ASSERT_TRUE(st.IsOK()) << st.ErrorMessage();
AllocatorPtr allocator =
execution_providers.Get(onnxruntime::kCpuExecutionProvider)->GetAllocator(0, OrtMemTypeDefault);
auto* data_type = DataTypeImpl::GetType<float>();