From efb72540be4b2a7900178b1e0097779b0baa06db Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Fri, 1 Feb 2019 10:55:49 +1000 Subject: [PATCH] Separate out constant node index information from ExecutionFrame (#410) * Separate out the NodeArg index information from ExecutionFrame so it is only calculated once. * Skip copy to/from device if only CPU execution provider is registered. Cleanups. * Address PR comments. Clean up a few areas. * Fix Linux build error --- include/onnxruntime/core/graph/graph.h | 14 +- onnxruntime/core/framework/execution_frame.cc | 145 +++++------------- onnxruntime/core/framework/execution_frame.h | 66 ++------ .../core/framework/execution_providers.h | 2 + onnxruntime/core/framework/node_index_info.cc | 51 ++++++ onnxruntime/core/framework/node_index_info.h | 52 +++++++ onnxruntime/core/framework/op_kernel.cc | 4 +- onnxruntime/core/framework/session_state.cc | 18 +++ onnxruntime/core/framework/session_state.h | 7 + onnxruntime/core/framework/utils.cc | 76 +++------ onnxruntime/core/graph/graph.cc | 9 +- onnxruntime/core/session/inference_session.cc | 2 + .../test/framework/execution_frame_test.cc | 25 +-- 13 files changed, 238 insertions(+), 233 deletions(-) create mode 100644 onnxruntime/core/framework/node_index_info.cc create mode 100644 onnxruntime/core/framework/node_index_info.h diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 15ff88f73c..39dfbf3d34 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -55,7 +55,7 @@ class Node { @param node The source node if this is an input edge to the current node, or the destination node if this is an output edge from the current node. @param src_arg_index The node arg index of source node of the edge. - @param dst_arg_index The node arg index of destination node of the edge. + @param dst_arg_index The node arg index of destination node of the edge. */ EdgeEnd(const Node& node, int src_arg_index, int dst_arg_index) noexcept; @@ -68,11 +68,11 @@ class Node { const Node& GetNode() const noexcept; /** Gets the source arg index. - @returns the source arg index of <*this> edge.*/ + @returns the source arg index of <*this> edge.*/ int GetSrcArgIndex() const; /** Gets the destination arg index. - @returns the destination arg index of <*this> edge.*/ + @returns the destination arg index of <*this> edge.*/ int GetDstArgIndex() const; private: @@ -283,8 +283,12 @@ class Node { void ToProto(ONNX_NAMESPACE::NodeProto& proto) const; /** Call the provided function for all explicit inputs, implicit inputs, and outputs of this Node. - If the NodeArg is an explicit or implicit input, is_input will be true when func is called. */ - void ForEachDef(std::function func) const; + If the NodeArg is an explicit or implicit input, is_input will be true when func is called. + @param include_missing_optional_defs Include NodeArgs that are optional and were not provided + i.e. NodeArg::Exists() == false. + */ + void ForEachDef(std::function func, + bool include_missing_optional_defs = false) const; /** Replaces any matching definitions in the Node's explicit inputs or explicit outputs. @param replacements Map of current NodeArg to replacement NodeArg. diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index e38af435fa..f43c43770b 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -7,6 +7,7 @@ #include "core/framework/mem_pattern_planner.h" #include "core/framework/ml_value_patterns_planner.h" +#include "core/framework/node_index_info.h" #include "core/framework/op_kernel.h" #include "core/framework/session_state.h" #include "core/framework/utils.h" @@ -20,12 +21,11 @@ ExecutionFrame::ExecutionFrame(const std::unordered_map& f const std::vector& fetches, const std::unordered_map& fetch_allocators, const SessionState& session_state) - : session_state_(session_state), + : node_index_info_(session_state.GetNodeIndexInfo()), + session_state_(session_state), mem_patterns_(nullptr), planner_(nullptr) { - auto* graph = session_state.GetGraphViewer(); - ORT_ENFORCE(graph); - Init(*graph, feeds, output_names, fetches, fetch_allocators); + Init(feeds, output_names, fetches, fetch_allocators); // If the session enable memory pattern optimization // and we have execution plan generated, try to setup @@ -79,7 +79,7 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(int mlvalue_inde const OrtAllocatorInfo& location, const TensorShape& shape, bool create_fence) { - if (mlvalue_index < 0) + if (mlvalue_index == NodeIndexInfo::kInvalidEntry) return Status(ONNXRUNTIME, FAIL, "Trying to allocate memory for unused optional inputs/outputs"); auto p_mlvalue = &all_values_[mlvalue_index]; @@ -166,15 +166,6 @@ void ExecutionFrame::TraceAllocate(int mlvalue_idx, size_t size) { } } -Status ExecutionFrame::AllocateTensorWithSelfOwnBuffer(const int index, - const DataTypeImpl* element_type, - const OrtAllocatorInfo& location, - const TensorShape& shape, - bool create_fence) { - ORT_ENFORCE(index >= 0 && static_cast(index) < node_values_.size()); - return AllocateMLValueTensorSelfOwnBufferHelper(node_values_[index], element_type, location, shape, create_fence); -} - Status ExecutionFrame::AllocateMLValueTensorPreAllocateBuffer(int mlvalue_index_to_allocate, int mlvalue_index_reuse, const DataTypeImpl* element_type, @@ -221,26 +212,6 @@ Status ExecutionFrame::AllocateTensorWithPreAllocateBufferHelper(MLValue* p_mlva return Status::OK(); } -Status ExecutionFrame::AllocateTensorWithPreAllocateBuffer(const int offset, - void* pBuffer, - const DataTypeImpl* element_type, - const OrtAllocatorInfo& location, - const TensorShape& shape) { - ORT_ENFORCE(offset >= 0 && offset < node_values_.size()); - if (node_values_[offset] < 0) - return Status(ONNXRUNTIME, FAIL, "Trying to allocate memory for unused optional inputs/outputs"); - auto value = &all_values_[node_values_[offset]]; - return AllocateTensorWithPreAllocateBufferHelper(value, pBuffer, element_type, location, shape); -} - -void ExecutionFrame::Release(const int offset) { - ORT_ENFORCE(offset >= 0 && offset < node_offsets_.size()); - if (node_values_[offset] >= 0 && node_values_[offset] < all_values_.size()) { - all_values_[node_values_[offset]] = MLValue(); - TraceFree(node_values_[offset]); - } -} - Status AllocateTraditionalMLValue(MLValue* p_mlvalue, const NonTensorTypeBase* type, const MLValueAllocationParameters& parameters) { @@ -257,7 +228,7 @@ Status AllocateTraditionalMLValue(MLValue* p_mlvalue, // This method is not thread safe! Status ExecutionFrame::AllocateAsPerAllocationPlan(int mlvalue_index, const MLValueAllocationParameters& parameters) { - if (mlvalue_index < 0 || mlvalue_index >= all_values_.size()) + if (mlvalue_index == NodeIndexInfo::kInvalidEntry || mlvalue_index >= all_values_.size()) return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Tried to allocated with invalid mlvalue index: " + std::to_string(mlvalue_index)); @@ -319,44 +290,23 @@ Status ExecutionFrame::AllocateAsPerAllocationPlan(int mlvalue_index, return Status::OK(); } -void ExecutionFrame::Init(const onnxruntime::GraphViewer& graph, - const std::unordered_map& feeds, +void ExecutionFrame::Init(const std::unordered_map& feeds, const std::vector& output_names, const std::vector& fetches, const std::unordered_map& fetch_allocators) { - // 1. resize the node_offsets and all_value_ vector - // We need to use the max index rather than number of nodes as we use Node.Index() - // when inserting into node_offsets_ - auto max_node_index = graph.MaxNodeIndex(); - node_offsets_.resize(max_node_index); - auto& mlvalue_idx_map = session_state_.GetMLValueNameIdxMap(); + // 1. resize the all_value_ vector all_values_.resize(mlvalue_idx_map.MaxIdx() + 1); - // 2. handle the weights. - for (const auto& entry : session_state_.GetInitializedTensors()) { - auto mlvalue_index = entry.first; - all_values_[mlvalue_index] = entry.second; // this copy should be cheap - } - - // 3. handle feed in values - for (const auto& feed : feeds) { - int mlvalue_idx; - Status status = mlvalue_idx_map.GetIdx(feed.first, mlvalue_idx); - ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); - // we are sharing the underline tensor/object for MLValue - all_values_[mlvalue_idx] = feed.second; - } - - // 4. Handle non-empty output vector + // 2. Handle non-empty output vector if (!fetches.empty()) { // should've already verified this much before when Run() starts ORT_ENFORCE(output_names.size() == fetches.size(), "output_names vector size: " + std::to_string(output_names.size()) + " does not match that of fetches vector: " + std::to_string(fetches.size())); - // setup output_indices_, we dont' want to generate mem plan on output tensors. + // setup output_indices_, we don't want to generate mem plan on output tensors. output_indices_.reserve(output_names.size()); auto idx = 0; for (const auto& oname : output_names) { @@ -375,45 +325,25 @@ void ExecutionFrame::Init(const onnxruntime::GraphViewer& graph, } } - // 5. set node args - std::size_t total_def_count{}; - for (const auto& node : graph.Nodes()) { - node.ForEachDef([&](const onnxruntime::NodeArg& /*arg*/, bool /*is_input*/) { - ++total_def_count; - }); + // 3. handle the weights. + // We do this after the fetches to handle an edge case (possibly dubious) where a Constant is an output. + // The Constant gets lifted to an initializer so there's no Node producing the value as an output during Graph + // execution (i.e. Graph execution won't write the value to all_values_). + // A non-empty fetches vector will overwrite the actual weight in all_values_[mlvalue_idx] if we did this earlier. + // This makes the ONNX Constant test (onnx\backend\test\data\node\test_constant) happy as that + // involves a graph with a single Constant node. + for (const auto& entry : session_state_.GetInitializedTensors()) { + auto mlvalue_index = entry.first; + all_values_[mlvalue_index] = entry.second; } - node_values_.reserve(total_def_count); - for (auto& node : graph.Nodes()) { - ORT_ENFORCE(node.Index() < node_offsets_.size()); - node_offsets_[node.Index()] = static_cast(node_values_.size()); - - for (auto input_def : node.InputDefs()) { - SetupNodeArg(input_def); - } - - for (auto input_def : node.ImplicitInputDefs()) { - SetupNodeArg(input_def); - } - - for (auto output_def : node.OutputDefs()) { - SetupNodeArg(output_def); - } - } -} - -void ExecutionFrame::SetupNodeArg(const onnxruntime::NodeArg* arg) { - ORT_ENFORCE(arg); - auto& name = arg->Name(); - //if the arg's name is empty, it is an not needed optional input/output - //set index to -1 - if (name.empty()) { - node_values_.push_back(-1); - } else { - int index; - Status status = session_state_.GetMLValueNameIdxMap().GetIdx(name, index); + // 4. handle feed in values. these can override initializer values so must be last + for (const auto& feed : feeds) { + int mlvalue_idx; + Status status = mlvalue_idx_map.GetIdx(feed.first, mlvalue_idx); ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); - node_values_.push_back(index); + // we are sharing the underline tensor/object for MLValue + all_values_[mlvalue_idx] = feed.second; } } @@ -452,10 +382,14 @@ Status ExecutionFrame::GeneratePatterns(MemoryPatternGroup* out) const { return planner_->GeneratePatterns(out); } +int ExecutionFrame::GetNodeOffset(onnxruntime::NodeIndex node_index) const { + return node_index_info_.GetNodeOffset(node_index); +} + // Return nullptr if index map to an value that is an unused optional input/output const MLValue* ExecutionFrame::GetNodeInputOrOutputMLValue(int index) const { - ORT_ENFORCE(index >= 0 && static_cast(index) < node_values_.size()); - return node_values_[index] >= 0 ? &all_values_[node_values_[index]] : nullptr; + int mlvalue_idx = node_index_info_.GetMLValueIndex(index); + return mlvalue_idx != NodeIndexInfo::kInvalidEntry ? &all_values_[mlvalue_idx] : nullptr; } // Return nullptr if index map to an value that is an unused optional input/output @@ -483,18 +417,15 @@ static inline void VerifyShape(const MLValue* p_mlvalue, Status ExecutionFrame::GetOrCreateNodeOutputMLValue(int index, const MLValueAllocationParameters& parameters, MLValue*& p_mlvalue) { - if (index < 0 || static_cast(index) >= node_values_.size()) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "Try to access with invalid node value index: " + std::to_string(index)); - } + int mlvalue_idx = node_index_info_.GetMLValueIndex(index); // return nullptr if it is optional - if (node_values_[index] < 0) { + if (mlvalue_idx == NodeIndexInfo::kInvalidEntry) { p_mlvalue = nullptr; return Status::OK(); } - p_mlvalue = &all_values_.at(node_values_[index]); + p_mlvalue = &all_values_.at(mlvalue_idx); if (p_mlvalue->IsAllocated()) { // The ml has already been allocated. @@ -505,12 +436,12 @@ Status ExecutionFrame::GetOrCreateNodeOutputMLValue(int index, // It's not allocated, then allocate it with given shape and return. // Perform allocation based on the allocation plan - ORT_RETURN_IF_ERROR(AllocateAsPerAllocationPlan(node_values_[index], parameters)); + ORT_RETURN_IF_ERROR(AllocateAsPerAllocationPlan(mlvalue_idx, parameters)); return Status::OK(); } Status ExecutionFrame::ReleaseMLValue(int mlvalue_idx) { - if (mlvalue_idx < 0 || static_cast(mlvalue_idx) >= all_values_.size()) { + if (mlvalue_idx == NodeIndexInfo::kInvalidEntry || static_cast(mlvalue_idx) >= all_values_.size()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index ", mlvalue_idx); } all_values_[mlvalue_idx] = MLValue(); @@ -521,7 +452,7 @@ Status ExecutionFrame::ReleaseMLValue(int mlvalue_idx) { const SequentialExecutionPlan::AllocPlanPerValue& ExecutionFrame::GetAllocationPlan(int mlvalue_idx) { const SequentialExecutionPlan* p_seq_exec_plan = session_state_.GetExecutionPlan(); const auto& alloc_plan = p_seq_exec_plan->allocation_plan; - ORT_ENFORCE(mlvalue_idx >= 0 && mlvalue_idx < alloc_plan.size()); + ORT_ENFORCE(mlvalue_idx != NodeIndexInfo::kInvalidEntry && mlvalue_idx < alloc_plan.size()); return alloc_plan[mlvalue_idx]; } } // namespace onnxruntime diff --git a/onnxruntime/core/framework/execution_frame.h b/onnxruntime/core/framework/execution_frame.h index 65dc7b4599..ed1f1b347d 100644 --- a/onnxruntime/core/framework/execution_frame.h +++ b/onnxruntime/core/framework/execution_frame.h @@ -19,20 +19,19 @@ namespace onnxruntime { class SessionState; class MLValuePatternPlanner; struct MemoryPatternGroup; +class NodeIndexInfo; struct MLValueAllocationParameters { MLValueAllocationParameters() = default; MLValueAllocationParameters(const TensorShape* shape) - : tensor_shape{ shape } - {} + : tensor_shape{shape} {} - const TensorShape& GetTensorShape() const - { + const TensorShape& GetTensorShape() const { static const TensorShape s_empty_tensor_shape; return tensor_shape != nullptr ? *tensor_shape : s_empty_tensor_shape; } -private: + private: const TensorShape* tensor_shape{}; // todo: is there any parameter needed for ml types? }; @@ -48,6 +47,9 @@ class ExecutionFrame { ~ExecutionFrame(); + // TODO: These two AllocateMLValue... methods are in the API purely for unit test usage. + // Fix the unit tests so they set an execution plan that results in these methods being called by + // GetOrCreateNodeOutputMLValue instead Status AllocateMLValueTensorSelfOwnBuffer(int mlvalue_index, MLDataType element_type, const OrtAllocatorInfo& location, @@ -60,30 +62,6 @@ class ExecutionFrame { const OrtAllocatorInfo& location, const TensorShape& shape, bool create_fence = false); - - // ?? Cheng: What about non-tensor values?? - // ?? Cheng: There are cases we may not want to use ORT_ENFORCE?? - // ?? Cheng: Graph must be immutable for GetNodesInTopologicalOrder?? - // Create tensor at index mlvalue, and allocate buffer for it. - // This tensor will own this buffer. - // This method is not thread safe! - Status AllocateTensorWithSelfOwnBuffer(int index, - MLDataType element_type, - const OrtAllocatorInfo& location, - const TensorShape& shape, - bool create_fence = false); - - // Create tensor at index mlvalue, with pre-allocate buffer - // This tensor does not own the buffer. - // The executor / planner need to be careful about the - // lifetime of the buffer. Tensor itself won't manage it. - // This method is not thread safe! - Status AllocateTensorWithPreAllocateBuffer(int offset, - void* pBuffer, - MLDataType element_type, - const OrtAllocatorInfo& location, - const TensorShape& shape); - const MLValue& GetMLValue(int mlvalue_index) const { ORT_ENFORCE(mlvalue_index >= 0 && static_cast(mlvalue_index) < all_values_.size()); return all_values_[mlvalue_index]; @@ -94,11 +72,8 @@ class ExecutionFrame { return all_values_[mlvalue_index]; } - // Index to the first argument of the given node. - int GetFirstArgIndex(onnxruntime::NodeIndex index) const { - ORT_ENFORCE(index < node_offsets_.size()); - return node_offsets_[index]; - } + // Get the index for the first entry of the given node. + int GetNodeOffset(onnxruntime::NodeIndex index) const; // Return nullptr if index map to an value that is an unused optional input/output const MLValue* GetNodeInputOrOutputMLValue(int index) const; @@ -128,8 +103,10 @@ class ExecutionFrame { private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ExecutionFrame); - // This method is not thread safe! - void Release(int offset); + void Init(const std::unordered_map& feeds, + const std::vector& output_names, + const std::vector& fetches, + const std::unordered_map& fetch_allocators); common::Status AllocateAsPerAllocationPlan(int mlvalue_index, const MLValueAllocationParameters& parameters); @@ -140,14 +117,6 @@ class ExecutionFrame { const TensorShape& shape, bool create_fence); - void Init(const onnxruntime::GraphViewer& graph, - const std::unordered_map& feeds, - const std::vector& output_names, - const std::vector& fetches, - const std::unordered_map& fetch_allocators); - - void SetupNodeArg(const onnxruntime::NodeArg* arg); - Status AllocateTensorWithPreAllocateBufferHelper(MLValue* p_mlvalue, void* pBuffer, MLDataType element_type, @@ -162,22 +131,15 @@ class ExecutionFrame { Status status_; - // The values for the inputs and outputs of the nodes. - // This vector contains the indices into the all_values_ vector. - std::vector node_values_; + const NodeIndexInfo& node_index_info_; // All the intermediate values for the entire graph. // Input and Output values are passed in by executors std::vector all_values_; - // The start index into node_values_ for all the nodes. - std::vector node_offsets_; - // i-th kernel is still waiting for pending_counts_[i] inputs. std::vector pending_counts_; // not used currently - std::unordered_map value_name_to_index_; - // map of index to custom allocator std::unordered_map custom_allocators_; diff --git a/onnxruntime/core/framework/execution_providers.h b/onnxruntime/core/framework/execution_providers.h index 4156d23458..fdb51f0e18 100644 --- a/onnxruntime/core/framework/execution_providers.h +++ b/onnxruntime/core/framework/execution_providers.h @@ -76,6 +76,8 @@ class ExecutionProviders { bool Empty() const { return exec_providers_.empty(); } + size_t NumProviders() const { return exec_providers_.size(); } + using const_iterator = typename std::vector>::const_iterator; const_iterator begin() const noexcept { return exec_providers_.cbegin(); } const_iterator end() const noexcept { return exec_providers_.cend(); } diff --git a/onnxruntime/core/framework/node_index_info.cc b/onnxruntime/core/framework/node_index_info.cc new file mode 100644 index 0000000000..04b844018d --- /dev/null +++ b/onnxruntime/core/framework/node_index_info.cc @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/node_index_info.h" + +#include "core/framework/mlvalue_name_idx_map.h" +#include "core/graph/graph_viewer.h" +#include "core/graph/node_arg.h" + +namespace onnxruntime { + +NodeIndexInfo::NodeIndexInfo(const GraphViewer& graph_viewer, const MLValueNameIdxMap& mlvalue_idx_map) + : max_mlvalue_idx_{mlvalue_idx_map.MaxIdx()} { + std::size_t total_def_count{}; + + bool include_missing_optional_defs = true; + + for (const auto& node : graph_viewer.Nodes()) { + node.ForEachDef( + [&](const onnxruntime::NodeArg& /*arg*/, bool /*is_input*/) { + ++total_def_count; + }, + include_missing_optional_defs); + } + + // init all to kInvalidEntry + node_offsets_.resize(graph_viewer.MaxNodeIndex(), kInvalidEntry); + node_values_.resize(total_def_count, kInvalidEntry); + int cur_idx = 0; + + for (auto& node : graph_viewer.Nodes()) { + node_offsets_[node.Index()] = cur_idx; + + node.ForEachDef( + [&](const onnxruntime::NodeArg& node_arg, bool /*is_input*/) { + auto& name = node_arg.Name(); + if (node_arg.Exists()) { + int index; + Status status = mlvalue_idx_map.GetIdx(name, index); + ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); + node_values_[cur_idx] = index; + } + // else it's a missing optional input or output so leave the -1 + + ++cur_idx; + }, + include_missing_optional_defs); + } +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/node_index_info.h b/onnxruntime/core/framework/node_index_info.h new file mode 100644 index 0000000000..efd2b5e8d7 --- /dev/null +++ b/onnxruntime/core/framework/node_index_info.h @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/common/common.h" +#include "core/framework/ml_value.h" + +namespace onnxruntime { +class GraphViewer; +class MLValueNameIdxMap; + +class NodeIndexInfo final { + public: + NodeIndexInfo(const GraphViewer& graph_viewer, const MLValueNameIdxMap& mlvalue_idx_map); + + enum { kInvalidEntry = -1 }; + + // Index to the first argument of the given Node. + // The Node will have (num inputs + num implicit inputs + num outputs) entries, in that order, starting at the + // offset that is returned. Use the offset in calls to GetMLValueIndex. + // Returns kInvalidEntry if the Node with the given node_index did not exist when the NodeIndexInfo was created. + int GetNodeOffset(onnxruntime::NodeIndex node_index) const { + ORT_ENFORCE(node_index < node_offsets_.size()); + return node_offsets_[node_index]; + } + + // Get the mlvalue index value. + // Returns kInvalidEntry for optional inputs/outputs that do not exist in this graph. + int GetMLValueIndex(int offset) const { + ORT_ENFORCE(offset >= 0 && static_cast(offset) < node_values_.size()); + return node_values_[offset]; + } + + int GetMaxMLValueIdx() const { return max_mlvalue_idx_; } + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(NodeIndexInfo); + + // This vector contains the indices from the MLValueNameIdxMap in the SessionState for each Node's input/outputs. + // Order is node inputs, implicit inputs, outputs. + std::vector node_values_; + + // The entry at node_offset_[Node::Index()] contains the index in node_values_ where the information for the Node + // begins. + std::vector node_offsets_; + + const int max_mlvalue_idx_; +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/op_kernel.cc b/onnxruntime/core/framework/op_kernel.cc index bd705dfb42..c614cbb34b 100644 --- a/onnxruntime/core/framework/op_kernel.cc +++ b/onnxruntime/core/framework/op_kernel.cc @@ -18,7 +18,7 @@ OpKernelContext::OpKernelContext(ExecutionFrame* frame, ORT_ENFORCE(frame != nullptr, "Execution frame was null"); ORT_ENFORCE(kernel != nullptr, "OpKernel was null"); - node_input_start_index_ = frame->GetFirstArgIndex(kernel->Node().Index()); + node_input_start_index_ = frame->GetNodeOffset(kernel->Node().Index()); node_implicit_input_start_index_ = node_input_start_index_ + InputCount(); node_output_start_index_ = node_implicit_input_start_index_ + ImplicitInputCount(); } @@ -29,7 +29,7 @@ Tensor* OpKernelContext::Output(int index, const TensorShape& shape) { // In this case, it's assumed that the tensor hasn't been allocated yet, // so that it's calling ExecutionFrame to create a tensor in the given position with given shape. - MLValueAllocationParameters parameters{ &shape }; + MLValueAllocationParameters parameters{&shape}; //: Though we don't need to give 'ret' an initial value, GCC would generate a warning if we don't do that //"error: 'ret' may be used uninitialized in this function" diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index c15bc3571f..0ba6f8b452 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -6,6 +6,7 @@ #include #include "core/common/logging/logging.h" +#include "core/framework/node_index_info.h" #include "core/framework/op_kernel.h" #include "core/framework/utils.h" @@ -206,4 +207,21 @@ const SessionState* SessionState::GetSubgraphSessionState(onnxruntime::NodeIndex return session_state; } +void SessionState::CalculateNodeIndexInfo() { + ORT_ENFORCE(graph_viewer_); + node_index_info_ = std::make_unique(*graph_viewer_, mlvalue_name_idx_map_); + + for (auto& node_to_map_pair : subgraph_session_states_) { + for (auto& attr_name_to_subgraph : node_to_map_pair.second) { + // TEMPORARY const_cast pending changes from PR that moves ownership of the subgraph SessionState into here + const_cast(attr_name_to_subgraph.second.get())->CalculateNodeIndexInfo(); + } + } +} + +const NodeIndexInfo& SessionState::GetNodeIndexInfo() const { + ORT_ENFORCE(node_index_info_, "CalculateNodeIndexInfo must be called prior to GetExecutionInfo."); + return *node_index_info_; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index 55e47841c0..3852296c01 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -18,6 +18,7 @@ #include "core/framework/mem_pattern.h" #include "core/framework/ml_value.h" #include "core/framework/mlvalue_name_idx_map.h" +#include "core/framework/node_index_info.h" #include "core/graph/graph_viewer.h" #include "core/framework/fuse_nodes_funcs.h" @@ -30,6 +31,7 @@ namespace onnxruntime { class ExecutionProviders; class KernelDef; class OpKernel; +class NodeIndexInfo; struct SequentialExecutionPlan; struct MemoryPatternGroup; @@ -165,6 +167,9 @@ class SessionState { void SetExportDllFlag(bool flag) { export_fused_dll_ = flag; } const FuncManager* GetFuncMgr() const { return &fused_funcs_mgr_; } + void CalculateNodeIndexInfo(); + const NodeIndexInfo& GetNodeIndexInfo() const; + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SessionState); @@ -208,5 +213,7 @@ class SessionState { bool export_fused_dll_ = false; FuncManager fused_funcs_mgr_; + + std::unique_ptr node_index_info_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index bd931bd9cc..be1b289725 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -133,8 +133,8 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, //no copy for TRT if (required_provider_type == onnxruntime::kTRTExecutionProvider) { - new_mlvalue = orig_mlvalue; - return Status::OK(); + new_mlvalue = orig_mlvalue; + return Status::OK(); } auto input_provider_type = p_input_provider->Type(); @@ -218,7 +218,6 @@ common::Status MatchOutputsWithProviders(const SessionState& session_state, for (auto* arg : node.OutputDefs()) { if (!arg->Exists() || - arg->Name().empty() || !(found = Contains(output_names, arg->Name())).first) { continue; } @@ -258,44 +257,6 @@ common::Status MatchOutputsWithProviders(const SessionState& session_state, } } - // If we've already seen all the outputs requested just return. - if (seen_outputs.size() == output_names.size()) { - return Status::OK(); - } - - // Handle the case when a constant is an output but has been folded into a weight - // and hence it doesn't show up in any of the OutputDefs before. - // assume that the weight has already been placed in the appropriate device before - auto& defs = p_graph->GetOutputs(); - auto& mlvalue_name_idx_map{session_state.GetMLValueNameIdxMap()}; - auto& weights = session_state.GetInitializedTensors(); - - for (auto& one_def : defs) { - if (!one_def->Exists() || - one_def->Name().empty() || - seen_outputs.count(one_def->Name()) || - !(found = Contains(output_names, one_def->Name())).first) { - continue; - } - - auto& def_name = one_def->Name(); - size_t idx = found.second; - int mlvalue_idx; - ORT_RETURN_IF_ERROR(mlvalue_name_idx_map.GetIdx(def_name, mlvalue_idx)); - if (!weights.count(mlvalue_idx)) { - LOGS(session_state.Logger(), INFO) << "Output with name " << def_name << " is not a weight."; - continue; - } - - seen_outputs.insert(def_name); - const auto& weight = weights.at(mlvalue_idx); - new_fetches[idx] = weight; - } - - if (seen_outputs.size() != output_names.size()) // make sure we've seen all outputs - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "output size mismatch, expected ", output_names.size(), - " got ", seen_outputs.size()); - return Status::OK(); } @@ -369,16 +330,6 @@ common::Status ExecuteGraph(const SessionState& session_state, bool sequential_execution, const bool& terminate_flag, const logging::Logger& logger) { - // TODO: Would be better to check upfront whether there was a need to copy inputs/outputs across devices, - // especially when a subgraph is repeatedly executed in a Scan or Loop node. If we checked once and no copy was - // needed we can skip everything here apart from the Execute call. - - NameMLValMap device_feeds; - ORT_RETURN_IF_ERROR(utils::CopyInputsAcrossDevices(session_state, feeds, device_feeds)); - - std::vector device_fetches; - ORT_RETURN_IF_ERROR(utils::MatchOutputsWithProviders(session_state, output_names, fetches, device_fetches)); - std::unique_ptr p_exec; if (sequential_execution) { @@ -387,9 +338,28 @@ common::Status ExecuteGraph(const SessionState& session_state, p_exec = std::unique_ptr(new ParallelExecutor(session_state, terminate_flag)); } - ORT_RETURN_IF_ERROR(p_exec->Execute(session_state, device_feeds, output_names, device_fetches, fetch_allocators, logger)); - ORT_RETURN_IF_ERROR(utils::CopyOutputsAcrossDevices(session_state, device_fetches, fetches)); + // If we only have one provider it's the CPU provider as that is always automatically registered. If that's the + // case, assume no copy to/from other devices is required. + // TODO: Next step: If there is more than one provider we could add an in/out param to track whether any + // copy to/from devices was needed, and set that on the first execution. That way when a subgraph is repeatedly + // executed in a Scan or Loop node we can skip unnecessary checks for copies. + + if (session_state.GetExecutionProviders().NumProviders() == 1) { + // no device copies are needed so simple execute + ORT_RETURN_IF_ERROR(p_exec->Execute(session_state, feeds, output_names, fetches, fetch_allocators, logger)); + } else { + NameMLValMap device_feeds; + ORT_RETURN_IF_ERROR(utils::CopyInputsAcrossDevices(session_state, feeds, device_feeds)); + + std::vector device_fetches; + ORT_RETURN_IF_ERROR(utils::MatchOutputsWithProviders(session_state, output_names, fetches, device_fetches)); + + ORT_RETURN_IF_ERROR(p_exec->Execute(session_state, device_feeds, output_names, device_fetches, fetch_allocators, + logger)); + + ORT_RETURN_IF_ERROR(utils::CopyOutputsAcrossDevices(session_state, device_fetches, fetches)); + } return Status::OK(); } diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 9c23dadd52..dcf394813e 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -550,19 +550,20 @@ const Graph* Node::GetGraphAttribute(const std::string& attr_name) const { return const_cast(this)->GetMutableGraphAttribute(attr_name); } -void Node::ForEachDef(std::function func) const { +void Node::ForEachDef(std::function func, + bool include_missing_optional_defs) const { for (const auto* arg : InputDefs()) { - if (arg->Exists()) + if (include_missing_optional_defs || arg->Exists()) func(*arg, true); } for (const auto* arg : ImplicitInputDefs()) { - if (arg->Exists()) + if (include_missing_optional_defs || arg->Exists()) func(*arg, true); } for (const auto* arg : OutputDefs()) { - if (arg->Exists()) + if (include_missing_optional_defs || arg->Exists()) func(*arg, false); } }; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index adb892409b..3829dfa49b 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -373,6 +373,8 @@ class InferenceSession::Impl { // handle any subgraphs ORT_RETURN_IF_ERROR(InitializeSubgraphSessions(graph, session_state_)); + session_state_.CalculateNodeIndexInfo(); + is_inited_ = true; LOGS(*session_logger_, INFO) << "Session successfully initialized."; diff --git a/onnxruntime/test/framework/execution_frame_test.cc b/onnxruntime/test/framework/execution_frame_test.cc index 0b0531876d..12cf00b2f3 100644 --- a/onnxruntime/test/framework/execution_frame_test.cc +++ b/onnxruntime/test/framework/execution_frame_test.cc @@ -70,15 +70,17 @@ TEST(ExecutionFrameTest, TensorAllocationTest) { EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); state.SetExecutionPlan(std::move(p_seq_exec_plan)); + state.CalculateNodeIndexInfo(); + vector outputs; ExecutionFrame frame(std::unordered_map{}, std::vector{}, outputs, {}, state); - int start_index = frame.GetFirstArgIndex(node->Index()); + int start_index = frame.GetNodeOffset(node->Index()); EXPECT_EQ(start_index, 0); TensorShape shape(std::vector{2, 3}); - status = frame.AllocateTensorWithSelfOwnBuffer(start_index, DataTypeImpl::GetType(), - execution_providers.Get(xp_typ)->GetAllocator(0, OrtMemTypeDefault)->Info(), shape); + status = frame.AllocateMLValueTensorSelfOwnBuffer(start_index, DataTypeImpl::GetType(), + execution_providers.Get(xp_typ)->GetAllocator(0, OrtMemTypeDefault)->Info(), shape); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); MLValue* p_ml_value = frame.GetMutableNodeInputOrOutputMLValue(0); @@ -89,12 +91,11 @@ TEST(ExecutionFrameTest, TensorAllocationTest) { //test share memory from tensor TensorShape shape2(std::vector{3, 2}); - status = frame.AllocateTensorWithPreAllocateBuffer( - start_index + 1, - p_tensor->template MutableData(), - DataTypeImpl::GetType(), - p_tensor->Location(), - shape2); + status = frame.AllocateMLValueTensorPreAllocateBuffer(start_index + 1, + start_index, + DataTypeImpl::GetType(), + p_tensor->Location(), + shape2); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); const MLValue* p_ml_value_const = frame.GetNodeInputOrOutputMLValue(1); @@ -144,6 +145,8 @@ TEST(ExecutionFrameTest, FeedInDataTest) { mlvalue_name_idx_map.Add("X"); mlvalue_name_idx_map.Add("Y"); + state.CalculateNodeIndexInfo(); + vector outputs; ExecutionFrame frame(std::unordered_map{{"X", value}}, std::vector{}, outputs, {}, state); @@ -221,6 +224,8 @@ TEST(ExecutionFrameTest, MemPatternTest) { state.SetExecutionPlan(std::move(p_seq_exec_plan)); + state.CalculateNodeIndexInfo(); + vector outputs; ExecutionFrame frame(std::unordered_map{{"X1", v1}, {"X2", v2}, {"X3", v3}}, std::vector{"T3"}, outputs, {}, state); @@ -250,7 +255,7 @@ TEST(ExecutionFrameTest, MemPatternTest) { EXPECT_EQ(pattern.patterns.size(), pattern.locations.size()); EXPECT_EQ(pattern.patterns.size(), 1); auto p = pattern.GetPatterns(cpu_allocator->Info()); - EXPECT_EQ(p->PeakSize(), 2 * 64); // each allocation is 64-byte aligned + EXPECT_EQ(p->PeakSize(), 2 * 64); // each allocation is 64-byte aligned EXPECT_EQ(p->GetBlock(3)->offset_, 0); EXPECT_EQ(p->GetBlock(4)->offset_, 64); }