mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[WebNN EP] Create MLGraphBuilder for every model builder (#21514)
Currently WebNN spec only allows MLGraphBuilder.build() to be called once, we need to create new builder for every subgraph in WebNN EP. Spec change: https://github.com/webmachinelearning/webnn/pull/717
This commit is contained in:
parent
3b73ef2bf7
commit
8c2ee7b32e
6 changed files with 28 additions and 28 deletions
|
|
@ -84,7 +84,7 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons
|
|||
}
|
||||
|
||||
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
|
||||
const emscripten::val& wnn_builder_,
|
||||
const emscripten::val& wnn_builder,
|
||||
const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger) {
|
||||
std::vector<std::vector<size_t>> supported_node_groups;
|
||||
|
|
@ -103,7 +103,7 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
|
|||
const auto* node(graph_viewer.GetNode(node_idx));
|
||||
bool supported = false;
|
||||
// Firstly check if platform supports the WebNN op.
|
||||
if (CheckSingleOp(node->OpType(), wnn_builder_, device_type)) {
|
||||
if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) {
|
||||
LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() << "] is supported by browser";
|
||||
supported = IsNodeSupported(*node, graph_viewer, device_type, logger);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -151,7 +151,7 @@ bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, c
|
|||
|
||||
// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP.
|
||||
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
|
||||
const emscripten::val& wnn_builder_,
|
||||
const emscripten::val& wnn_builder,
|
||||
const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger);
|
||||
static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
|
||||
|
|
@ -241,14 +241,14 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
|
|||
{"Where", {"where", true}},
|
||||
};
|
||||
|
||||
inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder_,
|
||||
inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder,
|
||||
const WebnnDeviceType device_type) {
|
||||
// Returns false if the op_type is not listed in the op_map.
|
||||
if (op_map.find(op_type) == op_map.end()) {
|
||||
return false;
|
||||
}
|
||||
// Returns false if the WebNN op has not been implemented in MLGraphBuilder in current browser.
|
||||
if (!wnn_builder_[op_map.find(op_type)->second.opName].as<bool>()) {
|
||||
if (!wnn_builder[op_map.find(op_type)->second.opName].as<bool>()) {
|
||||
return false;
|
||||
}
|
||||
// The current WebNN CPU (TFLite) backend supports a limited op list, and we'd rather
|
||||
|
|
|
|||
|
|
@ -20,14 +20,20 @@ namespace onnxruntime {
|
|||
namespace webnn {
|
||||
|
||||
ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger,
|
||||
const emscripten::val& context, const emscripten::val& builder,
|
||||
const DataLayout preferred_layout, const WebnnDeviceType wnn_device_type)
|
||||
const emscripten::val& context, const DataLayout preferred_layout,
|
||||
const WebnnDeviceType wnn_device_type)
|
||||
: graph_viewer_(graph_viewer),
|
||||
logger_(logger),
|
||||
wnn_context_(context),
|
||||
wnn_builder_(builder),
|
||||
preferred_layout_(preferred_layout),
|
||||
wnn_device_type_(wnn_device_type) {}
|
||||
wnn_device_type_(wnn_device_type) {
|
||||
// Create WebNN MLGraphBuilder for each ModelBuilder, because MLGraphBuilder.build()
|
||||
// is only allowed to be called once.
|
||||
wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(context);
|
||||
if (!wnn_builder_.as<bool>()) {
|
||||
ORT_THROW("Failed to create WebNN builder.");
|
||||
}
|
||||
}
|
||||
|
||||
Status ModelBuilder::Initialize() {
|
||||
PreprocessInitializers();
|
||||
|
|
@ -332,6 +338,8 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
|
|||
if (!wnn_graph.as<bool>()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to build WebNN graph.");
|
||||
}
|
||||
// Explicitly release the WebNN builder to free memory.
|
||||
wnn_builder_ = emscripten::val::undefined();
|
||||
model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_));
|
||||
model->SetInputs(std::move(input_names_));
|
||||
model->SetOutputs(std::move(output_names_));
|
||||
|
|
|
|||
|
|
@ -22,8 +22,8 @@ class IOpBuilder;
|
|||
class ModelBuilder {
|
||||
public:
|
||||
ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger,
|
||||
const emscripten::val& context, const emscripten::val& builder,
|
||||
const DataLayout preferred_layout, const WebnnDeviceType wnn_device_type);
|
||||
const emscripten::val& context, const DataLayout preferred_layout,
|
||||
const WebnnDeviceType wnn_device_type);
|
||||
~ModelBuilder() = default;
|
||||
|
||||
Status Compile(std::unique_ptr<Model>& model) ORT_MUST_USE_RESULT;
|
||||
|
|
@ -62,8 +62,8 @@ class ModelBuilder {
|
|||
const GraphViewer& graph_viewer_;
|
||||
const logging::Logger& logger_;
|
||||
|
||||
emscripten::val wnn_context_ = emscripten::val::object();
|
||||
emscripten::val wnn_builder_ = emscripten::val::object();
|
||||
emscripten::val wnn_context_ = emscripten::val::undefined();
|
||||
emscripten::val wnn_builder_ = emscripten::val::undefined();
|
||||
DataLayout preferred_layout_;
|
||||
WebnnDeviceType wnn_device_type_;
|
||||
InlinedHashMap<std::string, emscripten::val> wnn_operands_;
|
||||
|
|
|
|||
|
|
@ -38,10 +38,6 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f
|
|||
if (!wnn_context_.as<bool>()) {
|
||||
ORT_THROW("Failed to create WebNN context.");
|
||||
}
|
||||
wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(wnn_context_);
|
||||
if (!wnn_builder_.as<bool>()) {
|
||||
ORT_THROW("Failed to create WebNN builder.");
|
||||
}
|
||||
}
|
||||
|
||||
WebNNExecutionProvider::~WebNNExecutionProvider() {}
|
||||
|
|
@ -81,14 +77,13 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
|
|||
|
||||
const auto& logger = *GetLogger();
|
||||
|
||||
if (!wnn_builder_.as<bool>()) {
|
||||
// The GetCapability function may be called again after Compile due to the logic in the
|
||||
// PartitionOnnxFormatModel function (see onnxruntime/core/framework/graph_partitioner.cc).
|
||||
// We need to re-create the wnn_builder_ here to avoid it's been released in last Compile.
|
||||
wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(wnn_context_);
|
||||
emscripten::val wnn_builder = emscripten::val::global("MLGraphBuilder").new_(wnn_context_);
|
||||
if (!wnn_builder.as<bool>()) {
|
||||
ORT_THROW("Failed to create WebNN builder.");
|
||||
}
|
||||
|
||||
const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder_, wnn_device_type_, logger);
|
||||
const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, logger);
|
||||
wnn_builder = emscripten::val::undefined();
|
||||
|
||||
if (node_groups.empty()) {
|
||||
return result;
|
||||
|
|
@ -218,9 +213,10 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
|
|||
const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph);
|
||||
|
||||
webnn::ModelBuilder builder(graph_viewer, *GetLogger(), wnn_context_,
|
||||
wnn_builder_, preferred_layout_, wnn_device_type_);
|
||||
preferred_layout_, wnn_device_type_);
|
||||
std::unique_ptr<webnn::Model> model;
|
||||
ORT_RETURN_IF_ERROR(builder.Compile(model));
|
||||
|
||||
// Build map from input name to its index in input definitions.
|
||||
{
|
||||
InlinedHashMap<std::string, size_t> input_map;
|
||||
|
|
@ -329,9 +325,6 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
|
|||
node_compute_funcs.push_back(compute_info);
|
||||
}
|
||||
|
||||
// Explicitly release the WebNN builder to free memory.
|
||||
wnn_builder_ = emscripten::val::undefined();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -43,7 +43,6 @@ class WebNNExecutionProvider : public IExecutionProvider {
|
|||
|
||||
private:
|
||||
emscripten::val wnn_context_ = emscripten::val::undefined();
|
||||
mutable emscripten::val wnn_builder_ = emscripten::val::undefined();
|
||||
|
||||
DataLayout preferred_layout_;
|
||||
webnn::WebnnDeviceType wnn_device_type_;
|
||||
|
|
|
|||
Loading…
Reference in a new issue