Kezhan/partition logic update (#164)

* add check before fusing sub-graph in greedy partitioning

* update the partitioning logic to 1) not fuse sub-graph if inner nodes were assigned 2) avoid resolving graph after each provider capability checking and assignment.

* resolve conflicts
This commit is contained in:
Ke Zhang 2018-12-17 13:30:29 -08:00 committed by GitHub
parent 5a8acd7da8
commit 82d04412a0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -55,6 +55,11 @@ KernelDefBuilder& BuildFusedKernelDef(KernelDefBuilder& builder, const onnxrunti
}
Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const {
// It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now.
// 1. Execution providers' capabilities are checked one by one.
// 2. All sub-graphs that an execution provider returns will be assigned to it if it's not assigned yet.
// 3. CPU execution provider is expected to be able to run any node and is the last one in execution provider preference.
if (providers_.Empty()) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "No provider specified.");
}
@ -63,10 +68,17 @@ Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const {
std::shared_ptr<KernelRegistry> fused_kernel_registry = std::make_shared<KernelRegistry>();
// Partitioning <graph> based on provider preference and their capabilities.
auto kernel_registries = kernel_registry_mgr_.GetAllKernelRegistries();
std::vector<std::vector<std::unique_ptr<ComputeCapability>>> capabilities_of_all_providers;
GraphViewer graph_viewer(graph);
for (auto& provider : providers_) {
capabilities_of_all_providers.push_back(provider->GetCapability(graph_viewer, kernel_registries));
}
int i = 0;
for (auto& provider : providers_) {
auto capability_results = provider->GetCapability(GraphViewer(graph), kernel_registries);
int count = 0;
for (auto& capability : capability_results) {
for (auto& capability : capabilities_of_all_providers[i++]) {
if (nullptr == capability || nullptr == capability->sub_graph) {
continue;
}
@ -78,30 +90,44 @@ Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const {
auto node = graph.GetNode(capability->sub_graph->nodes[0]);
if (nullptr != node && node->GetExecutionProviderType().empty()) {
// The node was not fused or assigned. Assign it to this <provider>.
node->SetExecutionProviderType(provider->Type());
}
} else {
// The <provider> can run a fused <sub_graph> in the <graph>.
//
// Add fused node into <graph>
ORT_ENFORCE(nullptr != capability->sub_graph->GetMetaDef());
std::string node_name = provider->Type() + "_" + capability->sub_graph->GetMetaDef()->name + "_" + std::to_string(count++);
auto& fused_node = graph.FuseSubGraph(std::move(capability->sub_graph), node_name);
fused_node.SetExecutionProviderType(provider->Type());
auto fused_kernel_func = capability->fuse_kernel_function;
if (fused_kernel_func != nullptr) {
// build the kernel definition on the fly, and register it to the fused_kernel_regisitry.
KernelDefBuilder builder;
BuildFusedKernelDef(builder, fused_node);
fused_kernel_registry->Register(builder, fused_kernel_func);
// Check whether any node in the <sub_graph> was already assigned.
bool sub_graph_available_for_assignment = true;
for (auto node_index : capability->sub_graph->nodes) {
auto node = graph.GetNode(node_index);
if (nullptr == node || !node->GetExecutionProviderType().empty()) {
// The node was fused or assigned, so that the whole sub-graph will not be assigned to this <provider>
// The assumption is that this <provider> can only run the sub-graph as a whole unit.
sub_graph_available_for_assignment = false;
break;
}
}
if (sub_graph_available_for_assignment) {
// Add fused node into <graph>
std::string node_name = provider->Type() + "_" + capability->sub_graph->GetMetaDef()->name + "_" + std::to_string(count++);
auto& fused_node = graph.FuseSubGraph(std::move(capability->sub_graph), node_name);
fused_node.SetExecutionProviderType(provider->Type());
auto fused_kernel_func = capability->fuse_kernel_function;
if (fused_kernel_func != nullptr) {
// build the kernel definition on the fly, and register it to the fused_kernel_regisitry.
KernelDefBuilder builder;
BuildFusedKernelDef(builder, fused_node);
fused_kernel_registry->Register(builder, fused_kernel_func);
}
}
}
}
// all done with this provider, resolve the graph before we move on to the next provider.
// This is needed since we create a new GraphViewer() that we pass into the next provider's GetCapability().
ORT_ENFORCE(graph.Resolve().IsOK());
}
ORT_ENFORCE(graph.Resolve().IsOK());
// To see if the node with no provider can be inlined. If one such nodes can be
// successfully inlined, we re-run the partitioner on the modified graph.
bool inline_flag = false;
@ -126,10 +152,10 @@ Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const {
this->Partition(graph);
}
//For some cases, like fp16 on cpu, right now we don't have any kernel support that.
//But we will insert cast op to run the model, so skip the error checking here.
//If after graph transform phase, the node still not assigned, we will report error
//during kernel creation phase.
//For some cases, like fp16 on cpu, right now we don't have any kernel support that.
//But we will insert cast op to run the model, so skip the error checking here.
//If after graph transform phase, the node still not assigned, we will report error
//during kernel creation phase.
#ifdef COUNT_NON_CUDA_OPS
for (auto& node : graph.Nodes()) {
if (node.GetExecutionProviderType() != kCudaExecutionProvider &&