Separate out constant node index information from ExecutionFrame (#410)

* Separate out the NodeArg index information from ExecutionFrame so it is only calculated once.

* Skip copy to/from device if only CPU execution provider is registered.
Cleanups.

* Address PR comments.
Clean up a few areas.

* Fix Linux build error
This commit is contained in:
Scott McKay 2019-02-01 10:55:49 +10:00 committed by GitHub
parent fb7be27096
commit efb72540be
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 238 additions and 233 deletions

View file

@ -55,7 +55,7 @@ class Node {
@param node The source node if this is an input edge to the current node,
or the destination node if this is an output edge from the current node.
@param src_arg_index The node arg index of source node of the edge.
@param dst_arg_index The node arg index of destination node of the edge.
@param dst_arg_index The node arg index of destination node of the edge.
*/
EdgeEnd(const Node& node, int src_arg_index, int dst_arg_index) noexcept;
@ -68,11 +68,11 @@ class Node {
const Node& GetNode() const noexcept;
/** Gets the source arg index.
@returns the source arg index of <*this> edge.*/
@returns the source arg index of <*this> edge.*/
int GetSrcArgIndex() const;
/** Gets the destination arg index.
@returns the destination arg index of <*this> edge.*/
@returns the destination arg index of <*this> edge.*/
int GetDstArgIndex() const;
private:
@ -283,8 +283,12 @@ class Node {
void ToProto(ONNX_NAMESPACE::NodeProto& proto) const;
/** Call the provided function for all explicit inputs, implicit inputs, and outputs of this Node.
If the NodeArg is an explicit or implicit input, is_input will be true when func is called. */
void ForEachDef(std::function<void(const onnxruntime::NodeArg&, bool is_input)> func) const;
If the NodeArg is an explicit or implicit input, is_input will be true when func is called.
@param include_missing_optional_defs Include NodeArgs that are optional and were not provided
i.e. NodeArg::Exists() == false.
*/
void ForEachDef(std::function<void(const onnxruntime::NodeArg&, bool is_input)> func,
bool include_missing_optional_defs = false) const;
/** Replaces any matching definitions in the Node's explicit inputs or explicit outputs.
@param replacements Map of current NodeArg to replacement NodeArg.

View file

@ -7,6 +7,7 @@
#include "core/framework/mem_pattern_planner.h"
#include "core/framework/ml_value_patterns_planner.h"
#include "core/framework/node_index_info.h"
#include "core/framework/op_kernel.h"
#include "core/framework/session_state.h"
#include "core/framework/utils.h"
@ -20,12 +21,11 @@ ExecutionFrame::ExecutionFrame(const std::unordered_map<std::string, MLValue>& f
const std::vector<MLValue>& fetches,
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
const SessionState& session_state)
: session_state_(session_state),
: node_index_info_(session_state.GetNodeIndexInfo()),
session_state_(session_state),
mem_patterns_(nullptr),
planner_(nullptr) {
auto* graph = session_state.GetGraphViewer();
ORT_ENFORCE(graph);
Init(*graph, feeds, output_names, fetches, fetch_allocators);
Init(feeds, output_names, fetches, fetch_allocators);
// If the session enable memory pattern optimization
// and we have execution plan generated, try to setup
@ -79,7 +79,7 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(int mlvalue_inde
const OrtAllocatorInfo& location,
const TensorShape& shape,
bool create_fence) {
if (mlvalue_index < 0)
if (mlvalue_index == NodeIndexInfo::kInvalidEntry)
return Status(ONNXRUNTIME, FAIL, "Trying to allocate memory for unused optional inputs/outputs");
auto p_mlvalue = &all_values_[mlvalue_index];
@ -166,15 +166,6 @@ void ExecutionFrame::TraceAllocate(int mlvalue_idx, size_t size) {
}
}
Status ExecutionFrame::AllocateTensorWithSelfOwnBuffer(const int index,
const DataTypeImpl* element_type,
const OrtAllocatorInfo& location,
const TensorShape& shape,
bool create_fence) {
ORT_ENFORCE(index >= 0 && static_cast<size_t>(index) < node_values_.size());
return AllocateMLValueTensorSelfOwnBufferHelper(node_values_[index], element_type, location, shape, create_fence);
}
Status ExecutionFrame::AllocateMLValueTensorPreAllocateBuffer(int mlvalue_index_to_allocate,
int mlvalue_index_reuse,
const DataTypeImpl* element_type,
@ -221,26 +212,6 @@ Status ExecutionFrame::AllocateTensorWithPreAllocateBufferHelper(MLValue* p_mlva
return Status::OK();
}
Status ExecutionFrame::AllocateTensorWithPreAllocateBuffer(const int offset,
void* pBuffer,
const DataTypeImpl* element_type,
const OrtAllocatorInfo& location,
const TensorShape& shape) {
ORT_ENFORCE(offset >= 0 && offset < node_values_.size());
if (node_values_[offset] < 0)
return Status(ONNXRUNTIME, FAIL, "Trying to allocate memory for unused optional inputs/outputs");
auto value = &all_values_[node_values_[offset]];
return AllocateTensorWithPreAllocateBufferHelper(value, pBuffer, element_type, location, shape);
}
void ExecutionFrame::Release(const int offset) {
ORT_ENFORCE(offset >= 0 && offset < node_offsets_.size());
if (node_values_[offset] >= 0 && node_values_[offset] < all_values_.size()) {
all_values_[node_values_[offset]] = MLValue();
TraceFree(node_values_[offset]);
}
}
Status AllocateTraditionalMLValue(MLValue* p_mlvalue,
const NonTensorTypeBase* type,
const MLValueAllocationParameters& parameters) {
@ -257,7 +228,7 @@ Status AllocateTraditionalMLValue(MLValue* p_mlvalue,
// This method is not thread safe!
Status ExecutionFrame::AllocateAsPerAllocationPlan(int mlvalue_index,
const MLValueAllocationParameters& parameters) {
if (mlvalue_index < 0 || mlvalue_index >= all_values_.size())
if (mlvalue_index == NodeIndexInfo::kInvalidEntry || mlvalue_index >= all_values_.size())
return Status(ONNXRUNTIME, INVALID_ARGUMENT,
"Tried to allocated with invalid mlvalue index: " + std::to_string(mlvalue_index));
@ -319,44 +290,23 @@ Status ExecutionFrame::AllocateAsPerAllocationPlan(int mlvalue_index,
return Status::OK();
}
void ExecutionFrame::Init(const onnxruntime::GraphViewer& graph,
const std::unordered_map<std::string, MLValue>& feeds,
void ExecutionFrame::Init(const std::unordered_map<std::string, MLValue>& feeds,
const std::vector<std::string>& output_names,
const std::vector<MLValue>& fetches,
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators) {
// 1. resize the node_offsets and all_value_ vector
// We need to use the max index rather than number of nodes as we use Node.Index()
// when inserting into node_offsets_
auto max_node_index = graph.MaxNodeIndex();
node_offsets_.resize(max_node_index);
auto& mlvalue_idx_map = session_state_.GetMLValueNameIdxMap();
// 1. resize the all_value_ vector
all_values_.resize(mlvalue_idx_map.MaxIdx() + 1);
// 2. handle the weights.
for (const auto& entry : session_state_.GetInitializedTensors()) {
auto mlvalue_index = entry.first;
all_values_[mlvalue_index] = entry.second; // this copy should be cheap
}
// 3. handle feed in values
for (const auto& feed : feeds) {
int mlvalue_idx;
Status status = mlvalue_idx_map.GetIdx(feed.first, mlvalue_idx);
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
// we are sharing the underline tensor/object for MLValue
all_values_[mlvalue_idx] = feed.second;
}
// 4. Handle non-empty output vector
// 2. Handle non-empty output vector
if (!fetches.empty()) {
// should've already verified this much before when Run() starts
ORT_ENFORCE(output_names.size() == fetches.size(),
"output_names vector size: " + std::to_string(output_names.size()) +
" does not match that of fetches vector: " + std::to_string(fetches.size()));
// setup output_indices_, we dont' want to generate mem plan on output tensors.
// setup output_indices_, we don't want to generate mem plan on output tensors.
output_indices_.reserve(output_names.size());
auto idx = 0;
for (const auto& oname : output_names) {
@ -375,45 +325,25 @@ void ExecutionFrame::Init(const onnxruntime::GraphViewer& graph,
}
}
// 5. set node args
std::size_t total_def_count{};
for (const auto& node : graph.Nodes()) {
node.ForEachDef([&](const onnxruntime::NodeArg& /*arg*/, bool /*is_input*/) {
++total_def_count;
});
// 3. handle the weights.
// We do this after the fetches to handle an edge case (possibly dubious) where a Constant is an output.
// The Constant gets lifted to an initializer so there's no Node producing the value as an output during Graph
// execution (i.e. Graph execution won't write the value to all_values_).
// A non-empty fetches vector will overwrite the actual weight in all_values_[mlvalue_idx] if we did this earlier.
// This makes the ONNX Constant test (onnx\backend\test\data\node\test_constant) happy as that
// involves a graph with a single Constant node.
for (const auto& entry : session_state_.GetInitializedTensors()) {
auto mlvalue_index = entry.first;
all_values_[mlvalue_index] = entry.second;
}
node_values_.reserve(total_def_count);
for (auto& node : graph.Nodes()) {
ORT_ENFORCE(node.Index() < node_offsets_.size());
node_offsets_[node.Index()] = static_cast<int>(node_values_.size());
for (auto input_def : node.InputDefs()) {
SetupNodeArg(input_def);
}
for (auto input_def : node.ImplicitInputDefs()) {
SetupNodeArg(input_def);
}
for (auto output_def : node.OutputDefs()) {
SetupNodeArg(output_def);
}
}
}
void ExecutionFrame::SetupNodeArg(const onnxruntime::NodeArg* arg) {
ORT_ENFORCE(arg);
auto& name = arg->Name();
//if the arg's name is empty, it is an not needed optional input/output
//set index to -1
if (name.empty()) {
node_values_.push_back(-1);
} else {
int index;
Status status = session_state_.GetMLValueNameIdxMap().GetIdx(name, index);
// 4. handle feed in values. these can override initializer values so must be last
for (const auto& feed : feeds) {
int mlvalue_idx;
Status status = mlvalue_idx_map.GetIdx(feed.first, mlvalue_idx);
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
node_values_.push_back(index);
// we are sharing the underline tensor/object for MLValue
all_values_[mlvalue_idx] = feed.second;
}
}
@ -452,10 +382,14 @@ Status ExecutionFrame::GeneratePatterns(MemoryPatternGroup* out) const {
return planner_->GeneratePatterns(out);
}
int ExecutionFrame::GetNodeOffset(onnxruntime::NodeIndex node_index) const {
return node_index_info_.GetNodeOffset(node_index);
}
// Return nullptr if index map to an value that is an unused optional input/output
const MLValue* ExecutionFrame::GetNodeInputOrOutputMLValue(int index) const {
ORT_ENFORCE(index >= 0 && static_cast<size_t>(index) < node_values_.size());
return node_values_[index] >= 0 ? &all_values_[node_values_[index]] : nullptr;
int mlvalue_idx = node_index_info_.GetMLValueIndex(index);
return mlvalue_idx != NodeIndexInfo::kInvalidEntry ? &all_values_[mlvalue_idx] : nullptr;
}
// Return nullptr if index map to an value that is an unused optional input/output
@ -483,18 +417,15 @@ static inline void VerifyShape(const MLValue* p_mlvalue,
Status ExecutionFrame::GetOrCreateNodeOutputMLValue(int index,
const MLValueAllocationParameters& parameters,
MLValue*& p_mlvalue) {
if (index < 0 || static_cast<size_t>(index) >= node_values_.size()) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"Try to access with invalid node value index: " + std::to_string(index));
}
int mlvalue_idx = node_index_info_.GetMLValueIndex(index);
// return nullptr if it is optional
if (node_values_[index] < 0) {
if (mlvalue_idx == NodeIndexInfo::kInvalidEntry) {
p_mlvalue = nullptr;
return Status::OK();
}
p_mlvalue = &all_values_.at(node_values_[index]);
p_mlvalue = &all_values_.at(mlvalue_idx);
if (p_mlvalue->IsAllocated()) {
// The ml has already been allocated.
@ -505,12 +436,12 @@ Status ExecutionFrame::GetOrCreateNodeOutputMLValue(int index,
// It's not allocated, then allocate it with given shape and return.
// Perform allocation based on the allocation plan
ORT_RETURN_IF_ERROR(AllocateAsPerAllocationPlan(node_values_[index], parameters));
ORT_RETURN_IF_ERROR(AllocateAsPerAllocationPlan(mlvalue_idx, parameters));
return Status::OK();
}
Status ExecutionFrame::ReleaseMLValue(int mlvalue_idx) {
if (mlvalue_idx < 0 || static_cast<size_t>(mlvalue_idx) >= all_values_.size()) {
if (mlvalue_idx == NodeIndexInfo::kInvalidEntry || static_cast<size_t>(mlvalue_idx) >= all_values_.size()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index ", mlvalue_idx);
}
all_values_[mlvalue_idx] = MLValue();
@ -521,7 +452,7 @@ Status ExecutionFrame::ReleaseMLValue(int mlvalue_idx) {
const SequentialExecutionPlan::AllocPlanPerValue& ExecutionFrame::GetAllocationPlan(int mlvalue_idx) {
const SequentialExecutionPlan* p_seq_exec_plan = session_state_.GetExecutionPlan();
const auto& alloc_plan = p_seq_exec_plan->allocation_plan;
ORT_ENFORCE(mlvalue_idx >= 0 && mlvalue_idx < alloc_plan.size());
ORT_ENFORCE(mlvalue_idx != NodeIndexInfo::kInvalidEntry && mlvalue_idx < alloc_plan.size());
return alloc_plan[mlvalue_idx];
}
} // namespace onnxruntime

View file

@ -19,20 +19,19 @@ namespace onnxruntime {
class SessionState;
class MLValuePatternPlanner;
struct MemoryPatternGroup;
class NodeIndexInfo;
struct MLValueAllocationParameters {
MLValueAllocationParameters() = default;
MLValueAllocationParameters(const TensorShape* shape)
: tensor_shape{ shape }
{}
: tensor_shape{shape} {}
const TensorShape& GetTensorShape() const
{
const TensorShape& GetTensorShape() const {
static const TensorShape s_empty_tensor_shape;
return tensor_shape != nullptr ? *tensor_shape : s_empty_tensor_shape;
}
private:
private:
const TensorShape* tensor_shape{};
// todo: is there any parameter needed for ml types?
};
@ -48,6 +47,9 @@ class ExecutionFrame {
~ExecutionFrame();
// TODO: These two AllocateMLValue... methods are in the API purely for unit test usage.
// Fix the unit tests so they set an execution plan that results in these methods being called by
// GetOrCreateNodeOutputMLValue instead
Status AllocateMLValueTensorSelfOwnBuffer(int mlvalue_index,
MLDataType element_type,
const OrtAllocatorInfo& location,
@ -60,30 +62,6 @@ class ExecutionFrame {
const OrtAllocatorInfo& location,
const TensorShape& shape,
bool create_fence = false);
// ?? Cheng: What about non-tensor values??
// ?? Cheng: There are cases we may not want to use ORT_ENFORCE??
// ?? Cheng: Graph must be immutable for GetNodesInTopologicalOrder??
// Create tensor at index mlvalue, and allocate buffer for it.
// This tensor will own this buffer.
// This method is not thread safe!
Status AllocateTensorWithSelfOwnBuffer(int index,
MLDataType element_type,
const OrtAllocatorInfo& location,
const TensorShape& shape,
bool create_fence = false);
// Create tensor at index mlvalue, with pre-allocate buffer
// This tensor does not own the buffer.
// The executor / planner need to be careful about the
// lifetime of the buffer. Tensor itself won't manage it.
// This method is not thread safe!
Status AllocateTensorWithPreAllocateBuffer(int offset,
void* pBuffer,
MLDataType element_type,
const OrtAllocatorInfo& location,
const TensorShape& shape);
const MLValue& GetMLValue(int mlvalue_index) const {
ORT_ENFORCE(mlvalue_index >= 0 && static_cast<size_t>(mlvalue_index) < all_values_.size());
return all_values_[mlvalue_index];
@ -94,11 +72,8 @@ class ExecutionFrame {
return all_values_[mlvalue_index];
}
// Index to the first argument of the given node.
int GetFirstArgIndex(onnxruntime::NodeIndex index) const {
ORT_ENFORCE(index < node_offsets_.size());
return node_offsets_[index];
}
// Get the index for the first entry of the given node.
int GetNodeOffset(onnxruntime::NodeIndex index) const;
// Return nullptr if index map to an value that is an unused optional input/output
const MLValue* GetNodeInputOrOutputMLValue(int index) const;
@ -128,8 +103,10 @@ class ExecutionFrame {
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ExecutionFrame);
// This method is not thread safe!
void Release(int offset);
void Init(const std::unordered_map<std::string, MLValue>& feeds,
const std::vector<std::string>& output_names,
const std::vector<MLValue>& fetches,
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators);
common::Status AllocateAsPerAllocationPlan(int mlvalue_index,
const MLValueAllocationParameters& parameters);
@ -140,14 +117,6 @@ class ExecutionFrame {
const TensorShape& shape,
bool create_fence);
void Init(const onnxruntime::GraphViewer& graph,
const std::unordered_map<std::string, MLValue>& feeds,
const std::vector<std::string>& output_names,
const std::vector<MLValue>& fetches,
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators);
void SetupNodeArg(const onnxruntime::NodeArg* arg);
Status AllocateTensorWithPreAllocateBufferHelper(MLValue* p_mlvalue,
void* pBuffer,
MLDataType element_type,
@ -162,22 +131,15 @@ class ExecutionFrame {
Status status_;
// The values for the inputs and outputs of the nodes.
// This vector contains the indices into the all_values_ vector.
std::vector<int> node_values_;
const NodeIndexInfo& node_index_info_;
// All the intermediate values for the entire graph.
// Input and Output values are passed in by executors
std::vector<MLValue> all_values_;
// The start index into node_values_ for all the nodes.
std::vector<int> node_offsets_;
// i-th kernel is still waiting for pending_counts_[i] inputs.
std::vector<int> pending_counts_; // not used currently
std::unordered_map<std::string, int> value_name_to_index_;
// map of index to custom allocator
std::unordered_map<int, IExecutor::CustomAllocator> custom_allocators_;

View file

@ -76,6 +76,8 @@ class ExecutionProviders {
bool Empty() const { return exec_providers_.empty(); }
size_t NumProviders() const { return exec_providers_.size(); }
using const_iterator = typename std::vector<std::unique_ptr<IExecutionProvider>>::const_iterator;
const_iterator begin() const noexcept { return exec_providers_.cbegin(); }
const_iterator end() const noexcept { return exec_providers_.cend(); }

View file

@ -0,0 +1,51 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/node_index_info.h"
#include "core/framework/mlvalue_name_idx_map.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/node_arg.h"
namespace onnxruntime {
NodeIndexInfo::NodeIndexInfo(const GraphViewer& graph_viewer, const MLValueNameIdxMap& mlvalue_idx_map)
: max_mlvalue_idx_{mlvalue_idx_map.MaxIdx()} {
std::size_t total_def_count{};
bool include_missing_optional_defs = true;
for (const auto& node : graph_viewer.Nodes()) {
node.ForEachDef(
[&](const onnxruntime::NodeArg& /*arg*/, bool /*is_input*/) {
++total_def_count;
},
include_missing_optional_defs);
}
// init all to kInvalidEntry
node_offsets_.resize(graph_viewer.MaxNodeIndex(), kInvalidEntry);
node_values_.resize(total_def_count, kInvalidEntry);
int cur_idx = 0;
for (auto& node : graph_viewer.Nodes()) {
node_offsets_[node.Index()] = cur_idx;
node.ForEachDef(
[&](const onnxruntime::NodeArg& node_arg, bool /*is_input*/) {
auto& name = node_arg.Name();
if (node_arg.Exists()) {
int index;
Status status = mlvalue_idx_map.GetIdx(name, index);
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
node_values_[cur_idx] = index;
}
// else it's a missing optional input or output so leave the -1
++cur_idx;
},
include_missing_optional_defs);
}
}
} // namespace onnxruntime

View file

@ -0,0 +1,52 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <vector>
#include "core/common/common.h"
#include "core/framework/ml_value.h"
namespace onnxruntime {
class GraphViewer;
class MLValueNameIdxMap;
class NodeIndexInfo final {
public:
NodeIndexInfo(const GraphViewer& graph_viewer, const MLValueNameIdxMap& mlvalue_idx_map);
enum { kInvalidEntry = -1 };
// Index to the first argument of the given Node.
// The Node will have (num inputs + num implicit inputs + num outputs) entries, in that order, starting at the
// offset that is returned. Use the offset in calls to GetMLValueIndex.
// Returns kInvalidEntry if the Node with the given node_index did not exist when the NodeIndexInfo was created.
int GetNodeOffset(onnxruntime::NodeIndex node_index) const {
ORT_ENFORCE(node_index < node_offsets_.size());
return node_offsets_[node_index];
}
// Get the mlvalue index value.
// Returns kInvalidEntry for optional inputs/outputs that do not exist in this graph.
int GetMLValueIndex(int offset) const {
ORT_ENFORCE(offset >= 0 && static_cast<size_t>(offset) < node_values_.size());
return node_values_[offset];
}
int GetMaxMLValueIdx() const { return max_mlvalue_idx_; }
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(NodeIndexInfo);
// This vector contains the indices from the MLValueNameIdxMap in the SessionState for each Node's input/outputs.
// Order is node inputs, implicit inputs, outputs.
std::vector<int> node_values_;
// The entry at node_offset_[Node::Index()] contains the index in node_values_ where the information for the Node
// begins.
std::vector<int> node_offsets_;
const int max_mlvalue_idx_;
};
} // namespace onnxruntime

View file

@ -18,7 +18,7 @@ OpKernelContext::OpKernelContext(ExecutionFrame* frame,
ORT_ENFORCE(frame != nullptr, "Execution frame was null");
ORT_ENFORCE(kernel != nullptr, "OpKernel was null");
node_input_start_index_ = frame->GetFirstArgIndex(kernel->Node().Index());
node_input_start_index_ = frame->GetNodeOffset(kernel->Node().Index());
node_implicit_input_start_index_ = node_input_start_index_ + InputCount();
node_output_start_index_ = node_implicit_input_start_index_ + ImplicitInputCount();
}
@ -29,7 +29,7 @@ Tensor* OpKernelContext::Output(int index, const TensorShape& shape) {
// In this case, it's assumed that the tensor hasn't been allocated yet,
// so that it's calling ExecutionFrame to create a tensor in the given position with given shape.
MLValueAllocationParameters parameters{ &shape };
MLValueAllocationParameters parameters{&shape};
//: Though we don't need to give 'ret' an initial value, GCC would generate a warning if we don't do that
//"error: 'ret' may be used uninitialized in this function"

View file

@ -6,6 +6,7 @@
#include <sstream>
#include "core/common/logging/logging.h"
#include "core/framework/node_index_info.h"
#include "core/framework/op_kernel.h"
#include "core/framework/utils.h"
@ -206,4 +207,21 @@ const SessionState* SessionState::GetSubgraphSessionState(onnxruntime::NodeIndex
return session_state;
}
void SessionState::CalculateNodeIndexInfo() {
ORT_ENFORCE(graph_viewer_);
node_index_info_ = std::make_unique<NodeIndexInfo>(*graph_viewer_, mlvalue_name_idx_map_);
for (auto& node_to_map_pair : subgraph_session_states_) {
for (auto& attr_name_to_subgraph : node_to_map_pair.second) {
// TEMPORARY const_cast pending changes from PR that moves ownership of the subgraph SessionState into here
const_cast<SessionState*>(attr_name_to_subgraph.second.get())->CalculateNodeIndexInfo();
}
}
}
const NodeIndexInfo& SessionState::GetNodeIndexInfo() const {
ORT_ENFORCE(node_index_info_, "CalculateNodeIndexInfo must be called prior to GetExecutionInfo.");
return *node_index_info_;
}
} // namespace onnxruntime

View file

@ -18,6 +18,7 @@
#include "core/framework/mem_pattern.h"
#include "core/framework/ml_value.h"
#include "core/framework/mlvalue_name_idx_map.h"
#include "core/framework/node_index_info.h"
#include "core/graph/graph_viewer.h"
#include "core/framework/fuse_nodes_funcs.h"
@ -30,6 +31,7 @@ namespace onnxruntime {
class ExecutionProviders;
class KernelDef;
class OpKernel;
class NodeIndexInfo;
struct SequentialExecutionPlan;
struct MemoryPatternGroup;
@ -165,6 +167,9 @@ class SessionState {
void SetExportDllFlag(bool flag) { export_fused_dll_ = flag; }
const FuncManager* GetFuncMgr() const { return &fused_funcs_mgr_; }
void CalculateNodeIndexInfo();
const NodeIndexInfo& GetNodeIndexInfo() const;
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SessionState);
@ -208,5 +213,7 @@ class SessionState {
bool export_fused_dll_ = false;
FuncManager fused_funcs_mgr_;
std::unique_ptr<NodeIndexInfo> node_index_info_;
};
} // namespace onnxruntime

View file

@ -133,8 +133,8 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state,
//no copy for TRT
if (required_provider_type == onnxruntime::kTRTExecutionProvider) {
new_mlvalue = orig_mlvalue;
return Status::OK();
new_mlvalue = orig_mlvalue;
return Status::OK();
}
auto input_provider_type = p_input_provider->Type();
@ -218,7 +218,6 @@ common::Status MatchOutputsWithProviders(const SessionState& session_state,
for (auto* arg : node.OutputDefs()) {
if (!arg->Exists() ||
arg->Name().empty() ||
!(found = Contains(output_names, arg->Name())).first) {
continue;
}
@ -258,44 +257,6 @@ common::Status MatchOutputsWithProviders(const SessionState& session_state,
}
}
// If we've already seen all the outputs requested just return.
if (seen_outputs.size() == output_names.size()) {
return Status::OK();
}
// Handle the case when a constant is an output but has been folded into a weight
// and hence it doesn't show up in any of the OutputDefs before.
// assume that the weight has already been placed in the appropriate device before
auto& defs = p_graph->GetOutputs();
auto& mlvalue_name_idx_map{session_state.GetMLValueNameIdxMap()};
auto& weights = session_state.GetInitializedTensors();
for (auto& one_def : defs) {
if (!one_def->Exists() ||
one_def->Name().empty() ||
seen_outputs.count(one_def->Name()) ||
!(found = Contains(output_names, one_def->Name())).first) {
continue;
}
auto& def_name = one_def->Name();
size_t idx = found.second;
int mlvalue_idx;
ORT_RETURN_IF_ERROR(mlvalue_name_idx_map.GetIdx(def_name, mlvalue_idx));
if (!weights.count(mlvalue_idx)) {
LOGS(session_state.Logger(), INFO) << "Output with name " << def_name << " is not a weight.";
continue;
}
seen_outputs.insert(def_name);
const auto& weight = weights.at(mlvalue_idx);
new_fetches[idx] = weight;
}
if (seen_outputs.size() != output_names.size()) // make sure we've seen all outputs
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "output size mismatch, expected ", output_names.size(),
" got ", seen_outputs.size());
return Status::OK();
}
@ -369,16 +330,6 @@ common::Status ExecuteGraph(const SessionState& session_state,
bool sequential_execution,
const bool& terminate_flag,
const logging::Logger& logger) {
// TODO: Would be better to check upfront whether there was a need to copy inputs/outputs across devices,
// especially when a subgraph is repeatedly executed in a Scan or Loop node. If we checked once and no copy was
// needed we can skip everything here apart from the Execute call.
NameMLValMap device_feeds;
ORT_RETURN_IF_ERROR(utils::CopyInputsAcrossDevices(session_state, feeds, device_feeds));
std::vector<MLValue> device_fetches;
ORT_RETURN_IF_ERROR(utils::MatchOutputsWithProviders(session_state, output_names, fetches, device_fetches));
std::unique_ptr<IExecutor> p_exec;
if (sequential_execution) {
@ -387,9 +338,28 @@ common::Status ExecuteGraph(const SessionState& session_state,
p_exec = std::unique_ptr<IExecutor>(new ParallelExecutor(session_state, terminate_flag));
}
ORT_RETURN_IF_ERROR(p_exec->Execute(session_state, device_feeds, output_names, device_fetches, fetch_allocators, logger));
ORT_RETURN_IF_ERROR(utils::CopyOutputsAcrossDevices(session_state, device_fetches, fetches));
// If we only have one provider it's the CPU provider as that is always automatically registered. If that's the
// case, assume no copy to/from other devices is required.
// TODO: Next step: If there is more than one provider we could add an in/out param to track whether any
// copy to/from devices was needed, and set that on the first execution. That way when a subgraph is repeatedly
// executed in a Scan or Loop node we can skip unnecessary checks for copies.
if (session_state.GetExecutionProviders().NumProviders() == 1) {
// no device copies are needed so simple execute
ORT_RETURN_IF_ERROR(p_exec->Execute(session_state, feeds, output_names, fetches, fetch_allocators, logger));
} else {
NameMLValMap device_feeds;
ORT_RETURN_IF_ERROR(utils::CopyInputsAcrossDevices(session_state, feeds, device_feeds));
std::vector<MLValue> device_fetches;
ORT_RETURN_IF_ERROR(utils::MatchOutputsWithProviders(session_state, output_names, fetches, device_fetches));
ORT_RETURN_IF_ERROR(p_exec->Execute(session_state, device_feeds, output_names, device_fetches, fetch_allocators,
logger));
ORT_RETURN_IF_ERROR(utils::CopyOutputsAcrossDevices(session_state, device_fetches, fetches));
}
return Status::OK();
}

View file

@ -550,19 +550,20 @@ const Graph* Node::GetGraphAttribute(const std::string& attr_name) const {
return const_cast<Node*>(this)->GetMutableGraphAttribute(attr_name);
}
void Node::ForEachDef(std::function<void(const onnxruntime::NodeArg&, bool is_input)> func) const {
void Node::ForEachDef(std::function<void(const onnxruntime::NodeArg&, bool is_input)> func,
bool include_missing_optional_defs) const {
for (const auto* arg : InputDefs()) {
if (arg->Exists())
if (include_missing_optional_defs || arg->Exists())
func(*arg, true);
}
for (const auto* arg : ImplicitInputDefs()) {
if (arg->Exists())
if (include_missing_optional_defs || arg->Exists())
func(*arg, true);
}
for (const auto* arg : OutputDefs()) {
if (arg->Exists())
if (include_missing_optional_defs || arg->Exists())
func(*arg, false);
}
};

View file

@ -373,6 +373,8 @@ class InferenceSession::Impl {
// handle any subgraphs
ORT_RETURN_IF_ERROR(InitializeSubgraphSessions(graph, session_state_));
session_state_.CalculateNodeIndexInfo();
is_inited_ = true;
LOGS(*session_logger_, INFO) << "Session successfully initialized.";

View file

@ -70,15 +70,17 @@ TEST(ExecutionFrameTest, TensorAllocationTest) {
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
state.SetExecutionPlan(std::move(p_seq_exec_plan));
state.CalculateNodeIndexInfo();
vector<MLValue> outputs;
ExecutionFrame frame(std::unordered_map<std::string, MLValue>{}, std::vector<std::string>{}, outputs, {}, state);
int start_index = frame.GetFirstArgIndex(node->Index());
int start_index = frame.GetNodeOffset(node->Index());
EXPECT_EQ(start_index, 0);
TensorShape shape(std::vector<int64_t>{2, 3});
status = frame.AllocateTensorWithSelfOwnBuffer(start_index, DataTypeImpl::GetType<float>(),
execution_providers.Get(xp_typ)->GetAllocator(0, OrtMemTypeDefault)->Info(), shape);
status = frame.AllocateMLValueTensorSelfOwnBuffer(start_index, DataTypeImpl::GetType<float>(),
execution_providers.Get(xp_typ)->GetAllocator(0, OrtMemTypeDefault)->Info(), shape);
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
MLValue* p_ml_value = frame.GetMutableNodeInputOrOutputMLValue(0);
@ -89,12 +91,11 @@ TEST(ExecutionFrameTest, TensorAllocationTest) {
//test share memory from tensor
TensorShape shape2(std::vector<int64_t>{3, 2});
status = frame.AllocateTensorWithPreAllocateBuffer(
start_index + 1,
p_tensor->template MutableData<float>(),
DataTypeImpl::GetType<float>(),
p_tensor->Location(),
shape2);
status = frame.AllocateMLValueTensorPreAllocateBuffer(start_index + 1,
start_index,
DataTypeImpl::GetType<float>(),
p_tensor->Location(),
shape2);
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
const MLValue* p_ml_value_const = frame.GetNodeInputOrOutputMLValue(1);
@ -144,6 +145,8 @@ TEST(ExecutionFrameTest, FeedInDataTest) {
mlvalue_name_idx_map.Add("X");
mlvalue_name_idx_map.Add("Y");
state.CalculateNodeIndexInfo();
vector<MLValue> outputs;
ExecutionFrame frame(std::unordered_map<std::string, MLValue>{{"X", value}},
std::vector<std::string>{}, outputs, {}, state);
@ -221,6 +224,8 @@ TEST(ExecutionFrameTest, MemPatternTest) {
state.SetExecutionPlan(std::move(p_seq_exec_plan));
state.CalculateNodeIndexInfo();
vector<MLValue> outputs;
ExecutionFrame frame(std::unordered_map<std::string, MLValue>{{"X1", v1}, {"X2", v2}, {"X3", v3}},
std::vector<std::string>{"T3"}, outputs, {}, state);
@ -250,7 +255,7 @@ TEST(ExecutionFrameTest, MemPatternTest) {
EXPECT_EQ(pattern.patterns.size(), pattern.locations.size());
EXPECT_EQ(pattern.patterns.size(), 1);
auto p = pattern.GetPatterns(cpu_allocator->Info());
EXPECT_EQ(p->PeakSize(), 2 * 64); // each allocation is 64-byte aligned
EXPECT_EQ(p->PeakSize(), 2 * 64); // each allocation is 64-byte aligned
EXPECT_EQ(p->GetBlock(3)->offset_, 0);
EXPECT_EQ(p->GetBlock(4)->offset_, 64);
}