mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
[WebNN EP] Optimize model partitioning (#23332)
### Description
<!-- Describe your changes. -->
The old `GetCapability` function of WebNN EP is just a very simple
search for groups of nodes that can be handled. This doesn't work well
in the following example graph, where A and D could be handled by the
EP, but B is between them in the topological order, as you get two
single node capabilities. However, it may also be advantageous if C and
E could be handled by the EP, since they would be combined with D even
though they are not connected.
```
A B C
| / |
D E
| |
```
Therefore, we improve partitioning results by reusing
`utils::CreateSupportedPartitions`, which walks the edges for each node
that the EP can handle as they are iterated in topological order. This
would guarantee that all connected nodes that can be handled are grouped
together. Correspondingly, we modify the `webnn::GetSupportedNodes`
function to return the supported nodes instead of the group of supported
partitions.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Co-authored-by: Dwayne Robinson <fdwr@hotmail.com>
This commit is contained in:
parent
5735e1bce0
commit
80f686e055
3 changed files with 60 additions and 116 deletions
|
|
@ -99,44 +99,30 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n
|
|||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
|
||||
const emscripten::val& wnn_builder,
|
||||
const WebnnDeviceType device_type,
|
||||
const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) {
|
||||
std::vector<std::vector<size_t>> supported_node_groups;
|
||||
std::vector<size_t> supported_node_group;
|
||||
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
|
||||
std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewer,
|
||||
const emscripten::val& wnn_builder,
|
||||
const WebnnDeviceType device_type,
|
||||
const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) {
|
||||
std::unordered_set<const Node*> supported_nodes;
|
||||
|
||||
for (size_t i = 0; i < node_indices.size(); i++) {
|
||||
auto node_idx = node_indices[i];
|
||||
const auto* node(graph_viewer.GetNode(node_idx));
|
||||
for (const auto& node : graph_viewer.Nodes()) {
|
||||
bool supported = false;
|
||||
// Firstly check if platform supports the WebNN op.
|
||||
if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) {
|
||||
supported = IsNodeSupported(*node, graph_viewer, device_type, wnn_limits, logger);
|
||||
if (CheckSingleOp(node.OpType(), wnn_builder, device_type)) {
|
||||
supported = IsNodeSupported(node, graph_viewer, device_type, wnn_limits, logger);
|
||||
}
|
||||
|
||||
LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType()
|
||||
<< "] index: [" << node_idx
|
||||
<< "] name: [" << node->Name()
|
||||
LOGS(logger, VERBOSE) << "Operator type: [" << node.OpType()
|
||||
<< "] index: [" << node.Index()
|
||||
<< "] name: [" << node.Name()
|
||||
<< "] supported: [" << supported
|
||||
<< "]";
|
||||
if (supported) {
|
||||
supported_node_group.push_back(node_idx);
|
||||
} else {
|
||||
if (!supported_node_group.empty()) {
|
||||
supported_node_groups.push_back(supported_node_group);
|
||||
supported_node_group.clear();
|
||||
}
|
||||
supported_nodes.insert(&node);
|
||||
}
|
||||
}
|
||||
|
||||
if (!supported_node_group.empty()) {
|
||||
supported_node_groups.push_back(supported_node_group);
|
||||
}
|
||||
|
||||
return supported_node_groups;
|
||||
return supported_nodes;
|
||||
}
|
||||
|
||||
bool AreInputDataTypesSame(const std::string& op_type,
|
||||
|
|
|
|||
|
|
@ -188,12 +188,12 @@ inline bool TensorExists(const ConstPointerContainer<std::vector<NodeArg*>>& def
|
|||
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name,
|
||||
const logging::Logger& logger, bool allow_empty_input = false);
|
||||
|
||||
// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP.
|
||||
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
|
||||
const emscripten::val& wnn_builder,
|
||||
const WebnnDeviceType device_type,
|
||||
const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger);
|
||||
// Get a set of nodes supported by WebNN EP.
|
||||
std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewer,
|
||||
const emscripten::val& wnn_builder,
|
||||
const WebnnDeviceType device_type,
|
||||
const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger);
|
||||
// TODO(@Honry): Some ONNX ops are supported by decomposed WebNN ops,
|
||||
// we need to check the support of the decomposed ops.
|
||||
static const InlinedHashMap<std::string, std::string> op_map = {
|
||||
|
|
|
|||
|
|
@ -13,6 +13,9 @@
|
|||
#include "core/common/safeint.h"
|
||||
#include "core/providers/webnn/allocator.h"
|
||||
#include "core/providers/webnn/data_transfer.h"
|
||||
#include "core/providers/partitioning_utils.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
|
||||
|
||||
#include "builders/model.h"
|
||||
#include "builders/helper.h"
|
||||
|
|
@ -20,6 +23,8 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
constexpr const char* WEBNN = "WEBNN";
|
||||
|
||||
WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags)
|
||||
: IExecutionProvider{
|
||||
onnxruntime::kWebNNExecutionProvider,
|
||||
|
|
@ -51,8 +56,6 @@ WebNNExecutionProvider::~WebNNExecutionProvider() {}
|
|||
std::vector<std::unique_ptr<ComputeCapability>>
|
||||
WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer,
|
||||
const IKernelLookup& /*kernel_registries*/) const {
|
||||
std::vector<std::unique_ptr<ComputeCapability>> result;
|
||||
|
||||
// For subgraph which is the attribute of the control flow nodes, part of its initializers are stored in its
|
||||
// ancestor graphs as common initializers shared for other subgraphs. We need to collect all of them used for
|
||||
// identifying the required initializer names and storing into 'meta_def->constant_initializers'.
|
||||
|
|
@ -64,23 +67,6 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
|
|||
all_initializers = webnn::CollectAllInitializedTensors(graph_viewer);
|
||||
}
|
||||
|
||||
/*
|
||||
Very basic search for groups of nodes that can be handled by the EP.
|
||||
This doesn't work perfectly if you have a scenario like the following where A and D could be handled by the EP
|
||||
but B is between them in the topological sort as you'll get two single node capabilities. However if can also
|
||||
be advantageous if C and E could be handled by the EP as they would be combined with D even though not connected.
|
||||
Not sure how often each of these scenarios happens.
|
||||
|
||||
A B C
|
||||
| / |
|
||||
D E
|
||||
| |
|
||||
|
||||
Would probably be better to walk the edges for each node the EP can handle as they are iterated in topological order,
|
||||
accumulating nodes (and saving which ones have been taken) until you run out. This would guarantee all
|
||||
connected nodes that can be handled are grouped together.
|
||||
*/
|
||||
|
||||
const auto& logger = *GetLogger();
|
||||
|
||||
emscripten::val wnn_builder = emscripten::val::global("MLGraphBuilder").new_(wnn_context_);
|
||||
|
|
@ -88,43 +74,37 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
|
|||
ORT_THROW("Failed to create WebNN builder.");
|
||||
}
|
||||
|
||||
const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, wnn_limits_, logger);
|
||||
wnn_builder = emscripten::val::undefined();
|
||||
// Get all the NodeUnits in the graph_viewer
|
||||
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
|
||||
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
|
||||
|
||||
if (node_groups.empty()) {
|
||||
return result;
|
||||
}
|
||||
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger);
|
||||
|
||||
const auto supported_nodes = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, wnn_limits_, logger);
|
||||
|
||||
const auto gen_metadef_name = [&]() {
|
||||
HashValue model_hash;
|
||||
int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);
|
||||
return MakeString(WEBNN, "_", model_hash, "_", metadef_id);
|
||||
};
|
||||
|
||||
auto result = utils::CreateSupportedPartitions(graph_viewer, supported_nodes, {},
|
||||
gen_metadef_name, WEBNN, kWebNNExecutionProvider,
|
||||
&node_unit_map, /*drop_constant_initializers*/ true);
|
||||
|
||||
// Release wnn_builder
|
||||
wnn_builder = emscripten::val::undefined();
|
||||
|
||||
const auto& graph_output_list = graph_viewer.GetOutputs();
|
||||
InlinedHashSet<const NodeArg*> graph_outputs(graph_output_list.cbegin(), graph_output_list.cend());
|
||||
|
||||
size_t num_of_supported_nodes = 0;
|
||||
for (const auto& group : node_groups) {
|
||||
if (group.empty())
|
||||
for (auto& capability : result) {
|
||||
auto& sub_graph = capability->sub_graph;
|
||||
if (sub_graph->nodes.empty())
|
||||
continue;
|
||||
|
||||
num_of_supported_nodes += group.size();
|
||||
LOGS(logger, VERBOSE) << "WebNNExecutionProvider::GetCapability, current supported node group size: "
|
||||
<< group.size();
|
||||
|
||||
InlinedHashSet<NodeIndex> node_set;
|
||||
node_set.reserve(group.size());
|
||||
for (const auto& index : group) {
|
||||
node_set.insert(index);
|
||||
}
|
||||
|
||||
std::unique_ptr<IndexedSubGraph> sub_graph = std::make_unique<IndexedSubGraph>();
|
||||
|
||||
std::vector<std::string> subgraph_initializers;
|
||||
InlinedHashSet<const NodeArg*> node_outputs;
|
||||
InlinedHashSet<const NodeArg*> subgraph_inputs;
|
||||
InlinedHashSet<const NodeArg*> subgraph_outputs;
|
||||
std::vector<const NodeArg*> ordered_subgraph_inputs;
|
||||
// Output should be unique. It may be produced as graph output and subgraph output.
|
||||
InlinedHashSet<const NodeArg*> ordered_subgraph_outputs;
|
||||
|
||||
for (const auto& index : group) {
|
||||
sub_graph->nodes.push_back(index);
|
||||
for (const auto& index : sub_graph->nodes) {
|
||||
const auto* node = graph_viewer.GetNode(index);
|
||||
|
||||
for (const auto* input : node->InputDefs()) {
|
||||
|
|
@ -136,39 +116,13 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
|
|||
if (is_subgraph && Contains(all_initializers, input->Name())) {
|
||||
subgraph_initializers.push_back(input->Name());
|
||||
}
|
||||
// If the node input was not produced by this subgraph, add it to the subgraph inputs.
|
||||
if (node_outputs.count(input) == 0) {
|
||||
if (subgraph_inputs.count(input) == 0) {
|
||||
subgraph_inputs.insert(input);
|
||||
ordered_subgraph_inputs.push_back(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto& output_defs = node->OutputDefs();
|
||||
for (const auto* output_def : output_defs) {
|
||||
node_outputs.insert(output_def);
|
||||
// if output is overall graph output we need to produce it.
|
||||
if (graph_outputs.count(output_def) != 0) {
|
||||
ordered_subgraph_outputs.insert(output_def);
|
||||
}
|
||||
}
|
||||
|
||||
// if output connects to a node not in this subgraph we need to produce it.
|
||||
for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) {
|
||||
if (node_set.count(it->GetNode().Index()) == 0) {
|
||||
const auto* output_def = output_defs[it->GetSrcArgIndex()];
|
||||
if (subgraph_outputs.count(output_def) == 0) {
|
||||
subgraph_outputs.insert(output_def);
|
||||
ordered_subgraph_outputs.insert(output_def);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Assign inputs and outputs to subgraph's meta_def.
|
||||
uint64_t model_hash;
|
||||
int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);
|
||||
const auto meta_def_old = sub_graph->GetMetaDef();
|
||||
auto meta_def = std::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>();
|
||||
meta_def->name = "WEBNN_" + std::to_string(model_hash) + "_" + std::to_string(metadef_id);
|
||||
meta_def->domain = kMSDomain;
|
||||
|
|
@ -181,20 +135,24 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
|
|||
}
|
||||
}
|
||||
|
||||
for (const auto& input : ordered_subgraph_inputs) {
|
||||
meta_def->inputs.push_back(input->Name());
|
||||
for (const auto& input : meta_def_old->inputs) {
|
||||
meta_def->inputs.push_back(input);
|
||||
}
|
||||
|
||||
for (const auto& output : ordered_subgraph_outputs) {
|
||||
meta_def->outputs.push_back(output->Name());
|
||||
for (const auto& output : meta_def_old->outputs) {
|
||||
meta_def->outputs.push_back(output);
|
||||
}
|
||||
|
||||
sub_graph->SetMetaDef(std::move(meta_def));
|
||||
|
||||
result.push_back(std::make_unique<ComputeCapability>(std::move(sub_graph)));
|
||||
}
|
||||
|
||||
auto num_of_partitions = result.size();
|
||||
const auto num_of_partitions = result.size();
|
||||
const auto num_of_supported_nodes = std::accumulate(
|
||||
result.begin(), result.end(), size_t{0},
|
||||
[](const auto& acc, const auto& partition) -> size_t {
|
||||
return acc + (partition && partition->sub_graph ? partition->sub_graph->nodes.size() : 0);
|
||||
});
|
||||
|
||||
const auto summary_msg = MakeString(
|
||||
"WebNNExecutionProvider::GetCapability,",
|
||||
" number of partitions supported by WebNN: ", num_of_partitions,
|
||||
|
|
|
|||
Loading…
Reference in a new issue