mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-07 00:13:17 +00:00
Avoid round-trip copies for "pass through" subgraph inputs (#8702)
This commit is contained in:
parent
42ba0c5931
commit
7659148d9f
6 changed files with 385 additions and 31 deletions
|
|
@ -121,6 +121,7 @@ class PlannerImpl {
|
|||
PlannerImpl(const Node* parent_node, const onnxruntime::GraphViewer& graph_viewer,
|
||||
const std::vector<const NodeArg*>& outer_scope_node_args, const ExecutionProviders& providers,
|
||||
const std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>>& kernel_create_info_map,
|
||||
const std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_node_arg_to_location_map,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
const ISequentialPlannerContext& context, SequentialExecutionPlan& plan)
|
||||
: context_(context),
|
||||
|
|
@ -130,6 +131,7 @@ class PlannerImpl {
|
|||
outer_scope_node_args_(outer_scope_node_args),
|
||||
execution_providers_(providers),
|
||||
kernel_create_info_map_(kernel_create_info_map),
|
||||
outer_scope_node_arg_to_location_map_(outer_scope_node_arg_to_location_map),
|
||||
ort_value_name_idx_map_(ort_value_name_idx_map) {}
|
||||
|
||||
Status CreatePlan();
|
||||
|
|
@ -144,6 +146,9 @@ class PlannerImpl {
|
|||
const ExecutionProviders& execution_providers_;
|
||||
|
||||
const std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>>& kernel_create_info_map_;
|
||||
|
||||
const std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_node_arg_to_location_map_;
|
||||
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map_;
|
||||
|
||||
// OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation:
|
||||
|
|
@ -261,9 +266,9 @@ class PlannerImpl {
|
|||
// Inputs of Yields are essentially the outputs for FW partial subgraph
|
||||
// Thses tensors will be pass back to pytorch, thus cannot share the buffer with other tensors
|
||||
|
||||
// Unhandled corner case:
|
||||
// Unhandled corner case:
|
||||
// If FW output tensor is consumed by BW graph, and pytorch performs an inplace operation on th returned tensor,
|
||||
// we will run into a buffer corruption problem.
|
||||
// we will run into a buffer corruption problem.
|
||||
// One potential fix is returning a copy of output tensor, if it has downstream dependency
|
||||
auto p_next_node = node.OutputNodesBegin();
|
||||
if (p_next_node != node.OutputNodesEnd() && p_next_node->OpType() == "YieldOp") {
|
||||
|
|
@ -483,6 +488,8 @@ class PlannerImpl {
|
|||
UseCount(initializer_name)++;
|
||||
}
|
||||
|
||||
std::unordered_set<OrtValueIndex> set_node_arg_has_explicit_consumer;
|
||||
|
||||
for (SequentialExecutionPlan::NodeExecutionPlan& step : plan_.execution_plan) {
|
||||
auto pnode = graph_viewer_.GetNode(step.node_index);
|
||||
if (pnode == nullptr) {
|
||||
|
|
@ -507,26 +514,56 @@ class PlannerImpl {
|
|||
|
||||
// increment UseCount and add location information if applicable for the provided input def
|
||||
auto process_input = [&graph_inputs, &exec_provider, &p_kernel_def, &is_implicit_input,
|
||||
&set_node_arg_has_explicit_consumer,
|
||||
this](const NodeArg& input, size_t arg_idx) {
|
||||
const auto& name = input.Name();
|
||||
UseCount(name)++;
|
||||
|
||||
bool is_graph_input = (graph_inputs.find(name) != graph_inputs.cend());
|
||||
bool is_outer_scope_arg = std::find_if(outer_scope_node_args_.cbegin(), outer_scope_node_args_.cend(),
|
||||
[&name](const NodeArg* value) {
|
||||
return value && value->Name() == name;
|
||||
}) != outer_scope_node_args_.cend();
|
||||
bool is_subgraph = (parent_node_ != nullptr);
|
||||
|
||||
// If it's a graph input or outer scope node arg, set its plan.
|
||||
// NOTE: Copy nodes should have already been added if a graph input is fed as input
|
||||
// to nodes assigned to different providers.
|
||||
if (graph_inputs.find(name) != graph_inputs.cend() ||
|
||||
std::find_if(outer_scope_node_args_.cbegin(), outer_scope_node_args_.cend(),
|
||||
[&name](const NodeArg* value) {
|
||||
return value && value->Name() == name;
|
||||
}) != outer_scope_node_args_.cend()) {
|
||||
|
||||
if (is_graph_input || is_outer_scope_arg) {
|
||||
OrtValueIndex index = Index(name);
|
||||
|
||||
// implicit inputs do not have an entry in the kernel def, so do nothing to them here, leaving the control
|
||||
// flow op (Loop, Scan, If) to do the necessary copy if the input crosses different provider.
|
||||
// matching logic is used in TransformerMemcpyImpl::ProcessDefs
|
||||
if (!is_implicit_input) {
|
||||
OrtMemType mem_type = p_kernel_def->InputMemoryType(arg_idx);
|
||||
plan_.SetLocation(static_cast<size_t>(index), exec_provider->GetAllocator(0, mem_type)->Info());
|
||||
set_node_arg_has_explicit_consumer.insert(index);
|
||||
} else { // implicit input
|
||||
// Only process an implicit input:
|
||||
// 1) Within a subgraph
|
||||
// 2) If there is no explicit consumer at this graph level
|
||||
// If there is an explicit consumer, the location MUST be where it is consumed
|
||||
// and not where it is located in the outer scope.
|
||||
// It is okay if we process a node consuming this arg as an implicit input
|
||||
// ahead of a node that is an explicit consumer, because we will just reset
|
||||
// this location in the 'if' branch above.
|
||||
|
||||
if (is_subgraph && set_node_arg_has_explicit_consumer.count(index) == 0) {
|
||||
auto iter = outer_scope_node_arg_to_location_map_.find(name);
|
||||
bool found_in_outer_scope_location_map = (iter != outer_scope_node_arg_to_location_map_.end());
|
||||
|
||||
if (!is_graph_input) {
|
||||
// Failing this enforce for an implicit subgraph input points to an internal error somewhere.
|
||||
// For certain older opsets (Scan-8), we may not have added explicit subgraph inputs
|
||||
// to the outer scope location map. See explanation in IsNodeWhereNodeInputsAreSameAsExplicitSubgraphInputs()
|
||||
// called in FinalizeSessionStateImpl() in SessionState.
|
||||
ORT_ENFORCE(found_in_outer_scope_location_map,
|
||||
"There is no location for this node arg in the outer scope location map");
|
||||
}
|
||||
|
||||
if (found_in_outer_scope_location_map) {
|
||||
plan_.SetLocation(static_cast<size_t>(index), iter->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1062,6 +1099,7 @@ Status SequentialPlanner::CreatePlan(
|
|||
const std::vector<const NodeArg*>& outer_scope_node_args,
|
||||
const ExecutionProviders& providers,
|
||||
const std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>>& kernel_create_info_map,
|
||||
const std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_node_arg_to_location_map,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
const ISequentialPlannerContext& context,
|
||||
std::unique_ptr<SequentialExecutionPlan>& plan) {
|
||||
|
|
@ -1069,7 +1107,8 @@ Status SequentialPlanner::CreatePlan(
|
|||
plan = std::make_unique<SequentialExecutionPlan>();
|
||||
|
||||
PlannerImpl planner(parent_node, graph_viewer, outer_scope_node_args, providers,
|
||||
kernel_create_info_map, ort_value_name_idx_map, context, *plan);
|
||||
kernel_create_info_map, outer_scope_node_arg_to_location_map,
|
||||
ort_value_name_idx_map, context, *plan);
|
||||
|
||||
return planner.CreatePlan();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -66,6 +66,7 @@ class SequentialPlanner {
|
|||
const std::vector<const NodeArg*>& outer_scope_node_args,
|
||||
const ExecutionProviders& providers,
|
||||
const std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>>& kernel_create_info_map,
|
||||
const std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_arg_to_location_map,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
const ISequentialPlannerContext& context,
|
||||
std::unique_ptr<SequentialExecutionPlan>& plan);
|
||||
|
|
|
|||
|
|
@ -1101,13 +1101,90 @@ Status SessionState::FinalizeSessionState(const std::basic_string<PATH_CHAR_TYPE
|
|||
remove_initializers, constant_initializers_use_count);
|
||||
}
|
||||
|
||||
static Status Index(const OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
const OrtValueName& name,
|
||||
/*out*/ OrtValueIndex& value) {
|
||||
return ort_value_name_idx_map.GetIdx(name, value);
|
||||
}
|
||||
|
||||
static bool IsNodeWhereNodeInputsAreSameAsExplicitSubgraphInputs(const Node& node) {
|
||||
const auto& op_type = node.OpType();
|
||||
int since_version = node.SinceVersion();
|
||||
|
||||
// TODO: Re-visit this method if more subgraph ops get accepted into ONNX
|
||||
|
||||
// At the time of writing, there are only 3 ops in ONNX that have subgraphs
|
||||
// 1) If
|
||||
// 2) Loop
|
||||
// 3) Scan
|
||||
|
||||
// `If` - The op doesn't have explicit subgraph inputs (return false)
|
||||
// `Loop`- In all opset versions of Loop (at the time of writing) the node inputs
|
||||
// have a one-to-one mapping between them and the explicit subgraph inputs
|
||||
// of the subgraph held in the Loop (return true)
|
||||
// `Scan` - Except opset 8 version of Scan (at the time of writing), all other
|
||||
// versions have the same one-to-one mapping as Loop (return true for opset > 8)
|
||||
|
||||
return (op_type == "Loop" || (op_type == "Scan" && since_version >= 9));
|
||||
}
|
||||
|
||||
// The following method accumulates the locations of all inputs (implicit and explicit)
|
||||
// to a control flow node at the current graph level. This information will be used in
|
||||
// the allocation planner while determining the location of such inputs in the subgraph.
|
||||
// This method will not be called for the main graph (there is no concept of "outer scope" for the main graph).
|
||||
static Status OuterScopeNodeArgLocationAccumulator(const SequentialExecutionPlan& plan,
|
||||
const OrtValueNameIdxMap& ort_value_name_to_idx_map,
|
||||
const Node& parent_node,
|
||||
const GraphViewer& subgraph,
|
||||
/*out*/ std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_arg_to_location_map) {
|
||||
// Process implicit inputs to the node
|
||||
auto process_implicit_input = [&plan, &ort_value_name_to_idx_map,
|
||||
&outer_scope_arg_to_location_map](const NodeArg& input, size_t /*arg_idx*/) {
|
||||
const auto& name = input.Name();
|
||||
OrtValueIndex index = -1;
|
||||
ORT_RETURN_IF_ERROR(Index(ort_value_name_to_idx_map, name, index));
|
||||
outer_scope_arg_to_location_map.insert({name, plan.GetLocation(index)});
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
ORT_RETURN_IF_ERROR(Node::ForEachWithIndex(parent_node.ImplicitInputDefs(), process_implicit_input));
|
||||
|
||||
// Process explicit inputs to the node
|
||||
// (they are passed through as explicit subgraph inputs and hence requires a re-mapping of names
|
||||
// to their corresponding names in the inner nested subgraph(s) held by the node)
|
||||
const auto& subgraph_inputs = subgraph.GetInputs();
|
||||
|
||||
auto process_input = [&plan, &ort_value_name_to_idx_map, &outer_scope_arg_to_location_map,
|
||||
&subgraph_inputs](const NodeArg& input, size_t arg_idx) {
|
||||
const auto& name = input.Name();
|
||||
OrtValueIndex index = -1;
|
||||
ORT_RETURN_IF_ERROR(Index(ort_value_name_to_idx_map, name, index));
|
||||
|
||||
// Store the location of the outer scope value in the map using the subgraph input as the key
|
||||
// as that will be the referenced name in the subgraph (i.e.) re-mapping of names is required
|
||||
outer_scope_arg_to_location_map.insert({subgraph_inputs[arg_idx]->Name(), plan.GetLocation(index)});
|
||||
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
if (IsNodeWhereNodeInputsAreSameAsExplicitSubgraphInputs(parent_node)) {
|
||||
return Node::ForEachWithIndex(parent_node.InputDefs(), process_input);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_TYPE>& graph_location,
|
||||
KernelRegistryManager& kernel_registry_manager,
|
||||
_In_opt_ const Node* parent_node,
|
||||
const SessionOptions& session_options,
|
||||
bool remove_initializers,
|
||||
std::unordered_map<std::string, size_t>& constant_initializers_use_count) {
|
||||
CreateGraphInfo();
|
||||
std::unordered_map<std::string, size_t>& constant_initializers_use_count,
|
||||
const std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_node_arg_to_location_map,
|
||||
bool graph_info_already_created) {
|
||||
if (!graph_info_already_created) {
|
||||
CreateGraphInfo();
|
||||
}
|
||||
|
||||
// ignore any outer scope args we don't know about. this can happen if a node contains multiple subgraphs.
|
||||
std::vector<const NodeArg*> valid_outer_scope_node_args;
|
||||
|
|
@ -1127,6 +1204,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
|
|||
SequentialPlannerContext context(session_options.execution_mode, session_options.execution_order, session_options.enable_mem_reuse);
|
||||
ORT_RETURN_IF_ERROR(SequentialPlanner::CreatePlan(parent_node, *graph_viewer_, valid_outer_scope_node_args,
|
||||
execution_providers_, kernel_create_info_map_,
|
||||
outer_scope_node_arg_to_location_map,
|
||||
ort_value_name_idx_map_, context, p_seq_exec_plan_));
|
||||
//Record the allocation plan
|
||||
|
||||
|
|
@ -1218,8 +1296,19 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
|
|||
SessionState& subgraph_session_state = *entry->second;
|
||||
|
||||
// recurse
|
||||
|
||||
// We need to create graph info for the subgraphs because information accumulated there
|
||||
// is used in OuterScopeNodeArgLocationAccumulator()
|
||||
subgraph_session_state.CreateGraphInfo();
|
||||
|
||||
std::unordered_map<OrtValueName, OrtMemoryInfo> subgraph_outer_scope_node_arg_to_location_map;
|
||||
ORT_RETURN_IF_ERROR(OuterScopeNodeArgLocationAccumulator(*p_seq_exec_plan_, GetOrtValueNameIdxMap(),
|
||||
node,
|
||||
subgraph_session_state.GetGraphViewer(),
|
||||
subgraph_outer_scope_node_arg_to_location_map));
|
||||
ORT_RETURN_IF_ERROR(subgraph_session_state.FinalizeSessionStateImpl(
|
||||
graph_location, kernel_registry_manager, &node, subgraph_session_options, remove_initializers, constant_initializers_use_count));
|
||||
graph_location, kernel_registry_manager, &node, subgraph_session_options, remove_initializers,
|
||||
constant_initializers_use_count, subgraph_outer_scope_node_arg_to_location_map, true));
|
||||
|
||||
// setup all the info for handling the feeds and fetches used in subgraph execution
|
||||
auto* p_op_kernel = GetMutableKernel(node.Index());
|
||||
|
|
|
|||
|
|
@ -101,7 +101,6 @@ class SessionState {
|
|||
use_deterministic_compute_(use_deterministic_compute),
|
||||
enable_mem_reuse_(enable_mem_reuse),
|
||||
prepacked_weights_container_(prepacked_weights_container) {
|
||||
|
||||
SetupAllocators();
|
||||
}
|
||||
|
||||
|
|
@ -268,7 +267,7 @@ class SessionState {
|
|||
const KernelCreateInfo& GetNodeKernelCreateInfo(NodeIndex node_index) const;
|
||||
|
||||
/// Return SessionState for the given Node index and attribute name if found.
|
||||
const SessionState* GetSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name) const;
|
||||
const SessionState* GetSubgraphSessionState(NodeIndex index, const std::string& attribute_name) const;
|
||||
|
||||
concurrency::ThreadPool* GetThreadPool() const noexcept { return thread_pool_; }
|
||||
concurrency::ThreadPool* GetInterOpThreadPool() const noexcept { return inter_op_thread_pool_; }
|
||||
|
|
@ -368,7 +367,9 @@ class SessionState {
|
|||
_In_opt_ const Node* parent_node,
|
||||
const SessionOptions& session_options,
|
||||
bool remove_initializers,
|
||||
std::unordered_map<std::string, size_t>& constant_initializers_use_count);
|
||||
std::unordered_map<std::string, size_t>& constant_initializers_use_count,
|
||||
const std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_node_arg_to_location_map = {},
|
||||
bool graph_info_already_created = false);
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
Status GeneratePatternGroupCache(
|
||||
|
|
|
|||
|
|
@ -321,13 +321,11 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::GraphViewer&
|
|||
// implicit inputs to a node could come directly from a feed, so we need to make sure they have an entry too
|
||||
const auto& node_implicit_inputs = node.ImplicitInputDefs();
|
||||
if (!node_implicit_inputs.empty()) {
|
||||
// nested subgraph. for now map them to this node (which will be CPU based as all the control flow nodes
|
||||
// are currently CPU based and they're the only ones that have implicit inputs) as the inputs will be passed as a
|
||||
// feed when executing the subgraph and need to be in the mapping.
|
||||
// in the future we want to recurse and find where the implicit input is actually used to try and avoid a
|
||||
// copy to/from CPU to go through the control flow nodes where possible/applicable.
|
||||
// the processing for the subgraph where the implicit input is consumed will do the real check on whether any
|
||||
// copy to a different device is required
|
||||
// In nested subgraphs, the location of the implicit input(s) is the location it
|
||||
// is consumed in the subgraph if there is an explicit consumer.
|
||||
// If the only consumer(s) are implicit consumers (i.e.) other control flow nodes, its
|
||||
// location is the location of the value in the enclosing outer scope.
|
||||
// All this is setup in the planner, we just use the location from the plan here.
|
||||
for (const auto& input_def : node_implicit_inputs) {
|
||||
int arg_index;
|
||||
ORT_RETURN_IF_ERROR(name_to_id.GetIdx(input_def->Name(), arg_index));
|
||||
|
|
|
|||
|
|
@ -12,15 +12,23 @@
|
|||
#include "core/framework/op_kernel.h"
|
||||
#include "test/framework/model_builder_utils.h"
|
||||
#include "core/framework/allocation_planner.h"
|
||||
#include "core/session/inference_session.h"
|
||||
#include "core/graph/model.h"
|
||||
#include "core/providers/cpu/cpu_execution_provider.h"
|
||||
#include "core/util/thread_utils.h"
|
||||
|
||||
#include "test/test_environment.h"
|
||||
#include "test/util/include/asserts.h"
|
||||
#include "test/util/include/default_providers.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
// Explicitly provide a definition for the static const var 'GPU' in the OrtDevice struct,
|
||||
// GCC 4.x doesn't seem to define this and it breaks the pipelines based on CentOS as it uses
|
||||
// GCC 4.x.
|
||||
// (This static var is referenced in `PassThroughExplicitAndImplicitSubgraphInputs` test)
|
||||
const OrtDevice::DeviceType OrtDevice::GPU;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
|
|
@ -154,9 +162,9 @@ class PlannerTest : public ::testing::Test {
|
|||
// some standard components used to build test-cases:
|
||||
Type float_type_;
|
||||
|
||||
std::unique_ptr<::onnxruntime::KernelDef> std_kernel_; // a unary kernel with no-aliasing and no-in-place
|
||||
std::unique_ptr<::onnxruntime::KernelDef> in_place_kernel_; // a unary kernel with in-place
|
||||
std::unique_ptr<::onnxruntime::KernelDef> external_outputs_kernel_; // an unary kernel with external outputs
|
||||
std::unique_ptr<::onnxruntime::KernelDef> std_kernel_; // a unary kernel with no-aliasing and no-in-place
|
||||
std::unique_ptr<::onnxruntime::KernelDef> in_place_kernel_; // a unary kernel with in-place
|
||||
std::unique_ptr<::onnxruntime::KernelDef> external_outputs_kernel_; // an unary kernel with external outputs
|
||||
|
||||
std::unordered_map<std::string, onnxruntime::NodeArg*> name_to_arg_;
|
||||
std::vector<std::unique_ptr<UnaryNode>> nodes_;
|
||||
|
|
@ -270,7 +278,7 @@ class PlannerTest : public ::testing::Test {
|
|||
SequentialPlannerTestContext test_context(&shape_map_);
|
||||
|
||||
status = SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph_), outer_scope_node_args, execution_providers_,
|
||||
kernel_create_info_map, state_->GetOrtValueNameIdxMap(), test_context,
|
||||
kernel_create_info_map, {}, state_->GetOrtValueNameIdxMap(), test_context,
|
||||
plan_);
|
||||
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
|
@ -415,9 +423,9 @@ TEST_F(PlannerTest, ExternalOutputsTest) {
|
|||
std::string X1("X1"), X2("X2"), X3("X3"), X4("X4");
|
||||
|
||||
// graph structure:
|
||||
AddExternalOutputsNode(X1, X2); // external-outputs operator; X1: input; X2: temporary
|
||||
AddNormalNode(X2, X3); // normal operator; X3: temporary
|
||||
AddNormalNode(X3, X4); // normal operator; X4: output
|
||||
AddExternalOutputsNode(X1, X2); // external-outputs operator; X1: input; X2: temporary
|
||||
AddNormalNode(X2, X3); // normal operator; X3: temporary
|
||||
AddNormalNode(X3, X4); // normal operator; X4: output
|
||||
|
||||
// simulate shape-inference results:
|
||||
Shape shape1{"M", "N"};
|
||||
|
|
@ -505,5 +513,223 @@ TEST_F(PlannerTest, PlanOutputTest) {
|
|||
}
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
TEST_F(PlannerTest, PassThroughExplicitAndImplicitSubgraphInputs) {
|
||||
// Types
|
||||
TypeProto float_tensor;
|
||||
float_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_param("dim_param");
|
||||
|
||||
TypeProto int64_scalar;
|
||||
int64_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64);
|
||||
int64_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
|
||||
TypeProto bool_scalar;
|
||||
bool_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_BOOL);
|
||||
bool_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
|
||||
// The model has a main graph and 2 levels of nested subgraphs
|
||||
// Main graph: 2 Abs nodes + one Loop node
|
||||
// First level (Loop) subgraph: Identity (condition pass-through) + If node
|
||||
// Second level subgraph(s): Then and Else branches: Both have an Add node
|
||||
// The Add node adds 2 values:
|
||||
// One value from the main graph ("abs_data_0_out") that is "implicitly"
|
||||
// consumed by the Loop node and "passed through" to the If subgraphs.
|
||||
// Another value from the main graph ("abs_data_1_out") that is "explicitly"
|
||||
// consumed by the Loop node as a loop carried dependency and its name in
|
||||
// the scope of the Loop node is "loop_state_var".
|
||||
|
||||
// In the Loop subgraph, there are no explicit consumers of "abs_data_0_out"
|
||||
// and "loop_state_var", there is only one implicit consumer - "If".
|
||||
// We want to ensure that since there are no explicit consumers, the planned locations
|
||||
// for these values in this subgraph are the same locations as their corresponding
|
||||
// values in the outer scope, thus deferring any copies (if required) till the actual
|
||||
// subgraph(s) they are explicitly consumed in.
|
||||
auto create_model = [&float_tensor, &int64_scalar, &bool_scalar]() -> Model {
|
||||
auto create_if_subgraph = [&float_tensor](bool is_then) -> GraphProto {
|
||||
Model model("if_branch_subgraph", true, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
auto& outer_scope_0 = graph.GetOrCreateNodeArg("loop_state_var", &float_tensor);
|
||||
graph.AddOuterScopeNodeArg("loop_state_var");
|
||||
|
||||
auto& outer_scope_1 = graph.GetOrCreateNodeArg("abs_data_0_out", &float_tensor);
|
||||
graph.AddOuterScopeNodeArg("abs_data_0_out");
|
||||
|
||||
auto& if_out = graph.GetOrCreateNodeArg(is_then ? "if_then_out" : "if_else_out", &float_tensor);
|
||||
graph.AddNode("if_out", "Add", "add", {&outer_scope_0, &outer_scope_1}, {&if_out});
|
||||
|
||||
auto status = graph.Resolve();
|
||||
EXPECT_EQ(status, Status::OK());
|
||||
|
||||
return graph.ToGraphProto();
|
||||
};
|
||||
|
||||
auto create_loop_subgraph = [&create_if_subgraph, &float_tensor, &int64_scalar, &bool_scalar]() -> GraphProto {
|
||||
Model model("loop_subgraph", true, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
std::vector<NodeArg*> inputs;
|
||||
std::vector<NodeArg*> outputs;
|
||||
|
||||
/* Inputs: iter_num, cond_in, loop carried state variables.
|
||||
|
||||
iter_num_in cond_in [loop_state_var]
|
||||
(unused) | |
|
||||
[Identity] [If]
|
||||
| |
|
||||
cond_out loop_state_var_out
|
||||
*/
|
||||
|
||||
// graph inputs
|
||||
auto& iter_num_in = graph.GetOrCreateNodeArg("iter_num_in", &int64_scalar);
|
||||
auto& cond_in = graph.GetOrCreateNodeArg("cond_in", &bool_scalar);
|
||||
auto& loop_state_var = graph.GetOrCreateNodeArg("loop_state_var", &float_tensor);
|
||||
|
||||
// graph outputs
|
||||
auto& cond_out = graph.GetOrCreateNodeArg("cond_out", &bool_scalar);
|
||||
auto& loop_state_var_out = graph.GetOrCreateNodeArg("loop_state_var_out", &float_tensor);
|
||||
|
||||
// outer scope args
|
||||
ORT_IGNORE_RETURN_VALUE(graph.GetOrCreateNodeArg("abs_data_0_out", &float_tensor));
|
||||
graph.AddOuterScopeNodeArg("abs_data_0_out");
|
||||
|
||||
// cond_in -> cond_out
|
||||
{
|
||||
inputs = {&cond_in};
|
||||
outputs = {&cond_out};
|
||||
|
||||
graph.AddNode("cond_in_identity", "Identity", "Forward cond_in to cond_out", inputs, outputs);
|
||||
}
|
||||
|
||||
// loop_state_var -> If(cond_in) -> loop_state_var_out
|
||||
{
|
||||
inputs = {&cond_in};
|
||||
outputs = {&loop_state_var_out};
|
||||
|
||||
auto& node = graph.AddNode("loop_var_out", "If", "If with loop_state_var as implicit_input", inputs, outputs);
|
||||
node.AddAttribute("then_branch", create_if_subgraph(true));
|
||||
node.AddAttribute("else_branch", create_if_subgraph(false));
|
||||
}
|
||||
|
||||
graph.SetInputs({&iter_num_in, &cond_in, &loop_state_var});
|
||||
graph.SetOutputs({&cond_out, &loop_state_var_out});
|
||||
|
||||
auto status = graph.Resolve();
|
||||
EXPECT_EQ(status, Status::OK());
|
||||
|
||||
return graph.ToGraphProto();
|
||||
};
|
||||
|
||||
onnxruntime::Model model("main_graph", false, ModelMetaData(),
|
||||
PathString(), IOnnxRuntimeOpSchemaRegistryList(),
|
||||
{{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger());
|
||||
auto& main_graph = model.MainGraph();
|
||||
|
||||
// Abs-0
|
||||
auto& abs_data_0_in = main_graph.GetOrCreateNodeArg("abs_data_0_in", &float_tensor);
|
||||
auto& abs_data_0_out = main_graph.GetOrCreateNodeArg("abs_data_0_out", &float_tensor);
|
||||
std::vector<onnxruntime::NodeArg*> abs_0_inputs = {&abs_data_0_in};
|
||||
std::vector<onnxruntime::NodeArg*> abs_0_outputs = {&abs_data_0_out};
|
||||
main_graph.AddNode("abs_0", "Abs", "node abs", abs_0_inputs, abs_0_outputs);
|
||||
|
||||
// Abs-1
|
||||
auto& abs_data_1_in = main_graph.GetOrCreateNodeArg("abs_data_1_in", &float_tensor);
|
||||
auto& abs_data_1_out = main_graph.GetOrCreateNodeArg("abs_data_1_out", &float_tensor);
|
||||
std::vector<onnxruntime::NodeArg*> abs_1_inputs = {&abs_data_1_in};
|
||||
std::vector<onnxruntime::NodeArg*> abs_1_outputs = {&abs_data_1_out};
|
||||
main_graph.AddNode("abs_1", "Abs", "node abs", abs_1_inputs, abs_1_outputs);
|
||||
|
||||
// Loop
|
||||
auto& iter_num_in = main_graph.GetOrCreateNodeArg("iter_num_in", &int64_scalar);
|
||||
auto& cond_in = main_graph.GetOrCreateNodeArg("cond_in", &bool_scalar);
|
||||
auto& loop_state_out_var = main_graph.GetOrCreateNodeArg("loop_state_out_var", &float_tensor);
|
||||
|
||||
auto& loop_node = main_graph.AddNode("loop", "Loop", "Loop node",
|
||||
{&iter_num_in, &cond_in, &abs_data_1_out},
|
||||
{&loop_state_out_var});
|
||||
loop_node.AddAttribute("body", create_loop_subgraph());
|
||||
|
||||
main_graph.SetInputs({&abs_data_0_in, &abs_data_1_in, &iter_num_in, &cond_in});
|
||||
main_graph.SetOutputs({&loop_state_out_var});
|
||||
|
||||
auto status = main_graph.Resolve();
|
||||
EXPECT_EQ(status, Status::OK());
|
||||
|
||||
return model;
|
||||
};
|
||||
|
||||
// Create and load session
|
||||
SessionOptions so;
|
||||
InferenceSession sess{so, GetEnvironment()};
|
||||
|
||||
auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider());
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
|
||||
std::string s1;
|
||||
const bool rc = create_model().ToProto().SerializeToString(&s1);
|
||||
EXPECT_EQ(rc, true);
|
||||
std::stringstream sstr(s1);
|
||||
|
||||
status = sess.Load(sstr);
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
|
||||
status = sess.Initialize();
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
|
||||
// Check planned locations of values in the main graph that are implicit subgraph inputs
|
||||
// and explicit subgraph inputs to the Loop node
|
||||
|
||||
// Main graph (L0 graph)
|
||||
const auto& main_graph_session_state = sess.GetSessionState();
|
||||
|
||||
{
|
||||
const auto& main_graph_ort_value_index_map = main_graph_session_state.GetOrtValueNameIdxMap();
|
||||
const auto* main_graph_plan = main_graph_session_state.GetExecutionPlan();
|
||||
|
||||
OrtValueIndex abs_data_0_out_index;
|
||||
main_graph_ort_value_index_map.GetIdx("abs_data_0_out", abs_data_0_out_index);
|
||||
|
||||
OrtValueIndex abs_data_1_out_index;
|
||||
main_graph_ort_value_index_map.GetIdx("abs_data_1_out", abs_data_1_out_index);
|
||||
|
||||
EXPECT_EQ(main_graph_plan->allocation_plan[abs_data_0_out_index].location.device.Type(), OrtDevice::GPU);
|
||||
EXPECT_EQ(main_graph_plan->allocation_plan[abs_data_1_out_index].location.device.Type(), OrtDevice::GPU);
|
||||
}
|
||||
|
||||
// First subgraph (Loop) (L1 graph)
|
||||
// There are 3 nodes in the main level- Only one of them has a subgraph (Loop).
|
||||
// Find that.
|
||||
const SessionState* find_first_subgraph_session_state = nullptr;
|
||||
for (size_t i = 0; i < 3; ++i) {
|
||||
find_first_subgraph_session_state = main_graph_session_state.GetSubgraphSessionState(i, "body");
|
||||
if (find_first_subgraph_session_state) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const auto& first_subgraph_session_state = *find_first_subgraph_session_state;
|
||||
|
||||
{
|
||||
const auto& first_subgraph_ort_value_index_map = first_subgraph_session_state.GetOrtValueNameIdxMap();
|
||||
const auto* first_subgraph_plan = first_subgraph_session_state.GetExecutionPlan();
|
||||
|
||||
OrtValueIndex abs_data_0_out_index;
|
||||
first_subgraph_ort_value_index_map.GetIdx("abs_data_0_out", abs_data_0_out_index);
|
||||
|
||||
// "abs_data_1_out" is "loop_state_var" in this scope as it was consumed as an explicit subgraph input
|
||||
// to Loop's body subgraph
|
||||
OrtValueIndex abs_data_1_out_index;
|
||||
first_subgraph_ort_value_index_map.GetIdx("loop_state_var", abs_data_1_out_index);
|
||||
|
||||
// There are no explicit consumers of "abs_data_0_out" and "loop_state_var (abs_data_1_out)" in this scope.
|
||||
// There is only one implicit consumer "If". Hence, check that we are preserving the locations of these values
|
||||
// from the outer scope, thus deferring any copies till the actual nested subgraph these values are used in.
|
||||
EXPECT_EQ(first_subgraph_plan->allocation_plan[abs_data_0_out_index].location.device.Type(), OrtDevice::GPU);
|
||||
EXPECT_EQ(first_subgraph_plan->allocation_plan[abs_data_1_out_index].location.device.Type(), OrtDevice::GPU);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue