mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Handle implicit subgraph inputs required on different devices in Memcpy transformer (#9299)
This commit is contained in:
parent
48737091c0
commit
d5c5c4fa50
2 changed files with 200 additions and 7 deletions
|
|
@ -153,6 +153,35 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi
|
|||
modified = true;
|
||||
}
|
||||
|
||||
// Process implicit inputs in subgraphs that is explicitly consumed
|
||||
// on both provider and non-provider nodes. This is mimicking
|
||||
// logic for explicit graph inputs.
|
||||
if (graph_.IsSubgraph()) {
|
||||
for (auto arg : graph_.ParentNode()->ImplicitInputDefs()) {
|
||||
// Looking into `provider_input_defs_` and `non_provider_input_defs_`
|
||||
// using NodeArg pointers from the outer scope is okay because the
|
||||
// comparator is only name based (and doesn't compare raw pointers)
|
||||
if (provider_input_defs_.count(arg) && non_provider_input_defs_.count(arg)) {
|
||||
// There should be at-least one explicit consumer of the NodeArg
|
||||
// in both the provider node list and the non-provider node list.
|
||||
// If there are no explicit consumers in both lists, we don't want
|
||||
// to get into the business of adding copy nodes at this
|
||||
// level.
|
||||
// If there are explicit consumers in only one list (either provider
|
||||
// or non-provider node consumers), there isn't any point in adding
|
||||
// copy nodes in that case either as subgraph copy logic will take
|
||||
// it to the required device (i.e.) we don't need to care about it here.
|
||||
|
||||
// Be sure to use the NodeArg* relevant to the current graph level
|
||||
// (the name will be the same as the parent node's implicit input)
|
||||
const auto* node_arg_in_current_graph_level = *provider_input_defs_.find(arg);
|
||||
|
||||
AddCopyNode(const_cast<onnxruntime::NodeArg*>(node_arg_in_current_graph_level), true);
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return modified;
|
||||
}
|
||||
|
||||
|
|
@ -176,7 +205,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg
|
|||
}
|
||||
|
||||
// implicit inputs have no location info in the kernel def, so do nothing to them here, leaving the control
|
||||
// flow op (Loop, Scan, If) to do the necessary copy if the input crosses different provider.
|
||||
// flow op (Loop, Scan, If) to do the necessary copy if the input crosses different provider.
|
||||
// PlannerImpl::ComputeUseCounts has matching logic so the allocation plan does the same thing
|
||||
if (!is_implicit_input) {
|
||||
if (utils::IsInputOnCpu(node, kci, index)) {
|
||||
|
|
@ -220,10 +249,14 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg
|
|||
non_provider_input_defs_.insert(arg);
|
||||
}
|
||||
|
||||
for (const auto* arg : node.ImplicitInputDefs()) {
|
||||
if (arg->Exists())
|
||||
non_provider_input_defs_.insert(arg);
|
||||
}
|
||||
// Never add an implicit def to provider_input_defs_ or non_provider_input_defs_.
|
||||
// This is because we don't want to add copy nodes on account of implicit
|
||||
// inputs to nodes.
|
||||
// We will rely on utils::CopyInputsAcrossDevices() to do the job.
|
||||
//for (const auto* arg : node.ImplicitInputDefs()) {
|
||||
// if (arg->Exists())
|
||||
// non_provider_input_defs_.insert(arg);
|
||||
//}
|
||||
|
||||
for (auto* arg : node.MutableOutputDefs()) {
|
||||
if (arg->Exists())
|
||||
|
|
|
|||
|
|
@ -182,10 +182,16 @@ TEST(TransformerTest, MemcpyTransformerTestCudaFirst) {
|
|||
ExpectSame(node2, node4, 0);
|
||||
ExpectSame(node2, node4, 1);
|
||||
}
|
||||
TEST(TransformerTest, TestCopyNodeInsertionInitializerInSubgraph) {
|
||||
TEST(TransformerTest, TestInitializerDuplicationInSubgraph) {
|
||||
// In this test, we are going to create a subgraph consuming an implicit input
|
||||
// which is an initializer in the outer scope, and this implicit input to the subgraph
|
||||
// is consumed by nodes on multiple devices
|
||||
// is consumed by nodes on multiple devices.
|
||||
|
||||
// Since, the outer scope initializer is consumed on different devices in the subgraph,
|
||||
// a copy of the initializer is made in the subgraph to be provided to the provider (CUDA) node.
|
||||
// No explicit copy nodes are inserted in this scenario and hence we do not check for copy nodes.
|
||||
// Instead, we do check if the transformer modified the graph while processing the parent initializer
|
||||
// in the subgraph.
|
||||
TensorProto value_tensor;
|
||||
value_tensor.add_dims(1);
|
||||
value_tensor.add_float_data(1.f);
|
||||
|
|
@ -288,6 +294,160 @@ TEST(TransformerTest, TestCopyNodeInsertionInitializerInSubgraph) {
|
|||
EXPECT_TRUE(modified);
|
||||
}
|
||||
|
||||
TEST(TransformerTest, MemcpyTransformerTestGraphInputConsumedOnMultipleDevices) {
|
||||
// In this test, a graph input is consumed by 2 nodes partitioned to different devices.
|
||||
// We expect a copy node to get inserted to the provider (CUDA) node while consuming
|
||||
// the graph input.
|
||||
std::unordered_map<std::string, int> domain_to_version;
|
||||
domain_to_version[kOnnxDomain] = 7;
|
||||
auto model = std::make_shared<onnxruntime::Model>("test", false, ModelMetaData(), PathString(),
|
||||
IOnnxRuntimeOpSchemaRegistryList(),
|
||||
domain_to_version, std::vector<ONNX_NAMESPACE::FunctionProto>(),
|
||||
DefaultLoggingManager().DefaultLogger());
|
||||
onnxruntime::Graph& graph = model->MainGraph();
|
||||
|
||||
TypeProto tensor_float_type;
|
||||
tensor_float_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
onnxruntime::NodeArg i1_def("I1", &tensor_float_type),
|
||||
o1_def("O1", &tensor_float_type),
|
||||
o2_def("O2", &tensor_float_type);
|
||||
|
||||
// I1 is a graph input that is consumed by 2 MatMul nodes on different devices
|
||||
auto& node1 = graph.AddNode("node1", "MatMul", "cpu operator1", ArgMap{&i1_def, &i1_def}, ArgMap{&o1_def});
|
||||
node1.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider);
|
||||
auto& node2 = graph.AddNode("node2", "MatMul", "gpu operator1", ArgMap{&i1_def, &i1_def}, ArgMap{&o2_def});
|
||||
node2.SetExecutionProviderType(onnxruntime::kCudaExecutionProvider);
|
||||
|
||||
auto status = graph.Resolve();
|
||||
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
||||
KernelRegistryManager kernel_registry_manager;
|
||||
ExecutionProviders execution_providers;
|
||||
ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCudaExecutionProvider, DefaultCudaExecutionProvider()));
|
||||
ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCpuExecutionProvider,
|
||||
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo())));
|
||||
KernelRegistryManager test_registry_manager;
|
||||
ASSERT_STATUS_OK(test_registry_manager.RegisterKernels(execution_providers));
|
||||
|
||||
MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager);
|
||||
|
||||
bool modified = false;
|
||||
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
EXPECT_TRUE(modified);
|
||||
|
||||
auto op_count_map = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_count_map["MemcpyFromHost"] == 1);
|
||||
}
|
||||
|
||||
TEST(TransformerTest, MemcpyTransformerTestImplicitInputConsumedOnMultipleDevices) {
|
||||
// In this test, an implicit input (consumed by If subgraphs)
|
||||
// is consumed by 2 nodes partitioned to different devices.
|
||||
// We expect a copy node to get inserted to the provider (CUDA) node while consuming
|
||||
// the implicit input.
|
||||
std::unordered_map<std::string, int> domain_to_version;
|
||||
domain_to_version[kOnnxDomain] = 7;
|
||||
auto model = std::make_shared<onnxruntime::Model>("test", false, ModelMetaData(), PathString(),
|
||||
IOnnxRuntimeOpSchemaRegistryList(),
|
||||
domain_to_version, std::vector<ONNX_NAMESPACE::FunctionProto>(),
|
||||
DefaultLoggingManager().DefaultLogger());
|
||||
onnxruntime::Graph& graph = model->MainGraph();
|
||||
|
||||
std::unordered_map<std::string, int> subgraph_domain_to_version;
|
||||
subgraph_domain_to_version[kOnnxDomain] = 7;
|
||||
auto sub_model = std::make_shared<onnxruntime::Model>("test_subgraph",
|
||||
false,
|
||||
ModelMetaData(),
|
||||
PathString(),
|
||||
IOnnxRuntimeOpSchemaRegistryList(),
|
||||
subgraph_domain_to_version, std::vector<ONNX_NAMESPACE::FunctionProto>(),
|
||||
DefaultLoggingManager().DefaultLogger());
|
||||
onnxruntime::Graph& subgraph = sub_model->MainGraph();
|
||||
|
||||
TypeProto tensor_float_type;
|
||||
tensor_float_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
|
||||
TypeProto tensor_bool_type;
|
||||
tensor_bool_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_BOOL);
|
||||
|
||||
onnxruntime::NodeArg i1_def("I1", &tensor_bool_type),
|
||||
i2_def("I2", &tensor_float_type),
|
||||
o1_def("O1", &tensor_float_type),
|
||||
o2_def("O2", &tensor_float_type);
|
||||
|
||||
// I1 is a subgraph input that is consumed by 2 MatMul nodes on different devices
|
||||
auto& implicit_input_arg = graph.GetOrCreateNodeArg("I2", &tensor_float_type);
|
||||
subgraph.AddNode("node1", "MatMul", "cpu operator1", ArgMap{&implicit_input_arg, &implicit_input_arg}, ArgMap{&o1_def});
|
||||
subgraph.AddNode("node2", "MatMul", "gpu operator1", ArgMap{&implicit_input_arg, &implicit_input_arg}, ArgMap{&o2_def});
|
||||
|
||||
subgraph.AddOuterScopeNodeArg("I2");
|
||||
|
||||
auto status = subgraph.Resolve();
|
||||
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
||||
// Main graph continued
|
||||
TensorProto init;
|
||||
init.add_dims(1);
|
||||
init.add_int32_data(1);
|
||||
init.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_BOOL);
|
||||
init.set_name("I1");
|
||||
graph.AddInitializedTensor(init);
|
||||
|
||||
auto& if_node = graph.AddNode("node3", "If", "gpu operator", ArgMap{&i1_def}, ArgMap{&o1_def, &o2_def});
|
||||
|
||||
if_node.AddAttribute("then_branch", subgraph.ToGraphProto());
|
||||
if_node.AddAttribute("else_branch", subgraph.ToGraphProto());
|
||||
|
||||
graph.SetInputs({&i1_def, &i2_def});
|
||||
|
||||
onnxruntime::Graph* subgraph_1 = if_node.GetMutableGraphAttribute("then_branch");
|
||||
for (auto& node : subgraph_1->Nodes()) {
|
||||
if (node.Name() == "node2") {
|
||||
// only this node is on GPU
|
||||
node.SetExecutionProviderType(onnxruntime::kCudaExecutionProvider);
|
||||
} else {
|
||||
node.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider);
|
||||
}
|
||||
}
|
||||
|
||||
onnxruntime::Graph* subgraph_2 = if_node.GetMutableGraphAttribute("else_branch");
|
||||
for (auto& node : subgraph_2->Nodes()) {
|
||||
if (node.Name() == "node2") {
|
||||
// only this node is on GPU
|
||||
node.SetExecutionProviderType(onnxruntime::kCudaExecutionProvider);
|
||||
} else {
|
||||
node.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider);
|
||||
}
|
||||
}
|
||||
|
||||
status = graph.Resolve();
|
||||
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
||||
KernelRegistryManager kernel_registry_manager;
|
||||
ExecutionProviders execution_providers;
|
||||
ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCudaExecutionProvider, DefaultCudaExecutionProvider()));
|
||||
ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCpuExecutionProvider,
|
||||
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo())));
|
||||
KernelRegistryManager test_registry_manager;
|
||||
ASSERT_STATUS_OK(test_registry_manager.RegisterKernels(execution_providers));
|
||||
|
||||
MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager);
|
||||
|
||||
bool modified = false;
|
||||
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
EXPECT_TRUE(modified);
|
||||
|
||||
// We expect to see copy nodes inserted in each of the subgraphs
|
||||
// because an implicit input is consumed both by provider (CUDA) and
|
||||
// non-provider (CPU) nodes.
|
||||
auto op_count_map = CountOpsInGraph(*subgraph_1);
|
||||
ASSERT_TRUE(op_count_map["MemcpyFromHost"] == 1);
|
||||
|
||||
op_count_map = CountOpsInGraph(*subgraph_2);
|
||||
ASSERT_TRUE(op_count_map["MemcpyFromHost"] == 1);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace test
|
||||
|
|
|
|||
Loading…
Reference in a new issue