From 64b22cd00f3af5c07552188773f4ec159f68044b Mon Sep 17 00:00:00 2001 From: zesongw Date: Wed, 21 Jun 2023 07:45:40 +0800 Subject: [PATCH] [WebNN EP] Support Where Op (#16380) ### Description Add Where Op for WebNN EP as ternary conditional operator. --------- Co-authored-by: Dwayne Robinson --- .../core/providers/webnn/builders/helper.h | 1 + .../webnn/builders/impl/ternary_op_builder.cc | 61 +++++++++++++++++++ .../webnn/builders/op_builder_factory.cc | 4 ++ .../webnn/builders/op_builder_factory.h | 1 + 4 files changed, 67 insertions(+) create mode 100644 onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 728317f595..7b11f852b8 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -140,6 +140,7 @@ static const InlinedHashMap op_map = { {"Squeeze", "squeeze"}, {"Transpose", "transpose"}, {"Unsqueeze", "unsqueeze"}, + {"Where", "elementwiseIf"}, }; inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder_) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc new file mode 100644 index 0000000000..b5e5600af0 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" +#include "core/providers/webnn/builders/helper.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class TernaryOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; +}; + +// Add operator related. + +Status TernaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& /* logger */) const { + const auto& op_type(node.OpType()); + ORT_RETURN_IF(node.InputDefs().size() < 3, "Operator requires at least three inputs"); + + emscripten::val input0 = model_builder.GetOperand(node.InputDefs()[0]->Name()); + emscripten::val input1 = model_builder.GetOperand(node.InputDefs()[1]->Name()); + emscripten::val input2 = model_builder.GetOperand(node.InputDefs()[2]->Name()); + emscripten::val output = emscripten::val::object(); + if (op_type == "Where") { + output = model_builder.GetBuilder().call("elementwiseIf", input0, input1, input2); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "TernaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); + } + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + if (op_registrations.op_builder_map.find(op_type) != op_registrations.op_builder_map.cend()) + return; + + static std::vector op_types = + { + "Where", + }; + + op_registrations.builders.push_back(std::make_unique()); + for (const auto& type : op_types) { + op_registrations.op_builder_map.emplace(type, op_registrations.builders.back().get()); + } +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 77bc21561a..0a84fe07b8 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -37,6 +37,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateBinaryOpBuilder("Pow", op_registrations); } + { // Ternary + CreateTernaryOpBuilder("Where", op_registrations); + } + { // Activations CreateActivationOpBuilder("Relu", op_registrations); CreateActivationOpBuilder("LeakyRelu", op_registrations); diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h index e51d212e51..2dde66dff4 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h @@ -41,6 +41,7 @@ void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateSplitOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateSqueezeUnsqueezeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);