From cfe68e489e2f9a7932023e27550d1f564efff58a Mon Sep 17 00:00:00 2001 From: Peishen Yan Date: Sat, 25 May 2024 01:46:54 +0800 Subject: [PATCH] [WebNN EP] Support Trilu op (#20730) Adds support for Trilu via WebNN Triangular op --- js/web/docs/webnn-operators.md | 1 + .../core/providers/webnn/builders/helper.h | 1 + .../builders/impl/triangular_op_builder.cc | 103 ++++++++++++++++++ .../webnn/builders/op_builder_factory.cc | 3 + .../webnn/builders/op_builder_factory.h | 1 + 5 files changed, 109 insertions(+) create mode 100644 onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 98e0a11c57..bcabb6896f 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -88,5 +88,6 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Tan | ai.onnx(7+) | tan | ✗ | ✓ | | | Tanh | ai.onnx(7-12, 13+) | tanh | ✓ | ✓ | | | Transpose | ai.onnx(7-12, 13-20, 21+) | transpose | ✓ | ✓ | | +| Trilu | ai.onnx(14+) | triangular | ✗ | ✓ | Input 'k' (option 'diagonal' for WebNN) if present should be a constant | | Unsqueeze | ai.onnx(7-10, 11-12, 13-20, 21+) | reshape | ✓ | ✓ | | | Where | ai.onnx(7-8, 9-15, 16+) | where | ✗ | ✓ | | diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 58fb344c60..9cec78bbfe 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -236,6 +236,7 @@ static const InlinedHashMap op_map = { {"Tan", {"tan", false}}, {"Tanh", {"tanh", true}}, {"Transpose", {"transpose", true}}, + {"Trilu", {"triangular", false}}, {"Unsqueeze", {"reshape", true}}, {"Where", {"where", false}}, }; diff --git a/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc new file mode 100644 index 0000000000..caf3979d61 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class TriangularOpBuilder : public BaseOpBuilder { + // Add operator related. + public: + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; +}; + +// Add operator related. + +void TriangularOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + // Skip diagonal initializer if present. + if (node.InputDefs().size() > 1) { + model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); + } +} + +Status TriangularOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& initializers = model_builder.GetInitializerTensors(); + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val output = emscripten::val::object(); + NodeAttrHelper helper(node); + emscripten::val options = emscripten::val::object(); + + const bool upper = helper.Get("upper", 1); + options.set("upper", upper); + + if (!GetTensorName(input_defs, 1).empty()) { + // Optional input diagonal is provided, use diagonal initializer data. + const auto diagonal_tensor = *initializers.at(input_defs[1]->Name()); + + std::vector unpacked_tensor; + ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(diagonal_tensor, unpacked_tensor)); + const auto diagonal = *reinterpret_cast(unpacked_tensor.data()); + options.set("diagonal", narrow(diagonal)); + } + + output = model_builder.GetBuilder().call("triangular", input, options); + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. +bool TriangularOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, + const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + const auto input_size = input_shape.size(); + if (input_size < 2) { + LOGS(logger, VERBOSE) << "Triangular only support input size >= 2d shape, input is " + << input_size << "d shape"; + return false; + } + + const std::string diagonal_name = GetTensorName(input_defs, 1); + emscripten::val diagonal = emscripten::val::object(); + // Inputs contain optional 'diagonal' input. + if (!diagonal_name.empty()) { + if (!Contains(initializers, diagonal_name)) { + LOGS(logger, VERBOSE) << "The diagonal must be a constant initializer."; + return false; + } + } + return true; +} + +void CreateTriangularOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index c39a5510cf..dfe015725c 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -182,6 +182,9 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateTransposeOpBuilder("Transpose", op_registrations); } + { // Trilu + CreateTriangularOpBuilder("Trilu", op_registrations); + } return op_registrations; } diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h index a50a7318e3..818ff094fb 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h @@ -47,6 +47,7 @@ void CreateSplitOpBuilder(const std::string& op_type, OpBuilderRegistrations& op void CreateSqueezeUnsqueezeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateTriangularOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); } // namespace webnn