Add restriction to first usage in allocation planner (#10724)

* Add restriction to first usage in allocation planner

* change phrases

* add UT

Co-authored-by: Ubuntu <wy@linux-v100.aidmrjtolptuzevavgwhrapqcd.jx.internal.cloudapp.net>
This commit is contained in:
Ye Wang 2022-03-02 22:03:50 -08:00 committed by GitHub
parent 47ab0c2006
commit 44d08d80a0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 118 additions and 1 deletions

View file

@ -700,6 +700,8 @@ class PlannerImpl {
const std::string& subgraph_kernel_create_info_map_key_base,
size_t graph_depth,
/*out*/ std::vector<std::vector<OrtMemoryInfo>>& locations) {
// Iterate over nodes in current level firstly to record location of usages
// in current graph
for (const auto& node : graph_viewer.Nodes()) {
const auto& input_node_args = node.InputDefs();
size_t num_node_inputs = input_node_args.size();
@ -759,7 +761,10 @@ class PlannerImpl {
locations[wt_index].emplace_back(
GetLocationForNodeInput(node_input_index, node, kernel_create_info_map));
}
}
// Iterate over nodes in current graph with subgraphs and recurse.
for (const auto& node : graph_viewer.Nodes()) {
// If the node has subgraphs (i.e.) control flow nodes,
// walk the nodes in those subgraphs as well to best determine
// the location for the OrtValue corresponding to the weights
@ -790,7 +795,8 @@ class PlannerImpl {
Status GeneratePlanForWeights() {
// TODO: Move away from usage of vector of `OrtMemoryInfo`s per weight (initializer)
// We do not need to maintain a vector of locations that a weight is used in.
// We only need to know the location of its first usage because:
// We only need to know the location of its first usage according to the nodes
// iteration rule in GeneratePlanForWeightsHelper() because:
// (1) If the initializer is used in the graph level it is introduced in, then it can
// only be used on one device as the Memcpy transformer will duplicate the initializer
// (with a different name) in case it is used on multiple devices.

View file

@ -28,6 +28,7 @@ using namespace ONNX_NAMESPACE;
// GCC 4.x.
// (This static var is referenced in some tests below)
const OrtDevice::DeviceType OrtDevice::GPU;
const OrtDevice::DeviceType OrtDevice::CPU;
namespace onnxruntime {
namespace test {
@ -834,6 +835,7 @@ TEST_F(PlannerTest, LocationPlanningForPassThroughExplicitAndImplicitSubgraphInp
EXPECT_EQ(first_subgraph_plan->allocation_plan[abs_data_1_out_index].location.device.Type(), OrtDevice::GPU);
}
}
TEST_F(PlannerTest, LocationPlanningForInitializersOnlyUsedInANestedSubgraph) {
// This a simple model that has one outer scope initializer and an `If` node
// and that initializer is ONLY used in nested subgraphs (both the `If` subgraphs).
@ -933,6 +935,115 @@ TEST_F(PlannerTest, LocationPlanningForInitializersOnlyUsedInANestedSubgraph) {
EXPECT_EQ(main_graph_plan->allocation_plan[init_data_index].location.device.Type(), OrtDevice::GPU);
}
TEST_F(PlannerTest, LocationPlanningForInitializersUsedInMainGraphAndSubgraph) {
// This a simple model that has one outer scope initializer, an `If` node followed
// by a `TopK` node. The initializer is used in both nested subgraphs(`Add` consumes it
// and requires it on GPU) and main graph(the second input of `TopK` is required on CPU).
// The right location for the initializer should be CPU as no Memcpy will be inserted
// for a node in main graph that requires the input(initializer) on CPU if that initializer
// is placed on GPU by allocation planner.
TypeProto int_tensor;
int_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64);
int_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_param("dim_param");
TypeProto bool_scalar;
bool_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_BOOL);
bool_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
auto create_model = [&int_tensor, &bool_scalar]() -> Model {
auto create_if_subgraph = [&int_tensor](bool is_then) -> GraphProto {
Model model("if_branch_subgraph", true, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
auto& outer_scope_0 = graph.GetOrCreateNodeArg("abs_data_out", &int_tensor);
graph.AddOuterScopeNodeArg("abs_data_out");
auto& outer_scope_1 = graph.GetOrCreateNodeArg("init_data", &int_tensor);
graph.AddOuterScopeNodeArg("init_data");
auto& if_out = graph.GetOrCreateNodeArg(is_then ? "if_then_out" : "if_else_out", &int_tensor);
graph.AddNode("if_out", "Add", "add", {&outer_scope_0, &outer_scope_1}, {&if_out});
auto status = graph.Resolve();
EXPECT_EQ(status, Status::OK());
return graph.ToGraphProto();
};
onnxruntime::Model model("main_graph", false, ModelMetaData(),
PathString(), IOnnxRuntimeOpSchemaRegistryList(),
{{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger());
auto& main_graph = model.MainGraph();
// Abs-0
auto& abs_data_in = main_graph.GetOrCreateNodeArg("abs_data_in", &int_tensor);
auto& abs_data_out = main_graph.GetOrCreateNodeArg("abs_data_out", &int_tensor);
main_graph.AddNode("abs_0", "Abs", "node abs", {&abs_data_in}, {&abs_data_out});
// If
auto& if_in = main_graph.GetOrCreateNodeArg("if_in", &bool_scalar);
auto& if_out = main_graph.GetOrCreateNodeArg("if_out", &int_tensor);
auto& node = main_graph.AddNode("if_out", "If", "If", {&if_in}, {&if_out});
node.AddAttribute("then_branch", create_if_subgraph(true));
node.AddAttribute("else_branch", create_if_subgraph(false));
// TopK
auto& topk_data_in_0 = main_graph.GetOrCreateNodeArg("if_out", &int_tensor);
auto& topk_data_in_1 = main_graph.GetOrCreateNodeArg("init_data", &int_tensor);
auto& topk_data_out_0 = main_graph.GetOrCreateNodeArg("topk_data_out_0", &int_tensor);
auto& topk_data_out_1 = main_graph.GetOrCreateNodeArg("topk_data_out_1", &int_tensor);
main_graph.AddNode("topk_0", "TopK", "node topk", {&topk_data_in_0, &topk_data_in_1},
{&topk_data_out_0, &topk_data_out_1});
// Add initializer to the graph
ONNX_NAMESPACE::TensorProto tensor;
tensor.add_dims(1);
tensor.add_int64_data(1);
tensor.set_data_type(TensorProto_DataType_INT64);
tensor.set_name("init_data");
main_graph.AddInitializedTensor(tensor);
// Main graph's inputs/outputs
main_graph.SetInputs({&abs_data_in, &if_in});
main_graph.SetOutputs({&topk_data_out_0, &topk_data_out_1});
auto status = main_graph.Resolve();
EXPECT_EQ(status, Status::OK());
return model;
};
// Create and load session
SessionOptions so;
InferenceSession sess{so, GetEnvironment()};
auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider());
ASSERT_TRUE(status.IsOK());
std::string s1;
const bool rc = create_model().ToProto().SerializeToString(&s1);
EXPECT_EQ(rc, true);
std::stringstream sstr(s1);
status = sess.Load(sstr);
ASSERT_TRUE(status.IsOK());
status = sess.Initialize();
ASSERT_TRUE(status.IsOK());
// Check planned locations for the initializer
const auto& main_graph_session_state = sess.GetSessionState();
const auto& main_graph_ort_value_index_map = main_graph_session_state.GetOrtValueNameIdxMap();
const auto* main_graph_plan = main_graph_session_state.GetExecutionPlan();
OrtValueIndex init_data_index;
ASSERT_STATUS_OK(main_graph_ort_value_index_map.GetIdx("init_data", init_data_index));
EXPECT_EQ(main_graph_plan->allocation_plan[init_data_index].location.device.Type(), OrtDevice::CPU);
}
#endif
} // namespace test
} // namespace onnxruntime