mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
[WebNN EP] Support Where Op (#16380)
### Description Add Where Op for WebNN EP as ternary conditional operator. --------- Co-authored-by: Dwayne Robinson <dwayner@microsoft.com>
This commit is contained in:
parent
e2381c42f2
commit
64b22cd00f
4 changed files with 67 additions and 0 deletions
|
|
@ -140,6 +140,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
|
|||
{"Squeeze", "squeeze"},
|
||||
{"Transpose", "transpose"},
|
||||
{"Unsqueeze", "unsqueeze"},
|
||||
{"Where", "elementwiseIf"},
|
||||
};
|
||||
|
||||
inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder_) {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,61 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Copyright (c) Intel Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <core/providers/common.h>
|
||||
|
||||
#include "core/providers/webnn/builders/model_builder.h"
|
||||
#include "core/providers/webnn/builders/op_builder_factory.h"
|
||||
#include "core/providers/webnn/builders/helper.h"
|
||||
|
||||
#include "base_op_builder.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace webnn {
|
||||
|
||||
class TernaryOpBuilder : public BaseOpBuilder {
|
||||
// Add operator related.
|
||||
private:
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
||||
Status TernaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& /* logger */) const {
|
||||
const auto& op_type(node.OpType());
|
||||
ORT_RETURN_IF(node.InputDefs().size() < 3, "Operator requires at least three inputs");
|
||||
|
||||
emscripten::val input0 = model_builder.GetOperand(node.InputDefs()[0]->Name());
|
||||
emscripten::val input1 = model_builder.GetOperand(node.InputDefs()[1]->Name());
|
||||
emscripten::val input2 = model_builder.GetOperand(node.InputDefs()[2]->Name());
|
||||
emscripten::val output = emscripten::val::object();
|
||||
if (op_type == "Where") {
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("elementwiseIf", input0, input1, input2);
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"TernaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
|
||||
}
|
||||
|
||||
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
if (op_registrations.op_builder_map.find(op_type) != op_registrations.op_builder_map.cend())
|
||||
return;
|
||||
|
||||
static std::vector<std::string> op_types =
|
||||
{
|
||||
"Where",
|
||||
};
|
||||
|
||||
op_registrations.builders.push_back(std::make_unique<TernaryOpBuilder>());
|
||||
for (const auto& type : op_types) {
|
||||
op_registrations.op_builder_map.emplace(type, op_registrations.builders.back().get());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace webnn
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -37,6 +37,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
|
|||
CreateBinaryOpBuilder("Pow", op_registrations);
|
||||
}
|
||||
|
||||
{ // Ternary
|
||||
CreateTernaryOpBuilder("Where", op_registrations);
|
||||
}
|
||||
|
||||
{ // Activations
|
||||
CreateActivationOpBuilder("Relu", op_registrations);
|
||||
CreateActivationOpBuilder("LeakyRelu", op_registrations);
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op
|
|||
void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateSplitOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateSqueezeUnsqueezeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue