From 4de0aa8049f9967d4e91020286f383022aa0b1fa Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 22 Aug 2019 10:26:35 -0700 Subject: [PATCH] Optimize kernel index (#1672) --- cmake/onnxruntime_unittests.cmake | 2 +- onnxruntime/core/framework/session_state.cc | 141 +++++++---- onnxruntime/core/framework/session_state.h | 49 ++-- .../framework/session_state_initializer.cc | 119 +-------- .../framework/session_state_initializer.h | 9 +- onnxruntime/core/session/inference_session.cc | 5 - .../test/framework/allocation_planner_test.cc | 42 ++-- .../test/framework/execution_frame_test.cc | 64 +++-- .../test/framework/session_state_test.cc | 34 +-- .../test/framework/test_tensor_loader.cc | 3 +- .../test/onnx/microbenchmark/model_init.cc | 225 ------------------ onnxruntime/test/providers/memcpy_test.cc | 4 +- 12 files changed, 214 insertions(+), 483 deletions(-) delete mode 100644 onnxruntime/test/onnx/microbenchmark/model_init.cc diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index d63cd90f7e..ec6cdd6793 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -526,7 +526,7 @@ install(TARGETS onnx_test_runner RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) if(onnxruntime_BUILD_BENCHMARKS) - add_executable(onnxruntime_benchmark ${TEST_SRC_DIR}/onnx/microbenchmark/main.cc ${TEST_SRC_DIR}/onnx/microbenchmark/modeltest.cc ${TEST_SRC_DIR}/onnx/microbenchmark/model_init.cc) + add_executable(onnxruntime_benchmark ${TEST_SRC_DIR}/onnx/microbenchmark/main.cc ${TEST_SRC_DIR}/onnx/microbenchmark/modeltest.cc) target_include_directories(onnxruntime_benchmark PRIVATE ${ONNXRUNTIME_ROOT} ${onnxruntime_graph_header} benchmark) onnxruntime_add_include_to_target(onnxruntime_benchmark gsl) if(WIN32) diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index a6fe46be95..6a0f76e552 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -11,23 +11,96 @@ #include "core/framework/utils.h" using namespace ::onnxruntime::common; + namespace onnxruntime { -void SessionState::SetGraphViewer(std::unique_ptr graph_viewer) { - ORT_ENFORCE(nullptr != graph_viewer); - graph_viewer_ = std::move(graph_viewer); -} - const GraphViewer* SessionState::GetGraphViewer() const { return graph_viewer_.get(); } +Status SessionState::SetGraph(const Graph& graph) { + graph_viewer_ = std::make_unique(graph); + auto& logger = Logger(); + // use graph_viewer_ to initialize ort_value_name_idx_map_ + LOGS(logger, INFO) << "SaveMLValueNameIndexMapping"; + int idx = 0; -const OpKernel* SessionState::GetKernel(NodeIndex node_id) const { - auto kernel = session_kernels_.find(node_id); - return (kernel != session_kernels_.cend()) ? kernel->second.get() : nullptr; + // we keep all graph inputs (including initializers), even if they are unused, so make sure they all have an entry + for (const auto* input_def : graph_viewer_->GetInputsIncludingInitializers()) { + idx = ort_value_name_idx_map_.Add(input_def->Name()); + VLOGS(logger, 1) << "Added graph_viewer_ input with name: " << input_def->Name() + << " to OrtValueIndex with index: " << idx; + } + + for (auto& node : graph_viewer_->Nodes()) { + // build the OrtValue->index map + for (const auto* input_def : node.InputDefs()) { + if (input_def->Exists()) { + idx = ort_value_name_idx_map_.Add(input_def->Name()); + VLOGS(logger, 1) << "Added input argument with name: " << input_def->Name() + << " to OrtValueIndex with index: " << idx; + } + } + + for (const auto* input_def : node.ImplicitInputDefs()) { + if (input_def->Exists()) { + idx = ort_value_name_idx_map_.Add(input_def->Name()); + VLOGS(logger, 1) << "Added implicit input argument with name: " << input_def->Name() + << " to OrtValueIndex with index: " << idx; + } + } + + for (const auto* output_def : node.OutputDefs()) { + if (output_def->Exists()) { + ort_value_name_idx_map_.Add(output_def->Name()); + VLOGS(logger, 1) << "Added output argument with name: " << output_def->Name() + << " to OrtValueIndex with index: " << idx; + } + } + } + + // allocate OrtValue for graph outputs when coming from initializers + for (const auto& output : graph_viewer_->GetOutputs()) { + if (output->Exists()) { + idx = ort_value_name_idx_map_.Add(output->Name()); + VLOGS(logger, 1) << "Added graph output with name: " << output->Name() << " to OrtValueIndex with index: " << idx; + } + } + + LOGS(logger, INFO) << "Done saving OrtValue mappings."; + return Status::OK(); } -void SessionState::AddKernel(onnxruntime::NodeIndex node_id, std::unique_ptr p_kernel) { - // assumes vector is already resize()'ed to the number of nodes in the graph - session_kernels_[node_id] = std::move(p_kernel); +Status SessionState::CreateKernels(const KernelRegistryManager& custom_registry_manager) { + const GraphNodes& nodes = graph_viewer_->Nodes(); + if (!nodes.empty()) { + size_t max_nodeid = 0; + for (auto& node : graph_viewer_->Nodes()) { + max_nodeid = std::max(max_nodeid, node.Index()); + } + session_kernels_.clear(); + session_kernels_.resize(max_nodeid + 1, nullptr); + for (auto& node : graph_viewer_->Nodes()) { + // construct and save the kernels + std::unique_ptr op_kernel; + onnxruntime::ProviderType exec_provider_name = node.GetExecutionProviderType(); + + const IExecutionProvider* exec_provider = nullptr; + if (exec_provider_name.empty() || (exec_provider = execution_providers_.Get(exec_provider_name)) == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Could not create kernel for node: ", node.Name(), + " as there's no execution provider allocated."); + } + + common::Status status = custom_registry_manager.CreateKernel(node, *exec_provider, *this, op_kernel); + if (!status.IsOK()) { + return common::Status( + status.Category(), status.Code(), + MakeString("Kernel creation failed for node: ", node.Name(), " with error: ", status.ErrorMessage())); + } + assert(session_kernels_[node.Index()] == nullptr); + // assumes vector is already resize()'ed to the number of nodes in the graph + session_kernels_[node.Index()] = op_kernel.release(); + } + } + node_index_info_ = std::make_unique(*graph_viewer_, ort_value_name_idx_map_); + return Status::OK(); } void SessionState::SetExecutionPlan(std::unique_ptr p_seq_exec_plan) { @@ -38,7 +111,6 @@ const SequentialExecutionPlan* SessionState::GetExecutionPlan() const { return p Status SessionState::AddInitializedTensor(int ort_value_index, const OrtValue& ort_value, const OrtCallback* d, bool constant) { - ORT_ENFORCE(ort_value_index >= 0 && ort_value_index <= ort_value_name_idx_map_.MaxIdx()); auto p = initialized_tensors_.insert({ort_value_index, ort_value}); if (!p.second) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "duplicated ort_value index:", ort_value_index, @@ -55,9 +127,7 @@ Status SessionState::AddInitializedTensor(int ort_value_index, const OrtValue& o return Status::OK(); } -const std::unordered_map& SessionState::GetInitializedTensors() const { - return initialized_tensors_; -} +const std::unordered_map& SessionState::GetInitializedTensors() const { return initialized_tensors_; } const std::unordered_map& SessionState::GetConstantInitializedTensors() const { return constant_initialized_tensors_; @@ -86,7 +156,8 @@ static int64_t CalculateMemoryPatternsKey(const std::vector>& input_shapes) const { +const MemoryPatternGroup* SessionState::GetMemoryPatternGroup( + const std::vector>& input_shapes) const { int64_t key = CalculateMemoryPatternsKey(input_shapes); std::lock_guard lock(mem_patterns_lock_); @@ -96,8 +167,9 @@ const MemoryPatternGroup* SessionState::GetMemoryPatternGroup(const std::vector< return it->second.get(); } -Status SessionState::UpdateMemoryPatternGroupCache(const std::vector>& input_shapes, - std::unique_ptr mem_patterns) const { +Status SessionState::UpdateMemoryPatternGroupCache( + const std::vector>& input_shapes, + std::unique_ptr mem_patterns) const { int64_t key = CalculateMemoryPatternsKey(input_shapes); std::lock_guard lock(mem_patterns_lock_); @@ -109,9 +181,7 @@ Status SessionState::UpdateMemoryPatternGroupCache(const std::vectorName(), " (", current_provider, - ") and node ", node_info.p_node->Name(), " (", new_provider, ")."); + return ORT_MAKE_STATUS( + ONNXRUNTIME, NOT_IMPLEMENTED, + "Using an input in multiple nodes on different devices is not supported currently. Input:", input_name, + " is used by node ", existing_entry.p_node->Name(), " (", current_provider, ") and node ", + node_info.p_node->Name(), " (", new_provider, ")."); } } } @@ -178,16 +249,15 @@ const SessionState::NameNodeInfoMapType& SessionState::GetOutputNodeInfoMap() co return output_names_to_nodeinfo_mapping_; } -void SessionState::AddSubgraphSessionState(onnxruntime::NodeIndex index, - const std::string& attribute_name, +void SessionState::AddSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name, std::unique_ptr session_state) { auto entry = subgraph_session_states_.find(index); // make sure this is new. internal logic error if it is not so using ORT_ENFORCE. if (entry != subgraph_session_states_.cend()) { const auto& existing_entries = entry->second; - ORT_ENFORCE(existing_entries.find(attribute_name) == existing_entries.cend(), - "Entry exists in node ", index, " for attribute ", attribute_name); + ORT_ENFORCE(existing_entries.find(attribute_name) == existing_entries.cend(), "Entry exists in node ", index, + " for attribute ", attribute_name); } subgraph_session_states_[index].insert(std::make_pair(attribute_name, std::move(session_state))); @@ -215,19 +285,8 @@ const SessionState* SessionState::GetSubgraphSessionState(onnxruntime::NodeIndex return const_cast(this)->GetMutableSubgraphSessionState(index, attribute_name); } -void SessionState::CalculateNodeIndexInfo() { - ORT_ENFORCE(graph_viewer_); - node_index_info_ = std::make_unique(*graph_viewer_, ort_value_name_idx_map_); - - for (auto& node_to_map_pair : subgraph_session_states_) { - for (auto& attr_name_to_subgraph : node_to_map_pair.second) { - attr_name_to_subgraph.second->CalculateNodeIndexInfo(); - } - } -} - const NodeIndexInfo& SessionState::GetNodeIndexInfo() const { - ORT_ENFORCE(node_index_info_, "CalculateNodeIndexInfo must be called prior to GetExecutionInfo."); + ORT_ENFORCE(node_index_info_, "SetGraphAndCreateKernels must be called prior to GetExecutionInfo."); return *node_index_info_; } } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index 0f64b2b943..b9e4c08900 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -40,33 +40,41 @@ struct MemoryPatternGroup; * SessionState should be modified by the inference session class only. * It is supposed to be passed by const-ref only to all the executors. * This class owns all the initializers. + * Brief usage: + * SessionState s(...); + * for(...) s.AddInitializedTensor(...); + * s.SetGraphAndCreateKernels(...); + * Then you can use: + * s.GetKernel(...); */ class SessionState { public: - SessionState(const ExecutionProviders& execution_providers, bool enable_mem_pattern, concurrency::ThreadPool* thread_pool) + SessionState(const ExecutionProviders& execution_providers, bool enable_mem_pattern, + concurrency::ThreadPool* thread_pool) : execution_providers_{execution_providers}, enable_mem_pattern_(enable_mem_pattern), thread_pool_(thread_pool) {} ~SessionState() { + for (auto* p : session_kernels_) { + delete p; + } for (auto& kvp : deleter_for_initialized_tensors_) { kvp.second.f(kvp.second.param); } } // Graph viewer. - void SetGraphViewer(std::unique_ptr graph_viewer); const GraphViewer* GetGraphViewer() const; // kernels // Get kernel for specified node. // It should called right before graph execution only. - const OpKernel* GetKernel(NodeIndex node_id) const; - - void AddKernel(NodeIndex node_id, std::unique_ptr p_kernel); + const OpKernel* GetKernel(size_t node_id) const { + return (node_id < session_kernels_.size()) ? session_kernels_[node_id] : nullptr; + } const ExecutionProviders& GetExecutionProviders() const noexcept { return execution_providers_; } const OrtValueNameIdxMap& GetOrtValueNameIdxMap() const noexcept { return ort_value_name_idx_map_; } - OrtValueNameIdxMap& GetOrtValueNameIdxMap() noexcept { return ort_value_name_idx_map_; } // initialized tensors /** @@ -77,6 +85,12 @@ class SessionState { */ Status AddInitializedTensor(int ort_value_index, const OrtValue& ort_value, const OrtCallback* d, bool constant); + Status SetGraph(const Graph& graph); + Status CreateKernels(const KernelRegistryManager& custom_registry_manager); + Status SetGraphAndCreateKernels(const Graph& graph, const KernelRegistryManager& custom_registry_manager) { + ORT_RETURN_IF_ERROR(SetGraph(graph)); + return CreateKernels(custom_registry_manager); + } /** * Gets the map of ort_value_index to initialized tensors (weights) so that it can be used by the * execution frame to setup the appropriate OrtValue vectors. @@ -85,8 +99,8 @@ class SessionState { const std::unordered_map& GetInitializedTensors() const; /** - * Gets the map of ort_value_index to initialized tensors (e.g. weights) that are constant - * and cannot be overridden at runtime. + * Gets the map of ort_value_index to initialized tensors (e.g. weights) that are constant + * and cannot be overridden at runtime. * The lifetime of returned OrtValues are limited by this SessionState object. */ const std::unordered_map& GetConstantInitializedTensors() const; @@ -96,12 +110,12 @@ class SessionState { const SequentialExecutionPlan* GetExecutionPlan() const; /** - Set the logger to use for this session. + Set the logger to use for this session. */ SessionState& SetLogger(const logging::Logger& logger); /** - Get the logger for this session. + Get the logger for this session. Falls back to returning Logging::LoggingManager::DefaultLogger if SetLogger has not been called. */ const logging::Logger& Logger() const; @@ -120,10 +134,11 @@ class SessionState { /** Get cached memory pattern based on input shapes */ - const MemoryPatternGroup* GetMemoryPatternGroup(const std::vector>& input_shapes) const; + const MemoryPatternGroup* GetMemoryPatternGroup( + const std::vector>& input_shapes) const; /** - Set generated memory pattern with a given input shapes. + Set generated memory pattern with a given input shapes. Const as it's an internal cache update only. */ Status UpdateMemoryPatternGroupCache(const std::vector>& input_shape, @@ -142,10 +157,7 @@ class SessionState { * \param kci0 Nullable */ NodeInfo(size_t index0, const onnxruntime::Node* p_node0, const KernelCreateInfo* kci0, const OrtDevice& device0) - : index(index0), - p_node(p_node0), - kci(kci0), - device(&device0) {} + : index(index0), p_node(p_node0), kci(kci0), device(&device0) {} size_t index; // Nullable @@ -187,7 +199,6 @@ class SessionState { void SetDataTransferMgr(const DataTransferManager* data_transfer_mgr) { data_transfer_mgr_ = data_transfer_mgr; } std::vector& GetMutableWeightsBuffers() { return weights_buffers_; } - void CalculateNodeIndexInfo(); const NodeIndexInfo& GetNodeIndexInfo() const; private: @@ -195,7 +206,7 @@ class SessionState { // cache of the constructed kernels to avoid spending construction // time per executor - std::unordered_map> session_kernels_; + std::vector session_kernels_; std::unique_ptr graph_viewer_; const ExecutionProviders& execution_providers_; // owned by InferenceSession @@ -231,7 +242,7 @@ class SessionState { std::unordered_map>>; SubgraphSessionStateMap subgraph_session_states_; - //It could be NULL + // It could be NULL concurrency::ThreadPool* const thread_pool_; bool export_fused_dll_ = false; diff --git a/onnxruntime/core/framework/session_state_initializer.cc b/onnxruntime/core/framework/session_state_initializer.cc index 54dc8e4aa0..18589de826 100644 --- a/onnxruntime/core/framework/session_state_initializer.cc +++ b/onnxruntime/core/framework/session_state_initializer.cc @@ -27,9 +27,6 @@ namespace onnxruntime { -static common::Status SaveMLValueNameIndexMapping(const GraphViewer& graph_viewer, - OrtValueNameIdxMap& ort_value_name_idx_map, - const logging::Logger& logger); // T should have signature of '(int idx, const OrtValue& value, const OrtCallback& d) -> Status' template @@ -40,11 +37,6 @@ static common::Status SaveInitializedTensors(const Env& env, const std::basic_st const logging::Logger& logger, const DataTransferManager& data_transfer_mgr); -static common::Status SaveKernels(const ExecutionProviders& execution_providers, - SessionState& session_state, - const KernelRegistryManager& custom_registry_manager, - const logging::Logger& logger); - static common::Status SaveInputOutputNamesToNodeMapping( const onnxruntime::Graph& graph, const KernelRegistryManager& custom_registry_manager, @@ -68,11 +60,11 @@ common::Status SessionStateInitializer::CreatePlan( const Node* parent_node, const ConstPointerContainer>* outer_scope_node_args, bool enable_sequential_execution) { - auto graph_viewer = std::make_unique(graph_); + session_state_.SetGraph(graph_); + const GraphViewer* graph_viewer = session_state_.GetGraphViewer(); // populate the SessionState OrtValueNameIdxMap - auto& ort_value_name_idx_map = session_state_.GetOrtValueNameIdxMap(); - ORT_RETURN_IF_ERROR(SaveMLValueNameIndexMapping(*graph_viewer, ort_value_name_idx_map, logger_)); + const auto& ort_value_name_idx_map = session_state_.GetOrtValueNameIdxMap(); // ignore any outer scope args we don't know about. this can happen if a node contains multiple subgraphs. std::vector valid_outer_scope_node_args; @@ -92,17 +84,10 @@ common::Status SessionStateInitializer::CreatePlan( execution_providers_, kernel_registry_manager_, ort_value_name_idx_map, context, exec_plan)); session_state_.SetExecutionPlan(std::move(exec_plan)); - session_state_.SetGraphViewer(std::move(graph_viewer)); - return Status::OK(); -} - -common::Status SessionStateInitializer::InitializeAndSave( - const ConstPointerContainer>* implicit_inputs) { const auto* exec_plan_ptr = session_state_.GetExecutionPlan(); ORT_ENFORCE(exec_plan_ptr, "Execution plan was not found in SessionState. CreatePlan must be called first."); - const auto& ort_value_name_idx_map{session_state_.GetOrtValueNameIdxMap()}; std::unique_ptr tensor_allocator_(ITensorAllocator::Create( enable_mem_pattern_, *exec_plan_ptr, execution_providers_, session_state_.GetMutableWeightsBuffers())); @@ -119,64 +104,12 @@ common::Status SessionStateInitializer::InitializeAndSave( // TODO: make it better graph_.CleanAllInitializedTensors(); - ORT_RETURN_IF_ERROR(SaveKernels(execution_providers_, session_state_, kernel_registry_manager_, logger_)); - ORT_RETURN_IF_ERROR(SaveInputOutputNamesToNodeMapping(graph_, kernel_registry_manager_, session_state_, - implicit_inputs)); - + ORT_RETURN_IF_ERROR(session_state_.CreateKernels(kernel_registry_manager_)); + ORT_RETURN_IF_ERROR( + SaveInputOutputNamesToNodeMapping(graph_, kernel_registry_manager_, session_state_, outer_scope_node_args)); return Status::OK(); } -// Build the OrtValue name->idx mapping -common::Status SaveMLValueNameIndexMapping(const GraphViewer& graph_viewer, OrtValueNameIdxMap& ort_value_name_idx_map, - const logging::Logger& logger) { - LOGS(logger, INFO) << "SaveMLValueNameIndexMapping"; - int idx = 0; - - // we keep all graph inputs (including initializers), even if they are unused, so make sure they all have an entry - for (const auto* input_def : graph_viewer.GetInputsIncludingInitializers()) { - idx = ort_value_name_idx_map.Add(input_def->Name()); - VLOGS(logger, 1) << "Added graph_viewer input with name: " << input_def->Name() - << " to OrtValueIndex with index: " << idx; - } - - for (auto& node : graph_viewer.Nodes()) { - // build the OrtValue->index map - for (const auto* input_def : node.InputDefs()) { - if (input_def->Exists()) { - idx = ort_value_name_idx_map.Add(input_def->Name()); - VLOGS(logger, 1) << "Added input argument with name: " << input_def->Name() - << " to OrtValueIndex with index: " << idx; - } - } - - for (const auto* input_def : node.ImplicitInputDefs()) { - if (input_def->Exists()) { - idx = ort_value_name_idx_map.Add(input_def->Name()); - VLOGS(logger, 1) << "Added implicit input argument with name: " << input_def->Name() - << " to OrtValueIndex with index: " << idx; - } - } - - for (const auto* output_def : node.OutputDefs()) { - if (output_def->Exists()) { - ort_value_name_idx_map.Add(output_def->Name()); - VLOGS(logger, 1) << "Added output argument with name: " << output_def->Name() - << " to OrtValueIndex with index: " << idx; - } - } - } - - // allocate OrtValue for graph outputs when coming from initializers - for (const auto& output : graph_viewer.GetOutputs()) { - if (output->Exists()) { - idx = ort_value_name_idx_map.Add(output->Name()); - VLOGS(logger, 1) << "Added graph output with name: " << output->Name() << " to OrtValueIndex with index: " << idx; - } - } - - LOGS(logger, INFO) << "Done saving OrtValue mappings."; - return Status::OK(); -} static common::Status DeserializeTensorProto(const Env& env, const std::basic_string& proto_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer& m, @@ -292,46 +225,6 @@ common::Status SaveInitializedTensors(const Env& env, const std::basic_string& op_kernel) { - onnxruntime::ProviderType exec_provider_name = node.GetExecutionProviderType(); - - const IExecutionProvider* exec_provider = nullptr; - if (exec_provider_name.empty() || (exec_provider = execution_providers.Get(exec_provider_name)) == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Could not create kernel for node: ", node.Name(), - " as there's no execution provider allocated."); - } - - common::Status status = custom_registry_manager.CreateKernel(node, *exec_provider, session_state, op_kernel); - if (!status.IsOK()) { - return common::Status( - status.Category(), status.Code(), - MakeString("Kernel creation failed for node: ", node.Name(), " with error: ", status.ErrorMessage())); - } - - return status; -} - -common::Status SaveKernels(const ExecutionProviders& execution_providers, - SessionState& session_state, - const KernelRegistryManager& custom_registry_manager, - const logging::Logger& logger) { - LOGS(logger, INFO) << "Saving kernels."; - - for (auto& node : session_state.GetGraphViewer()->Nodes()) { - // construct and save the kernels - std::unique_ptr op_kernel; - ORT_RETURN_IF_ERROR(CreateOpKernel(node, execution_providers, session_state, custom_registry_manager, op_kernel)); - session_state.AddKernel(node.Index(), std::move(op_kernel)); - } - - LOGS(logger, INFO) << "Done saving kernels."; - - return Status::OK(); -} - template // T is container of const NodeArg* or NodeArg* static bool IsArgNameInInputsOutputs(const std::string& name, const T& graph_args) { diff --git a/onnxruntime/core/framework/session_state_initializer.h b/onnxruntime/core/framework/session_state_initializer.h index 3634704de5..8c969571c5 100644 --- a/onnxruntime/core/framework/session_state_initializer.h +++ b/onnxruntime/core/framework/session_state_initializer.h @@ -36,14 +36,11 @@ class SessionStateInitializer { KernelRegistryManager& kernel_registry_manager); // First perform any transformations and create the execution plan - common::Status CreatePlan(const Node* parent_node, - const ConstPointerContainer>* outer_scope_node_args, + // Then initialize tensors, and save. save kernels and input/output node mappings + common::Status CreatePlan(_In_opt_ const Node* parent_node, + _In_opt_ const ConstPointerContainer>* outer_scope_node_args, bool enable_sequential_execution); - // initialize tensors, and save. save kernels and input/output node mappings - // \param implicit_inputs could be NULL - common::Status InitializeAndSave(const ConstPointerContainer>* implicit_inputs); - private: const std::basic_string& graph_loc_; onnxruntime::Graph& graph_; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 4d8fa70784..829414ecbe 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -435,7 +435,6 @@ common::Status InferenceSession::InitializeSubgraphSessions(Graph& graph, Sessio ORT_RETURN_IF_ERROR(initializer.CreatePlan(&node, &implicit_inputs, session_options_.enable_sequential_execution)); - ORT_RETURN_IF_ERROR(initializer.InitializeAndSave(&implicit_inputs)); // LOGS(*session_logger_, VERBOSE) << std::make_pair(subgraph_info.session_state->GetExecutionPlan(), // &*subgraph_info.session_state); @@ -533,13 +532,9 @@ common::Status InferenceSession::Initialize() { } ORT_RETURN_IF_ERROR(session_initializer.CreatePlan(nullptr, nullptr, session_options_.enable_sequential_execution)); - ORT_RETURN_IF_ERROR(session_initializer.InitializeAndSave(nullptr)); // handle any subgraphs ORT_RETURN_IF_ERROR(InitializeSubgraphSessions(graph, session_state_)); - - session_state_.CalculateNodeIndexInfo(); - is_inited_ = true; LOGS(*session_logger_, INFO) << "Session successfully initialized."; diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index 29736fec41..362dbf2c64 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -34,8 +34,8 @@ struct UnaryNode { std::vector output_args; onnxruntime::Node* p_node; - UnaryNode(onnxruntime::Graph& graph, const std::string& op, - onnxruntime::NodeArg* p_input_arg, onnxruntime::NodeArg* p_output_arg) + UnaryNode(onnxruntime::Graph& graph, const std::string& op, onnxruntime::NodeArg* p_input_arg, + onnxruntime::NodeArg* p_output_arg) : input_args({p_input_arg}), output_args({p_output_arg}) { int num = NodeCounter::Next(); p_node = &graph.AddNode("node" + std::to_string(num), op, "test op", input_args, output_args); @@ -161,9 +161,11 @@ class PlannerTest : public ::testing::Test { std::unique_ptr plan_; public: - PlannerTest() : model_("test"), graph_(model_.MainGraph()), tp_("test", 1), state_(execution_providers_, false, &tp_) { - std_kernel_ = KernelDefBuilder().SetName("Transpose").Build(); - in_place_kernel_ = KernelDefBuilder().SetName("Relu").MayInplace(0, 0).Build(); + PlannerTest() + : model_("test"), graph_(model_.MainGraph()), tp_("test", 1), state_(execution_providers_, false, &tp_) { + std_kernel_ = KernelDefBuilder().SetName("Transpose").Provider(kCpuExecutionProvider).SinceVersion(1, 10).Build(); + in_place_kernel_ = + KernelDefBuilder().SetName("Relu").Provider(kCpuExecutionProvider).SinceVersion(1, 10).MayInplace(0, 0).Build(); CPUExecutionProviderInfo epi; auto execution_provider = std::make_unique(epi); execution_providers_.Add("CPUExecutionProvider", std::move(execution_provider)); @@ -194,18 +196,20 @@ class PlannerTest : public ::testing::Test { return AddNode(*in_place_kernel_, input, output); } - void BindKernel(onnxruntime::Node* p_node, ::onnxruntime::KernelDef& kernel_def) { + void BindKernel(onnxruntime::Node* p_node, ::onnxruntime::KernelDef& kernel_def, KernelRegistry* reg) { auto info = std::make_unique(*p_node, kernel_def, *execution_providers_.Get(*p_node), state_.GetInitializedTensors(), state_.GetOrtValueNameIdxMap(), state_.GetFuncMgr(), state_.GetDataTransferMgr()); - auto dummy = std::make_unique(*info); op_kernel_infos_.push_back(std::move(info)); - state_.AddKernel(p_node->Index(), std::move(dummy)); + if (reg->TryFindKernel(*p_node, onnxruntime::kCpuExecutionProvider) == nullptr) { + auto st = reg->Register( + KernelCreateInfo(std::make_unique(kernel_def), + [](const OpKernelInfo& info) -> OpKernel* { return new DummyOpKernel(info); })); + ORT_ENFORCE(st.IsOK(), st.ErrorMessage()); + } } - void SetShape(std::string& name, TensorShapeProto* shape) { - shape_map_[Arg(name)] = shape; - } + void SetShape(std::string& name, TensorShapeProto* shape) { shape_map_[Arg(name)] = shape; } void SetShape(std::initializer_list> shapes) { for (auto& pair : shapes) { @@ -215,29 +219,27 @@ class PlannerTest : public ::testing::Test { void CreatePlan(const std::vector& outer_scope_node_args = {}) { EXPECT_EQ(graph_.Resolve(), Status::OK()); - state_.SetGraphViewer(std::make_unique(graph_)); - OrtValueNameIdxMap& mlvalue_name_idx_map{state_.GetOrtValueNameIdxMap()}; + state_.SetGraph(graph_); - int count = 0; - for (auto& pair : name_to_arg_) { - EXPECT_EQ(mlvalue_name_idx_map.Add(pair.first), count++); - } + std::shared_ptr reg = std::make_shared(); for (auto& binding : kernel_bindings_) { - BindKernel(binding.first, binding.second); + BindKernel(binding.first, binding.second, reg.get()); } auto cpu_execution_provider = std::make_unique(CPUExecutionProviderInfo()); KernelRegistryManager kernel_registry_manager; + kernel_registry_manager.RegisterKernelRegistry(reg); ExecutionProviders execution_providers; execution_providers.Add(onnxruntime::kCpuExecutionProvider, std::move(cpu_execution_provider)); auto status = kernel_registry_manager.RegisterKernels(execution_providers); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); - + status = state_.CreateKernels(kernel_registry_manager); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); SequentialPlannerTestContext test_context(&shape_map_); status = SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph_), outer_scope_node_args, execution_providers, - kernel_registry_manager, mlvalue_name_idx_map, test_context, plan_); + kernel_registry_manager, state_.GetOrtValueNameIdxMap(), test_context, plan_); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); AllocationPlanTestUtility::BasicIntegrityCheck(*plan_, name_to_arg_.size()); diff --git a/onnxruntime/test/framework/execution_frame_test.cc b/onnxruntime/test/framework/execution_frame_test.cc index 8253cbe5c8..ee4ce1f58c 100644 --- a/onnxruntime/test/framework/execution_frame_test.cc +++ b/onnxruntime/test/framework/execution_frame_test.cc @@ -48,9 +48,8 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) { tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); onnxruntime::NodeArg input_def("X", &tensor_float), output_def("Y", &tensor_float); - graph.AddNode("node1", "Relu", "Relu operator", ArgMap{&input_def}, ArgMap{&output_def}); - onnxruntime::Node* node = graph.GetNode(graph.NumberOfNodes() - 1); - + onnxruntime::Node* node = &graph.AddNode("node1", "Relu", "Relu operator", ArgMap{&input_def}, ArgMap{&output_def}); + node->SetExecutionProviderType(kCpuExecutionProvider); Status status = graph.Resolve(); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); @@ -63,11 +62,8 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) { EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); SessionState state{execution_providers, true, &tp_}; - state.SetGraphViewer(std::make_unique(graph)); - - OrtValueNameIdxMap& mlvalue_name_idx_map{state.GetOrtValueNameIdxMap()}; - mlvalue_name_idx_map.Add("X"); - mlvalue_name_idx_map.Add("Y"); + status = state.SetGraphAndCreateKernels(graph, kernel_registry_manager); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); node->SetExecutionProviderType(xp_typ); @@ -75,12 +71,10 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) { // TODO below line is for testing only. In production use SequentialPlanner::CreatePlan() SequentialPlannerContext context(false); status = SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph), {}, execution_providers, kernel_registry_manager, - mlvalue_name_idx_map, context, p_seq_exec_plan); + state.GetOrtValueNameIdxMap(), context, p_seq_exec_plan); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); state.SetExecutionPlan(std::move(p_seq_exec_plan)); - state.CalculateNodeIndexInfo(); - vector outputs; ExecutionFrame frame({}, {}, {}, outputs, {}, state); @@ -117,21 +111,22 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) { } TEST_F(ExecutionFrameTest, FeedInDataTest) { - onnxruntime::Model model("test"); + onnxruntime::Model model("test", false, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), + std::unordered_map{{"", 10}}); onnxruntime::Graph& graph = model.MainGraph(); TypeProto tensor_float; tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); onnxruntime::NodeArg input_def("X", &tensor_float), output_def("Y", &tensor_float); - graph.AddNode("node1", "Clip", "Clip operator", ArgMap{&input_def}, ArgMap{&output_def}); + graph.AddNode("node1", "Clip", "Clip operator", ArgMap{&input_def}, ArgMap{&output_def}) + .SetExecutionProviderType(kCpuExecutionProvider); graph.Resolve(); - auto cpu_allocator = TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault); auto element_type = DataTypeImpl::GetType(); TensorShape shape({3, 2}); + std::vector fdata(static_cast(shape.Size())); //create fake ml value with owned buffer. - std::unique_ptr p_tensor = std::make_unique(element_type, - shape, - cpu_allocator); + OrtAllocatorInfo cpuinfo(kCpuExecutionProvider, OrtDeviceAllocator); + std::unique_ptr p_tensor = std::make_unique(element_type, shape, fdata.data(), cpuinfo); OrtValue value; value.Init(p_tensor.release(), DataTypeImpl::GetType(), @@ -144,15 +139,14 @@ TEST_F(ExecutionFrameTest, FeedInDataTest) { ExecutionProviders execution_providers; execution_providers.Add(xp_typ, std::move(cpu_xp)); EXPECT_TRUE(kernel_registry_manager.RegisterKernels(execution_providers).IsOK()); - SessionState state{execution_providers, true, &tp_}; - state.SetGraphViewer(std::make_unique(graph)); + auto status = state.SetGraphAndCreateKernels(graph, kernel_registry_manager); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); - OrtValueNameIdxMap& mlvalue_name_idx_map{state.GetOrtValueNameIdxMap()}; - auto x_idx = mlvalue_name_idx_map.Add("X"); - auto y_idx = mlvalue_name_idx_map.Add("Y"); - - state.CalculateNodeIndexInfo(); + const OrtValueNameIdxMap& mlvalue_name_idx_map = state.GetOrtValueNameIdxMap(); + int x_idx, y_idx; + ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("X", x_idx).IsOK()); + ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("Y", y_idx).IsOK()); vector outputs; ExecutionFrame frame({x_idx}, {value}, {y_idx}, outputs, {}, state); @@ -198,16 +192,20 @@ TEST_F(ExecutionFrameTest, MemPatternTest) { kernel_registry_manager.RegisterKernels(execution_providers); //1. prepare input SessionState state{execution_providers, true, &tp_}; - state.SetGraphViewer(std::make_unique(graph)); + status = state.SetGraphAndCreateKernels(graph, kernel_registry_manager); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); - OrtValueNameIdxMap& mlvalue_name_idx_map{state.GetOrtValueNameIdxMap()}; + const OrtValueNameIdxMap& mlvalue_name_idx_map{state.GetOrtValueNameIdxMap()}; - auto x1_idx = mlvalue_name_idx_map.Add("X1"); - auto x2_idx = mlvalue_name_idx_map.Add("X2"); - auto x3_idx = mlvalue_name_idx_map.Add("X3"); - mlvalue_name_idx_map.Add("T1"); - mlvalue_name_idx_map.Add("T2"); - auto t3_idx = mlvalue_name_idx_map.Add("T3"); + int x1_idx, x2_idx, x3_idx; + int t1_idx, t2_idx, t3_idx; + ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("X1", x1_idx).IsOK()); + ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("X2", x2_idx).IsOK()); + ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("X3", x3_idx).IsOK()); + + ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("T1", t1_idx).IsOK()); + ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("T2", t2_idx).IsOK()); + ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("T3", t3_idx).IsOK()); auto cpu_allocator = execution_providers.Get(xp_type)->GetAllocator(0, OrtMemTypeDefault); @@ -230,8 +228,6 @@ TEST_F(ExecutionFrameTest, MemPatternTest) { state.SetExecutionPlan(std::move(p_seq_exec_plan)); - state.CalculateNodeIndexInfo(); - vector outputs; ExecutionFrame frame({x1_idx, x2_idx, x3_idx}, {v1, v2, v3}, {t3_idx}, outputs, {}, state); diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index dec86627b0..9941a21d51 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -53,19 +53,28 @@ TEST(SessionStateTest, AddGetKernelTest) { outputs.push_back(&output_arg); onnxruntime::Node& node = graph.AddNode("node_1", "Variable", "node 1.", inputs, outputs); auto status = graph.Resolve(); - EXPECT_TRUE(status.IsOK()); - KernelDef kernel_def; - CPUExecutionProvider execution_provider{CPUExecutionProviderInfo{"CPUExecutionProvider"}}; + ASSERT_TRUE(status.IsOK()); + auto kernel_def = KernelDefBuilder().SetName("Variable").Provider(kCpuExecutionProvider).SinceVersion(1, 10).Build(); + auto cpu_execution_provider = std::make_unique(CPUExecutionProviderInfo(false)); - OpKernelInfo p_info(node, kernel_def, execution_provider, s.GetConstantInitializedTensors(), + OpKernelInfo p_info(node, *kernel_def, *cpu_execution_provider.get(), s.GetConstantInitializedTensors(), s.GetOrtValueNameIdxMap(), s.GetFuncMgr(), s.GetDataTransferMgr()); unique_ptr p_kernel; p_kernel.reset(new TestOpKernel(p_info)); size_t orig_num_outputs = p_kernel->Node().OutputDefs().size(); std::cout << "node_idx: " << node.Index() << std::endl; - s.SetGraphViewer(std::make_unique(graph)); - s.AddKernel(node.Index(), std::move(p_kernel)); + execution_providers.Add(kCpuExecutionProvider, std::move(cpu_execution_provider)); + KernelRegistryManager kernel_registry_manager; + status = kernel_registry_manager.RegisterKernels(execution_providers); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + node.SetExecutionProviderType(kCpuExecutionProvider); + std::shared_ptr kernel_registry = std::make_shared(); + kernel_registry->Register(KernelCreateInfo( + std::move(kernel_def), [](const OpKernelInfo& info) -> OpKernel* { return new TestOpKernel(info); })); + kernel_registry_manager.RegisterKernelRegistry(kernel_registry); + status = s.SetGraphAndCreateKernels(graph, kernel_registry_manager); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); auto test_kernel = s.GetKernel(node.Index()); std::cout << "orig: " << orig_num_outputs << " new: " << test_kernel->Node().OutputDefs().size() << std::endl; EXPECT_EQ(orig_num_outputs, test_kernel->Node().OutputDefs().size()); @@ -79,8 +88,7 @@ class TestParam { }; TestParam param_list[] = {{3, true}, {4, true}, {3, false}, {4, false}}; } // namespace -class SessionStateTestP : public testing::TestWithParam { -}; +class SessionStateTestP : public testing::TestWithParam {}; // Test that we separate out constant and non-constant initializers correctly TEST_P(SessionStateTestP, TestInitializerProcessing) { const TestParam& param = GetParam(); @@ -104,8 +112,8 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { ASSERT_TRUE(status.IsOK()) << status; SessionState session_state(execution_providers, param.enable_mem_pattern, &tp); - SessionStateInitializer session_initializer(param.enable_mem_pattern, ToWideString(model_path), graph, - session_state, execution_providers, krm); + SessionStateInitializer session_initializer(param.enable_mem_pattern, ToWideString(model_path), graph, session_state, + execution_providers, krm); GraphPartitioner partitioner(krm, execution_providers); status = partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr()); @@ -114,9 +122,6 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { status = session_initializer.CreatePlan(nullptr, nullptr, true); ASSERT_TRUE(status.IsOK()) << status; - status = session_initializer.InitializeAndSave(nullptr); - ASSERT_TRUE(status.IsOK()) << status; - const auto& initialized_tensors = session_state.GetInitializedTensors(); const auto& const_initialized_tensors = session_state.GetConstantInitializedTensors(); @@ -144,7 +149,6 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { } } -INSTANTIATE_TEST_CASE_P(SessionStateTests, SessionStateTestP, - testing::ValuesIn(param_list)); +INSTANTIATE_TEST_CASE_P(SessionStateTests, SessionStateTestP, testing::ValuesIn(param_list)); } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/framework/test_tensor_loader.cc b/onnxruntime/test/framework/test_tensor_loader.cc index 9bdacd595d..2d1dcea7f9 100644 --- a/onnxruntime/test/framework/test_tensor_loader.cc +++ b/onnxruntime/test/framework/test_tensor_loader.cc @@ -135,7 +135,7 @@ TEST(CApiTest, load_float_tensor_with_external_data) { } #if defined(__amd64__) || defined(_M_X64) - +#ifdef NDEBUG TEST(CApiTest, load_huge_tensor_with_external_data) { FILE* fp; std::basic_string filename(ORT_TSTR("tensor_XXXXXX")); @@ -183,5 +183,6 @@ TEST(CApiTest, load_huge_tensor_with_external_data) { } } #endif +#endif } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/onnx/microbenchmark/model_init.cc b/onnxruntime/test/onnx/microbenchmark/model_init.cc deleted file mode 100644 index ecf48fa655..0000000000 --- a/onnxruntime/test/onnx/microbenchmark/model_init.cc +++ /dev/null @@ -1,225 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include -#include -#include -#include -#include -#include -#ifdef USE_CUDA -#include -#endif -#ifdef USE_MKLDNN -#include -#endif -#include -#include -#include -#include -using namespace google::protobuf::io; - -constexpr const char* model_str = - "ir_version: 4\n" - "graph {\n" - " node {\n" - " input: \"X\"\n" - " input: \"X\"\n" - " output: \"Y\"\n" - " op_type: \"MatMul\"\n" - " }\n" - " name: \"test-model\"\n" - " input {\n" - " name: \"X\"\n" - " type {\n" - " tensor_type {\n" - " elem_type: 1\n" - " shape {\n" - " dim {\n" - " dim_value: 2\n" - " }\n" - " dim {\n" - " dim_value: 2\n" - " }\n" - " }\n" - " }\n" - " }\n" - " }\n" - " output {\n" - " name: \"Y\"\n" - " type {\n" - " tensor_type {\n" - " elem_type: 1\n" - " shape {\n" - " dim {\n" - " dim_value: 2\n" - " }\n" - " dim {\n" - " dim_value: 2\n" - " }\n" - " }\n" - " }\n" - " }\n" - " }\n" - "}\n" - "opset_import {\n" - " version: 8\n" - "}"; - -using namespace onnxruntime; - -#define BM_BREAK_IF_ERROR(expr) \ - do { \ - auto _status = (expr); \ - if ((!_status.IsOK())) state.SkipWithError(_status.ErrorMessage().c_str()); \ - } while (0) - -Status CreateModelFromStr(const char* str, std::unique_ptr* out) { - ONNX_NAMESPACE::ModelProto mp; - if (!google::protobuf::TextFormat::ParseFromString(str, &mp)) throw std::runtime_error("load model failed"); - *out = std::make_unique(mp); - return Status::OK(); -} - -Status CreateExecutionProviders(std::unique_ptr* ret) { - std::unique_ptr execution_providers = std::make_unique(); -#ifdef USE_CUDA - { - CUDAExecutionProviderInfo epi; - ORT_RETURN_IF_ERROR( - execution_providers->Add(onnxruntime::kCudaExecutionProvider, std::make_unique(epi))); - } -#endif -#ifdef USE_MKLDNN - { - MKLDNNExecutionProviderInfo epi; - ORT_RETURN_IF_ERROR(execution_providers->Add(onnxruntime::kMklDnnExecutionProvider, - std::make_unique(epi))); - } -#endif - { - CPUExecutionProviderInfo epi; - ORT_RETURN_IF_ERROR( - execution_providers->Add(onnxruntime::kCpuExecutionProvider, std::make_unique(epi))); - } - *ret = std::move(execution_providers); - return Status::OK(); -} - -Status CreateKernelRegistryManagerFromModel(std::unique_ptr* ret, Model* model, concurrency::ThreadPool& tp) { - std::unique_ptr execution_providers; - ORT_RETURN_IF_ERROR(CreateExecutionProviders(&execution_providers)); - std::unique_ptr kernel_registry_manager = std::make_unique(); - ORT_RETURN_IF_ERROR(kernel_registry_manager->RegisterKernels(*execution_providers)); - SessionState s{*execution_providers, true, &tp}; - s.SetLogger(logging::LoggingManager::DefaultLogger()); - - ORT_RETURN_IF_ERROR(model->MainGraph().Resolve()); - s.SetGraphViewer(std::make_unique(model->MainGraph())); - GraphPartitioner partitioner(*kernel_registry_manager, *execution_providers); - ORT_RETURN_IF_ERROR(partitioner.Partition(model->MainGraph(), s.ExportDll(), s.GetMutableFuncMgr())); - *ret = std::move(kernel_registry_manager); - return Status::OK(); -} - -static void SearchKernelRegistry_IMPL(benchmark::State& state, Model* model) { - std::unique_ptr kernel_registry_manager; - concurrency::ThreadPool tp{"test", 1}; - auto st = CreateKernelRegistryManagerFromModel(&kernel_registry_manager, model, tp); - if (!st.IsOK()) throw std::runtime_error("failed"); - for (auto _ : state) { - for (const auto& n : model->MainGraph().Nodes()) { - const KernelCreateInfo* info; - BM_BREAK_IF_ERROR(kernel_registry_manager->SearchKernelRegistry(n, &info)); - if (info == nullptr) state.SkipWithError("Search kernel failed"); - } - } -} - -static void BM_SearchKernelRegistry_SingleNodeModel(benchmark::State& state) { - std::unique_ptr model; - Status st = CreateModelFromStr(model_str, &model); - if (!st.IsOK()) throw std::runtime_error("failed"); - SearchKernelRegistry_IMPL(state, model.get()); -} - -BENCHMARK(BM_SearchKernelRegistry_SingleNodeModel); - -static void BM_SearchKernelRegistry_RealModel_tiny_yolo(benchmark::State& state) { - std::shared_ptr model; - auto st = onnxruntime::Model::Load("../models/opset8/test_tiny_yolov2/model.onnx", model); - SearchKernelRegistry_IMPL(state, model.get()); -} - -BENCHMARK(BM_SearchKernelRegistry_RealModel_tiny_yolo); - -static void BM_SearchKernelRegistry_RealModel_inception_v4(benchmark::State& state) { - std::shared_ptr model; - auto st = onnxruntime::Model::Load("../models/opset9/tf_inception_v4/model.onnx", model); - SearchKernelRegistry_IMPL(state, model.get()); -} - -BENCHMARK(BM_SearchKernelRegistry_RealModel_inception_v4); - -static void BM_PartitionModel_tiny_yolo(benchmark::State& state) { - int fd; - Status status = Env::Default().FileOpenRd("../models/opset8/test_tiny_yolov2/model.onnx", fd); - if (!status.IsOK()) throw std::runtime_error("open test data failed"); - auto raw_input = std::unique_ptr(std::make_unique(fd)); - auto coded_input = std::make_unique(raw_input.get()); - - ONNX_NAMESPACE::ModelProto model_proto; - if (!model_proto.ParseFromCodedStream(coded_input.get())) throw std::runtime_error("open test data failed"); - std::unique_ptr execution_providers; - BM_BREAK_IF_ERROR(CreateExecutionProviders(&execution_providers)); - std::unique_ptr kernel_registry_manager = std::make_unique(); - status = kernel_registry_manager->RegisterKernels(*execution_providers); - if (!status.IsOK()) throw std::runtime_error("RegisterKernels failed"); - concurrency::ThreadPool tp{"test", 1}; - - for (auto _ : state) { - state.PauseTiming(); - std::shared_ptr model = std::make_shared(model_proto); - SessionState s{*execution_providers, true, &tp}; - s.SetLogger(logging::LoggingManager::DefaultLogger()); - BM_BREAK_IF_ERROR(model->MainGraph().Resolve()); - s.SetGraphViewer(std::make_unique(model->MainGraph())); - GraphPartitioner partitioner(*kernel_registry_manager, *execution_providers); - state.ResumeTiming(); - BM_BREAK_IF_ERROR(partitioner.Partition(model->MainGraph(), s.ExportDll(), s.GetMutableFuncMgr())); - } -} - -BENCHMARK(BM_PartitionModel_tiny_yolo); - -static void BM_PartitionModel_inception_v4(benchmark::State& state) { - int fd; - Status status = Env::Default().FileOpenRd("../models/opset9/tf_inception_v4/model.onnx", fd); - if (!status.IsOK()) throw std::runtime_error("open test data failed"); - auto raw_input = std::unique_ptr(std::make_unique(fd)); - auto coded_input = std::make_unique(raw_input.get()); - - ONNX_NAMESPACE::ModelProto model_proto; - if (!model_proto.ParseFromCodedStream(coded_input.get())) throw std::runtime_error("open test data failed"); - std::unique_ptr execution_providers; - BM_BREAK_IF_ERROR(CreateExecutionProviders(&execution_providers)); - std::unique_ptr kernel_registry_manager = std::make_unique(); - status = kernel_registry_manager->RegisterKernels(*execution_providers); - if (!status.IsOK()) throw std::runtime_error("RegisterKernels failed"); - concurrency::ThreadPool tp{"test", 1}; - - for (auto _ : state) { - state.PauseTiming(); - std::shared_ptr model = std::make_shared(model_proto); - SessionState s{*execution_providers, true, &tp}; - s.SetLogger(logging::LoggingManager::DefaultLogger()); - BM_BREAK_IF_ERROR(model->MainGraph().Resolve()); - s.SetGraphViewer(std::make_unique(model->MainGraph())); - GraphPartitioner partitioner(*kernel_registry_manager, *execution_providers); - state.ResumeTiming(); - BM_BREAK_IF_ERROR(partitioner.Partition(model->MainGraph(), s.ExportDll(), s.GetMutableFuncMgr())); - } -} - -BENCHMARK(BM_PartitionModel_inception_v4); diff --git a/onnxruntime/test/providers/memcpy_test.cc b/onnxruntime/test/providers/memcpy_test.cc index 133c2873e2..bf52a11579 100644 --- a/onnxruntime/test/providers/memcpy_test.cc +++ b/onnxruntime/test/providers/memcpy_test.cc @@ -42,14 +42,12 @@ TEST(MemcpyTest, copy1) { Model model(mp); st = model.MainGraph().Resolve(); ASSERT_TRUE(st.IsOK()) << st.ErrorMessage(); - s.SetGraphViewer(std::make_unique(model.MainGraph())); PutAllNodesOnOneProvider(model.MainGraph(), onnxruntime::kCpuExecutionProvider); SessionStateInitializer session_initializer{true, ORT_TSTR(""), model.MainGraph(), s, execution_providers, kernel_registry_manager}; st = session_initializer.CreatePlan(nullptr, {}, true); ASSERT_TRUE(st.IsOK()) << st.ErrorMessage(); - st = session_initializer.InitializeAndSave(nullptr); - ASSERT_TRUE(st.IsOK()) << st.ErrorMessage(); + AllocatorPtr allocator = execution_providers.Get(onnxruntime::kCpuExecutionProvider)->GetAllocator(0, OrtMemTypeDefault); auto* data_type = DataTypeImpl::GetType();