From 0ad940027c0b93e971663cdf3fbcc83eb10b7b5b Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 1 May 2019 14:22:28 +1000 Subject: [PATCH] Use ConstPointerContainer for Node::ImplicitInputDefs() for better consistency with InputDefs() and OutputDefs(). (#894) --- .../core/common/const_pointer_container.h | 4 ++ include/onnxruntime/core/graph/graph.h | 30 +++++------ .../framework/op_kernel_context_internal.h | 4 +- .../framework/session_state_initializer.cc | 52 +++++++++++-------- .../framework/session_state_initializer.h | 5 +- onnxruntime/core/graph/graph_utils.cc | 4 +- .../optimizer/optimizer_execution_frame.cc | 2 +- onnxruntime/core/session/inference_session.cc | 7 +-- 8 files changed, 60 insertions(+), 48 deletions(-) diff --git a/include/onnxruntime/core/common/const_pointer_container.h b/include/onnxruntime/core/common/const_pointer_container.h index 49f0421146..1d821ba609 100644 --- a/include/onnxruntime/core/common/const_pointer_container.h +++ b/include/onnxruntime/core/common/const_pointer_container.h @@ -64,6 +64,10 @@ class ConstPointerContainer { explicit ConstPointerContainer(const Container& data) noexcept : data_(data) {} size_t size() const noexcept { return data_.size(); } + bool empty() const noexcept { return data_.empty(); } + + ConstIterator cbegin() const noexcept { return ConstIterator(data_.cbegin()); } + ConstIterator cend() const noexcept { return ConstIterator(data_.cend()); } ConstIterator begin() const noexcept { return ConstIterator(data_.cbegin()); } ConstIterator end() const noexcept { return ConstIterator(data_.cend()); } diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index b0cc98f47f..fb02f9d935 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -125,6 +125,14 @@ class Node { return common::Status::OK(); } + /** Gets the count of arguments for each of the Node's explicit inputs. */ + const std::vector& InputArgCount() const noexcept { return definitions_.input_arg_count; } + + /** Gets a modifiable count of arguments for each of the Node's explicit inputs. + @todo This should be removed in favor of a method that updates the input args and the count. + Currently these operations are separate which is not a good setup. */ + std::vector& MutableInputArgsCount() { return definitions_.input_arg_count; } + /** Gets the Node's input definitions. @remarks requires ConstPointerContainer wrapper to apply const to the NodeArg pointers so access is read-only. */ const ConstPointerContainer> InputDefs() const noexcept { @@ -136,24 +144,11 @@ class Node { return definitions_.input_defs; } - /** Gets a modifiable collection of the Node's output definitions. */ - std::vector& MutableOutputDefs() noexcept { - return definitions_.output_defs; - } - - /** Gets the count of arguments for each of the Node's explicit inputs. */ - const std::vector& InputArgCount() const noexcept { return definitions_.input_arg_count; } - - /** Gets a modifiable count of arguments for each of the Node's explicit inputs. - @todo This should be removed in favor of a method that updates the input args and the count. - Currently these operations are separate which is not a good setup. */ - std::vector& MutableInputArgsCount() { return definitions_.input_arg_count; } - /** Gets the implicit inputs to this Node. If this Node contains a subgraph, these are the NodeArg's that are implicitly consumed by Nodes within that subgraph. e.g. If and Loop operators.*/ - const std::vector& ImplicitInputDefs() const noexcept { - return definitions_.implicit_input_defs; + const ConstPointerContainer> ImplicitInputDefs() const noexcept { + return ConstPointerContainer>(definitions_.implicit_input_defs); } /** Gets a modifiable collection of the Node's implicit input definitions. */ @@ -167,6 +162,11 @@ class Node { return ConstPointerContainer>(definitions_.output_defs); } + /** Gets a modifiable collection of the Node's output definitions. */ + std::vector& MutableOutputDefs() noexcept { + return definitions_.output_defs; + } + /** Struct to provide sorting between EdgeEnd instances based on NodeIndex first, and NodeArg::Name second. */ struct EdgeEndCompare { bool operator()(const EdgeEnd& lhs, const EdgeEnd& rhs) const { diff --git a/onnxruntime/core/framework/op_kernel_context_internal.h b/onnxruntime/core/framework/op_kernel_context_internal.h index c65fbd0e18..c8f6901de1 100644 --- a/onnxruntime/core/framework/op_kernel_context_internal.h +++ b/onnxruntime/core/framework/op_kernel_context_internal.h @@ -19,7 +19,7 @@ class OpKernelContextInternal : public OpKernelContext { IExecutionFrame& frame, const OpKernel& kernel, const logging::Logger& logger, - const std::vector& implicit_inputs, + const ConstPointerContainer> implicit_inputs, const bool& terminate_flag) : OpKernelContext(&frame, &kernel, logger), session_state_{session_state}, @@ -61,7 +61,7 @@ class OpKernelContextInternal : public OpKernelContext { private: const SessionState& session_state_; - const std::vector& implicit_inputs_; + const ConstPointerContainer> implicit_inputs_; const bool& terminate_flag_; }; diff --git a/onnxruntime/core/framework/session_state_initializer.cc b/onnxruntime/core/framework/session_state_initializer.cc index ab72323c49..28550605a5 100644 --- a/onnxruntime/core/framework/session_state_initializer.cc +++ b/onnxruntime/core/framework/session_state_initializer.cc @@ -43,10 +43,11 @@ static common::Status SaveKernels(const ExecutionProviders& execution_providers, const KernelRegistryManager& custom_registry_manager, const logging::Logger& logger); -static common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph, - const KernelRegistryManager& custom_registry_manager, - SessionState& session_state, - const std::vector* implicit_inputs); +static common::Status SaveInputOutputNamesToNodeMapping( + const onnxruntime::Graph& graph, + const KernelRegistryManager& custom_registry_manager, + SessionState& session_state, + const ConstPointerContainer>* implicit_inputs); SessionStateInitializer::SessionStateInitializer(const std::basic_string& graph_loc, onnxruntime::Graph& graph, SessionState& session_state, @@ -59,9 +60,10 @@ SessionStateInitializer::SessionStateInitializer(const std::basic_string& outer_scope_node_args, - bool enable_sequential_execution) { +common::Status SessionStateInitializer::CreatePlan( + const Node* parent_node, + const ConstPointerContainer>* outer_scope_node_args, + bool enable_sequential_execution) { auto graph_viewer = std::make_unique(graph_); // populate the SessionState MLValueNameIdxMap @@ -70,13 +72,15 @@ common::Status SessionStateInitializer::CreatePlan(const Node* parent_node, // ignore any outer scope args we don't know about. this can happen if a node contains multiple subgraphs. std::vector valid_outer_scope_node_args; - std::for_each(outer_scope_node_args.cbegin(), outer_scope_node_args.cend(), - [&mlvalue_name_idx_map, &valid_outer_scope_node_args](const NodeArg* node_arg) { - int idx; - if (mlvalue_name_idx_map.GetIdx(node_arg->Name(), idx).IsOK()) { - valid_outer_scope_node_args.push_back(node_arg); - }; - }); + if (outer_scope_node_args) { + std::for_each(outer_scope_node_args->cbegin(), outer_scope_node_args->cend(), + [&mlvalue_name_idx_map, &valid_outer_scope_node_args](const NodeArg* node_arg) { + int idx; + if (mlvalue_name_idx_map.GetIdx(node_arg->Name(), idx).IsOK()) { + valid_outer_scope_node_args.push_back(node_arg); + }; + }); + } std::unique_ptr exec_plan; @@ -103,7 +107,8 @@ common::Status SessionStateInitializer::CreatePlan(const Node* parent_node, return Status::OK(); } -common::Status SessionStateInitializer::InitializeAndSave(const std::vector* implicit_inputs) { +common::Status SessionStateInitializer::InitializeAndSave( + const ConstPointerContainer>* implicit_inputs) { const auto* exec_plan_ptr = session_state_.GetExecutionPlan(); ORT_ENFORCE(exec_plan_ptr, "Execution plan was not found in SessionState. CreatePlan must be called first."); @@ -188,7 +193,8 @@ common::Status SaveMLValueNameIndexMapping(const GraphViewer& graph_viewer, static common::Status DeserializeTensorProto(const Env& env, const std::basic_string& proto_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer& m, - const ExecutionProviders& exec_providers, MLValue& mlvalue, OrtCallback& deleter) { + const ExecutionProviders& exec_providers, MLValue& mlvalue, + OrtCallback& deleter) { const OrtAllocatorInfo& alloc_info = m.GetAllocInfo(); if (strcmp(alloc_info.name, CPU) == 0 || alloc_info.mem_type == OrtMemTypeCPUOutput) { // deserialize directly to CPU tensor @@ -212,8 +218,8 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st size_t cpu_tensor_length; ORT_RETURN_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &cpu_tensor_length)); if (m.GetLen() < cpu_tensor_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Internal error. The preallocated buffer is too small. Requires ", cpu_tensor_length, - ", Got ", m.GetLen()); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Internal error. The preallocated buffer is too small. Requires ", + cpu_tensor_length, ", Got ", m.GetLen()); } OrtAllocatorInfo info(CPU, OrtDeviceAllocator, 0, OrtMemTypeDefault); std::unique_ptr data(new char[cpu_tensor_length]); @@ -411,19 +417,19 @@ common::Status SaveKernels(const ExecutionProviders& execution_providers, return Status::OK(); } -template // T is const NodeArg or NodeArg +template // T is container of const NodeArg* or NodeArg* static bool IsArgNameInInputsOutputs(const std::string& name, - const std::vector& graph_args) { - auto it = std::find_if(std::begin(graph_args), std::end(graph_args), [&name](const onnxruntime::NodeArg* arg) { + const T& graph_args) { + auto it = std::find_if(graph_args.cbegin(), graph_args.cend(), [&name](const onnxruntime::NodeArg* arg) { return arg->Name() == name; }); - return it != graph_args.end(); + return it != graph_args.cend(); } common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph, const KernelRegistryManager& custom_registry_manager, SessionState& session_state, - const std::vector* implicit_inputs) { + const ConstPointerContainer>* implicit_inputs) { auto& graph_inputs = graph.GetInputsIncludingInitializers(); auto& graph_outputs = graph.GetOutputs(); diff --git a/onnxruntime/core/framework/session_state_initializer.h b/onnxruntime/core/framework/session_state_initializer.h index 47516c3b75..8ef554c9f2 100644 --- a/onnxruntime/core/framework/session_state_initializer.h +++ b/onnxruntime/core/framework/session_state_initializer.h @@ -4,6 +4,7 @@ #pragma once #include +#include "core/common/const_pointer_container.h" #include "core/framework/allocator.h" #include "core/framework/tensor.h" #include "core/framework/path_lib.h" @@ -35,12 +36,12 @@ class SessionStateInitializer { // First perform any transformations and create the execution plan common::Status CreatePlan(const Node* parent_node, - const std::vector& outer_scope_node_args, + const ConstPointerContainer>* outer_scope_node_args, bool enable_sequential_execution); // initialize tensors, and save. save kernels and input/output node mappings // \param implicit_inputs could be NULL - common::Status InitializeAndSave(const std::vector* implicit_inputs); + common::Status InitializeAndSave(const ConstPointerContainer>* implicit_inputs); private: const std::basic_string& graph_loc_; diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc index 269e981dd9..2a01bdacb4 100644 --- a/onnxruntime/core/graph/graph_utils.cc +++ b/onnxruntime/core/graph/graph_utils.cc @@ -69,7 +69,7 @@ static bool CanUpdateImplicitInputNameInSubgraph(Node& node, for (auto& subgraph_node : attr_subgraph_pair.second->Nodes()) { // recurse if this node also consumes removed_output_name as an implicit input (i.e. there are multiple levels of nested // subgraphs, and at least one level lower uses removed_output_name as an implicit input - const auto& subgraph_node_implicit_inputs = subgraph_node.ImplicitInputDefs(); + const auto subgraph_node_implicit_inputs = subgraph_node.ImplicitInputDefs(); if (!subgraph_node_implicit_inputs.empty()) { auto subgraph_node_also_consumes_nodearg_as_implicit_input = std::find_if(subgraph_node_implicit_inputs.cbegin(), subgraph_node_implicit_inputs.cend(), @@ -99,7 +99,7 @@ static void UpdateImplicitInputNameInSubgraph(Node& node, // recurse if this node also consumes removed_output_name as an implicit input // (i.e. there are multiple levels of nested subgraphs, and at least one level lower uses // removed_output_name as an implicit input - const auto& subgraph_node_implicit_inputs = subgraph_node.ImplicitInputDefs(); + const auto subgraph_node_implicit_inputs = subgraph_node.ImplicitInputDefs(); if (!subgraph_node_implicit_inputs.empty()) { auto subgraph_node_also_consumes_nodearg_as_implicit_input = std::find_if(subgraph_node_implicit_inputs.cbegin(), subgraph_node_implicit_inputs.cend(), diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.cc b/onnxruntime/core/optimizer/optimizer_execution_frame.cc index 4350cb7073..36b1ef803f 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.cc +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.cc @@ -128,4 +128,4 @@ Status OptimizerExecutionFrame::CreateNodeOutputMLValueImpl(MLValue& mlvalue, in return Status::OK(); } -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index cb80691436..16b69960b9 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -380,10 +380,11 @@ common::Status InferenceSession::InitializeSubgraphSessions(Graph& graph, Sessio SessionStateInitializer initializer{model_location_, subgraph, *subgraph_session_state, execution_providers_, kernel_registry_manager_}; - ORT_RETURN_IF_ERROR(initializer.CreatePlan(&node, node.ImplicitInputDefs(), + const auto implicit_inputs = node.ImplicitInputDefs(); + ORT_RETURN_IF_ERROR(initializer.CreatePlan(&node, &implicit_inputs, session_options_.enable_sequential_execution)); - ORT_RETURN_IF_ERROR(initializer.InitializeAndSave(&node.ImplicitInputDefs())); + ORT_RETURN_IF_ERROR(initializer.InitializeAndSave(&implicit_inputs)); // LOGS(*session_logger_, VERBOSE) << std::make_pair(subgraph_info.session_state->GetExecutionPlan(), // &*subgraph_info.session_state); @@ -451,7 +452,7 @@ common::Status InferenceSession::Initialize() { // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs. ORT_RETURN_IF_ERROR(graph.Resolve()); - ORT_RETURN_IF_ERROR(session_initializer.CreatePlan(nullptr, {}, session_options_.enable_sequential_execution)); + ORT_RETURN_IF_ERROR(session_initializer.CreatePlan(nullptr, nullptr, session_options_.enable_sequential_execution)); ORT_RETURN_IF_ERROR(session_initializer.InitializeAndSave(nullptr)); // handle any subgraphs