From f5bfbd6d81fbe0ba652182f8c0f7da3d6816f7f8 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 16 May 2024 07:49:07 +0800 Subject: [PATCH] [WebNN EP] Remove activation fusion (#20635) WebNN spec has removed activation option for conv and batchNormalization. We don't need additional activation fusion in WebNN EP anymore. [edit by fdwr] Note this is handled in the browser now, which knows more about the backend platform version and can more safely make decisions about which fusions are possible (e.g. for the DirectML backend, whether softmax and gelu can fuse successfully with their base operator). --- .../builders/impl/activation_op_builder.cc | 61 +++++----- .../webnn/builders/impl/clip_op_builder.cc | 8 +- .../webnn/builders/impl/conv_op_builder.cc | 5 - .../builders/impl/normalization_op_builder.cc | 5 +- .../providers/webnn/builders/model_builder.cc | 112 ------------------ .../providers/webnn/builders/model_builder.h | 14 --- 6 files changed, 30 insertions(+), 175 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc index 6dacd749b8..cd6a95a2aa 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc @@ -32,40 +32,35 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val input = model_builder.GetOperand(node.InputDefs()[0]->Name()); emscripten::val output = emscripten::val::object(); - if (Contains(model_builder.GetFusedActivations(), node.InputDefs()[0]->Name())) { - LOGS_DEFAULT(VERBOSE) << op_type << " Node [" << node.Name() << "] fused"; - output = input; + NodeAttrHelper helper(node); + emscripten::val options = emscripten::val::object(); + if (op_type == "Elu") { + options.set("alpha", helper.Get("alpha", 1.0f)); + output = model_builder.GetBuilder().call("elu", input, options); + } else if (op_type == "Gelu") { + output = model_builder.GetBuilder().call("gelu", input, options); + } else if (op_type == "HardSigmoid") { + options.set("alpha", helper.Get("alpha", 0.2f)); + options.set("beta", helper.Get("beta", 0.5f)); + output = model_builder.GetBuilder().call("hardSigmoid", input, options); + } else if (op_type == "HardSwish") { + output = model_builder.GetBuilder().call("hardSwish", input); + } else if (op_type == "LeakyRelu") { + options.set("alpha", helper.Get("alpha", 0.0f)); + output = model_builder.GetBuilder().call("leakyRelu", input, options); + } else if (op_type == "Relu") { + output = model_builder.GetBuilder().call("relu", input); + } else if (op_type == "Sigmoid") { + output = model_builder.GetBuilder().call("sigmoid", input); + } else if (op_type == "Softplus") { + output = model_builder.GetBuilder().call("softplus", input); + } else if (op_type == "Softsign") { + output = model_builder.GetBuilder().call("softsign", input); + } else if (op_type == "Tanh") { + output = model_builder.GetBuilder().call("tanh", input); } else { - NodeAttrHelper helper(node); - emscripten::val options = emscripten::val::object(); - if (op_type == "Elu") { - options.set("alpha", helper.Get("alpha", 1.0f)); - output = model_builder.GetBuilder().call("elu", input, options); - } else if (op_type == "Gelu") { - output = model_builder.GetBuilder().call("gelu", input, options); - } else if (op_type == "HardSigmoid") { - options.set("alpha", helper.Get("alpha", 0.2f)); - options.set("beta", helper.Get("beta", 0.5f)); - output = model_builder.GetBuilder().call("hardSigmoid", input, options); - } else if (op_type == "HardSwish") { - output = model_builder.GetBuilder().call("hardSwish", input); - } else if (op_type == "LeakyRelu") { - options.set("alpha", helper.Get("alpha", 0.0f)); - output = model_builder.GetBuilder().call("leakyRelu", input, options); - } else if (op_type == "Relu") { - output = model_builder.GetBuilder().call("relu", input); - } else if (op_type == "Sigmoid") { - output = model_builder.GetBuilder().call("sigmoid", input); - } else if (op_type == "Softplus") { - output = model_builder.GetBuilder().call("softplus", input); - } else if (op_type == "Softsign") { - output = model_builder.GetBuilder().call("softsign", input); - } else if (op_type == "Tanh") { - output = model_builder.GetBuilder().call("tanh", input); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); - } + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); } model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); diff --git a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc index 0d6001bcba..84a51b6679 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc @@ -52,13 +52,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, options.set("minValue", minValue); options.set("maxValue", maxValue); emscripten::val input = model_builder.GetOperand(input_name); - emscripten::val output = emscripten::val::object(); - if (Contains(model_builder.GetFusedActivations(), input_name)) { - LOGS_DEFAULT(VERBOSE) << "Clip Node [" << node.Name() << "] fused"; - output = input; - } else { - output = model_builder.GetBuilder().call("clamp", input, options); - } + emscripten::val output = model_builder.GetBuilder().call("clamp", input, options); model_builder.AddOperand(output_name, std::move(output)); return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 6b232a58aa..f79920b63a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -138,11 +138,6 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, options.set("bias", model_builder.GetOperand(input_defs[2]->Name())); } - emscripten::val activation = model_builder.FindActivation(node, *node.OutputDefs()[0]); - if (emscripten::val::null() != activation) { - options.set("activation", activation); - } - return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 3601d3b2b3..b6948b48c3 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -79,10 +79,7 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { options.set("axis", rank - 1); } - emscripten::val activation = model_builder.FindActivation(node, *node.OutputDefs()[0]); - if (emscripten::val::null() != activation) { - options.set("activation", activation); - } + output = model_builder.GetBuilder().call("batchNormalization", input, mean, variance, options); } else if (op_type == "LayerNormalization") { int64_t axis = helper.Get("axis", -1); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index d9acfa1f98..ff3e1a7179 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -31,7 +31,6 @@ ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logge Status ModelBuilder::Initialize() { PreprocessInitializers(); - PreprocessActivations(); ORT_RETURN_IF_ERROR(RegisterInitializers()); ORT_RETURN_IF_ERROR(RegisterModelInputs()); ORT_RETURN_IF_ERROR(AddOperations()); @@ -78,79 +77,6 @@ void ModelBuilder::PreprocessInitializers() { } } -void ModelBuilder::PreprocessActivations() { - const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder(); - - if (wnn_device_type_ == WebnnDeviceType::CPU) { - // WebNN CPU currently only supports "Relu" and "Clip" fusion. - supported_activation_nodes_ = {"Clip", "Relu"}; - } else { - supported_activation_nodes_ = { - // Temporarily disable clamp fusion for WebNN GPU as which is not supported yet. - // "Clip", - "Elu", - "Gelu", - "HardSigmoid", - "HardSwish", - "Relu", - "LeakyRelu", - "Sigmoid", - "Softplus", - "Softsign", - "Tanh", - }; - } - - for (size_t i = 0; i < node_indices.size(); i++) { - const auto* node(graph_viewer_.GetNode(node_indices[i])); - const auto& op_type(node->OpType()); - - // Ignore unsupported activation nodes. - if (!Contains(supported_activation_nodes_, op_type)) { - continue; - } - - if (op_type == "Clip") { - float minValue, maxValue; - GetClipMinMax(GetInitializerTensors(), *node, minValue, maxValue, logger_); - emscripten::val options = emscripten::val::object(); - options.set("minValue", minValue); - options.set("maxValue", maxValue); - activation_nodes_.emplace(node->Index(), wnn_builder_.call("clamp", options)); - } else if (op_type == "Elu") { - NodeAttrHelper helper(*node); - emscripten::val options = emscripten::val::object(); - options.set("alpha", helper.Get("alpha", 1.0f)); - activation_nodes_.emplace(node->Index(), wnn_builder_.call("elu", options)); - } else if (op_type == "Gelu") { - activation_nodes_.emplace(node->Index(), wnn_builder_.call("gelu")); - } else if (op_type == "HardSigmoid") { - NodeAttrHelper helper(*node); - emscripten::val options = emscripten::val::object(); - options.set("alpha", helper.Get("alpha", 0.2f)); - options.set("beta", helper.Get("beta", 0.5f)); - activation_nodes_.emplace(node->Index(), wnn_builder_.call("hardSigmoid", options)); - } else if (op_type == "HardSwish") { - activation_nodes_.emplace(node->Index(), wnn_builder_.call("hardSwish")); - } else if (op_type == "Relu") { - activation_nodes_.emplace(node->Index(), wnn_builder_.call("relu")); - } else if (op_type == "LeakyRelu") { - NodeAttrHelper helper(*node); - emscripten::val options = emscripten::val::object(); - options.set("alpha", helper.Get("alpha", 0.0f)); - activation_nodes_.emplace(node->Index(), wnn_builder_.call("leakyRelu", options)); - } else if (op_type == "Sigmoid") { - activation_nodes_.emplace(node->Index(), wnn_builder_.call("sigmoid")); - } else if (op_type == "Softplus") { - activation_nodes_.emplace(node->Index(), wnn_builder_.call("softplus")); - } else if (op_type == "Softsign") { - activation_nodes_.emplace(node->Index(), wnn_builder_.call("softsign")); - } else if (op_type == "Tanh") { - activation_nodes_.emplace(node->Index(), wnn_builder_.call("tanh")); - } - } -} - Status ModelBuilder::RegisterInitializers() { for (const auto& pair : GetInitializerTensors()) { const auto& tensor = *pair.second; @@ -421,44 +347,6 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { return Status::OK(); } -emscripten::val ModelBuilder::FindActivation(const Node& node, const NodeArg& output) { - emscripten::val fused_op = emscripten::val::null(); - for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) { - const auto& dst_node = it->GetNode(); - const auto* dst_input = dst_node.InputDefs()[it->GetDstArgIndex()]; - if (!Contains(supported_activation_nodes_, dst_node.OpType())) { - return emscripten::val::null(); - } - if (Contains(activation_nodes_, dst_node.Index())) { - if (&output == dst_input) { - fused_op = activation_nodes_.at(dst_node.Index()); - } - } else { - // If there is any other non-relu node using the output - // will add relu separately. - if (&output == dst_input) { - return emscripten::val::null(); - } - } - } - - // If output is a graph output, will add relu separately. - if (fused_op != emscripten::val::null()) { - for (const auto* graph_output : graph_viewer_.GetOutputs()) { - if (&output == graph_output) { - return emscripten::val::null(); - } - } - - LOGS_DEFAULT(VERBOSE) << "Node [" << node.Name() << "] type [" << node.OpType() - << "], fused the output [" << output.Name() << "]"; - - fused_activations_.insert(output.Name()); - } - - return fused_op; -} - void ModelBuilder::AddScalarOutput(const std::string& output_name) { scalar_outputs_.insert(output_name); } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index 2120309980..8c1848eb83 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -44,11 +44,6 @@ class ModelBuilder { Status AddOperandFromPersistMemoryBuffer( const std::string& name, const void* buffer, const size_t size, const std::vector shape, const int32_t data_type); - // Find if an output has a fuseable activation (e.g., Relu). - emscripten::val FindActivation(const Node& node, const NodeArg& output); - - const InlinedHashSet& - GetFusedActivations() const { return fused_activations_; } DataLayout GetPreferredLayout() const { return preferred_layout_; } @@ -82,22 +77,13 @@ class ModelBuilder { InlinedHashSet skipped_initializers_; InlinedHashSet skipped_inputs_; - InlinedHashSet fused_activations_; - - InlinedHashSet supported_activation_nodes_; - uint32_t name_token_{0}; InlinedHashSet unique_names_; - // All activation nodes (e.g., Relu) as a map . - InlinedHashMap activation_nodes_; - // Convert the onnx model to WebNN operands Status Initialize() ORT_MUST_USE_RESULT; void PreprocessInitializers(); - // Preprocess all the activation nodes (e.g., Relu) for easy query later. - void PreprocessActivations(); // Copy and process all the initializers to WebNN constants. Status RegisterInitializers() ORT_MUST_USE_RESULT;