Avoid round-trip copies for "pass through" subgraph inputs (#8702)

This commit is contained in:
Hariharan Seshadri 2021-08-30 21:30:01 -07:00 committed by GitHub
parent 42ba0c5931
commit 7659148d9f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 385 additions and 31 deletions

View file

@ -121,6 +121,7 @@ class PlannerImpl {
PlannerImpl(const Node* parent_node, const onnxruntime::GraphViewer& graph_viewer,
const std::vector<const NodeArg*>& outer_scope_node_args, const ExecutionProviders& providers,
const std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>>& kernel_create_info_map,
const std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_node_arg_to_location_map,
const OrtValueNameIdxMap& ort_value_name_idx_map,
const ISequentialPlannerContext& context, SequentialExecutionPlan& plan)
: context_(context),
@ -130,6 +131,7 @@ class PlannerImpl {
outer_scope_node_args_(outer_scope_node_args),
execution_providers_(providers),
kernel_create_info_map_(kernel_create_info_map),
outer_scope_node_arg_to_location_map_(outer_scope_node_arg_to_location_map),
ort_value_name_idx_map_(ort_value_name_idx_map) {}
Status CreatePlan();
@ -144,6 +146,9 @@ class PlannerImpl {
const ExecutionProviders& execution_providers_;
const std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>>& kernel_create_info_map_;
const std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_node_arg_to_location_map_;
const OrtValueNameIdxMap& ort_value_name_idx_map_;
// OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation:
@ -261,9 +266,9 @@ class PlannerImpl {
// Inputs of Yields are essentially the outputs for FW partial subgraph
// Thses tensors will be pass back to pytorch, thus cannot share the buffer with other tensors
// Unhandled corner case:
// Unhandled corner case:
// If FW output tensor is consumed by BW graph, and pytorch performs an inplace operation on th returned tensor,
// we will run into a buffer corruption problem.
// we will run into a buffer corruption problem.
// One potential fix is returning a copy of output tensor, if it has downstream dependency
auto p_next_node = node.OutputNodesBegin();
if (p_next_node != node.OutputNodesEnd() && p_next_node->OpType() == "YieldOp") {
@ -483,6 +488,8 @@ class PlannerImpl {
UseCount(initializer_name)++;
}
std::unordered_set<OrtValueIndex> set_node_arg_has_explicit_consumer;
for (SequentialExecutionPlan::NodeExecutionPlan& step : plan_.execution_plan) {
auto pnode = graph_viewer_.GetNode(step.node_index);
if (pnode == nullptr) {
@ -507,26 +514,56 @@ class PlannerImpl {
// increment UseCount and add location information if applicable for the provided input def
auto process_input = [&graph_inputs, &exec_provider, &p_kernel_def, &is_implicit_input,
&set_node_arg_has_explicit_consumer,
this](const NodeArg& input, size_t arg_idx) {
const auto& name = input.Name();
UseCount(name)++;
bool is_graph_input = (graph_inputs.find(name) != graph_inputs.cend());
bool is_outer_scope_arg = std::find_if(outer_scope_node_args_.cbegin(), outer_scope_node_args_.cend(),
[&name](const NodeArg* value) {
return value && value->Name() == name;
}) != outer_scope_node_args_.cend();
bool is_subgraph = (parent_node_ != nullptr);
// If it's a graph input or outer scope node arg, set its plan.
// NOTE: Copy nodes should have already been added if a graph input is fed as input
// to nodes assigned to different providers.
if (graph_inputs.find(name) != graph_inputs.cend() ||
std::find_if(outer_scope_node_args_.cbegin(), outer_scope_node_args_.cend(),
[&name](const NodeArg* value) {
return value && value->Name() == name;
}) != outer_scope_node_args_.cend()) {
if (is_graph_input || is_outer_scope_arg) {
OrtValueIndex index = Index(name);
// implicit inputs do not have an entry in the kernel def, so do nothing to them here, leaving the control
// flow op (Loop, Scan, If) to do the necessary copy if the input crosses different provider.
// matching logic is used in TransformerMemcpyImpl::ProcessDefs
if (!is_implicit_input) {
OrtMemType mem_type = p_kernel_def->InputMemoryType(arg_idx);
plan_.SetLocation(static_cast<size_t>(index), exec_provider->GetAllocator(0, mem_type)->Info());
set_node_arg_has_explicit_consumer.insert(index);
} else { // implicit input
// Only process an implicit input:
// 1) Within a subgraph
// 2) If there is no explicit consumer at this graph level
// If there is an explicit consumer, the location MUST be where it is consumed
// and not where it is located in the outer scope.
// It is okay if we process a node consuming this arg as an implicit input
// ahead of a node that is an explicit consumer, because we will just reset
// this location in the 'if' branch above.
if (is_subgraph && set_node_arg_has_explicit_consumer.count(index) == 0) {
auto iter = outer_scope_node_arg_to_location_map_.find(name);
bool found_in_outer_scope_location_map = (iter != outer_scope_node_arg_to_location_map_.end());
if (!is_graph_input) {
// Failing this enforce for an implicit subgraph input points to an internal error somewhere.
// For certain older opsets (Scan-8), we may not have added explicit subgraph inputs
// to the outer scope location map. See explanation in IsNodeWhereNodeInputsAreSameAsExplicitSubgraphInputs()
// called in FinalizeSessionStateImpl() in SessionState.
ORT_ENFORCE(found_in_outer_scope_location_map,
"There is no location for this node arg in the outer scope location map");
}
if (found_in_outer_scope_location_map) {
plan_.SetLocation(static_cast<size_t>(index), iter->second);
}
}
}
}
@ -1062,6 +1099,7 @@ Status SequentialPlanner::CreatePlan(
const std::vector<const NodeArg*>& outer_scope_node_args,
const ExecutionProviders& providers,
const std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>>& kernel_create_info_map,
const std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_node_arg_to_location_map,
const OrtValueNameIdxMap& ort_value_name_idx_map,
const ISequentialPlannerContext& context,
std::unique_ptr<SequentialExecutionPlan>& plan) {
@ -1069,7 +1107,8 @@ Status SequentialPlanner::CreatePlan(
plan = std::make_unique<SequentialExecutionPlan>();
PlannerImpl planner(parent_node, graph_viewer, outer_scope_node_args, providers,
kernel_create_info_map, ort_value_name_idx_map, context, *plan);
kernel_create_info_map, outer_scope_node_arg_to_location_map,
ort_value_name_idx_map, context, *plan);
return planner.CreatePlan();
}

View file

@ -66,6 +66,7 @@ class SequentialPlanner {
const std::vector<const NodeArg*>& outer_scope_node_args,
const ExecutionProviders& providers,
const std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>>& kernel_create_info_map,
const std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_arg_to_location_map,
const OrtValueNameIdxMap& ort_value_name_idx_map,
const ISequentialPlannerContext& context,
std::unique_ptr<SequentialExecutionPlan>& plan);

View file

@ -1101,13 +1101,90 @@ Status SessionState::FinalizeSessionState(const std::basic_string<PATH_CHAR_TYPE
remove_initializers, constant_initializers_use_count);
}
static Status Index(const OrtValueNameIdxMap& ort_value_name_idx_map,
const OrtValueName& name,
/*out*/ OrtValueIndex& value) {
return ort_value_name_idx_map.GetIdx(name, value);
}
static bool IsNodeWhereNodeInputsAreSameAsExplicitSubgraphInputs(const Node& node) {
const auto& op_type = node.OpType();
int since_version = node.SinceVersion();
// TODO: Re-visit this method if more subgraph ops get accepted into ONNX
// At the time of writing, there are only 3 ops in ONNX that have subgraphs
// 1) If
// 2) Loop
// 3) Scan
// `If` - The op doesn't have explicit subgraph inputs (return false)
// `Loop`- In all opset versions of Loop (at the time of writing) the node inputs
// have a one-to-one mapping between them and the explicit subgraph inputs
// of the subgraph held in the Loop (return true)
// `Scan` - Except opset 8 version of Scan (at the time of writing), all other
// versions have the same one-to-one mapping as Loop (return true for opset > 8)
return (op_type == "Loop" || (op_type == "Scan" && since_version >= 9));
}
// The following method accumulates the locations of all inputs (implicit and explicit)
// to a control flow node at the current graph level. This information will be used in
// the allocation planner while determining the location of such inputs in the subgraph.
// This method will not be called for the main graph (there is no concept of "outer scope" for the main graph).
static Status OuterScopeNodeArgLocationAccumulator(const SequentialExecutionPlan& plan,
const OrtValueNameIdxMap& ort_value_name_to_idx_map,
const Node& parent_node,
const GraphViewer& subgraph,
/*out*/ std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_arg_to_location_map) {
// Process implicit inputs to the node
auto process_implicit_input = [&plan, &ort_value_name_to_idx_map,
&outer_scope_arg_to_location_map](const NodeArg& input, size_t /*arg_idx*/) {
const auto& name = input.Name();
OrtValueIndex index = -1;
ORT_RETURN_IF_ERROR(Index(ort_value_name_to_idx_map, name, index));
outer_scope_arg_to_location_map.insert({name, plan.GetLocation(index)});
return Status::OK();
};
ORT_RETURN_IF_ERROR(Node::ForEachWithIndex(parent_node.ImplicitInputDefs(), process_implicit_input));
// Process explicit inputs to the node
// (they are passed through as explicit subgraph inputs and hence requires a re-mapping of names
// to their corresponding names in the inner nested subgraph(s) held by the node)
const auto& subgraph_inputs = subgraph.GetInputs();
auto process_input = [&plan, &ort_value_name_to_idx_map, &outer_scope_arg_to_location_map,
&subgraph_inputs](const NodeArg& input, size_t arg_idx) {
const auto& name = input.Name();
OrtValueIndex index = -1;
ORT_RETURN_IF_ERROR(Index(ort_value_name_to_idx_map, name, index));
// Store the location of the outer scope value in the map using the subgraph input as the key
// as that will be the referenced name in the subgraph (i.e.) re-mapping of names is required
outer_scope_arg_to_location_map.insert({subgraph_inputs[arg_idx]->Name(), plan.GetLocation(index)});
return Status::OK();
};
if (IsNodeWhereNodeInputsAreSameAsExplicitSubgraphInputs(parent_node)) {
return Node::ForEachWithIndex(parent_node.InputDefs(), process_input);
}
return Status::OK();
}
Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_TYPE>& graph_location,
KernelRegistryManager& kernel_registry_manager,
_In_opt_ const Node* parent_node,
const SessionOptions& session_options,
bool remove_initializers,
std::unordered_map<std::string, size_t>& constant_initializers_use_count) {
CreateGraphInfo();
std::unordered_map<std::string, size_t>& constant_initializers_use_count,
const std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_node_arg_to_location_map,
bool graph_info_already_created) {
if (!graph_info_already_created) {
CreateGraphInfo();
}
// ignore any outer scope args we don't know about. this can happen if a node contains multiple subgraphs.
std::vector<const NodeArg*> valid_outer_scope_node_args;
@ -1127,6 +1204,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
SequentialPlannerContext context(session_options.execution_mode, session_options.execution_order, session_options.enable_mem_reuse);
ORT_RETURN_IF_ERROR(SequentialPlanner::CreatePlan(parent_node, *graph_viewer_, valid_outer_scope_node_args,
execution_providers_, kernel_create_info_map_,
outer_scope_node_arg_to_location_map,
ort_value_name_idx_map_, context, p_seq_exec_plan_));
//Record the allocation plan
@ -1218,8 +1296,19 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
SessionState& subgraph_session_state = *entry->second;
// recurse
// We need to create graph info for the subgraphs because information accumulated there
// is used in OuterScopeNodeArgLocationAccumulator()
subgraph_session_state.CreateGraphInfo();
std::unordered_map<OrtValueName, OrtMemoryInfo> subgraph_outer_scope_node_arg_to_location_map;
ORT_RETURN_IF_ERROR(OuterScopeNodeArgLocationAccumulator(*p_seq_exec_plan_, GetOrtValueNameIdxMap(),
node,
subgraph_session_state.GetGraphViewer(),
subgraph_outer_scope_node_arg_to_location_map));
ORT_RETURN_IF_ERROR(subgraph_session_state.FinalizeSessionStateImpl(
graph_location, kernel_registry_manager, &node, subgraph_session_options, remove_initializers, constant_initializers_use_count));
graph_location, kernel_registry_manager, &node, subgraph_session_options, remove_initializers,
constant_initializers_use_count, subgraph_outer_scope_node_arg_to_location_map, true));
// setup all the info for handling the feeds and fetches used in subgraph execution
auto* p_op_kernel = GetMutableKernel(node.Index());

View file

@ -101,7 +101,6 @@ class SessionState {
use_deterministic_compute_(use_deterministic_compute),
enable_mem_reuse_(enable_mem_reuse),
prepacked_weights_container_(prepacked_weights_container) {
SetupAllocators();
}
@ -268,7 +267,7 @@ class SessionState {
const KernelCreateInfo& GetNodeKernelCreateInfo(NodeIndex node_index) const;
/// Return SessionState for the given Node index and attribute name if found.
const SessionState* GetSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name) const;
const SessionState* GetSubgraphSessionState(NodeIndex index, const std::string& attribute_name) const;
concurrency::ThreadPool* GetThreadPool() const noexcept { return thread_pool_; }
concurrency::ThreadPool* GetInterOpThreadPool() const noexcept { return inter_op_thread_pool_; }
@ -368,7 +367,9 @@ class SessionState {
_In_opt_ const Node* parent_node,
const SessionOptions& session_options,
bool remove_initializers,
std::unordered_map<std::string, size_t>& constant_initializers_use_count);
std::unordered_map<std::string, size_t>& constant_initializers_use_count,
const std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_node_arg_to_location_map = {},
bool graph_info_already_created = false);
#ifdef ENABLE_TRAINING
Status GeneratePatternGroupCache(

View file

@ -321,13 +321,11 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::GraphViewer&
// implicit inputs to a node could come directly from a feed, so we need to make sure they have an entry too
const auto& node_implicit_inputs = node.ImplicitInputDefs();
if (!node_implicit_inputs.empty()) {
// nested subgraph. for now map them to this node (which will be CPU based as all the control flow nodes
// are currently CPU based and they're the only ones that have implicit inputs) as the inputs will be passed as a
// feed when executing the subgraph and need to be in the mapping.
// in the future we want to recurse and find where the implicit input is actually used to try and avoid a
// copy to/from CPU to go through the control flow nodes where possible/applicable.
// the processing for the subgraph where the implicit input is consumed will do the real check on whether any
// copy to a different device is required
// In nested subgraphs, the location of the implicit input(s) is the location it
// is consumed in the subgraph if there is an explicit consumer.
// If the only consumer(s) are implicit consumers (i.e.) other control flow nodes, its
// location is the location of the value in the enclosing outer scope.
// All this is setup in the planner, we just use the location from the plan here.
for (const auto& input_def : node_implicit_inputs) {
int arg_index;
ORT_RETURN_IF_ERROR(name_to_id.GetIdx(input_def->Name(), arg_index));

View file

@ -12,15 +12,23 @@
#include "core/framework/op_kernel.h"
#include "test/framework/model_builder_utils.h"
#include "core/framework/allocation_planner.h"
#include "core/session/inference_session.h"
#include "core/graph/model.h"
#include "core/providers/cpu/cpu_execution_provider.h"
#include "core/util/thread_utils.h"
#include "test/test_environment.h"
#include "test/util/include/asserts.h"
#include "test/util/include/default_providers.h"
using namespace ONNX_NAMESPACE;
// Explicitly provide a definition for the static const var 'GPU' in the OrtDevice struct,
// GCC 4.x doesn't seem to define this and it breaks the pipelines based on CentOS as it uses
// GCC 4.x.
// (This static var is referenced in `PassThroughExplicitAndImplicitSubgraphInputs` test)
const OrtDevice::DeviceType OrtDevice::GPU;
namespace onnxruntime {
namespace test {
@ -154,9 +162,9 @@ class PlannerTest : public ::testing::Test {
// some standard components used to build test-cases:
Type float_type_;
std::unique_ptr<::onnxruntime::KernelDef> std_kernel_; // a unary kernel with no-aliasing and no-in-place
std::unique_ptr<::onnxruntime::KernelDef> in_place_kernel_; // a unary kernel with in-place
std::unique_ptr<::onnxruntime::KernelDef> external_outputs_kernel_; // an unary kernel with external outputs
std::unique_ptr<::onnxruntime::KernelDef> std_kernel_; // a unary kernel with no-aliasing and no-in-place
std::unique_ptr<::onnxruntime::KernelDef> in_place_kernel_; // a unary kernel with in-place
std::unique_ptr<::onnxruntime::KernelDef> external_outputs_kernel_; // an unary kernel with external outputs
std::unordered_map<std::string, onnxruntime::NodeArg*> name_to_arg_;
std::vector<std::unique_ptr<UnaryNode>> nodes_;
@ -270,7 +278,7 @@ class PlannerTest : public ::testing::Test {
SequentialPlannerTestContext test_context(&shape_map_);
status = SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph_), outer_scope_node_args, execution_providers_,
kernel_create_info_map, state_->GetOrtValueNameIdxMap(), test_context,
kernel_create_info_map, {}, state_->GetOrtValueNameIdxMap(), test_context,
plan_);
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
@ -415,9 +423,9 @@ TEST_F(PlannerTest, ExternalOutputsTest) {
std::string X1("X1"), X2("X2"), X3("X3"), X4("X4");
// graph structure:
AddExternalOutputsNode(X1, X2); // external-outputs operator; X1: input; X2: temporary
AddNormalNode(X2, X3); // normal operator; X3: temporary
AddNormalNode(X3, X4); // normal operator; X4: output
AddExternalOutputsNode(X1, X2); // external-outputs operator; X1: input; X2: temporary
AddNormalNode(X2, X3); // normal operator; X3: temporary
AddNormalNode(X3, X4); // normal operator; X4: output
// simulate shape-inference results:
Shape shape1{"M", "N"};
@ -505,5 +513,223 @@ TEST_F(PlannerTest, PlanOutputTest) {
}
}
#ifdef USE_CUDA
TEST_F(PlannerTest, PassThroughExplicitAndImplicitSubgraphInputs) {
// Types
TypeProto float_tensor;
float_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_param("dim_param");
TypeProto int64_scalar;
int64_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64);
int64_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
TypeProto bool_scalar;
bool_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_BOOL);
bool_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
// The model has a main graph and 2 levels of nested subgraphs
// Main graph: 2 Abs nodes + one Loop node
// First level (Loop) subgraph: Identity (condition pass-through) + If node
// Second level subgraph(s): Then and Else branches: Both have an Add node
// The Add node adds 2 values:
// One value from the main graph ("abs_data_0_out") that is "implicitly"
// consumed by the Loop node and "passed through" to the If subgraphs.
// Another value from the main graph ("abs_data_1_out") that is "explicitly"
// consumed by the Loop node as a loop carried dependency and its name in
// the scope of the Loop node is "loop_state_var".
// In the Loop subgraph, there are no explicit consumers of "abs_data_0_out"
// and "loop_state_var", there is only one implicit consumer - "If".
// We want to ensure that since there are no explicit consumers, the planned locations
// for these values in this subgraph are the same locations as their corresponding
// values in the outer scope, thus deferring any copies (if required) till the actual
// subgraph(s) they are explicitly consumed in.
auto create_model = [&float_tensor, &int64_scalar, &bool_scalar]() -> Model {
auto create_if_subgraph = [&float_tensor](bool is_then) -> GraphProto {
Model model("if_branch_subgraph", true, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
auto& outer_scope_0 = graph.GetOrCreateNodeArg("loop_state_var", &float_tensor);
graph.AddOuterScopeNodeArg("loop_state_var");
auto& outer_scope_1 = graph.GetOrCreateNodeArg("abs_data_0_out", &float_tensor);
graph.AddOuterScopeNodeArg("abs_data_0_out");
auto& if_out = graph.GetOrCreateNodeArg(is_then ? "if_then_out" : "if_else_out", &float_tensor);
graph.AddNode("if_out", "Add", "add", {&outer_scope_0, &outer_scope_1}, {&if_out});
auto status = graph.Resolve();
EXPECT_EQ(status, Status::OK());
return graph.ToGraphProto();
};
auto create_loop_subgraph = [&create_if_subgraph, &float_tensor, &int64_scalar, &bool_scalar]() -> GraphProto {
Model model("loop_subgraph", true, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
std::vector<NodeArg*> inputs;
std::vector<NodeArg*> outputs;
/* Inputs: iter_num, cond_in, loop carried state variables.
iter_num_in cond_in [loop_state_var]
(unused) | |
[Identity] [If]
| |
cond_out loop_state_var_out
*/
// graph inputs
auto& iter_num_in = graph.GetOrCreateNodeArg("iter_num_in", &int64_scalar);
auto& cond_in = graph.GetOrCreateNodeArg("cond_in", &bool_scalar);
auto& loop_state_var = graph.GetOrCreateNodeArg("loop_state_var", &float_tensor);
// graph outputs
auto& cond_out = graph.GetOrCreateNodeArg("cond_out", &bool_scalar);
auto& loop_state_var_out = graph.GetOrCreateNodeArg("loop_state_var_out", &float_tensor);
// outer scope args
ORT_IGNORE_RETURN_VALUE(graph.GetOrCreateNodeArg("abs_data_0_out", &float_tensor));
graph.AddOuterScopeNodeArg("abs_data_0_out");
// cond_in -> cond_out
{
inputs = {&cond_in};
outputs = {&cond_out};
graph.AddNode("cond_in_identity", "Identity", "Forward cond_in to cond_out", inputs, outputs);
}
// loop_state_var -> If(cond_in) -> loop_state_var_out
{
inputs = {&cond_in};
outputs = {&loop_state_var_out};
auto& node = graph.AddNode("loop_var_out", "If", "If with loop_state_var as implicit_input", inputs, outputs);
node.AddAttribute("then_branch", create_if_subgraph(true));
node.AddAttribute("else_branch", create_if_subgraph(false));
}
graph.SetInputs({&iter_num_in, &cond_in, &loop_state_var});
graph.SetOutputs({&cond_out, &loop_state_var_out});
auto status = graph.Resolve();
EXPECT_EQ(status, Status::OK());
return graph.ToGraphProto();
};
onnxruntime::Model model("main_graph", false, ModelMetaData(),
PathString(), IOnnxRuntimeOpSchemaRegistryList(),
{{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger());
auto& main_graph = model.MainGraph();
// Abs-0
auto& abs_data_0_in = main_graph.GetOrCreateNodeArg("abs_data_0_in", &float_tensor);
auto& abs_data_0_out = main_graph.GetOrCreateNodeArg("abs_data_0_out", &float_tensor);
std::vector<onnxruntime::NodeArg*> abs_0_inputs = {&abs_data_0_in};
std::vector<onnxruntime::NodeArg*> abs_0_outputs = {&abs_data_0_out};
main_graph.AddNode("abs_0", "Abs", "node abs", abs_0_inputs, abs_0_outputs);
// Abs-1
auto& abs_data_1_in = main_graph.GetOrCreateNodeArg("abs_data_1_in", &float_tensor);
auto& abs_data_1_out = main_graph.GetOrCreateNodeArg("abs_data_1_out", &float_tensor);
std::vector<onnxruntime::NodeArg*> abs_1_inputs = {&abs_data_1_in};
std::vector<onnxruntime::NodeArg*> abs_1_outputs = {&abs_data_1_out};
main_graph.AddNode("abs_1", "Abs", "node abs", abs_1_inputs, abs_1_outputs);
// Loop
auto& iter_num_in = main_graph.GetOrCreateNodeArg("iter_num_in", &int64_scalar);
auto& cond_in = main_graph.GetOrCreateNodeArg("cond_in", &bool_scalar);
auto& loop_state_out_var = main_graph.GetOrCreateNodeArg("loop_state_out_var", &float_tensor);
auto& loop_node = main_graph.AddNode("loop", "Loop", "Loop node",
{&iter_num_in, &cond_in, &abs_data_1_out},
{&loop_state_out_var});
loop_node.AddAttribute("body", create_loop_subgraph());
main_graph.SetInputs({&abs_data_0_in, &abs_data_1_in, &iter_num_in, &cond_in});
main_graph.SetOutputs({&loop_state_out_var});
auto status = main_graph.Resolve();
EXPECT_EQ(status, Status::OK());
return model;
};
// Create and load session
SessionOptions so;
InferenceSession sess{so, GetEnvironment()};
auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider());
ASSERT_TRUE(status.IsOK());
std::string s1;
const bool rc = create_model().ToProto().SerializeToString(&s1);
EXPECT_EQ(rc, true);
std::stringstream sstr(s1);
status = sess.Load(sstr);
ASSERT_TRUE(status.IsOK());
status = sess.Initialize();
ASSERT_TRUE(status.IsOK());
// Check planned locations of values in the main graph that are implicit subgraph inputs
// and explicit subgraph inputs to the Loop node
// Main graph (L0 graph)
const auto& main_graph_session_state = sess.GetSessionState();
{
const auto& main_graph_ort_value_index_map = main_graph_session_state.GetOrtValueNameIdxMap();
const auto* main_graph_plan = main_graph_session_state.GetExecutionPlan();
OrtValueIndex abs_data_0_out_index;
main_graph_ort_value_index_map.GetIdx("abs_data_0_out", abs_data_0_out_index);
OrtValueIndex abs_data_1_out_index;
main_graph_ort_value_index_map.GetIdx("abs_data_1_out", abs_data_1_out_index);
EXPECT_EQ(main_graph_plan->allocation_plan[abs_data_0_out_index].location.device.Type(), OrtDevice::GPU);
EXPECT_EQ(main_graph_plan->allocation_plan[abs_data_1_out_index].location.device.Type(), OrtDevice::GPU);
}
// First subgraph (Loop) (L1 graph)
// There are 3 nodes in the main level- Only one of them has a subgraph (Loop).
// Find that.
const SessionState* find_first_subgraph_session_state = nullptr;
for (size_t i = 0; i < 3; ++i) {
find_first_subgraph_session_state = main_graph_session_state.GetSubgraphSessionState(i, "body");
if (find_first_subgraph_session_state) {
break;
}
}
const auto& first_subgraph_session_state = *find_first_subgraph_session_state;
{
const auto& first_subgraph_ort_value_index_map = first_subgraph_session_state.GetOrtValueNameIdxMap();
const auto* first_subgraph_plan = first_subgraph_session_state.GetExecutionPlan();
OrtValueIndex abs_data_0_out_index;
first_subgraph_ort_value_index_map.GetIdx("abs_data_0_out", abs_data_0_out_index);
// "abs_data_1_out" is "loop_state_var" in this scope as it was consumed as an explicit subgraph input
// to Loop's body subgraph
OrtValueIndex abs_data_1_out_index;
first_subgraph_ort_value_index_map.GetIdx("loop_state_var", abs_data_1_out_index);
// There are no explicit consumers of "abs_data_0_out" and "loop_state_var (abs_data_1_out)" in this scope.
// There is only one implicit consumer "If". Hence, check that we are preserving the locations of these values
// from the outer scope, thus deferring any copies till the actual nested subgraph these values are used in.
EXPECT_EQ(first_subgraph_plan->allocation_plan[abs_data_0_out_index].location.device.Type(), OrtDevice::GPU);
EXPECT_EQ(first_subgraph_plan->allocation_plan[abs_data_1_out_index].location.device.Type(), OrtDevice::GPU);
}
}
#endif
} // namespace test
} // namespace onnxruntime