Use ConstPointerContainer for Node::ImplicitInputDefs() for better consistency with InputDefs() and OutputDefs(). (#894)

This commit is contained in:
Scott McKay 2019-05-01 14:22:28 +10:00 committed by GitHub
parent df513c7fe6
commit 0ad940027c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 60 additions and 48 deletions

View file

@ -64,6 +64,10 @@ class ConstPointerContainer {
explicit ConstPointerContainer(const Container& data) noexcept : data_(data) {}
size_t size() const noexcept { return data_.size(); }
bool empty() const noexcept { return data_.empty(); }
ConstIterator cbegin() const noexcept { return ConstIterator(data_.cbegin()); }
ConstIterator cend() const noexcept { return ConstIterator(data_.cend()); }
ConstIterator begin() const noexcept { return ConstIterator(data_.cbegin()); }
ConstIterator end() const noexcept { return ConstIterator(data_.cend()); }

View file

@ -125,6 +125,14 @@ class Node {
return common::Status::OK();
}
/** Gets the count of arguments for each of the Node's explicit inputs. */
const std::vector<int>& InputArgCount() const noexcept { return definitions_.input_arg_count; }
/** Gets a modifiable count of arguments for each of the Node's explicit inputs.
@todo This should be removed in favor of a method that updates the input args and the count.
Currently these operations are separate which is not a good setup. */
std::vector<int>& MutableInputArgsCount() { return definitions_.input_arg_count; }
/** Gets the Node's input definitions.
@remarks requires ConstPointerContainer wrapper to apply const to the NodeArg pointers so access is read-only. */
const ConstPointerContainer<std::vector<NodeArg*>> InputDefs() const noexcept {
@ -136,24 +144,11 @@ class Node {
return definitions_.input_defs;
}
/** Gets a modifiable collection of the Node's output definitions. */
std::vector<NodeArg*>& MutableOutputDefs() noexcept {
return definitions_.output_defs;
}
/** Gets the count of arguments for each of the Node's explicit inputs. */
const std::vector<int>& InputArgCount() const noexcept { return definitions_.input_arg_count; }
/** Gets a modifiable count of arguments for each of the Node's explicit inputs.
@todo This should be removed in favor of a method that updates the input args and the count.
Currently these operations are separate which is not a good setup. */
std::vector<int>& MutableInputArgsCount() { return definitions_.input_arg_count; }
/** Gets the implicit inputs to this Node.
If this Node contains a subgraph, these are the NodeArg's that are implicitly consumed by Nodes within that
subgraph. e.g. If and Loop operators.*/
const std::vector<NodeArg*>& ImplicitInputDefs() const noexcept {
return definitions_.implicit_input_defs;
const ConstPointerContainer<std::vector<NodeArg*>> ImplicitInputDefs() const noexcept {
return ConstPointerContainer<std::vector<NodeArg*>>(definitions_.implicit_input_defs);
}
/** Gets a modifiable collection of the Node's implicit input definitions. */
@ -167,6 +162,11 @@ class Node {
return ConstPointerContainer<std::vector<NodeArg*>>(definitions_.output_defs);
}
/** Gets a modifiable collection of the Node's output definitions. */
std::vector<NodeArg*>& MutableOutputDefs() noexcept {
return definitions_.output_defs;
}
/** Struct to provide sorting between EdgeEnd instances based on NodeIndex first, and NodeArg::Name second. */
struct EdgeEndCompare {
bool operator()(const EdgeEnd& lhs, const EdgeEnd& rhs) const {

View file

@ -19,7 +19,7 @@ class OpKernelContextInternal : public OpKernelContext {
IExecutionFrame& frame,
const OpKernel& kernel,
const logging::Logger& logger,
const std::vector<NodeArg*>& implicit_inputs,
const ConstPointerContainer<std::vector<NodeArg*>> implicit_inputs,
const bool& terminate_flag)
: OpKernelContext(&frame, &kernel, logger),
session_state_{session_state},
@ -61,7 +61,7 @@ class OpKernelContextInternal : public OpKernelContext {
private:
const SessionState& session_state_;
const std::vector<NodeArg*>& implicit_inputs_;
const ConstPointerContainer<std::vector<NodeArg*>> implicit_inputs_;
const bool& terminate_flag_;
};

View file

@ -43,10 +43,11 @@ static common::Status SaveKernels(const ExecutionProviders& execution_providers,
const KernelRegistryManager& custom_registry_manager,
const logging::Logger& logger);
static common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph,
const KernelRegistryManager& custom_registry_manager,
SessionState& session_state,
const std::vector<NodeArg*>* implicit_inputs);
static common::Status SaveInputOutputNamesToNodeMapping(
const onnxruntime::Graph& graph,
const KernelRegistryManager& custom_registry_manager,
SessionState& session_state,
const ConstPointerContainer<std::vector<NodeArg*>>* implicit_inputs);
SessionStateInitializer::SessionStateInitializer(const std::basic_string<PATH_CHAR_TYPE>& graph_loc,
onnxruntime::Graph& graph, SessionState& session_state,
@ -59,9 +60,10 @@ SessionStateInitializer::SessionStateInitializer(const std::basic_string<PATH_CH
kernel_registry_manager_{kernel_registry_manager},
logger_{session_state.Logger()} {}
common::Status SessionStateInitializer::CreatePlan(const Node* parent_node,
const std::vector<NodeArg*>& outer_scope_node_args,
bool enable_sequential_execution) {
common::Status SessionStateInitializer::CreatePlan(
const Node* parent_node,
const ConstPointerContainer<std::vector<NodeArg*>>* outer_scope_node_args,
bool enable_sequential_execution) {
auto graph_viewer = std::make_unique<onnxruntime::GraphViewer>(graph_);
// populate the SessionState MLValueNameIdxMap
@ -70,13 +72,15 @@ common::Status SessionStateInitializer::CreatePlan(const Node* parent_node,
// ignore any outer scope args we don't know about. this can happen if a node contains multiple subgraphs.
std::vector<const NodeArg*> valid_outer_scope_node_args;
std::for_each(outer_scope_node_args.cbegin(), outer_scope_node_args.cend(),
[&mlvalue_name_idx_map, &valid_outer_scope_node_args](const NodeArg* node_arg) {
int idx;
if (mlvalue_name_idx_map.GetIdx(node_arg->Name(), idx).IsOK()) {
valid_outer_scope_node_args.push_back(node_arg);
};
});
if (outer_scope_node_args) {
std::for_each(outer_scope_node_args->cbegin(), outer_scope_node_args->cend(),
[&mlvalue_name_idx_map, &valid_outer_scope_node_args](const NodeArg* node_arg) {
int idx;
if (mlvalue_name_idx_map.GetIdx(node_arg->Name(), idx).IsOK()) {
valid_outer_scope_node_args.push_back(node_arg);
};
});
}
std::unique_ptr<SequentialExecutionPlan> exec_plan;
@ -103,7 +107,8 @@ common::Status SessionStateInitializer::CreatePlan(const Node* parent_node,
return Status::OK();
}
common::Status SessionStateInitializer::InitializeAndSave(const std::vector<NodeArg*>* implicit_inputs) {
common::Status SessionStateInitializer::InitializeAndSave(
const ConstPointerContainer<std::vector<NodeArg*>>* implicit_inputs) {
const auto* exec_plan_ptr = session_state_.GetExecutionPlan();
ORT_ENFORCE(exec_plan_ptr, "Execution plan was not found in SessionState. CreatePlan must be called first.");
@ -188,7 +193,8 @@ common::Status SaveMLValueNameIndexMapping(const GraphViewer& graph_viewer,
static common::Status DeserializeTensorProto(const Env& env, const std::basic_string<PATH_CHAR_TYPE>& proto_path,
const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer& m,
const ExecutionProviders& exec_providers, MLValue& mlvalue, OrtCallback& deleter) {
const ExecutionProviders& exec_providers, MLValue& mlvalue,
OrtCallback& deleter) {
const OrtAllocatorInfo& alloc_info = m.GetAllocInfo();
if (strcmp(alloc_info.name, CPU) == 0 || alloc_info.mem_type == OrtMemTypeCPUOutput) {
// deserialize directly to CPU tensor
@ -212,8 +218,8 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st
size_t cpu_tensor_length;
ORT_RETURN_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &cpu_tensor_length));
if (m.GetLen() < cpu_tensor_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Internal error. The preallocated buffer is too small. Requires ", cpu_tensor_length,
", Got ", m.GetLen());
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Internal error. The preallocated buffer is too small. Requires ",
cpu_tensor_length, ", Got ", m.GetLen());
}
OrtAllocatorInfo info(CPU, OrtDeviceAllocator, 0, OrtMemTypeDefault);
std::unique_ptr<char[]> data(new char[cpu_tensor_length]);
@ -411,19 +417,19 @@ common::Status SaveKernels(const ExecutionProviders& execution_providers,
return Status::OK();
}
template <typename T> // T is const NodeArg or NodeArg
template <typename T> // T is container of const NodeArg* or NodeArg*
static bool IsArgNameInInputsOutputs(const std::string& name,
const std::vector<T*>& graph_args) {
auto it = std::find_if(std::begin(graph_args), std::end(graph_args), [&name](const onnxruntime::NodeArg* arg) {
const T& graph_args) {
auto it = std::find_if(graph_args.cbegin(), graph_args.cend(), [&name](const onnxruntime::NodeArg* arg) {
return arg->Name() == name;
});
return it != graph_args.end();
return it != graph_args.cend();
}
common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph,
const KernelRegistryManager& custom_registry_manager,
SessionState& session_state,
const std::vector<NodeArg*>* implicit_inputs) {
const ConstPointerContainer<std::vector<NodeArg*>>* implicit_inputs) {
auto& graph_inputs = graph.GetInputsIncludingInitializers();
auto& graph_outputs = graph.GetOutputs();

View file

@ -4,6 +4,7 @@
#pragma once
#include <map>
#include "core/common/const_pointer_container.h"
#include "core/framework/allocator.h"
#include "core/framework/tensor.h"
#include "core/framework/path_lib.h"
@ -35,12 +36,12 @@ class SessionStateInitializer {
// First perform any transformations and create the execution plan
common::Status CreatePlan(const Node* parent_node,
const std::vector<NodeArg*>& outer_scope_node_args,
const ConstPointerContainer<std::vector<NodeArg*>>* outer_scope_node_args,
bool enable_sequential_execution);
// initialize tensors, and save. save kernels and input/output node mappings
// \param implicit_inputs could be NULL
common::Status InitializeAndSave(const std::vector<NodeArg*>* implicit_inputs);
common::Status InitializeAndSave(const ConstPointerContainer<std::vector<NodeArg*>>* implicit_inputs);
private:
const std::basic_string<PATH_CHAR_TYPE>& graph_loc_;

View file

@ -69,7 +69,7 @@ static bool CanUpdateImplicitInputNameInSubgraph(Node& node,
for (auto& subgraph_node : attr_subgraph_pair.second->Nodes()) {
// recurse if this node also consumes removed_output_name as an implicit input (i.e. there are multiple levels of nested
// subgraphs, and at least one level lower uses removed_output_name as an implicit input
const auto& subgraph_node_implicit_inputs = subgraph_node.ImplicitInputDefs();
const auto subgraph_node_implicit_inputs = subgraph_node.ImplicitInputDefs();
if (!subgraph_node_implicit_inputs.empty()) {
auto subgraph_node_also_consumes_nodearg_as_implicit_input =
std::find_if(subgraph_node_implicit_inputs.cbegin(), subgraph_node_implicit_inputs.cend(),
@ -99,7 +99,7 @@ static void UpdateImplicitInputNameInSubgraph(Node& node,
// recurse if this node also consumes removed_output_name as an implicit input
// (i.e. there are multiple levels of nested subgraphs, and at least one level lower uses
// removed_output_name as an implicit input
const auto& subgraph_node_implicit_inputs = subgraph_node.ImplicitInputDefs();
const auto subgraph_node_implicit_inputs = subgraph_node.ImplicitInputDefs();
if (!subgraph_node_implicit_inputs.empty()) {
auto subgraph_node_also_consumes_nodearg_as_implicit_input =
std::find_if(subgraph_node_implicit_inputs.cbegin(), subgraph_node_implicit_inputs.cend(),

View file

@ -128,4 +128,4 @@ Status OptimizerExecutionFrame::CreateNodeOutputMLValueImpl(MLValue& mlvalue, in
return Status::OK();
}
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -380,10 +380,11 @@ common::Status InferenceSession::InitializeSubgraphSessions(Graph& graph, Sessio
SessionStateInitializer initializer{model_location_, subgraph, *subgraph_session_state, execution_providers_,
kernel_registry_manager_};
ORT_RETURN_IF_ERROR(initializer.CreatePlan(&node, node.ImplicitInputDefs(),
const auto implicit_inputs = node.ImplicitInputDefs();
ORT_RETURN_IF_ERROR(initializer.CreatePlan(&node, &implicit_inputs,
session_options_.enable_sequential_execution));
ORT_RETURN_IF_ERROR(initializer.InitializeAndSave(&node.ImplicitInputDefs()));
ORT_RETURN_IF_ERROR(initializer.InitializeAndSave(&implicit_inputs));
// LOGS(*session_logger_, VERBOSE) << std::make_pair(subgraph_info.session_state->GetExecutionPlan(),
// &*subgraph_info.session_state);
@ -451,7 +452,7 @@ common::Status InferenceSession::Initialize() {
// now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs.
ORT_RETURN_IF_ERROR(graph.Resolve());
ORT_RETURN_IF_ERROR(session_initializer.CreatePlan(nullptr, {}, session_options_.enable_sequential_execution));
ORT_RETURN_IF_ERROR(session_initializer.CreatePlan(nullptr, nullptr, session_options_.enable_sequential_execution));
ORT_RETURN_IF_ERROR(session_initializer.InitializeAndSave(nullptr));
// handle any subgraphs