mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
Kezhan/partition logic update (#164)
* add check before fusing sub-graph in greedy partitioning * update the partitioning logic to 1) not fuse sub-graph if inner nodes were assigned 2) avoid resolving graph after each provider capability checking and assignment. * resolve conflicts
This commit is contained in:
parent
5a8acd7da8
commit
82d04412a0
1 changed files with 46 additions and 20 deletions
|
|
@ -55,6 +55,11 @@ KernelDefBuilder& BuildFusedKernelDef(KernelDefBuilder& builder, const onnxrunti
|
|||
}
|
||||
|
||||
Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const {
|
||||
// It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now.
|
||||
// 1. Execution providers' capabilities are checked one by one.
|
||||
// 2. All sub-graphs that an execution provider returns will be assigned to it if it's not assigned yet.
|
||||
// 3. CPU execution provider is expected to be able to run any node and is the last one in execution provider preference.
|
||||
|
||||
if (providers_.Empty()) {
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "No provider specified.");
|
||||
}
|
||||
|
|
@ -63,10 +68,17 @@ Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const {
|
|||
std::shared_ptr<KernelRegistry> fused_kernel_registry = std::make_shared<KernelRegistry>();
|
||||
// Partitioning <graph> based on provider preference and their capabilities.
|
||||
auto kernel_registries = kernel_registry_mgr_.GetAllKernelRegistries();
|
||||
|
||||
std::vector<std::vector<std::unique_ptr<ComputeCapability>>> capabilities_of_all_providers;
|
||||
GraphViewer graph_viewer(graph);
|
||||
for (auto& provider : providers_) {
|
||||
capabilities_of_all_providers.push_back(provider->GetCapability(graph_viewer, kernel_registries));
|
||||
}
|
||||
|
||||
int i = 0;
|
||||
for (auto& provider : providers_) {
|
||||
auto capability_results = provider->GetCapability(GraphViewer(graph), kernel_registries);
|
||||
int count = 0;
|
||||
for (auto& capability : capability_results) {
|
||||
for (auto& capability : capabilities_of_all_providers[i++]) {
|
||||
if (nullptr == capability || nullptr == capability->sub_graph) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -78,30 +90,44 @@ Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const {
|
|||
|
||||
auto node = graph.GetNode(capability->sub_graph->nodes[0]);
|
||||
if (nullptr != node && node->GetExecutionProviderType().empty()) {
|
||||
// The node was not fused or assigned. Assign it to this <provider>.
|
||||
node->SetExecutionProviderType(provider->Type());
|
||||
}
|
||||
} else {
|
||||
// The <provider> can run a fused <sub_graph> in the <graph>.
|
||||
//
|
||||
// Add fused node into <graph>
|
||||
ORT_ENFORCE(nullptr != capability->sub_graph->GetMetaDef());
|
||||
std::string node_name = provider->Type() + "_" + capability->sub_graph->GetMetaDef()->name + "_" + std::to_string(count++);
|
||||
auto& fused_node = graph.FuseSubGraph(std::move(capability->sub_graph), node_name);
|
||||
fused_node.SetExecutionProviderType(provider->Type());
|
||||
auto fused_kernel_func = capability->fuse_kernel_function;
|
||||
if (fused_kernel_func != nullptr) {
|
||||
// build the kernel definition on the fly, and register it to the fused_kernel_regisitry.
|
||||
KernelDefBuilder builder;
|
||||
BuildFusedKernelDef(builder, fused_node);
|
||||
fused_kernel_registry->Register(builder, fused_kernel_func);
|
||||
|
||||
// Check whether any node in the <sub_graph> was already assigned.
|
||||
bool sub_graph_available_for_assignment = true;
|
||||
for (auto node_index : capability->sub_graph->nodes) {
|
||||
auto node = graph.GetNode(node_index);
|
||||
if (nullptr == node || !node->GetExecutionProviderType().empty()) {
|
||||
// The node was fused or assigned, so that the whole sub-graph will not be assigned to this <provider>
|
||||
// The assumption is that this <provider> can only run the sub-graph as a whole unit.
|
||||
sub_graph_available_for_assignment = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (sub_graph_available_for_assignment) {
|
||||
// Add fused node into <graph>
|
||||
std::string node_name = provider->Type() + "_" + capability->sub_graph->GetMetaDef()->name + "_" + std::to_string(count++);
|
||||
auto& fused_node = graph.FuseSubGraph(std::move(capability->sub_graph), node_name);
|
||||
fused_node.SetExecutionProviderType(provider->Type());
|
||||
auto fused_kernel_func = capability->fuse_kernel_function;
|
||||
if (fused_kernel_func != nullptr) {
|
||||
// build the kernel definition on the fly, and register it to the fused_kernel_regisitry.
|
||||
KernelDefBuilder builder;
|
||||
BuildFusedKernelDef(builder, fused_node);
|
||||
fused_kernel_registry->Register(builder, fused_kernel_func);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// all done with this provider, resolve the graph before we move on to the next provider.
|
||||
// This is needed since we create a new GraphViewer() that we pass into the next provider's GetCapability().
|
||||
ORT_ENFORCE(graph.Resolve().IsOK());
|
||||
}
|
||||
|
||||
ORT_ENFORCE(graph.Resolve().IsOK());
|
||||
|
||||
// To see if the node with no provider can be inlined. If one such nodes can be
|
||||
// successfully inlined, we re-run the partitioner on the modified graph.
|
||||
bool inline_flag = false;
|
||||
|
|
@ -126,10 +152,10 @@ Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const {
|
|||
this->Partition(graph);
|
||||
}
|
||||
|
||||
//For some cases, like fp16 on cpu, right now we don't have any kernel support that.
|
||||
//But we will insert cast op to run the model, so skip the error checking here.
|
||||
//If after graph transform phase, the node still not assigned, we will report error
|
||||
//during kernel creation phase.
|
||||
//For some cases, like fp16 on cpu, right now we don't have any kernel support that.
|
||||
//But we will insert cast op to run the model, so skip the error checking here.
|
||||
//If after graph transform phase, the node still not assigned, we will report error
|
||||
//during kernel creation phase.
|
||||
#ifdef COUNT_NON_CUDA_OPS
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.GetExecutionProviderType() != kCudaExecutionProvider &&
|
||||
|
|
|
|||
Loading…
Reference in a new issue