mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
[WebNN EP] Support Trilu op (#20730)
Adds support for Trilu via WebNN Triangular op
This commit is contained in:
parent
33a68d221f
commit
cfe68e489e
5 changed files with 109 additions and 0 deletions
|
|
@ -88,5 +88,6 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
|
|||
| Tan | ai.onnx(7+) | tan | ✗ | ✓ | |
|
||||
| Tanh | ai.onnx(7-12, 13+) | tanh | ✓ | ✓ | |
|
||||
| Transpose | ai.onnx(7-12, 13-20, 21+) | transpose | ✓ | ✓ | |
|
||||
| Trilu | ai.onnx(14+) | triangular | ✗ | ✓ | Input 'k' (option 'diagonal' for WebNN) if present should be a constant |
|
||||
| Unsqueeze | ai.onnx(7-10, 11-12, 13-20, 21+) | reshape | ✓ | ✓ | |
|
||||
| Where | ai.onnx(7-8, 9-15, 16+) | where | ✗ | ✓ | |
|
||||
|
|
|
|||
|
|
@ -236,6 +236,7 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
|
|||
{"Tan", {"tan", false}},
|
||||
{"Tanh", {"tanh", true}},
|
||||
{"Transpose", {"transpose", true}},
|
||||
{"Trilu", {"triangular", false}},
|
||||
{"Unsqueeze", {"reshape", true}},
|
||||
{"Where", {"where", false}},
|
||||
};
|
||||
|
|
|
|||
|
|
@ -0,0 +1,103 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Copyright (c) Intel Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
#include "core/providers/webnn/builders/helper.h"
|
||||
#include "core/providers/webnn/builders/model_builder.h"
|
||||
#include "core/providers/webnn/builders/op_builder_factory.h"
|
||||
|
||||
#include "base_op_builder.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace webnn {
|
||||
|
||||
class TriangularOpBuilder : public BaseOpBuilder {
|
||||
// Add operator related.
|
||||
public:
|
||||
void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;
|
||||
|
||||
private:
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
|
||||
// Operator support related.
|
||||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
||||
void TriangularOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
|
||||
// Skip diagonal initializer if present.
|
||||
if (node.InputDefs().size() > 1) {
|
||||
model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name());
|
||||
}
|
||||
}
|
||||
|
||||
Status TriangularOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
||||
const Node& node,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& initializers = model_builder.GetInitializerTensors();
|
||||
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
|
||||
emscripten::val output = emscripten::val::object();
|
||||
NodeAttrHelper helper(node);
|
||||
emscripten::val options = emscripten::val::object();
|
||||
|
||||
const bool upper = helper.Get("upper", 1);
|
||||
options.set("upper", upper);
|
||||
|
||||
if (!GetTensorName(input_defs, 1).empty()) {
|
||||
// Optional input diagonal is provided, use diagonal initializer data.
|
||||
const auto diagonal_tensor = *initializers.at(input_defs[1]->Name());
|
||||
|
||||
std::vector<uint8_t> unpacked_tensor;
|
||||
ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(diagonal_tensor, unpacked_tensor));
|
||||
const auto diagonal = *reinterpret_cast<int64_t*>(unpacked_tensor.data());
|
||||
options.set("diagonal", narrow<int32_t>(diagonal));
|
||||
}
|
||||
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("triangular", input, options);
|
||||
|
||||
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Operator support related.
|
||||
bool TriangularOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
|
||||
const Node& node,
|
||||
const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
std::vector<int64_t> input_shape;
|
||||
if (!GetShape(*input_defs[0], input_shape, logger))
|
||||
return false;
|
||||
const auto input_size = input_shape.size();
|
||||
if (input_size < 2) {
|
||||
LOGS(logger, VERBOSE) << "Triangular only support input size >= 2d shape, input is "
|
||||
<< input_size << "d shape";
|
||||
return false;
|
||||
}
|
||||
|
||||
const std::string diagonal_name = GetTensorName(input_defs, 1);
|
||||
emscripten::val diagonal = emscripten::val::object();
|
||||
// Inputs contain optional 'diagonal' input.
|
||||
if (!diagonal_name.empty()) {
|
||||
if (!Contains(initializers, diagonal_name)) {
|
||||
LOGS(logger, VERBOSE) << "The diagonal must be a constant initializer.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateTriangularOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<TriangularOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
}
|
||||
|
||||
} // namespace webnn
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -182,6 +182,9 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
|
|||
CreateTransposeOpBuilder("Transpose", op_registrations);
|
||||
}
|
||||
|
||||
{ // Trilu
|
||||
CreateTriangularOpBuilder("Trilu", op_registrations);
|
||||
}
|
||||
return op_registrations;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ void CreateSplitOpBuilder(const std::string& op_type, OpBuilderRegistrations& op
|
|||
void CreateSqueezeUnsqueezeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateTriangularOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
|
||||
} // namespace webnn
|
||||
|
|
|
|||
Loading…
Reference in a new issue