diff --git a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc index 6dacd749b8..cd6a95a2aa 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc @@ -32,40 +32,35 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val input = model_builder.GetOperand(node.InputDefs()[0]->Name()); emscripten::val output = emscripten::val::object(); - if (Contains(model_builder.GetFusedActivations(), node.InputDefs()[0]->Name())) { - LOGS_DEFAULT(VERBOSE) << op_type << " Node [" << node.Name() << "] fused"; - output = input; + NodeAttrHelper helper(node); + emscripten::val options = emscripten::val::object(); + if (op_type == "Elu") { + options.set("alpha", helper.Get("alpha", 1.0f)); + output = model_builder.GetBuilder().call("elu", input, options); + } else if (op_type == "Gelu") { + output = model_builder.GetBuilder().call("gelu", input, options); + } else if (op_type == "HardSigmoid") { + options.set("alpha", helper.Get("alpha", 0.2f)); + options.set("beta", helper.Get("beta", 0.5f)); + output = model_builder.GetBuilder().call("hardSigmoid", input, options); + } else if (op_type == "HardSwish") { + output = model_builder.GetBuilder().call("hardSwish", input); + } else if (op_type == "LeakyRelu") { + options.set("alpha", helper.Get("alpha", 0.0f)); + output = model_builder.GetBuilder().call("leakyRelu", input, options); + } else if (op_type == "Relu") { + output = model_builder.GetBuilder().call("relu", input); + } else if (op_type == "Sigmoid") { + output = model_builder.GetBuilder().call("sigmoid", input); + } else if (op_type == "Softplus") { + output = model_builder.GetBuilder().call("softplus", input); + } else if (op_type == "Softsign") { + output = model_builder.GetBuilder().call("softsign", input); + } else if (op_type == "Tanh") { + output = model_builder.GetBuilder().call("tanh", input); } else { - NodeAttrHelper helper(node); - emscripten::val options = emscripten::val::object(); - if (op_type == "Elu") { - options.set("alpha", helper.Get("alpha", 1.0f)); - output = model_builder.GetBuilder().call("elu", input, options); - } else if (op_type == "Gelu") { - output = model_builder.GetBuilder().call("gelu", input, options); - } else if (op_type == "HardSigmoid") { - options.set("alpha", helper.Get("alpha", 0.2f)); - options.set("beta", helper.Get("beta", 0.5f)); - output = model_builder.GetBuilder().call("hardSigmoid", input, options); - } else if (op_type == "HardSwish") { - output = model_builder.GetBuilder().call("hardSwish", input); - } else if (op_type == "LeakyRelu") { - options.set("alpha", helper.Get("alpha", 0.0f)); - output = model_builder.GetBuilder().call("leakyRelu", input, options); - } else if (op_type == "Relu") { - output = model_builder.GetBuilder().call("relu", input); - } else if (op_type == "Sigmoid") { - output = model_builder.GetBuilder().call("sigmoid", input); - } else if (op_type == "Softplus") { - output = model_builder.GetBuilder().call("softplus", input); - } else if (op_type == "Softsign") { - output = model_builder.GetBuilder().call("softsign", input); - } else if (op_type == "Tanh") { - output = model_builder.GetBuilder().call("tanh", input); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); - } + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); } model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); diff --git a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc index 0d6001bcba..84a51b6679 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc @@ -52,13 +52,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, options.set("minValue", minValue); options.set("maxValue", maxValue); emscripten::val input = model_builder.GetOperand(input_name); - emscripten::val output = emscripten::val::object(); - if (Contains(model_builder.GetFusedActivations(), input_name)) { - LOGS_DEFAULT(VERBOSE) << "Clip Node [" << node.Name() << "] fused"; - output = input; - } else { - output = model_builder.GetBuilder().call("clamp", input, options); - } + emscripten::val output = model_builder.GetBuilder().call("clamp", input, options); model_builder.AddOperand(output_name, std::move(output)); return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 6b232a58aa..f79920b63a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -138,11 +138,6 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, options.set("bias", model_builder.GetOperand(input_defs[2]->Name())); } - emscripten::val activation = model_builder.FindActivation(node, *node.OutputDefs()[0]); - if (emscripten::val::null() != activation) { - options.set("activation", activation); - } - return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 3601d3b2b3..b6948b48c3 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -79,10 +79,7 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { options.set("axis", rank - 1); } - emscripten::val activation = model_builder.FindActivation(node, *node.OutputDefs()[0]); - if (emscripten::val::null() != activation) { - options.set("activation", activation); - } + output = model_builder.GetBuilder().call("batchNormalization", input, mean, variance, options); } else if (op_type == "LayerNormalization") { int64_t axis = helper.Get("axis", -1); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index d9acfa1f98..ff3e1a7179 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -31,7 +31,6 @@ ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logge Status ModelBuilder::Initialize() { PreprocessInitializers(); - PreprocessActivations(); ORT_RETURN_IF_ERROR(RegisterInitializers()); ORT_RETURN_IF_ERROR(RegisterModelInputs()); ORT_RETURN_IF_ERROR(AddOperations()); @@ -78,79 +77,6 @@ void ModelBuilder::PreprocessInitializers() { } } -void ModelBuilder::PreprocessActivations() { - const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder(); - - if (wnn_device_type_ == WebnnDeviceType::CPU) { - // WebNN CPU currently only supports "Relu" and "Clip" fusion. - supported_activation_nodes_ = {"Clip", "Relu"}; - } else { - supported_activation_nodes_ = { - // Temporarily disable clamp fusion for WebNN GPU as which is not supported yet. - // "Clip", - "Elu", - "Gelu", - "HardSigmoid", - "HardSwish", - "Relu", - "LeakyRelu", - "Sigmoid", - "Softplus", - "Softsign", - "Tanh", - }; - } - - for (size_t i = 0; i < node_indices.size(); i++) { - const auto* node(graph_viewer_.GetNode(node_indices[i])); - const auto& op_type(node->OpType()); - - // Ignore unsupported activation nodes. - if (!Contains(supported_activation_nodes_, op_type)) { - continue; - } - - if (op_type == "Clip") { - float minValue, maxValue; - GetClipMinMax(GetInitializerTensors(), *node, minValue, maxValue, logger_); - emscripten::val options = emscripten::val::object(); - options.set("minValue", minValue); - options.set("maxValue", maxValue); - activation_nodes_.emplace(node->Index(), wnn_builder_.call("clamp", options)); - } else if (op_type == "Elu") { - NodeAttrHelper helper(*node); - emscripten::val options = emscripten::val::object(); - options.set("alpha", helper.Get("alpha", 1.0f)); - activation_nodes_.emplace(node->Index(), wnn_builder_.call("elu", options)); - } else if (op_type == "Gelu") { - activation_nodes_.emplace(node->Index(), wnn_builder_.call("gelu")); - } else if (op_type == "HardSigmoid") { - NodeAttrHelper helper(*node); - emscripten::val options = emscripten::val::object(); - options.set("alpha", helper.Get("alpha", 0.2f)); - options.set("beta", helper.Get("beta", 0.5f)); - activation_nodes_.emplace(node->Index(), wnn_builder_.call("hardSigmoid", options)); - } else if (op_type == "HardSwish") { - activation_nodes_.emplace(node->Index(), wnn_builder_.call("hardSwish")); - } else if (op_type == "Relu") { - activation_nodes_.emplace(node->Index(), wnn_builder_.call("relu")); - } else if (op_type == "LeakyRelu") { - NodeAttrHelper helper(*node); - emscripten::val options = emscripten::val::object(); - options.set("alpha", helper.Get("alpha", 0.0f)); - activation_nodes_.emplace(node->Index(), wnn_builder_.call("leakyRelu", options)); - } else if (op_type == "Sigmoid") { - activation_nodes_.emplace(node->Index(), wnn_builder_.call("sigmoid")); - } else if (op_type == "Softplus") { - activation_nodes_.emplace(node->Index(), wnn_builder_.call("softplus")); - } else if (op_type == "Softsign") { - activation_nodes_.emplace(node->Index(), wnn_builder_.call("softsign")); - } else if (op_type == "Tanh") { - activation_nodes_.emplace(node->Index(), wnn_builder_.call("tanh")); - } - } -} - Status ModelBuilder::RegisterInitializers() { for (const auto& pair : GetInitializerTensors()) { const auto& tensor = *pair.second; @@ -421,44 +347,6 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { return Status::OK(); } -emscripten::val ModelBuilder::FindActivation(const Node& node, const NodeArg& output) { - emscripten::val fused_op = emscripten::val::null(); - for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) { - const auto& dst_node = it->GetNode(); - const auto* dst_input = dst_node.InputDefs()[it->GetDstArgIndex()]; - if (!Contains(supported_activation_nodes_, dst_node.OpType())) { - return emscripten::val::null(); - } - if (Contains(activation_nodes_, dst_node.Index())) { - if (&output == dst_input) { - fused_op = activation_nodes_.at(dst_node.Index()); - } - } else { - // If there is any other non-relu node using the output - // will add relu separately. - if (&output == dst_input) { - return emscripten::val::null(); - } - } - } - - // If output is a graph output, will add relu separately. - if (fused_op != emscripten::val::null()) { - for (const auto* graph_output : graph_viewer_.GetOutputs()) { - if (&output == graph_output) { - return emscripten::val::null(); - } - } - - LOGS_DEFAULT(VERBOSE) << "Node [" << node.Name() << "] type [" << node.OpType() - << "], fused the output [" << output.Name() << "]"; - - fused_activations_.insert(output.Name()); - } - - return fused_op; -} - void ModelBuilder::AddScalarOutput(const std::string& output_name) { scalar_outputs_.insert(output_name); } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index 2120309980..8c1848eb83 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -44,11 +44,6 @@ class ModelBuilder { Status AddOperandFromPersistMemoryBuffer( const std::string& name, const void* buffer, const size_t size, const std::vector shape, const int32_t data_type); - // Find if an output has a fuseable activation (e.g., Relu). - emscripten::val FindActivation(const Node& node, const NodeArg& output); - - const InlinedHashSet& - GetFusedActivations() const { return fused_activations_; } DataLayout GetPreferredLayout() const { return preferred_layout_; } @@ -82,22 +77,13 @@ class ModelBuilder { InlinedHashSet skipped_initializers_; InlinedHashSet skipped_inputs_; - InlinedHashSet fused_activations_; - - InlinedHashSet supported_activation_nodes_; - uint32_t name_token_{0}; InlinedHashSet unique_names_; - // All activation nodes (e.g., Relu) as a map . - InlinedHashMap activation_nodes_; - // Convert the onnx model to WebNN operands Status Initialize() ORT_MUST_USE_RESULT; void PreprocessInitializers(); - // Preprocess all the activation nodes (e.g., Relu) for easy query later. - void PreprocessActivations(); // Copy and process all the initializers to WebNN constants. Status RegisterInitializers() ORT_MUST_USE_RESULT;