mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Add restriction to first usage in allocation planner (#10724)
* Add restriction to first usage in allocation planner * change phrases * add UT Co-authored-by: Ubuntu <wy@linux-v100.aidmrjtolptuzevavgwhrapqcd.jx.internal.cloudapp.net>
This commit is contained in:
parent
47ab0c2006
commit
44d08d80a0
2 changed files with 118 additions and 1 deletions
|
|
@ -700,6 +700,8 @@ class PlannerImpl {
|
|||
const std::string& subgraph_kernel_create_info_map_key_base,
|
||||
size_t graph_depth,
|
||||
/*out*/ std::vector<std::vector<OrtMemoryInfo>>& locations) {
|
||||
// Iterate over nodes in current level firstly to record location of usages
|
||||
// in current graph
|
||||
for (const auto& node : graph_viewer.Nodes()) {
|
||||
const auto& input_node_args = node.InputDefs();
|
||||
size_t num_node_inputs = input_node_args.size();
|
||||
|
|
@ -759,7 +761,10 @@ class PlannerImpl {
|
|||
locations[wt_index].emplace_back(
|
||||
GetLocationForNodeInput(node_input_index, node, kernel_create_info_map));
|
||||
}
|
||||
}
|
||||
|
||||
// Iterate over nodes in current graph with subgraphs and recurse.
|
||||
for (const auto& node : graph_viewer.Nodes()) {
|
||||
// If the node has subgraphs (i.e.) control flow nodes,
|
||||
// walk the nodes in those subgraphs as well to best determine
|
||||
// the location for the OrtValue corresponding to the weights
|
||||
|
|
@ -790,7 +795,8 @@ class PlannerImpl {
|
|||
Status GeneratePlanForWeights() {
|
||||
// TODO: Move away from usage of vector of `OrtMemoryInfo`s per weight (initializer)
|
||||
// We do not need to maintain a vector of locations that a weight is used in.
|
||||
// We only need to know the location of its first usage because:
|
||||
// We only need to know the location of its first usage according to the nodes
|
||||
// iteration rule in GeneratePlanForWeightsHelper() because:
|
||||
// (1) If the initializer is used in the graph level it is introduced in, then it can
|
||||
// only be used on one device as the Memcpy transformer will duplicate the initializer
|
||||
// (with a different name) in case it is used on multiple devices.
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ using namespace ONNX_NAMESPACE;
|
|||
// GCC 4.x.
|
||||
// (This static var is referenced in some tests below)
|
||||
const OrtDevice::DeviceType OrtDevice::GPU;
|
||||
const OrtDevice::DeviceType OrtDevice::CPU;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
|
@ -834,6 +835,7 @@ TEST_F(PlannerTest, LocationPlanningForPassThroughExplicitAndImplicitSubgraphInp
|
|||
EXPECT_EQ(first_subgraph_plan->allocation_plan[abs_data_1_out_index].location.device.Type(), OrtDevice::GPU);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(PlannerTest, LocationPlanningForInitializersOnlyUsedInANestedSubgraph) {
|
||||
// This a simple model that has one outer scope initializer and an `If` node
|
||||
// and that initializer is ONLY used in nested subgraphs (both the `If` subgraphs).
|
||||
|
|
@ -933,6 +935,115 @@ TEST_F(PlannerTest, LocationPlanningForInitializersOnlyUsedInANestedSubgraph) {
|
|||
|
||||
EXPECT_EQ(main_graph_plan->allocation_plan[init_data_index].location.device.Type(), OrtDevice::GPU);
|
||||
}
|
||||
|
||||
TEST_F(PlannerTest, LocationPlanningForInitializersUsedInMainGraphAndSubgraph) {
|
||||
// This a simple model that has one outer scope initializer, an `If` node followed
|
||||
// by a `TopK` node. The initializer is used in both nested subgraphs(`Add` consumes it
|
||||
// and requires it on GPU) and main graph(the second input of `TopK` is required on CPU).
|
||||
// The right location for the initializer should be CPU as no Memcpy will be inserted
|
||||
// for a node in main graph that requires the input(initializer) on CPU if that initializer
|
||||
// is placed on GPU by allocation planner.
|
||||
TypeProto int_tensor;
|
||||
int_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64);
|
||||
int_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_param("dim_param");
|
||||
|
||||
TypeProto bool_scalar;
|
||||
bool_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_BOOL);
|
||||
bool_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
|
||||
auto create_model = [&int_tensor, &bool_scalar]() -> Model {
|
||||
auto create_if_subgraph = [&int_tensor](bool is_then) -> GraphProto {
|
||||
Model model("if_branch_subgraph", true, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
auto& outer_scope_0 = graph.GetOrCreateNodeArg("abs_data_out", &int_tensor);
|
||||
graph.AddOuterScopeNodeArg("abs_data_out");
|
||||
|
||||
auto& outer_scope_1 = graph.GetOrCreateNodeArg("init_data", &int_tensor);
|
||||
graph.AddOuterScopeNodeArg("init_data");
|
||||
|
||||
auto& if_out = graph.GetOrCreateNodeArg(is_then ? "if_then_out" : "if_else_out", &int_tensor);
|
||||
graph.AddNode("if_out", "Add", "add", {&outer_scope_0, &outer_scope_1}, {&if_out});
|
||||
|
||||
auto status = graph.Resolve();
|
||||
EXPECT_EQ(status, Status::OK());
|
||||
|
||||
return graph.ToGraphProto();
|
||||
};
|
||||
|
||||
onnxruntime::Model model("main_graph", false, ModelMetaData(),
|
||||
PathString(), IOnnxRuntimeOpSchemaRegistryList(),
|
||||
{{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger());
|
||||
auto& main_graph = model.MainGraph();
|
||||
|
||||
// Abs-0
|
||||
auto& abs_data_in = main_graph.GetOrCreateNodeArg("abs_data_in", &int_tensor);
|
||||
auto& abs_data_out = main_graph.GetOrCreateNodeArg("abs_data_out", &int_tensor);
|
||||
main_graph.AddNode("abs_0", "Abs", "node abs", {&abs_data_in}, {&abs_data_out});
|
||||
|
||||
// If
|
||||
auto& if_in = main_graph.GetOrCreateNodeArg("if_in", &bool_scalar);
|
||||
auto& if_out = main_graph.GetOrCreateNodeArg("if_out", &int_tensor);
|
||||
auto& node = main_graph.AddNode("if_out", "If", "If", {&if_in}, {&if_out});
|
||||
node.AddAttribute("then_branch", create_if_subgraph(true));
|
||||
node.AddAttribute("else_branch", create_if_subgraph(false));
|
||||
|
||||
// TopK
|
||||
auto& topk_data_in_0 = main_graph.GetOrCreateNodeArg("if_out", &int_tensor);
|
||||
auto& topk_data_in_1 = main_graph.GetOrCreateNodeArg("init_data", &int_tensor);
|
||||
auto& topk_data_out_0 = main_graph.GetOrCreateNodeArg("topk_data_out_0", &int_tensor);
|
||||
auto& topk_data_out_1 = main_graph.GetOrCreateNodeArg("topk_data_out_1", &int_tensor);
|
||||
main_graph.AddNode("topk_0", "TopK", "node topk", {&topk_data_in_0, &topk_data_in_1},
|
||||
{&topk_data_out_0, &topk_data_out_1});
|
||||
|
||||
// Add initializer to the graph
|
||||
ONNX_NAMESPACE::TensorProto tensor;
|
||||
tensor.add_dims(1);
|
||||
tensor.add_int64_data(1);
|
||||
tensor.set_data_type(TensorProto_DataType_INT64);
|
||||
tensor.set_name("init_data");
|
||||
main_graph.AddInitializedTensor(tensor);
|
||||
|
||||
// Main graph's inputs/outputs
|
||||
main_graph.SetInputs({&abs_data_in, &if_in});
|
||||
main_graph.SetOutputs({&topk_data_out_0, &topk_data_out_1});
|
||||
|
||||
auto status = main_graph.Resolve();
|
||||
EXPECT_EQ(status, Status::OK());
|
||||
|
||||
return model;
|
||||
};
|
||||
|
||||
// Create and load session
|
||||
SessionOptions so;
|
||||
InferenceSession sess{so, GetEnvironment()};
|
||||
|
||||
auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider());
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
|
||||
std::string s1;
|
||||
const bool rc = create_model().ToProto().SerializeToString(&s1);
|
||||
EXPECT_EQ(rc, true);
|
||||
std::stringstream sstr(s1);
|
||||
|
||||
status = sess.Load(sstr);
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
|
||||
status = sess.Initialize();
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
|
||||
// Check planned locations for the initializer
|
||||
const auto& main_graph_session_state = sess.GetSessionState();
|
||||
const auto& main_graph_ort_value_index_map = main_graph_session_state.GetOrtValueNameIdxMap();
|
||||
const auto* main_graph_plan = main_graph_session_state.GetExecutionPlan();
|
||||
|
||||
OrtValueIndex init_data_index;
|
||||
ASSERT_STATUS_OK(main_graph_ort_value_index_map.GetIdx("init_data", init_data_index));
|
||||
|
||||
EXPECT_EQ(main_graph_plan->allocation_plan[init_data_index].location.device.Type(), OrtDevice::CPU);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue