diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index cd33a96fae..fbbfdebcca 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -121,6 +121,7 @@ class PlannerImpl { PlannerImpl(const Node* parent_node, const onnxruntime::GraphViewer& graph_viewer, const std::vector& outer_scope_node_args, const ExecutionProviders& providers, const std::unordered_map>& kernel_create_info_map, + const std::unordered_map& outer_scope_node_arg_to_location_map, const OrtValueNameIdxMap& ort_value_name_idx_map, const ISequentialPlannerContext& context, SequentialExecutionPlan& plan) : context_(context), @@ -130,6 +131,7 @@ class PlannerImpl { outer_scope_node_args_(outer_scope_node_args), execution_providers_(providers), kernel_create_info_map_(kernel_create_info_map), + outer_scope_node_arg_to_location_map_(outer_scope_node_arg_to_location_map), ort_value_name_idx_map_(ort_value_name_idx_map) {} Status CreatePlan(); @@ -144,6 +146,9 @@ class PlannerImpl { const ExecutionProviders& execution_providers_; const std::unordered_map>& kernel_create_info_map_; + + const std::unordered_map& outer_scope_node_arg_to_location_map_; + const OrtValueNameIdxMap& ort_value_name_idx_map_; // OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation: @@ -261,9 +266,9 @@ class PlannerImpl { // Inputs of Yields are essentially the outputs for FW partial subgraph // Thses tensors will be pass back to pytorch, thus cannot share the buffer with other tensors - // Unhandled corner case: + // Unhandled corner case: // If FW output tensor is consumed by BW graph, and pytorch performs an inplace operation on th returned tensor, - // we will run into a buffer corruption problem. + // we will run into a buffer corruption problem. // One potential fix is returning a copy of output tensor, if it has downstream dependency auto p_next_node = node.OutputNodesBegin(); if (p_next_node != node.OutputNodesEnd() && p_next_node->OpType() == "YieldOp") { @@ -483,6 +488,8 @@ class PlannerImpl { UseCount(initializer_name)++; } + std::unordered_set set_node_arg_has_explicit_consumer; + for (SequentialExecutionPlan::NodeExecutionPlan& step : plan_.execution_plan) { auto pnode = graph_viewer_.GetNode(step.node_index); if (pnode == nullptr) { @@ -507,26 +514,56 @@ class PlannerImpl { // increment UseCount and add location information if applicable for the provided input def auto process_input = [&graph_inputs, &exec_provider, &p_kernel_def, &is_implicit_input, + &set_node_arg_has_explicit_consumer, this](const NodeArg& input, size_t arg_idx) { const auto& name = input.Name(); UseCount(name)++; + bool is_graph_input = (graph_inputs.find(name) != graph_inputs.cend()); + bool is_outer_scope_arg = std::find_if(outer_scope_node_args_.cbegin(), outer_scope_node_args_.cend(), + [&name](const NodeArg* value) { + return value && value->Name() == name; + }) != outer_scope_node_args_.cend(); + bool is_subgraph = (parent_node_ != nullptr); + // If it's a graph input or outer scope node arg, set its plan. // NOTE: Copy nodes should have already been added if a graph input is fed as input // to nodes assigned to different providers. - if (graph_inputs.find(name) != graph_inputs.cend() || - std::find_if(outer_scope_node_args_.cbegin(), outer_scope_node_args_.cend(), - [&name](const NodeArg* value) { - return value && value->Name() == name; - }) != outer_scope_node_args_.cend()) { + + if (is_graph_input || is_outer_scope_arg) { OrtValueIndex index = Index(name); - // implicit inputs do not have an entry in the kernel def, so do nothing to them here, leaving the control - // flow op (Loop, Scan, If) to do the necessary copy if the input crosses different provider. - // matching logic is used in TransformerMemcpyImpl::ProcessDefs if (!is_implicit_input) { OrtMemType mem_type = p_kernel_def->InputMemoryType(arg_idx); plan_.SetLocation(static_cast(index), exec_provider->GetAllocator(0, mem_type)->Info()); + set_node_arg_has_explicit_consumer.insert(index); + } else { // implicit input + // Only process an implicit input: + // 1) Within a subgraph + // 2) If there is no explicit consumer at this graph level + // If there is an explicit consumer, the location MUST be where it is consumed + // and not where it is located in the outer scope. + // It is okay if we process a node consuming this arg as an implicit input + // ahead of a node that is an explicit consumer, because we will just reset + // this location in the 'if' branch above. + + if (is_subgraph && set_node_arg_has_explicit_consumer.count(index) == 0) { + auto iter = outer_scope_node_arg_to_location_map_.find(name); + bool found_in_outer_scope_location_map = (iter != outer_scope_node_arg_to_location_map_.end()); + + if (!is_graph_input) { + // Failing this enforce for an implicit subgraph input points to an internal error somewhere. + // For certain older opsets (Scan-8), we may not have added explicit subgraph inputs + // to the outer scope location map. See explanation in IsNodeWhereNodeInputsAreSameAsExplicitSubgraphInputs() + // called in FinalizeSessionStateImpl() in SessionState. + ORT_ENFORCE(found_in_outer_scope_location_map, + "There is no location for this node arg in the outer scope location map"); + } + + if (found_in_outer_scope_location_map) { + plan_.SetLocation(static_cast(index), iter->second); + } + } } } @@ -1062,6 +1099,7 @@ Status SequentialPlanner::CreatePlan( const std::vector& outer_scope_node_args, const ExecutionProviders& providers, const std::unordered_map>& kernel_create_info_map, + const std::unordered_map& outer_scope_node_arg_to_location_map, const OrtValueNameIdxMap& ort_value_name_idx_map, const ISequentialPlannerContext& context, std::unique_ptr& plan) { @@ -1069,7 +1107,8 @@ Status SequentialPlanner::CreatePlan( plan = std::make_unique(); PlannerImpl planner(parent_node, graph_viewer, outer_scope_node_args, providers, - kernel_create_info_map, ort_value_name_idx_map, context, *plan); + kernel_create_info_map, outer_scope_node_arg_to_location_map, + ort_value_name_idx_map, context, *plan); return planner.CreatePlan(); } diff --git a/onnxruntime/core/framework/allocation_planner.h b/onnxruntime/core/framework/allocation_planner.h index 3bafc05a9d..0bce0390ea 100644 --- a/onnxruntime/core/framework/allocation_planner.h +++ b/onnxruntime/core/framework/allocation_planner.h @@ -66,6 +66,7 @@ class SequentialPlanner { const std::vector& outer_scope_node_args, const ExecutionProviders& providers, const std::unordered_map>& kernel_create_info_map, + const std::unordered_map& outer_scope_arg_to_location_map, const OrtValueNameIdxMap& ort_value_name_idx_map, const ISequentialPlannerContext& context, std::unique_ptr& plan); diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 611564f370..b1618a2d70 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -1101,13 +1101,90 @@ Status SessionState::FinalizeSessionState(const std::basic_string 8) + + return (op_type == "Loop" || (op_type == "Scan" && since_version >= 9)); +} + +// The following method accumulates the locations of all inputs (implicit and explicit) +// to a control flow node at the current graph level. This information will be used in +// the allocation planner while determining the location of such inputs in the subgraph. +// This method will not be called for the main graph (there is no concept of "outer scope" for the main graph). +static Status OuterScopeNodeArgLocationAccumulator(const SequentialExecutionPlan& plan, + const OrtValueNameIdxMap& ort_value_name_to_idx_map, + const Node& parent_node, + const GraphViewer& subgraph, + /*out*/ std::unordered_map& outer_scope_arg_to_location_map) { + // Process implicit inputs to the node + auto process_implicit_input = [&plan, &ort_value_name_to_idx_map, + &outer_scope_arg_to_location_map](const NodeArg& input, size_t /*arg_idx*/) { + const auto& name = input.Name(); + OrtValueIndex index = -1; + ORT_RETURN_IF_ERROR(Index(ort_value_name_to_idx_map, name, index)); + outer_scope_arg_to_location_map.insert({name, plan.GetLocation(index)}); + return Status::OK(); + }; + + ORT_RETURN_IF_ERROR(Node::ForEachWithIndex(parent_node.ImplicitInputDefs(), process_implicit_input)); + + // Process explicit inputs to the node + // (they are passed through as explicit subgraph inputs and hence requires a re-mapping of names + // to their corresponding names in the inner nested subgraph(s) held by the node) + const auto& subgraph_inputs = subgraph.GetInputs(); + + auto process_input = [&plan, &ort_value_name_to_idx_map, &outer_scope_arg_to_location_map, + &subgraph_inputs](const NodeArg& input, size_t arg_idx) { + const auto& name = input.Name(); + OrtValueIndex index = -1; + ORT_RETURN_IF_ERROR(Index(ort_value_name_to_idx_map, name, index)); + + // Store the location of the outer scope value in the map using the subgraph input as the key + // as that will be the referenced name in the subgraph (i.e.) re-mapping of names is required + outer_scope_arg_to_location_map.insert({subgraph_inputs[arg_idx]->Name(), plan.GetLocation(index)}); + + return Status::OK(); + }; + + if (IsNodeWhereNodeInputsAreSameAsExplicitSubgraphInputs(parent_node)) { + return Node::ForEachWithIndex(parent_node.InputDefs(), process_input); + } + + return Status::OK(); +} + Status SessionState::FinalizeSessionStateImpl(const std::basic_string& graph_location, KernelRegistryManager& kernel_registry_manager, _In_opt_ const Node* parent_node, const SessionOptions& session_options, bool remove_initializers, - std::unordered_map& constant_initializers_use_count) { - CreateGraphInfo(); + std::unordered_map& constant_initializers_use_count, + const std::unordered_map& outer_scope_node_arg_to_location_map, + bool graph_info_already_created) { + if (!graph_info_already_created) { + CreateGraphInfo(); + } // ignore any outer scope args we don't know about. this can happen if a node contains multiple subgraphs. std::vector valid_outer_scope_node_args; @@ -1127,6 +1204,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_stringsecond; // recurse + + // We need to create graph info for the subgraphs because information accumulated there + // is used in OuterScopeNodeArgLocationAccumulator() + subgraph_session_state.CreateGraphInfo(); + + std::unordered_map subgraph_outer_scope_node_arg_to_location_map; + ORT_RETURN_IF_ERROR(OuterScopeNodeArgLocationAccumulator(*p_seq_exec_plan_, GetOrtValueNameIdxMap(), + node, + subgraph_session_state.GetGraphViewer(), + subgraph_outer_scope_node_arg_to_location_map)); ORT_RETURN_IF_ERROR(subgraph_session_state.FinalizeSessionStateImpl( - graph_location, kernel_registry_manager, &node, subgraph_session_options, remove_initializers, constant_initializers_use_count)); + graph_location, kernel_registry_manager, &node, subgraph_session_options, remove_initializers, + constant_initializers_use_count, subgraph_outer_scope_node_arg_to_location_map, true)); // setup all the info for handling the feeds and fetches used in subgraph execution auto* p_op_kernel = GetMutableKernel(node.Index()); diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index d1aa5a9d3a..c664b5bfe4 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -101,7 +101,6 @@ class SessionState { use_deterministic_compute_(use_deterministic_compute), enable_mem_reuse_(enable_mem_reuse), prepacked_weights_container_(prepacked_weights_container) { - SetupAllocators(); } @@ -268,7 +267,7 @@ class SessionState { const KernelCreateInfo& GetNodeKernelCreateInfo(NodeIndex node_index) const; /// Return SessionState for the given Node index and attribute name if found. - const SessionState* GetSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name) const; + const SessionState* GetSubgraphSessionState(NodeIndex index, const std::string& attribute_name) const; concurrency::ThreadPool* GetThreadPool() const noexcept { return thread_pool_; } concurrency::ThreadPool* GetInterOpThreadPool() const noexcept { return inter_op_thread_pool_; } @@ -368,7 +367,9 @@ class SessionState { _In_opt_ const Node* parent_node, const SessionOptions& session_options, bool remove_initializers, - std::unordered_map& constant_initializers_use_count); + std::unordered_map& constant_initializers_use_count, + const std::unordered_map& outer_scope_node_arg_to_location_map = {}, + bool graph_info_already_created = false); #ifdef ENABLE_TRAINING Status GeneratePatternGroupCache( diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 6ec20d0981..bf35b82beb 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -321,13 +321,11 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::GraphViewer& // implicit inputs to a node could come directly from a feed, so we need to make sure they have an entry too const auto& node_implicit_inputs = node.ImplicitInputDefs(); if (!node_implicit_inputs.empty()) { - // nested subgraph. for now map them to this node (which will be CPU based as all the control flow nodes - // are currently CPU based and they're the only ones that have implicit inputs) as the inputs will be passed as a - // feed when executing the subgraph and need to be in the mapping. - // in the future we want to recurse and find where the implicit input is actually used to try and avoid a - // copy to/from CPU to go through the control flow nodes where possible/applicable. - // the processing for the subgraph where the implicit input is consumed will do the real check on whether any - // copy to a different device is required + // In nested subgraphs, the location of the implicit input(s) is the location it + // is consumed in the subgraph if there is an explicit consumer. + // If the only consumer(s) are implicit consumers (i.e.) other control flow nodes, its + // location is the location of the value in the enclosing outer scope. + // All this is setup in the planner, we just use the location from the plan here. for (const auto& input_def : node_implicit_inputs) { int arg_index; ORT_RETURN_IF_ERROR(name_to_id.GetIdx(input_def->Name(), arg_index)); diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index 941661390f..6ab95736e8 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -12,15 +12,23 @@ #include "core/framework/op_kernel.h" #include "test/framework/model_builder_utils.h" #include "core/framework/allocation_planner.h" +#include "core/session/inference_session.h" #include "core/graph/model.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "core/util/thread_utils.h" #include "test/test_environment.h" #include "test/util/include/asserts.h" +#include "test/util/include/default_providers.h" using namespace ONNX_NAMESPACE; +// Explicitly provide a definition for the static const var 'GPU' in the OrtDevice struct, +// GCC 4.x doesn't seem to define this and it breaks the pipelines based on CentOS as it uses +// GCC 4.x. +// (This static var is referenced in `PassThroughExplicitAndImplicitSubgraphInputs` test) +const OrtDevice::DeviceType OrtDevice::GPU; + namespace onnxruntime { namespace test { @@ -154,9 +162,9 @@ class PlannerTest : public ::testing::Test { // some standard components used to build test-cases: Type float_type_; - std::unique_ptr<::onnxruntime::KernelDef> std_kernel_; // a unary kernel with no-aliasing and no-in-place - std::unique_ptr<::onnxruntime::KernelDef> in_place_kernel_; // a unary kernel with in-place - std::unique_ptr<::onnxruntime::KernelDef> external_outputs_kernel_; // an unary kernel with external outputs + std::unique_ptr<::onnxruntime::KernelDef> std_kernel_; // a unary kernel with no-aliasing and no-in-place + std::unique_ptr<::onnxruntime::KernelDef> in_place_kernel_; // a unary kernel with in-place + std::unique_ptr<::onnxruntime::KernelDef> external_outputs_kernel_; // an unary kernel with external outputs std::unordered_map name_to_arg_; std::vector> nodes_; @@ -270,7 +278,7 @@ class PlannerTest : public ::testing::Test { SequentialPlannerTestContext test_context(&shape_map_); status = SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph_), outer_scope_node_args, execution_providers_, - kernel_create_info_map, state_->GetOrtValueNameIdxMap(), test_context, + kernel_create_info_map, {}, state_->GetOrtValueNameIdxMap(), test_context, plan_); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); @@ -415,9 +423,9 @@ TEST_F(PlannerTest, ExternalOutputsTest) { std::string X1("X1"), X2("X2"), X3("X3"), X4("X4"); // graph structure: - AddExternalOutputsNode(X1, X2); // external-outputs operator; X1: input; X2: temporary - AddNormalNode(X2, X3); // normal operator; X3: temporary - AddNormalNode(X3, X4); // normal operator; X4: output + AddExternalOutputsNode(X1, X2); // external-outputs operator; X1: input; X2: temporary + AddNormalNode(X2, X3); // normal operator; X3: temporary + AddNormalNode(X3, X4); // normal operator; X4: output // simulate shape-inference results: Shape shape1{"M", "N"}; @@ -505,5 +513,223 @@ TEST_F(PlannerTest, PlanOutputTest) { } } +#ifdef USE_CUDA +TEST_F(PlannerTest, PassThroughExplicitAndImplicitSubgraphInputs) { + // Types + TypeProto float_tensor; + float_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_param("dim_param"); + + TypeProto int64_scalar; + int64_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64); + int64_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + + TypeProto bool_scalar; + bool_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_BOOL); + bool_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + + // The model has a main graph and 2 levels of nested subgraphs + // Main graph: 2 Abs nodes + one Loop node + // First level (Loop) subgraph: Identity (condition pass-through) + If node + // Second level subgraph(s): Then and Else branches: Both have an Add node + // The Add node adds 2 values: + // One value from the main graph ("abs_data_0_out") that is "implicitly" + // consumed by the Loop node and "passed through" to the If subgraphs. + // Another value from the main graph ("abs_data_1_out") that is "explicitly" + // consumed by the Loop node as a loop carried dependency and its name in + // the scope of the Loop node is "loop_state_var". + + // In the Loop subgraph, there are no explicit consumers of "abs_data_0_out" + // and "loop_state_var", there is only one implicit consumer - "If". + // We want to ensure that since there are no explicit consumers, the planned locations + // for these values in this subgraph are the same locations as their corresponding + // values in the outer scope, thus deferring any copies (if required) till the actual + // subgraph(s) they are explicitly consumed in. + auto create_model = [&float_tensor, &int64_scalar, &bool_scalar]() -> Model { + auto create_if_subgraph = [&float_tensor](bool is_then) -> GraphProto { + Model model("if_branch_subgraph", true, DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + auto& outer_scope_0 = graph.GetOrCreateNodeArg("loop_state_var", &float_tensor); + graph.AddOuterScopeNodeArg("loop_state_var"); + + auto& outer_scope_1 = graph.GetOrCreateNodeArg("abs_data_0_out", &float_tensor); + graph.AddOuterScopeNodeArg("abs_data_0_out"); + + auto& if_out = graph.GetOrCreateNodeArg(is_then ? "if_then_out" : "if_else_out", &float_tensor); + graph.AddNode("if_out", "Add", "add", {&outer_scope_0, &outer_scope_1}, {&if_out}); + + auto status = graph.Resolve(); + EXPECT_EQ(status, Status::OK()); + + return graph.ToGraphProto(); + }; + + auto create_loop_subgraph = [&create_if_subgraph, &float_tensor, &int64_scalar, &bool_scalar]() -> GraphProto { + Model model("loop_subgraph", true, DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + std::vector inputs; + std::vector outputs; + + /* Inputs: iter_num, cond_in, loop carried state variables. + + iter_num_in cond_in [loop_state_var] + (unused) | | + [Identity] [If] + | | + cond_out loop_state_var_out + */ + + // graph inputs + auto& iter_num_in = graph.GetOrCreateNodeArg("iter_num_in", &int64_scalar); + auto& cond_in = graph.GetOrCreateNodeArg("cond_in", &bool_scalar); + auto& loop_state_var = graph.GetOrCreateNodeArg("loop_state_var", &float_tensor); + + // graph outputs + auto& cond_out = graph.GetOrCreateNodeArg("cond_out", &bool_scalar); + auto& loop_state_var_out = graph.GetOrCreateNodeArg("loop_state_var_out", &float_tensor); + + // outer scope args + ORT_IGNORE_RETURN_VALUE(graph.GetOrCreateNodeArg("abs_data_0_out", &float_tensor)); + graph.AddOuterScopeNodeArg("abs_data_0_out"); + + // cond_in -> cond_out + { + inputs = {&cond_in}; + outputs = {&cond_out}; + + graph.AddNode("cond_in_identity", "Identity", "Forward cond_in to cond_out", inputs, outputs); + } + + // loop_state_var -> If(cond_in) -> loop_state_var_out + { + inputs = {&cond_in}; + outputs = {&loop_state_var_out}; + + auto& node = graph.AddNode("loop_var_out", "If", "If with loop_state_var as implicit_input", inputs, outputs); + node.AddAttribute("then_branch", create_if_subgraph(true)); + node.AddAttribute("else_branch", create_if_subgraph(false)); + } + + graph.SetInputs({&iter_num_in, &cond_in, &loop_state_var}); + graph.SetOutputs({&cond_out, &loop_state_var_out}); + + auto status = graph.Resolve(); + EXPECT_EQ(status, Status::OK()); + + return graph.ToGraphProto(); + }; + + onnxruntime::Model model("main_graph", false, ModelMetaData(), + PathString(), IOnnxRuntimeOpSchemaRegistryList(), + {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger()); + auto& main_graph = model.MainGraph(); + + // Abs-0 + auto& abs_data_0_in = main_graph.GetOrCreateNodeArg("abs_data_0_in", &float_tensor); + auto& abs_data_0_out = main_graph.GetOrCreateNodeArg("abs_data_0_out", &float_tensor); + std::vector abs_0_inputs = {&abs_data_0_in}; + std::vector abs_0_outputs = {&abs_data_0_out}; + main_graph.AddNode("abs_0", "Abs", "node abs", abs_0_inputs, abs_0_outputs); + + // Abs-1 + auto& abs_data_1_in = main_graph.GetOrCreateNodeArg("abs_data_1_in", &float_tensor); + auto& abs_data_1_out = main_graph.GetOrCreateNodeArg("abs_data_1_out", &float_tensor); + std::vector abs_1_inputs = {&abs_data_1_in}; + std::vector abs_1_outputs = {&abs_data_1_out}; + main_graph.AddNode("abs_1", "Abs", "node abs", abs_1_inputs, abs_1_outputs); + + // Loop + auto& iter_num_in = main_graph.GetOrCreateNodeArg("iter_num_in", &int64_scalar); + auto& cond_in = main_graph.GetOrCreateNodeArg("cond_in", &bool_scalar); + auto& loop_state_out_var = main_graph.GetOrCreateNodeArg("loop_state_out_var", &float_tensor); + + auto& loop_node = main_graph.AddNode("loop", "Loop", "Loop node", + {&iter_num_in, &cond_in, &abs_data_1_out}, + {&loop_state_out_var}); + loop_node.AddAttribute("body", create_loop_subgraph()); + + main_graph.SetInputs({&abs_data_0_in, &abs_data_1_in, &iter_num_in, &cond_in}); + main_graph.SetOutputs({&loop_state_out_var}); + + auto status = main_graph.Resolve(); + EXPECT_EQ(status, Status::OK()); + + return model; + }; + + // Create and load session + SessionOptions so; + InferenceSession sess{so, GetEnvironment()}; + + auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider()); + ASSERT_TRUE(status.IsOK()); + + std::string s1; + const bool rc = create_model().ToProto().SerializeToString(&s1); + EXPECT_EQ(rc, true); + std::stringstream sstr(s1); + + status = sess.Load(sstr); + ASSERT_TRUE(status.IsOK()); + + status = sess.Initialize(); + ASSERT_TRUE(status.IsOK()); + + // Check planned locations of values in the main graph that are implicit subgraph inputs + // and explicit subgraph inputs to the Loop node + + // Main graph (L0 graph) + const auto& main_graph_session_state = sess.GetSessionState(); + + { + const auto& main_graph_ort_value_index_map = main_graph_session_state.GetOrtValueNameIdxMap(); + const auto* main_graph_plan = main_graph_session_state.GetExecutionPlan(); + + OrtValueIndex abs_data_0_out_index; + main_graph_ort_value_index_map.GetIdx("abs_data_0_out", abs_data_0_out_index); + + OrtValueIndex abs_data_1_out_index; + main_graph_ort_value_index_map.GetIdx("abs_data_1_out", abs_data_1_out_index); + + EXPECT_EQ(main_graph_plan->allocation_plan[abs_data_0_out_index].location.device.Type(), OrtDevice::GPU); + EXPECT_EQ(main_graph_plan->allocation_plan[abs_data_1_out_index].location.device.Type(), OrtDevice::GPU); + } + + // First subgraph (Loop) (L1 graph) + // There are 3 nodes in the main level- Only one of them has a subgraph (Loop). + // Find that. + const SessionState* find_first_subgraph_session_state = nullptr; + for (size_t i = 0; i < 3; ++i) { + find_first_subgraph_session_state = main_graph_session_state.GetSubgraphSessionState(i, "body"); + if (find_first_subgraph_session_state) { + break; + } + } + + const auto& first_subgraph_session_state = *find_first_subgraph_session_state; + + { + const auto& first_subgraph_ort_value_index_map = first_subgraph_session_state.GetOrtValueNameIdxMap(); + const auto* first_subgraph_plan = first_subgraph_session_state.GetExecutionPlan(); + + OrtValueIndex abs_data_0_out_index; + first_subgraph_ort_value_index_map.GetIdx("abs_data_0_out", abs_data_0_out_index); + + // "abs_data_1_out" is "loop_state_var" in this scope as it was consumed as an explicit subgraph input + // to Loop's body subgraph + OrtValueIndex abs_data_1_out_index; + first_subgraph_ort_value_index_map.GetIdx("loop_state_var", abs_data_1_out_index); + + // There are no explicit consumers of "abs_data_0_out" and "loop_state_var (abs_data_1_out)" in this scope. + // There is only one implicit consumer "If". Hence, check that we are preserving the locations of these values + // from the outer scope, thus deferring any copies till the actual nested subgraph these values are used in. + EXPECT_EQ(first_subgraph_plan->allocation_plan[abs_data_0_out_index].location.device.Type(), OrtDevice::GPU); + EXPECT_EQ(first_subgraph_plan->allocation_plan[abs_data_1_out_index].location.device.Type(), OrtDevice::GPU); + } +} +#endif + } // namespace test } // namespace onnxruntime