From ea43671eb6a55702a2faa3bb3d8eee4c48cac62b Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Sat, 15 Jul 2023 05:36:15 +0800 Subject: [PATCH] [WebNN EP] Support several activation ops (#16693) Support Elu, HardSigmoid, HardSwish, Softplus, Softsign, Tanh. --- .../core/providers/webnn/builders/helper.h | 58 ++++++++------- .../builders/impl/activation_op_builder.cc | 73 ++++++++++++------- .../providers/webnn/builders/model_builder.cc | 42 +++++++---- .../webnn/builders/op_builder_factory.cc | 8 +- 4 files changed, 112 insertions(+), 69 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index b4b7cb175a..0a8fb6bf1d 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -93,55 +93,61 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v const logging::Logger& logger); static const InlinedHashMap op_map = { {"Abs", "abs"}, + {"Add", "add"}, {"ArgMax", "argMax"}, {"ArgMin", "argMin"}, - {"Add", "add"}, - {"Sub", "sub"}, - {"Mul", "mul"}, - {"Div", "div"}, - {"Pow", "pow"}, + {"AveragePool", "averagePool2d"}, + {"Cast", "cast"}, {"Ceil", "ceil"}, + {"Clip", "clamp"}, + {"Concat", "concat"}, + {"Conv", "conv2d"}, + {"ConvTranspose", "convTranspose2d"}, {"Cos", "cos"}, + {"Div", "div"}, + {"Elu", "elu"}, {"Equal", "equal"}, {"Erf", "erf"}, {"Exp", "exp"}, - {"Neg", "neg"}, - {"Not", "logicalNot"}, - {"Floor", "floor"}, - {"Flatten", "flattenTo2d"}, - {"Identity", "identity"}, - {"Reciprocal", "reciprocal"}, - {"Sin", "sin"}, - {"Sqrt", "sqrt"}, - {"Tan", "tan"}, - {"Relu", "relu"}, - {"LeakyRelu", "leakyRelu"}, - {"Sigmoid", "sigmoid"}, - {"Slice", "slice"}, - {"Softmax", "softmax"}, - {"Cast", "cast"}, - {"Clip", "clamp"}, - {"Conv", "conv2d"}, - {"ConvTranspose", "convTranspose2d"}, - {"Concat", "concat"}, {"Expand", "expand"}, + {"Flatten", "flattenTo2d"}, + {"Floor", "floor"}, {"Gather", "gather"}, {"Gemm", "gemm"}, - {"MatMul", "matmul"}, {"GlobalAveragePool", "averagePool2d"}, {"GlobalMaxPool", "maxPool2d"}, - {"AveragePool", "averagePool2d"}, {"GroupNormalization", "meanVarianceNormalization"}, + {"HardSigmoid", "hardSigmoid"}, + {"HardSwish", "hardSwish"}, + {"Identity", "identity"}, {"InstanceNormalization", "meanVarianceNormalization"}, {"LayerNormalization", "meanVarianceNormalization"}, + {"LeakyRelu", "leakyRelu"}, + {"MatMul", "matmul"}, {"MaxPool", "maxPool2d"}, + {"Mul", "mul"}, + {"Neg", "neg"}, + {"Not", "logicalNot"}, + {"Pow", "pow"}, + {"Reciprocal", "reciprocal"}, {"ReduceMax", "reduceMax"}, {"ReduceMean", "reduceMean"}, + {"Relu", "relu"}, {"Reshape", "reshape"}, {"Resize", "resample2d"}, {"Shape", "slice"}, + {"Sigmoid", "sigmoid"}, + {"Softplus", "softplus"}, + {"Softsign", "softsign"}, + {"Sin", "sin"}, + {"Slice", "slice"}, + {"Softmax", "softmax"}, {"Split", "split"}, + {"Sqrt", "sqrt"}, {"Squeeze", "squeeze"}, + {"Sub", "sub"}, + {"Tan", "tan"}, + {"Tanh", "tanh"}, {"Transpose", "transpose"}, {"Unsqueeze", "unsqueeze"}, {"Where", "elementwiseIf"}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc index 0dcbf9b527..cb2ff135a7 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc @@ -31,33 +31,39 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& op_type(node.OpType()); emscripten::val input = model_builder.GetOperand(node.InputDefs()[0]->Name()); emscripten::val output = emscripten::val::object(); - if (op_type == "Relu") { - if (Contains(model_builder.GetFusedActivations(), node.InputDefs()[0]->Name())) { - LOGS_DEFAULT(VERBOSE) << "Relu Node [" << node.Name() << "] fused"; - output = input; - } else { - output = model_builder.GetBuilder().call("relu", input); - } - } else if (op_type == "LeakyRelu") { - if (Contains(model_builder.GetFusedActivations(), node.InputDefs()[0]->Name())) { - LOGS_DEFAULT(VERBOSE) << "LeakyRelu Node [" << node.Name() << "] fused"; - output = input; - } else { - NodeAttrHelper helper(node); - emscripten::val options = emscripten::val::object(); - options.set("alpha", helper.Get("alpha", (float)0.0)); - output = model_builder.GetBuilder().call("leakyRelu", input, options); - } - } else if (op_type == "Sigmoid") { - if (Contains(model_builder.GetFusedActivations(), node.InputDefs()[0]->Name())) { - LOGS_DEFAULT(VERBOSE) << "Sigmoid Node [" << node.Name() << "] fused"; - output = input; - } else { - output = model_builder.GetBuilder().call("sigmoid", input); - } + + if (Contains(model_builder.GetFusedActivations(), node.InputDefs()[0]->Name())) { + LOGS_DEFAULT(VERBOSE) << op_type << " Node [" << node.Name() << "] fused"; + output = input; } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); + NodeAttrHelper helper(node); + emscripten::val options = emscripten::val::object(); + if (op_type == "Elu") { + options.set("alpha", helper.Get("alpha", 1.0f)); + output = model_builder.GetBuilder().call("elu", input, options); + } else if (op_type == "HardSigmoid") { + options.set("alpha", helper.Get("alpha", 0.2f)); + options.set("beta", helper.Get("beta", 0.5f)); + output = model_builder.GetBuilder().call("hardSigmoid", input, options); + } else if (op_type == "HardSwish") { + output = model_builder.GetBuilder().call("hardSwish", input); + } else if (op_type == "LeakyRelu") { + options.set("alpha", helper.Get("alpha", 0.0f)); + output = model_builder.GetBuilder().call("leakyRelu", input, options); + } else if (op_type == "Relu") { + output = model_builder.GetBuilder().call("relu", input); + } else if (op_type == "Sigmoid") { + output = model_builder.GetBuilder().call("sigmoid", input); + } else if (op_type == "Softplus") { + output = model_builder.GetBuilder().call("softplus", input); + } else if (op_type == "Softsign") { + output = model_builder.GetBuilder().call("softsign", input); + } else if (op_type == "Tanh") { + output = model_builder.GetBuilder().call("tanh", input); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); + } } model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); @@ -67,7 +73,7 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // Operator support related. int ActivationOpBuilder::GetMinSupportedOpSet(const Node& /* node */) const { - // All ops opset 5- uses consumed_inputs attribute which is not supported for now. + // Any operators < opset 6 used the deprecated "consumed_inputs attribute" which is unsupported. return 6; } @@ -75,7 +81,18 @@ void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistration if (op_registrations.op_builder_map.find(op_type) != op_registrations.op_builder_map.cend()) return; - static std::vector op_types = {"Relu", "LeakyRelu", "Sigmoid"}; + static std::vector op_types = + { + "Elu", + "HardSigmoid", + "HardSwish", + "LeakyRelu", + "Relu", + "Sigmoid", + "Softplus", + "Softsign", + "Tanh", + }; op_registrations.builders.push_back(std::make_unique()); for (const auto& type : op_types) { diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 098193bc63..14ca4f1a1e 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -57,33 +57,47 @@ void ModelBuilder::PreprocessInitializers() { } } -emscripten::val GetClampOperator( - const emscripten::val& builder, float min_value, float max_value) { - emscripten::val options = emscripten::val::object(); - options.set("minValue", min_value); - options.set("maxValue", max_value); - return builder.call("clamp", options); -} - void ModelBuilder::PreprocessActivations() { const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder(); for (size_t i = 0; i < node_indices.size(); i++) { const auto* node(graph_viewer_.GetNode(node_indices[i])); const auto& op_type(node->OpType()); - if (op_type == "Relu") { + if (op_type == "Clip") { + float minValue, maxValue; + GetClipMinMax(GetInitializerTensors(), *node, minValue, maxValue, logger_); + emscripten::val options = emscripten::val::object(); + options.set("minValue", minValue); + options.set("maxValue", maxValue); + activation_nodes_.emplace(node->Index(), wnn_builder_.call("clamp", options)); + } else if (op_type == "Elu") { + NodeAttrHelper helper(*node); + emscripten::val options = emscripten::val::object(); + options.set("alpha", helper.Get("alpha", 1.0f)); + activation_nodes_.emplace(node->Index(), wnn_builder_.call("elu", options)); + } else if (op_type == "HardSigmoid") { + NodeAttrHelper helper(*node); + emscripten::val options = emscripten::val::object(); + options.set("alpha", helper.Get("alpha", 0.2f)); + options.set("beta", helper.Get("beta", 0.5f)); + activation_nodes_.emplace(node->Index(), wnn_builder_.call("hardSigmoid", options)); + } else if (op_type == "HardSwish") { + activation_nodes_.emplace(node->Index(), wnn_builder_.call("hardSwish")); + } else if (op_type == "Relu") { activation_nodes_.emplace(node->Index(), wnn_builder_.call("relu")); } else if (op_type == "LeakyRelu") { NodeAttrHelper helper(*node); emscripten::val options = emscripten::val::object(); - options.set("alpha", helper.Get("alpha", (float)0.0)); + options.set("alpha", helper.Get("alpha", 0.0f)); activation_nodes_.emplace(node->Index(), wnn_builder_.call("leakyRelu", options)); } else if (op_type == "Sigmoid") { activation_nodes_.emplace(node->Index(), wnn_builder_.call("sigmoid")); - } else if (op_type == "Clip") { - float minValue, maxValue; - GetClipMinMax(GetInitializerTensors(), *node, minValue, maxValue, logger_); - activation_nodes_.emplace(node->Index(), GetClampOperator(wnn_builder_, minValue, maxValue)); + } else if (op_type == "Softplus") { + activation_nodes_.emplace(node->Index(), wnn_builder_.call("softplus")); + } else if (op_type == "Softsign") { + activation_nodes_.emplace(node->Index(), wnn_builder_.call("softsign")); + } else if (op_type == "Tanh") { + activation_nodes_.emplace(node->Index(), wnn_builder_.call("tanh")); } } } diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index b334fdc638..416c9e1bf9 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -44,9 +44,15 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { } { // Activations - CreateActivationOpBuilder("Relu", op_registrations); + CreateActivationOpBuilder("Elu", op_registrations); + CreateActivationOpBuilder("HardSigmoid", op_registrations); + CreateActivationOpBuilder("HardSwish", op_registrations); CreateActivationOpBuilder("LeakyRelu", op_registrations); + CreateActivationOpBuilder("Relu", op_registrations); CreateActivationOpBuilder("Sigmoid", op_registrations); + CreateActivationOpBuilder("Softplus", op_registrations); + CreateActivationOpBuilder("Softsign", op_registrations); + CreateActivationOpBuilder("Tanh", op_registrations); } { // ArgMax/ArgMin