mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
Use ConstPointerContainer for Node::ImplicitInputDefs() for better consistency with InputDefs() and OutputDefs(). (#894)
This commit is contained in:
parent
df513c7fe6
commit
0ad940027c
8 changed files with 60 additions and 48 deletions
|
|
@ -64,6 +64,10 @@ class ConstPointerContainer {
|
|||
explicit ConstPointerContainer(const Container& data) noexcept : data_(data) {}
|
||||
|
||||
size_t size() const noexcept { return data_.size(); }
|
||||
bool empty() const noexcept { return data_.empty(); }
|
||||
|
||||
ConstIterator cbegin() const noexcept { return ConstIterator(data_.cbegin()); }
|
||||
ConstIterator cend() const noexcept { return ConstIterator(data_.cend()); }
|
||||
|
||||
ConstIterator begin() const noexcept { return ConstIterator(data_.cbegin()); }
|
||||
ConstIterator end() const noexcept { return ConstIterator(data_.cend()); }
|
||||
|
|
|
|||
|
|
@ -125,6 +125,14 @@ class Node {
|
|||
return common::Status::OK();
|
||||
}
|
||||
|
||||
/** Gets the count of arguments for each of the Node's explicit inputs. */
|
||||
const std::vector<int>& InputArgCount() const noexcept { return definitions_.input_arg_count; }
|
||||
|
||||
/** Gets a modifiable count of arguments for each of the Node's explicit inputs.
|
||||
@todo This should be removed in favor of a method that updates the input args and the count.
|
||||
Currently these operations are separate which is not a good setup. */
|
||||
std::vector<int>& MutableInputArgsCount() { return definitions_.input_arg_count; }
|
||||
|
||||
/** Gets the Node's input definitions.
|
||||
@remarks requires ConstPointerContainer wrapper to apply const to the NodeArg pointers so access is read-only. */
|
||||
const ConstPointerContainer<std::vector<NodeArg*>> InputDefs() const noexcept {
|
||||
|
|
@ -136,24 +144,11 @@ class Node {
|
|||
return definitions_.input_defs;
|
||||
}
|
||||
|
||||
/** Gets a modifiable collection of the Node's output definitions. */
|
||||
std::vector<NodeArg*>& MutableOutputDefs() noexcept {
|
||||
return definitions_.output_defs;
|
||||
}
|
||||
|
||||
/** Gets the count of arguments for each of the Node's explicit inputs. */
|
||||
const std::vector<int>& InputArgCount() const noexcept { return definitions_.input_arg_count; }
|
||||
|
||||
/** Gets a modifiable count of arguments for each of the Node's explicit inputs.
|
||||
@todo This should be removed in favor of a method that updates the input args and the count.
|
||||
Currently these operations are separate which is not a good setup. */
|
||||
std::vector<int>& MutableInputArgsCount() { return definitions_.input_arg_count; }
|
||||
|
||||
/** Gets the implicit inputs to this Node.
|
||||
If this Node contains a subgraph, these are the NodeArg's that are implicitly consumed by Nodes within that
|
||||
subgraph. e.g. If and Loop operators.*/
|
||||
const std::vector<NodeArg*>& ImplicitInputDefs() const noexcept {
|
||||
return definitions_.implicit_input_defs;
|
||||
const ConstPointerContainer<std::vector<NodeArg*>> ImplicitInputDefs() const noexcept {
|
||||
return ConstPointerContainer<std::vector<NodeArg*>>(definitions_.implicit_input_defs);
|
||||
}
|
||||
|
||||
/** Gets a modifiable collection of the Node's implicit input definitions. */
|
||||
|
|
@ -167,6 +162,11 @@ class Node {
|
|||
return ConstPointerContainer<std::vector<NodeArg*>>(definitions_.output_defs);
|
||||
}
|
||||
|
||||
/** Gets a modifiable collection of the Node's output definitions. */
|
||||
std::vector<NodeArg*>& MutableOutputDefs() noexcept {
|
||||
return definitions_.output_defs;
|
||||
}
|
||||
|
||||
/** Struct to provide sorting between EdgeEnd instances based on NodeIndex first, and NodeArg::Name second. */
|
||||
struct EdgeEndCompare {
|
||||
bool operator()(const EdgeEnd& lhs, const EdgeEnd& rhs) const {
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ class OpKernelContextInternal : public OpKernelContext {
|
|||
IExecutionFrame& frame,
|
||||
const OpKernel& kernel,
|
||||
const logging::Logger& logger,
|
||||
const std::vector<NodeArg*>& implicit_inputs,
|
||||
const ConstPointerContainer<std::vector<NodeArg*>> implicit_inputs,
|
||||
const bool& terminate_flag)
|
||||
: OpKernelContext(&frame, &kernel, logger),
|
||||
session_state_{session_state},
|
||||
|
|
@ -61,7 +61,7 @@ class OpKernelContextInternal : public OpKernelContext {
|
|||
|
||||
private:
|
||||
const SessionState& session_state_;
|
||||
const std::vector<NodeArg*>& implicit_inputs_;
|
||||
const ConstPointerContainer<std::vector<NodeArg*>> implicit_inputs_;
|
||||
const bool& terminate_flag_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -43,10 +43,11 @@ static common::Status SaveKernels(const ExecutionProviders& execution_providers,
|
|||
const KernelRegistryManager& custom_registry_manager,
|
||||
const logging::Logger& logger);
|
||||
|
||||
static common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph,
|
||||
const KernelRegistryManager& custom_registry_manager,
|
||||
SessionState& session_state,
|
||||
const std::vector<NodeArg*>* implicit_inputs);
|
||||
static common::Status SaveInputOutputNamesToNodeMapping(
|
||||
const onnxruntime::Graph& graph,
|
||||
const KernelRegistryManager& custom_registry_manager,
|
||||
SessionState& session_state,
|
||||
const ConstPointerContainer<std::vector<NodeArg*>>* implicit_inputs);
|
||||
|
||||
SessionStateInitializer::SessionStateInitializer(const std::basic_string<PATH_CHAR_TYPE>& graph_loc,
|
||||
onnxruntime::Graph& graph, SessionState& session_state,
|
||||
|
|
@ -59,9 +60,10 @@ SessionStateInitializer::SessionStateInitializer(const std::basic_string<PATH_CH
|
|||
kernel_registry_manager_{kernel_registry_manager},
|
||||
logger_{session_state.Logger()} {}
|
||||
|
||||
common::Status SessionStateInitializer::CreatePlan(const Node* parent_node,
|
||||
const std::vector<NodeArg*>& outer_scope_node_args,
|
||||
bool enable_sequential_execution) {
|
||||
common::Status SessionStateInitializer::CreatePlan(
|
||||
const Node* parent_node,
|
||||
const ConstPointerContainer<std::vector<NodeArg*>>* outer_scope_node_args,
|
||||
bool enable_sequential_execution) {
|
||||
auto graph_viewer = std::make_unique<onnxruntime::GraphViewer>(graph_);
|
||||
|
||||
// populate the SessionState MLValueNameIdxMap
|
||||
|
|
@ -70,13 +72,15 @@ common::Status SessionStateInitializer::CreatePlan(const Node* parent_node,
|
|||
|
||||
// ignore any outer scope args we don't know about. this can happen if a node contains multiple subgraphs.
|
||||
std::vector<const NodeArg*> valid_outer_scope_node_args;
|
||||
std::for_each(outer_scope_node_args.cbegin(), outer_scope_node_args.cend(),
|
||||
[&mlvalue_name_idx_map, &valid_outer_scope_node_args](const NodeArg* node_arg) {
|
||||
int idx;
|
||||
if (mlvalue_name_idx_map.GetIdx(node_arg->Name(), idx).IsOK()) {
|
||||
valid_outer_scope_node_args.push_back(node_arg);
|
||||
};
|
||||
});
|
||||
if (outer_scope_node_args) {
|
||||
std::for_each(outer_scope_node_args->cbegin(), outer_scope_node_args->cend(),
|
||||
[&mlvalue_name_idx_map, &valid_outer_scope_node_args](const NodeArg* node_arg) {
|
||||
int idx;
|
||||
if (mlvalue_name_idx_map.GetIdx(node_arg->Name(), idx).IsOK()) {
|
||||
valid_outer_scope_node_args.push_back(node_arg);
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
std::unique_ptr<SequentialExecutionPlan> exec_plan;
|
||||
|
||||
|
|
@ -103,7 +107,8 @@ common::Status SessionStateInitializer::CreatePlan(const Node* parent_node,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
common::Status SessionStateInitializer::InitializeAndSave(const std::vector<NodeArg*>* implicit_inputs) {
|
||||
common::Status SessionStateInitializer::InitializeAndSave(
|
||||
const ConstPointerContainer<std::vector<NodeArg*>>* implicit_inputs) {
|
||||
const auto* exec_plan_ptr = session_state_.GetExecutionPlan();
|
||||
ORT_ENFORCE(exec_plan_ptr, "Execution plan was not found in SessionState. CreatePlan must be called first.");
|
||||
|
||||
|
|
@ -188,7 +193,8 @@ common::Status SaveMLValueNameIndexMapping(const GraphViewer& graph_viewer,
|
|||
|
||||
static common::Status DeserializeTensorProto(const Env& env, const std::basic_string<PATH_CHAR_TYPE>& proto_path,
|
||||
const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer& m,
|
||||
const ExecutionProviders& exec_providers, MLValue& mlvalue, OrtCallback& deleter) {
|
||||
const ExecutionProviders& exec_providers, MLValue& mlvalue,
|
||||
OrtCallback& deleter) {
|
||||
const OrtAllocatorInfo& alloc_info = m.GetAllocInfo();
|
||||
if (strcmp(alloc_info.name, CPU) == 0 || alloc_info.mem_type == OrtMemTypeCPUOutput) {
|
||||
// deserialize directly to CPU tensor
|
||||
|
|
@ -212,8 +218,8 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st
|
|||
size_t cpu_tensor_length;
|
||||
ORT_RETURN_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &cpu_tensor_length));
|
||||
if (m.GetLen() < cpu_tensor_length) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Internal error. The preallocated buffer is too small. Requires ", cpu_tensor_length,
|
||||
", Got ", m.GetLen());
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Internal error. The preallocated buffer is too small. Requires ",
|
||||
cpu_tensor_length, ", Got ", m.GetLen());
|
||||
}
|
||||
OrtAllocatorInfo info(CPU, OrtDeviceAllocator, 0, OrtMemTypeDefault);
|
||||
std::unique_ptr<char[]> data(new char[cpu_tensor_length]);
|
||||
|
|
@ -411,19 +417,19 @@ common::Status SaveKernels(const ExecutionProviders& execution_providers,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T> // T is const NodeArg or NodeArg
|
||||
template <typename T> // T is container of const NodeArg* or NodeArg*
|
||||
static bool IsArgNameInInputsOutputs(const std::string& name,
|
||||
const std::vector<T*>& graph_args) {
|
||||
auto it = std::find_if(std::begin(graph_args), std::end(graph_args), [&name](const onnxruntime::NodeArg* arg) {
|
||||
const T& graph_args) {
|
||||
auto it = std::find_if(graph_args.cbegin(), graph_args.cend(), [&name](const onnxruntime::NodeArg* arg) {
|
||||
return arg->Name() == name;
|
||||
});
|
||||
return it != graph_args.end();
|
||||
return it != graph_args.cend();
|
||||
}
|
||||
|
||||
common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph,
|
||||
const KernelRegistryManager& custom_registry_manager,
|
||||
SessionState& session_state,
|
||||
const std::vector<NodeArg*>* implicit_inputs) {
|
||||
const ConstPointerContainer<std::vector<NodeArg*>>* implicit_inputs) {
|
||||
auto& graph_inputs = graph.GetInputsIncludingInitializers();
|
||||
auto& graph_outputs = graph.GetOutputs();
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#pragma once
|
||||
#include <map>
|
||||
|
||||
#include "core/common/const_pointer_container.h"
|
||||
#include "core/framework/allocator.h"
|
||||
#include "core/framework/tensor.h"
|
||||
#include "core/framework/path_lib.h"
|
||||
|
|
@ -35,12 +36,12 @@ class SessionStateInitializer {
|
|||
|
||||
// First perform any transformations and create the execution plan
|
||||
common::Status CreatePlan(const Node* parent_node,
|
||||
const std::vector<NodeArg*>& outer_scope_node_args,
|
||||
const ConstPointerContainer<std::vector<NodeArg*>>* outer_scope_node_args,
|
||||
bool enable_sequential_execution);
|
||||
|
||||
// initialize tensors, and save. save kernels and input/output node mappings
|
||||
// \param implicit_inputs could be NULL
|
||||
common::Status InitializeAndSave(const std::vector<NodeArg*>* implicit_inputs);
|
||||
common::Status InitializeAndSave(const ConstPointerContainer<std::vector<NodeArg*>>* implicit_inputs);
|
||||
|
||||
private:
|
||||
const std::basic_string<PATH_CHAR_TYPE>& graph_loc_;
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ static bool CanUpdateImplicitInputNameInSubgraph(Node& node,
|
|||
for (auto& subgraph_node : attr_subgraph_pair.second->Nodes()) {
|
||||
// recurse if this node also consumes removed_output_name as an implicit input (i.e. there are multiple levels of nested
|
||||
// subgraphs, and at least one level lower uses removed_output_name as an implicit input
|
||||
const auto& subgraph_node_implicit_inputs = subgraph_node.ImplicitInputDefs();
|
||||
const auto subgraph_node_implicit_inputs = subgraph_node.ImplicitInputDefs();
|
||||
if (!subgraph_node_implicit_inputs.empty()) {
|
||||
auto subgraph_node_also_consumes_nodearg_as_implicit_input =
|
||||
std::find_if(subgraph_node_implicit_inputs.cbegin(), subgraph_node_implicit_inputs.cend(),
|
||||
|
|
@ -99,7 +99,7 @@ static void UpdateImplicitInputNameInSubgraph(Node& node,
|
|||
// recurse if this node also consumes removed_output_name as an implicit input
|
||||
// (i.e. there are multiple levels of nested subgraphs, and at least one level lower uses
|
||||
// removed_output_name as an implicit input
|
||||
const auto& subgraph_node_implicit_inputs = subgraph_node.ImplicitInputDefs();
|
||||
const auto subgraph_node_implicit_inputs = subgraph_node.ImplicitInputDefs();
|
||||
if (!subgraph_node_implicit_inputs.empty()) {
|
||||
auto subgraph_node_also_consumes_nodearg_as_implicit_input =
|
||||
std::find_if(subgraph_node_implicit_inputs.cbegin(), subgraph_node_implicit_inputs.cend(),
|
||||
|
|
|
|||
|
|
@ -128,4 +128,4 @@ Status OptimizerExecutionFrame::CreateNodeOutputMLValueImpl(MLValue& mlvalue, in
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -380,10 +380,11 @@ common::Status InferenceSession::InitializeSubgraphSessions(Graph& graph, Sessio
|
|||
SessionStateInitializer initializer{model_location_, subgraph, *subgraph_session_state, execution_providers_,
|
||||
kernel_registry_manager_};
|
||||
|
||||
ORT_RETURN_IF_ERROR(initializer.CreatePlan(&node, node.ImplicitInputDefs(),
|
||||
const auto implicit_inputs = node.ImplicitInputDefs();
|
||||
ORT_RETURN_IF_ERROR(initializer.CreatePlan(&node, &implicit_inputs,
|
||||
session_options_.enable_sequential_execution));
|
||||
|
||||
ORT_RETURN_IF_ERROR(initializer.InitializeAndSave(&node.ImplicitInputDefs()));
|
||||
ORT_RETURN_IF_ERROR(initializer.InitializeAndSave(&implicit_inputs));
|
||||
|
||||
// LOGS(*session_logger_, VERBOSE) << std::make_pair(subgraph_info.session_state->GetExecutionPlan(),
|
||||
// &*subgraph_info.session_state);
|
||||
|
|
@ -451,7 +452,7 @@ common::Status InferenceSession::Initialize() {
|
|||
// now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs.
|
||||
ORT_RETURN_IF_ERROR(graph.Resolve());
|
||||
|
||||
ORT_RETURN_IF_ERROR(session_initializer.CreatePlan(nullptr, {}, session_options_.enable_sequential_execution));
|
||||
ORT_RETURN_IF_ERROR(session_initializer.CreatePlan(nullptr, nullptr, session_options_.enable_sequential_execution));
|
||||
ORT_RETURN_IF_ERROR(session_initializer.InitializeAndSave(nullptr));
|
||||
|
||||
// handle any subgraphs
|
||||
|
|
|
|||
Loading…
Reference in a new issue