From 80f686e055e90bd7c57978cee0583da71cd4382b Mon Sep 17 00:00:00 2001 From: Peishen Yan Date: Fri, 17 Jan 2025 05:26:22 +0800 Subject: [PATCH] [WebNN EP] Optimize model partitioning (#23332) ### Description The old `GetCapability` function of WebNN EP is just a very simple search for groups of nodes that can be handled. This doesn't work well in the following example graph, where A and D could be handled by the EP, but B is between them in the topological order, as you get two single node capabilities. However, it may also be advantageous if C and E could be handled by the EP, since they would be combined with D even though they are not connected. ``` A B C | / | D E | | ``` Therefore, we improve partitioning results by reusing `utils::CreateSupportedPartitions`, which walks the edges for each node that the EP can handle as they are iterated in topological order. This would guarantee that all connected nodes that can be handled are grouped together. Correspondingly, we modify the `webnn::GetSupportedNodes` function to return the supported nodes instead of the group of supported partitions. ### Motivation and Context Co-authored-by: Dwayne Robinson --- .../core/providers/webnn/builders/helper.cc | 42 ++---- .../core/providers/webnn/builders/helper.h | 12 +- .../webnn/webnn_execution_provider.cc | 122 ++++++------------ 3 files changed, 60 insertions(+), 116 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 45a8796012..e5124a90df 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -99,44 +99,30 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n return true; } -std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, - const emscripten::val& wnn_builder, - const WebnnDeviceType device_type, - const emscripten::val& wnn_limits, - const logging::Logger& logger) { - std::vector> supported_node_groups; - std::vector supported_node_group; - const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); +std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, + const emscripten::val& wnn_builder, + const WebnnDeviceType device_type, + const emscripten::val& wnn_limits, + const logging::Logger& logger) { + std::unordered_set supported_nodes; - for (size_t i = 0; i < node_indices.size(); i++) { - auto node_idx = node_indices[i]; - const auto* node(graph_viewer.GetNode(node_idx)); + for (const auto& node : graph_viewer.Nodes()) { bool supported = false; // Firstly check if platform supports the WebNN op. - if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) { - supported = IsNodeSupported(*node, graph_viewer, device_type, wnn_limits, logger); + if (CheckSingleOp(node.OpType(), wnn_builder, device_type)) { + supported = IsNodeSupported(node, graph_viewer, device_type, wnn_limits, logger); } - - LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() - << "] index: [" << node_idx - << "] name: [" << node->Name() + LOGS(logger, VERBOSE) << "Operator type: [" << node.OpType() + << "] index: [" << node.Index() + << "] name: [" << node.Name() << "] supported: [" << supported << "]"; if (supported) { - supported_node_group.push_back(node_idx); - } else { - if (!supported_node_group.empty()) { - supported_node_groups.push_back(supported_node_group); - supported_node_group.clear(); - } + supported_nodes.insert(&node); } } - if (!supported_node_group.empty()) { - supported_node_groups.push_back(supported_node_group); - } - - return supported_node_groups; + return supported_nodes; } bool AreInputDataTypesSame(const std::string& op_type, diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index c4e73809e5..27607ddb4d 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -188,12 +188,12 @@ inline bool TensorExists(const ConstPointerContainer>& def bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger, bool allow_empty_input = false); -// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP. -std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, - const emscripten::val& wnn_builder, - const WebnnDeviceType device_type, - const emscripten::val& wnn_limits, - const logging::Logger& logger); +// Get a set of nodes supported by WebNN EP. +std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, + const emscripten::val& wnn_builder, + const WebnnDeviceType device_type, + const emscripten::val& wnn_limits, + const logging::Logger& logger); // TODO(@Honry): Some ONNX ops are supported by decomposed WebNN ops, // we need to check the support of the decomposed ops. static const InlinedHashMap op_map = { diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 1a337e185b..00fbb26b73 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -13,6 +13,9 @@ #include "core/common/safeint.h" #include "core/providers/webnn/allocator.h" #include "core/providers/webnn/data_transfer.h" +#include "core/providers/partitioning_utils.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" #include "builders/model.h" #include "builders/helper.h" @@ -20,6 +23,8 @@ namespace onnxruntime { +constexpr const char* WEBNN = "WEBNN"; + WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags) : IExecutionProvider{ onnxruntime::kWebNNExecutionProvider, @@ -51,8 +56,6 @@ WebNNExecutionProvider::~WebNNExecutionProvider() {} std::vector> WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_registries*/) const { - std::vector> result; - // For subgraph which is the attribute of the control flow nodes, part of its initializers are stored in its // ancestor graphs as common initializers shared for other subgraphs. We need to collect all of them used for // identifying the required initializer names and storing into 'meta_def->constant_initializers'. @@ -64,23 +67,6 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view all_initializers = webnn::CollectAllInitializedTensors(graph_viewer); } - /* - Very basic search for groups of nodes that can be handled by the EP. - This doesn't work perfectly if you have a scenario like the following where A and D could be handled by the EP - but B is between them in the topological sort as you'll get two single node capabilities. However if can also - be advantageous if C and E could be handled by the EP as they would be combined with D even though not connected. - Not sure how often each of these scenarios happens. - - A B C - | / | - D E - | | - - Would probably be better to walk the edges for each node the EP can handle as they are iterated in topological order, - accumulating nodes (and saving which ones have been taken) until you run out. This would guarantee all - connected nodes that can be handled are grouped together. - */ - const auto& logger = *GetLogger(); emscripten::val wnn_builder = emscripten::val::global("MLGraphBuilder").new_(wnn_context_); @@ -88,43 +74,37 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view ORT_THROW("Failed to create WebNN builder."); } - const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, wnn_limits_, logger); - wnn_builder = emscripten::val::undefined(); + // Get all the NodeUnits in the graph_viewer + std::vector> node_unit_holder; + std::unordered_map node_unit_map; - if (node_groups.empty()) { - return result; - } + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); + + const auto supported_nodes = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, wnn_limits_, logger); + + const auto gen_metadef_name = [&]() { + HashValue model_hash; + int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); + return MakeString(WEBNN, "_", model_hash, "_", metadef_id); + }; + + auto result = utils::CreateSupportedPartitions(graph_viewer, supported_nodes, {}, + gen_metadef_name, WEBNN, kWebNNExecutionProvider, + &node_unit_map, /*drop_constant_initializers*/ true); + + // Release wnn_builder + wnn_builder = emscripten::val::undefined(); const auto& graph_output_list = graph_viewer.GetOutputs(); InlinedHashSet graph_outputs(graph_output_list.cbegin(), graph_output_list.cend()); - size_t num_of_supported_nodes = 0; - for (const auto& group : node_groups) { - if (group.empty()) + for (auto& capability : result) { + auto& sub_graph = capability->sub_graph; + if (sub_graph->nodes.empty()) continue; - num_of_supported_nodes += group.size(); - LOGS(logger, VERBOSE) << "WebNNExecutionProvider::GetCapability, current supported node group size: " - << group.size(); - - InlinedHashSet node_set; - node_set.reserve(group.size()); - for (const auto& index : group) { - node_set.insert(index); - } - - std::unique_ptr sub_graph = std::make_unique(); - std::vector subgraph_initializers; - InlinedHashSet node_outputs; - InlinedHashSet subgraph_inputs; - InlinedHashSet subgraph_outputs; - std::vector ordered_subgraph_inputs; - // Output should be unique. It may be produced as graph output and subgraph output. - InlinedHashSet ordered_subgraph_outputs; - - for (const auto& index : group) { - sub_graph->nodes.push_back(index); + for (const auto& index : sub_graph->nodes) { const auto* node = graph_viewer.GetNode(index); for (const auto* input : node->InputDefs()) { @@ -136,39 +116,13 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view if (is_subgraph && Contains(all_initializers, input->Name())) { subgraph_initializers.push_back(input->Name()); } - // If the node input was not produced by this subgraph, add it to the subgraph inputs. - if (node_outputs.count(input) == 0) { - if (subgraph_inputs.count(input) == 0) { - subgraph_inputs.insert(input); - ordered_subgraph_inputs.push_back(input); - } - } - } - - const auto& output_defs = node->OutputDefs(); - for (const auto* output_def : output_defs) { - node_outputs.insert(output_def); - // if output is overall graph output we need to produce it. - if (graph_outputs.count(output_def) != 0) { - ordered_subgraph_outputs.insert(output_def); - } - } - - // if output connects to a node not in this subgraph we need to produce it. - for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) { - if (node_set.count(it->GetNode().Index()) == 0) { - const auto* output_def = output_defs[it->GetSrcArgIndex()]; - if (subgraph_outputs.count(output_def) == 0) { - subgraph_outputs.insert(output_def); - ordered_subgraph_outputs.insert(output_def); - } - } } } // Assign inputs and outputs to subgraph's meta_def. uint64_t model_hash; int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); + const auto meta_def_old = sub_graph->GetMetaDef(); auto meta_def = std::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>(); meta_def->name = "WEBNN_" + std::to_string(model_hash) + "_" + std::to_string(metadef_id); meta_def->domain = kMSDomain; @@ -181,20 +135,24 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view } } - for (const auto& input : ordered_subgraph_inputs) { - meta_def->inputs.push_back(input->Name()); + for (const auto& input : meta_def_old->inputs) { + meta_def->inputs.push_back(input); } - for (const auto& output : ordered_subgraph_outputs) { - meta_def->outputs.push_back(output->Name()); + for (const auto& output : meta_def_old->outputs) { + meta_def->outputs.push_back(output); } sub_graph->SetMetaDef(std::move(meta_def)); - - result.push_back(std::make_unique(std::move(sub_graph))); } - auto num_of_partitions = result.size(); + const auto num_of_partitions = result.size(); + const auto num_of_supported_nodes = std::accumulate( + result.begin(), result.end(), size_t{0}, + [](const auto& acc, const auto& partition) -> size_t { + return acc + (partition && partition->sub_graph ? partition->sub_graph->nodes.size() : 0); + }); + const auto summary_msg = MakeString( "WebNNExecutionProvider::GetCapability,", " number of partitions supported by WebNN: ", num_of_partitions,