mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
[WebNN EP] Support Tile operator (#22148)
PTAL, thanks! @Honry , @fdwr thanks!
This commit is contained in:
parent
98a75900ef
commit
004bd36f3d
6 changed files with 107 additions and 2 deletions
|
|
@ -92,6 +92,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
|
|||
| Sub | ai.onnx(7-12, 13, 14+) | sub | ✓ | ✓ | |
|
||||
| Tan | ai.onnx(7+) | tan | ✓ | ✓ | |
|
||||
| Tanh | ai.onnx(7-12, 13+) | tanh | ✓ | ✓ | |
|
||||
| Tile | ai.onnx(7-12, 13+) | tile | ✗ | ✓ | Input 'repeats' should be a constant |
|
||||
| Transpose | ai.onnx(7-12, 13-20, 21+) | transpose | ✓ | ✓ | |
|
||||
| Trilu | ai.onnx(14+) | triangular | ✓ | ✓ | Input 'k' (option 'diagonal' for WebNN) if present should be a constant |
|
||||
| Unsqueeze | ai.onnx(7-10, 11-12, 13-20, 21+) | reshape | ✓ | ✓ | |
|
||||
|
|
|
|||
|
|
@ -2498,8 +2498,8 @@
|
|||
// "test_thresholdedrelu_default",
|
||||
// "test_thresholdedrelu_example",
|
||||
// "test_thresholdedrelu",
|
||||
// "test_tile_precomputed",
|
||||
// "test_tile",
|
||||
"test_tile_precomputed",
|
||||
"test_tile",
|
||||
// // "test_top_k_negative_axis",
|
||||
// // "test_top_k_smallest",
|
||||
// // "test_top_k",
|
||||
|
|
|
|||
|
|
@ -237,6 +237,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
|
|||
{"Sub", "sub"},
|
||||
{"Tan", "tan"},
|
||||
{"Tanh", "tanh"},
|
||||
{"Tile", "tile"},
|
||||
{"Transpose", "transpose"},
|
||||
{"Trilu", "triangular"},
|
||||
{"Unsqueeze", "reshape"},
|
||||
|
|
|
|||
|
|
@ -0,0 +1,98 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Copyright (c) Intel Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
#include "core/providers/webnn/builders/helper.h"
|
||||
#include "core/providers/webnn/builders/model_builder.h"
|
||||
#include "core/providers/webnn/builders/op_builder_factory.h"
|
||||
|
||||
#include "base_op_builder.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace webnn {
|
||||
|
||||
class TileOpBuilder : public BaseOpBuilder {
|
||||
// Add operator related.
|
||||
public:
|
||||
void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;
|
||||
|
||||
private:
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
|
||||
// Operator support related.
|
||||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
||||
void TileOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
|
||||
model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name());
|
||||
}
|
||||
|
||||
Status TileOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
||||
const Node& node,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& initializers(model_builder.GetInitializerTensors());
|
||||
const auto& repetitions_initializer = *initializers.at(input_defs[1]->Name());
|
||||
const int64_t* raw_repetitions_data = repetitions_initializer.int64_data().empty()
|
||||
? reinterpret_cast<const int64_t*>(repetitions_initializer.raw_data().data())
|
||||
: repetitions_initializer.int64_data().data();
|
||||
const auto size = repetitions_initializer.dims()[0];
|
||||
TensorShapeVector repetitions_data{raw_repetitions_data, raw_repetitions_data + size};
|
||||
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
|
||||
std::vector<uint32_t> repetitions;
|
||||
std::transform(repetitions_data.cbegin(), repetitions_data.cend(),
|
||||
std::back_inserter(repetitions),
|
||||
[](int64_t repetition) -> uint32_t { return SafeInt<uint32_t>(repetition); });
|
||||
|
||||
emscripten::val options = emscripten::val::object();
|
||||
options.set("label", node.Name());
|
||||
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("tile",
|
||||
input,
|
||||
emscripten::val::array(repetitions),
|
||||
options);
|
||||
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Operator support related.
|
||||
|
||||
bool TileOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
|
||||
const Node& node,
|
||||
const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& repetitions_name = input_defs[1]->Name();
|
||||
if (!Contains(initializers, repetitions_name)) {
|
||||
LOGS(logger, VERBOSE) << "Repetitions of tile must be a constant initializer";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<int64_t> input_shape;
|
||||
if (!GetShape(*input_defs[0], input_shape, logger))
|
||||
return false;
|
||||
|
||||
if (input_shape.empty()) {
|
||||
LOGS(logger, VERBOSE) << "Tile does not support empty input shape";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateTileOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<TileOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
}
|
||||
|
||||
} // namespace webnn
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -191,6 +191,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
|
|||
CreateSqueezeUnsqueezeOpBuilder("Unsqueeze", op_registrations);
|
||||
}
|
||||
|
||||
{ // Tile
|
||||
CreateTileOpBuilder("Tile", op_registrations);
|
||||
}
|
||||
|
||||
{ // Transpose
|
||||
CreateTransposeOpBuilder("Transpose", op_registrations);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations&
|
|||
void CreateSplitOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateSqueezeUnsqueezeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateTileOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateTriangularOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
|
|
|
|||
Loading…
Reference in a new issue