From ddd4e8c3e360ffedb267704fc7fbb8a368d76e34 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Fri, 26 Apr 2024 08:16:55 -0700 Subject: [PATCH] [WebNN EP] Improve activation fusion (#20320) - Create a common util to get supported activation set - Fuse activation to BatchNormalization if possible --- .../webnn/builders/impl/conv_op_builder.cc | 4 +- .../builders/impl/normalization_op_builder.cc | 4 ++ .../providers/webnn/builders/model_builder.cc | 47 ++++++++++++++----- .../providers/webnn/builders/model_builder.h | 5 +- 4 files changed, 43 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index c74545479e..6b232a58aa 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -137,8 +137,8 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, if (input_defs.size() > 2) { options.set("bias", model_builder.GetOperand(input_defs[2]->Name())); } - InlinedHashSet supported_nodes{"Clip", "Relu"}; - emscripten::val activation = model_builder.FindActivation(node, *node.OutputDefs()[0], supported_nodes); + + emscripten::val activation = model_builder.FindActivation(node, *node.OutputDefs()[0]); if (emscripten::val::null() != activation) { options.set("activation", activation); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 50e04df4fe..3601d3b2b3 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -79,6 +79,10 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { options.set("axis", rank - 1); } + emscripten::val activation = model_builder.FindActivation(node, *node.OutputDefs()[0]); + if (emscripten::val::null() != activation) { + options.set("activation", activation); + } output = model_builder.GetBuilder().call("batchNormalization", input, mean, variance, options); } else if (op_type == "LayerNormalization") { int64_t axis = helper.Get("axis", -1); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index cb6669ecba..d9acfa1f98 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -80,20 +80,43 @@ void ModelBuilder::PreprocessInitializers() { void ModelBuilder::PreprocessActivations() { const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder(); + + if (wnn_device_type_ == WebnnDeviceType::CPU) { + // WebNN CPU currently only supports "Relu" and "Clip" fusion. + supported_activation_nodes_ = {"Clip", "Relu"}; + } else { + supported_activation_nodes_ = { + // Temporarily disable clamp fusion for WebNN GPU as which is not supported yet. + // "Clip", + "Elu", + "Gelu", + "HardSigmoid", + "HardSwish", + "Relu", + "LeakyRelu", + "Sigmoid", + "Softplus", + "Softsign", + "Tanh", + }; + } + for (size_t i = 0; i < node_indices.size(); i++) { const auto* node(graph_viewer_.GetNode(node_indices[i])); const auto& op_type(node->OpType()); + // Ignore unsupported activation nodes. + if (!Contains(supported_activation_nodes_, op_type)) { + continue; + } + if (op_type == "Clip") { - // Temporarily disable clamp fusion for WebNN GPU as which is not supported yet. - if (wnn_device_type_ == WebnnDeviceType::CPU) { - float minValue, maxValue; - GetClipMinMax(GetInitializerTensors(), *node, minValue, maxValue, logger_); - emscripten::val options = emscripten::val::object(); - options.set("minValue", minValue); - options.set("maxValue", maxValue); - activation_nodes_.emplace(node->Index(), wnn_builder_.call("clamp", options)); - } + float minValue, maxValue; + GetClipMinMax(GetInitializerTensors(), *node, minValue, maxValue, logger_); + emscripten::val options = emscripten::val::object(); + options.set("minValue", minValue); + options.set("maxValue", maxValue); + activation_nodes_.emplace(node->Index(), wnn_builder_.call("clamp", options)); } else if (op_type == "Elu") { NodeAttrHelper helper(*node); emscripten::val options = emscripten::val::object(); @@ -398,14 +421,12 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { return Status::OK(); } -// supported_nodes is provided by the op to indicate whether it can be fused with the activation node. -emscripten::val ModelBuilder::FindActivation(const Node& node, const NodeArg& output, - const InlinedHashSet supported_nodes) { +emscripten::val ModelBuilder::FindActivation(const Node& node, const NodeArg& output) { emscripten::val fused_op = emscripten::val::null(); for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) { const auto& dst_node = it->GetNode(); const auto* dst_input = dst_node.InputDefs()[it->GetDstArgIndex()]; - if (!Contains(supported_nodes, dst_node.OpType())) { + if (!Contains(supported_activation_nodes_, dst_node.OpType())) { return emscripten::val::null(); } if (Contains(activation_nodes_, dst_node.Index())) { diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index 16cc7a376b..2120309980 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -45,8 +45,7 @@ class ModelBuilder { const std::string& name, const void* buffer, const size_t size, const std::vector shape, const int32_t data_type); // Find if an output has a fuseable activation (e.g., Relu). - emscripten::val FindActivation(const Node& node, const NodeArg& output, - const InlinedHashSet supported_nodes = {}); + emscripten::val FindActivation(const Node& node, const NodeArg& output); const InlinedHashSet& GetFusedActivations() const { return fused_activations_; } @@ -85,6 +84,8 @@ class ModelBuilder { InlinedHashSet fused_activations_; + InlinedHashSet supported_activation_nodes_; + uint32_t name_token_{0}; InlinedHashSet unique_names_;