[WebNN EP] Create MLGraphBuilder for every model builder (#21514)

Currently WebNN spec only allows MLGraphBuilder.build() to be called
once, we need to create new builder for every subgraph in WebNN EP.

Spec change: https://github.com/webmachinelearning/webnn/pull/717
This commit is contained in:
Wanming Lin 2024-08-02 00:15:31 +08:00 committed by GitHub
parent 3b73ef2bf7
commit 8c2ee7b32e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 28 additions and 28 deletions

View file

@ -84,7 +84,7 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons
}
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
const emscripten::val& wnn_builder_,
const emscripten::val& wnn_builder,
const WebnnDeviceType device_type,
const logging::Logger& logger) {
std::vector<std::vector<size_t>> supported_node_groups;
@ -103,7 +103,7 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
const auto* node(graph_viewer.GetNode(node_idx));
bool supported = false;
// Firstly check if platform supports the WebNN op.
if (CheckSingleOp(node->OpType(), wnn_builder_, device_type)) {
if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) {
LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() << "] is supported by browser";
supported = IsNodeSupported(*node, graph_viewer, device_type, logger);
}

View file

@ -151,7 +151,7 @@ bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, c
// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP.
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
const emscripten::val& wnn_builder_,
const emscripten::val& wnn_builder,
const WebnnDeviceType device_type,
const logging::Logger& logger);
static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
@ -241,14 +241,14 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
{"Where", {"where", true}},
};
inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder_,
inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder,
const WebnnDeviceType device_type) {
// Returns false if the op_type is not listed in the op_map.
if (op_map.find(op_type) == op_map.end()) {
return false;
}
// Returns false if the WebNN op has not been implemented in MLGraphBuilder in current browser.
if (!wnn_builder_[op_map.find(op_type)->second.opName].as<bool>()) {
if (!wnn_builder[op_map.find(op_type)->second.opName].as<bool>()) {
return false;
}
// The current WebNN CPU (TFLite) backend supports a limited op list, and we'd rather

View file

@ -20,14 +20,20 @@ namespace onnxruntime {
namespace webnn {
ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger,
const emscripten::val& context, const emscripten::val& builder,
const DataLayout preferred_layout, const WebnnDeviceType wnn_device_type)
const emscripten::val& context, const DataLayout preferred_layout,
const WebnnDeviceType wnn_device_type)
: graph_viewer_(graph_viewer),
logger_(logger),
wnn_context_(context),
wnn_builder_(builder),
preferred_layout_(preferred_layout),
wnn_device_type_(wnn_device_type) {}
wnn_device_type_(wnn_device_type) {
// Create WebNN MLGraphBuilder for each ModelBuilder, because MLGraphBuilder.build()
// is only allowed to be called once.
wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(context);
if (!wnn_builder_.as<bool>()) {
ORT_THROW("Failed to create WebNN builder.");
}
}
Status ModelBuilder::Initialize() {
PreprocessInitializers();
@ -332,6 +338,8 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
if (!wnn_graph.as<bool>()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to build WebNN graph.");
}
// Explicitly release the WebNN builder to free memory.
wnn_builder_ = emscripten::val::undefined();
model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_));
model->SetInputs(std::move(input_names_));
model->SetOutputs(std::move(output_names_));

View file

@ -22,8 +22,8 @@ class IOpBuilder;
class ModelBuilder {
public:
ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger,
const emscripten::val& context, const emscripten::val& builder,
const DataLayout preferred_layout, const WebnnDeviceType wnn_device_type);
const emscripten::val& context, const DataLayout preferred_layout,
const WebnnDeviceType wnn_device_type);
~ModelBuilder() = default;
Status Compile(std::unique_ptr<Model>& model) ORT_MUST_USE_RESULT;
@ -62,8 +62,8 @@ class ModelBuilder {
const GraphViewer& graph_viewer_;
const logging::Logger& logger_;
emscripten::val wnn_context_ = emscripten::val::object();
emscripten::val wnn_builder_ = emscripten::val::object();
emscripten::val wnn_context_ = emscripten::val::undefined();
emscripten::val wnn_builder_ = emscripten::val::undefined();
DataLayout preferred_layout_;
WebnnDeviceType wnn_device_type_;
InlinedHashMap<std::string, emscripten::val> wnn_operands_;

View file

@ -38,10 +38,6 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f
if (!wnn_context_.as<bool>()) {
ORT_THROW("Failed to create WebNN context.");
}
wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(wnn_context_);
if (!wnn_builder_.as<bool>()) {
ORT_THROW("Failed to create WebNN builder.");
}
}
WebNNExecutionProvider::~WebNNExecutionProvider() {}
@ -81,14 +77,13 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
const auto& logger = *GetLogger();
if (!wnn_builder_.as<bool>()) {
// The GetCapability function may be called again after Compile due to the logic in the
// PartitionOnnxFormatModel function (see onnxruntime/core/framework/graph_partitioner.cc).
// We need to re-create the wnn_builder_ here to avoid it's been released in last Compile.
wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(wnn_context_);
emscripten::val wnn_builder = emscripten::val::global("MLGraphBuilder").new_(wnn_context_);
if (!wnn_builder.as<bool>()) {
ORT_THROW("Failed to create WebNN builder.");
}
const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder_, wnn_device_type_, logger);
const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, logger);
wnn_builder = emscripten::val::undefined();
if (node_groups.empty()) {
return result;
@ -218,9 +213,10 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph);
webnn::ModelBuilder builder(graph_viewer, *GetLogger(), wnn_context_,
wnn_builder_, preferred_layout_, wnn_device_type_);
preferred_layout_, wnn_device_type_);
std::unique_ptr<webnn::Model> model;
ORT_RETURN_IF_ERROR(builder.Compile(model));
// Build map from input name to its index in input definitions.
{
InlinedHashMap<std::string, size_t> input_map;
@ -329,9 +325,6 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
node_compute_funcs.push_back(compute_info);
}
// Explicitly release the WebNN builder to free memory.
wnn_builder_ = emscripten::val::undefined();
return Status::OK();
}

View file

@ -43,7 +43,6 @@ class WebNNExecutionProvider : public IExecutionProvider {
private:
emscripten::val wnn_context_ = emscripten::val::undefined();
mutable emscripten::val wnn_builder_ = emscripten::val::undefined();
DataLayout preferred_layout_;
webnn::WebnnDeviceType wnn_device_type_;