From 44d08d80a03654b4e8e06d8a49056258a2050be8 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Wed, 2 Mar 2022 22:03:50 -0800 Subject: [PATCH] Add restriction to first usage in allocation planner (#10724) * Add restriction to first usage in allocation planner * change phrases * add UT Co-authored-by: Ubuntu --- .../core/framework/allocation_planner.cc | 8 +- .../test/framework/allocation_planner_test.cc | 111 ++++++++++++++++++ 2 files changed, 118 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index c0958a6e2c..5af5ac15c0 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -700,6 +700,8 @@ class PlannerImpl { const std::string& subgraph_kernel_create_info_map_key_base, size_t graph_depth, /*out*/ std::vector>& locations) { + // Iterate over nodes in current level firstly to record location of usages + // in current graph for (const auto& node : graph_viewer.Nodes()) { const auto& input_node_args = node.InputDefs(); size_t num_node_inputs = input_node_args.size(); @@ -759,7 +761,10 @@ class PlannerImpl { locations[wt_index].emplace_back( GetLocationForNodeInput(node_input_index, node, kernel_create_info_map)); } + } + // Iterate over nodes in current graph with subgraphs and recurse. + for (const auto& node : graph_viewer.Nodes()) { // If the node has subgraphs (i.e.) control flow nodes, // walk the nodes in those subgraphs as well to best determine // the location for the OrtValue corresponding to the weights @@ -790,7 +795,8 @@ class PlannerImpl { Status GeneratePlanForWeights() { // TODO: Move away from usage of vector of `OrtMemoryInfo`s per weight (initializer) // We do not need to maintain a vector of locations that a weight is used in. - // We only need to know the location of its first usage because: + // We only need to know the location of its first usage according to the nodes + // iteration rule in GeneratePlanForWeightsHelper() because: // (1) If the initializer is used in the graph level it is introduced in, then it can // only be used on one device as the Memcpy transformer will duplicate the initializer // (with a different name) in case it is used on multiple devices. diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index e38a092cd4..cf165bc046 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -28,6 +28,7 @@ using namespace ONNX_NAMESPACE; // GCC 4.x. // (This static var is referenced in some tests below) const OrtDevice::DeviceType OrtDevice::GPU; +const OrtDevice::DeviceType OrtDevice::CPU; namespace onnxruntime { namespace test { @@ -834,6 +835,7 @@ TEST_F(PlannerTest, LocationPlanningForPassThroughExplicitAndImplicitSubgraphInp EXPECT_EQ(first_subgraph_plan->allocation_plan[abs_data_1_out_index].location.device.Type(), OrtDevice::GPU); } } + TEST_F(PlannerTest, LocationPlanningForInitializersOnlyUsedInANestedSubgraph) { // This a simple model that has one outer scope initializer and an `If` node // and that initializer is ONLY used in nested subgraphs (both the `If` subgraphs). @@ -933,6 +935,115 @@ TEST_F(PlannerTest, LocationPlanningForInitializersOnlyUsedInANestedSubgraph) { EXPECT_EQ(main_graph_plan->allocation_plan[init_data_index].location.device.Type(), OrtDevice::GPU); } + +TEST_F(PlannerTest, LocationPlanningForInitializersUsedInMainGraphAndSubgraph) { + // This a simple model that has one outer scope initializer, an `If` node followed + // by a `TopK` node. The initializer is used in both nested subgraphs(`Add` consumes it + // and requires it on GPU) and main graph(the second input of `TopK` is required on CPU). + // The right location for the initializer should be CPU as no Memcpy will be inserted + // for a node in main graph that requires the input(initializer) on CPU if that initializer + // is placed on GPU by allocation planner. + TypeProto int_tensor; + int_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64); + int_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_param("dim_param"); + + TypeProto bool_scalar; + bool_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_BOOL); + bool_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + + auto create_model = [&int_tensor, &bool_scalar]() -> Model { + auto create_if_subgraph = [&int_tensor](bool is_then) -> GraphProto { + Model model("if_branch_subgraph", true, DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + auto& outer_scope_0 = graph.GetOrCreateNodeArg("abs_data_out", &int_tensor); + graph.AddOuterScopeNodeArg("abs_data_out"); + + auto& outer_scope_1 = graph.GetOrCreateNodeArg("init_data", &int_tensor); + graph.AddOuterScopeNodeArg("init_data"); + + auto& if_out = graph.GetOrCreateNodeArg(is_then ? "if_then_out" : "if_else_out", &int_tensor); + graph.AddNode("if_out", "Add", "add", {&outer_scope_0, &outer_scope_1}, {&if_out}); + + auto status = graph.Resolve(); + EXPECT_EQ(status, Status::OK()); + + return graph.ToGraphProto(); + }; + + onnxruntime::Model model("main_graph", false, ModelMetaData(), + PathString(), IOnnxRuntimeOpSchemaRegistryList(), + {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger()); + auto& main_graph = model.MainGraph(); + + // Abs-0 + auto& abs_data_in = main_graph.GetOrCreateNodeArg("abs_data_in", &int_tensor); + auto& abs_data_out = main_graph.GetOrCreateNodeArg("abs_data_out", &int_tensor); + main_graph.AddNode("abs_0", "Abs", "node abs", {&abs_data_in}, {&abs_data_out}); + + // If + auto& if_in = main_graph.GetOrCreateNodeArg("if_in", &bool_scalar); + auto& if_out = main_graph.GetOrCreateNodeArg("if_out", &int_tensor); + auto& node = main_graph.AddNode("if_out", "If", "If", {&if_in}, {&if_out}); + node.AddAttribute("then_branch", create_if_subgraph(true)); + node.AddAttribute("else_branch", create_if_subgraph(false)); + + // TopK + auto& topk_data_in_0 = main_graph.GetOrCreateNodeArg("if_out", &int_tensor); + auto& topk_data_in_1 = main_graph.GetOrCreateNodeArg("init_data", &int_tensor); + auto& topk_data_out_0 = main_graph.GetOrCreateNodeArg("topk_data_out_0", &int_tensor); + auto& topk_data_out_1 = main_graph.GetOrCreateNodeArg("topk_data_out_1", &int_tensor); + main_graph.AddNode("topk_0", "TopK", "node topk", {&topk_data_in_0, &topk_data_in_1}, + {&topk_data_out_0, &topk_data_out_1}); + + // Add initializer to the graph + ONNX_NAMESPACE::TensorProto tensor; + tensor.add_dims(1); + tensor.add_int64_data(1); + tensor.set_data_type(TensorProto_DataType_INT64); + tensor.set_name("init_data"); + main_graph.AddInitializedTensor(tensor); + + // Main graph's inputs/outputs + main_graph.SetInputs({&abs_data_in, &if_in}); + main_graph.SetOutputs({&topk_data_out_0, &topk_data_out_1}); + + auto status = main_graph.Resolve(); + EXPECT_EQ(status, Status::OK()); + + return model; + }; + + // Create and load session + SessionOptions so; + InferenceSession sess{so, GetEnvironment()}; + + auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider()); + ASSERT_TRUE(status.IsOK()); + + std::string s1; + const bool rc = create_model().ToProto().SerializeToString(&s1); + EXPECT_EQ(rc, true); + std::stringstream sstr(s1); + + status = sess.Load(sstr); + ASSERT_TRUE(status.IsOK()); + + status = sess.Initialize(); + ASSERT_TRUE(status.IsOK()); + + // Check planned locations for the initializer + const auto& main_graph_session_state = sess.GetSessionState(); + const auto& main_graph_ort_value_index_map = main_graph_session_state.GetOrtValueNameIdxMap(); + const auto* main_graph_plan = main_graph_session_state.GetExecutionPlan(); + + OrtValueIndex init_data_index; + ASSERT_STATUS_OK(main_graph_ort_value_index_map.GetIdx("init_data", init_data_index)); + + EXPECT_EQ(main_graph_plan->allocation_plan[init_data_index].location.device.Type(), OrtDevice::CPU); +} + #endif + } // namespace test } // namespace onnxruntime