diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 45a8796012..e5124a90df 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -99,44 +99,30 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n return true; } -std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, - const emscripten::val& wnn_builder, - const WebnnDeviceType device_type, - const emscripten::val& wnn_limits, - const logging::Logger& logger) { - std::vector> supported_node_groups; - std::vector supported_node_group; - const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); +std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, + const emscripten::val& wnn_builder, + const WebnnDeviceType device_type, + const emscripten::val& wnn_limits, + const logging::Logger& logger) { + std::unordered_set supported_nodes; - for (size_t i = 0; i < node_indices.size(); i++) { - auto node_idx = node_indices[i]; - const auto* node(graph_viewer.GetNode(node_idx)); + for (const auto& node : graph_viewer.Nodes()) { bool supported = false; // Firstly check if platform supports the WebNN op. - if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) { - supported = IsNodeSupported(*node, graph_viewer, device_type, wnn_limits, logger); + if (CheckSingleOp(node.OpType(), wnn_builder, device_type)) { + supported = IsNodeSupported(node, graph_viewer, device_type, wnn_limits, logger); } - - LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() - << "] index: [" << node_idx - << "] name: [" << node->Name() + LOGS(logger, VERBOSE) << "Operator type: [" << node.OpType() + << "] index: [" << node.Index() + << "] name: [" << node.Name() << "] supported: [" << supported << "]"; if (supported) { - supported_node_group.push_back(node_idx); - } else { - if (!supported_node_group.empty()) { - supported_node_groups.push_back(supported_node_group); - supported_node_group.clear(); - } + supported_nodes.insert(&node); } } - if (!supported_node_group.empty()) { - supported_node_groups.push_back(supported_node_group); - } - - return supported_node_groups; + return supported_nodes; } bool AreInputDataTypesSame(const std::string& op_type, diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index c4e73809e5..27607ddb4d 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -188,12 +188,12 @@ inline bool TensorExists(const ConstPointerContainer>& def bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger, bool allow_empty_input = false); -// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP. -std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, - const emscripten::val& wnn_builder, - const WebnnDeviceType device_type, - const emscripten::val& wnn_limits, - const logging::Logger& logger); +// Get a set of nodes supported by WebNN EP. +std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, + const emscripten::val& wnn_builder, + const WebnnDeviceType device_type, + const emscripten::val& wnn_limits, + const logging::Logger& logger); // TODO(@Honry): Some ONNX ops are supported by decomposed WebNN ops, // we need to check the support of the decomposed ops. static const InlinedHashMap op_map = { diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 1a337e185b..00fbb26b73 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -13,6 +13,9 @@ #include "core/common/safeint.h" #include "core/providers/webnn/allocator.h" #include "core/providers/webnn/data_transfer.h" +#include "core/providers/partitioning_utils.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" #include "builders/model.h" #include "builders/helper.h" @@ -20,6 +23,8 @@ namespace onnxruntime { +constexpr const char* WEBNN = "WEBNN"; + WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags) : IExecutionProvider{ onnxruntime::kWebNNExecutionProvider, @@ -51,8 +56,6 @@ WebNNExecutionProvider::~WebNNExecutionProvider() {} std::vector> WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_registries*/) const { - std::vector> result; - // For subgraph which is the attribute of the control flow nodes, part of its initializers are stored in its // ancestor graphs as common initializers shared for other subgraphs. We need to collect all of them used for // identifying the required initializer names and storing into 'meta_def->constant_initializers'. @@ -64,23 +67,6 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view all_initializers = webnn::CollectAllInitializedTensors(graph_viewer); } - /* - Very basic search for groups of nodes that can be handled by the EP. - This doesn't work perfectly if you have a scenario like the following where A and D could be handled by the EP - but B is between them in the topological sort as you'll get two single node capabilities. However if can also - be advantageous if C and E could be handled by the EP as they would be combined with D even though not connected. - Not sure how often each of these scenarios happens. - - A B C - | / | - D E - | | - - Would probably be better to walk the edges for each node the EP can handle as they are iterated in topological order, - accumulating nodes (and saving which ones have been taken) until you run out. This would guarantee all - connected nodes that can be handled are grouped together. - */ - const auto& logger = *GetLogger(); emscripten::val wnn_builder = emscripten::val::global("MLGraphBuilder").new_(wnn_context_); @@ -88,43 +74,37 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view ORT_THROW("Failed to create WebNN builder."); } - const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, wnn_limits_, logger); - wnn_builder = emscripten::val::undefined(); + // Get all the NodeUnits in the graph_viewer + std::vector> node_unit_holder; + std::unordered_map node_unit_map; - if (node_groups.empty()) { - return result; - } + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); + + const auto supported_nodes = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, wnn_limits_, logger); + + const auto gen_metadef_name = [&]() { + HashValue model_hash; + int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); + return MakeString(WEBNN, "_", model_hash, "_", metadef_id); + }; + + auto result = utils::CreateSupportedPartitions(graph_viewer, supported_nodes, {}, + gen_metadef_name, WEBNN, kWebNNExecutionProvider, + &node_unit_map, /*drop_constant_initializers*/ true); + + // Release wnn_builder + wnn_builder = emscripten::val::undefined(); const auto& graph_output_list = graph_viewer.GetOutputs(); InlinedHashSet graph_outputs(graph_output_list.cbegin(), graph_output_list.cend()); - size_t num_of_supported_nodes = 0; - for (const auto& group : node_groups) { - if (group.empty()) + for (auto& capability : result) { + auto& sub_graph = capability->sub_graph; + if (sub_graph->nodes.empty()) continue; - num_of_supported_nodes += group.size(); - LOGS(logger, VERBOSE) << "WebNNExecutionProvider::GetCapability, current supported node group size: " - << group.size(); - - InlinedHashSet node_set; - node_set.reserve(group.size()); - for (const auto& index : group) { - node_set.insert(index); - } - - std::unique_ptr sub_graph = std::make_unique(); - std::vector subgraph_initializers; - InlinedHashSet node_outputs; - InlinedHashSet subgraph_inputs; - InlinedHashSet subgraph_outputs; - std::vector ordered_subgraph_inputs; - // Output should be unique. It may be produced as graph output and subgraph output. - InlinedHashSet ordered_subgraph_outputs; - - for (const auto& index : group) { - sub_graph->nodes.push_back(index); + for (const auto& index : sub_graph->nodes) { const auto* node = graph_viewer.GetNode(index); for (const auto* input : node->InputDefs()) { @@ -136,39 +116,13 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view if (is_subgraph && Contains(all_initializers, input->Name())) { subgraph_initializers.push_back(input->Name()); } - // If the node input was not produced by this subgraph, add it to the subgraph inputs. - if (node_outputs.count(input) == 0) { - if (subgraph_inputs.count(input) == 0) { - subgraph_inputs.insert(input); - ordered_subgraph_inputs.push_back(input); - } - } - } - - const auto& output_defs = node->OutputDefs(); - for (const auto* output_def : output_defs) { - node_outputs.insert(output_def); - // if output is overall graph output we need to produce it. - if (graph_outputs.count(output_def) != 0) { - ordered_subgraph_outputs.insert(output_def); - } - } - - // if output connects to a node not in this subgraph we need to produce it. - for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) { - if (node_set.count(it->GetNode().Index()) == 0) { - const auto* output_def = output_defs[it->GetSrcArgIndex()]; - if (subgraph_outputs.count(output_def) == 0) { - subgraph_outputs.insert(output_def); - ordered_subgraph_outputs.insert(output_def); - } - } } } // Assign inputs and outputs to subgraph's meta_def. uint64_t model_hash; int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); + const auto meta_def_old = sub_graph->GetMetaDef(); auto meta_def = std::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>(); meta_def->name = "WEBNN_" + std::to_string(model_hash) + "_" + std::to_string(metadef_id); meta_def->domain = kMSDomain; @@ -181,20 +135,24 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view } } - for (const auto& input : ordered_subgraph_inputs) { - meta_def->inputs.push_back(input->Name()); + for (const auto& input : meta_def_old->inputs) { + meta_def->inputs.push_back(input); } - for (const auto& output : ordered_subgraph_outputs) { - meta_def->outputs.push_back(output->Name()); + for (const auto& output : meta_def_old->outputs) { + meta_def->outputs.push_back(output); } sub_graph->SetMetaDef(std::move(meta_def)); - - result.push_back(std::make_unique(std::move(sub_graph))); } - auto num_of_partitions = result.size(); + const auto num_of_partitions = result.size(); + const auto num_of_supported_nodes = std::accumulate( + result.begin(), result.end(), size_t{0}, + [](const auto& acc, const auto& partition) -> size_t { + return acc + (partition && partition->sub_graph ? partition->sub_graph->nodes.size() : 0); + }); + const auto summary_msg = MakeString( "WebNNExecutionProvider::GetCapability,", " number of partitions supported by WebNN: ", num_of_partitions,