mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[WebNN EP] Support GatherND and ScatterND op (#22181)
This commit is contained in:
parent
975d3dffcf
commit
dcf91266bd
7 changed files with 189 additions and 6 deletions
|
|
@ -35,6 +35,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
|
|||
| Flatten | ai.onnx(7-8, 9-10, 11-12, 13-20, 21+) | reshape | ✓ | ✓ | |
|
||||
| Floor | ai.onnx(7-12, 13+) | floor | ✓ | ✓ | |
|
||||
| Gather | ai.onnx(7-10, 11-12, 13+) | gather | ✓ | ✓ | |
|
||||
| GatherND | ai.onnx(11, 12, 13+) | gatherND | ✓ | ✓ | Only supports 'batch_dims' == 0 |
|
||||
| Gelu | ai.onnx(20+) | gelu | ✓ | ✓ | |
|
||||
| Gemm | ai.onnx(7-8, 9-10, 11-12, 13+) | gemm | ✓ | ✓ | Only supports 1-D 'C' input |
|
||||
| GlobalAveragePool | ai.onnx(7+) | averagePool2d | ✓ | ✓ | Only supports 4-D input |
|
||||
|
|
@ -79,6 +80,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
|
|||
| Relu | ai.onnx(7-12, 13, 14+) | relu | ✓ | ✓ | |
|
||||
| Reshape | ai.onnx(7-12, 13, 14-18, 19-20, 21+) | reshape | ✓ | ✓ | Input 'shape' should be a constant, 0 dimension value in 'shape' is not supported |
|
||||
| Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d | ✓ | ✓ | Only supports 4-D input, antialias == 0, coordinate_transformation_mode == 'half_pixel', exclude_outside == 0, keep_aspect_ratio_policy == 'stretch', 'linear' and 'nearest' modes, input 'scales' and 'sizes' if present must be a constant |
|
||||
| ScatterND | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterND | ✗ | ✓ | Only supports 'reduction' == 'none' |
|
||||
| Shape | ai.onnx(7-12, 13-14, 15-18, 19-20, 21+) | slice | ✓ | ✓ | |
|
||||
| Sigmoid | ai.onnx(7-12, 13+) | sigmoid | ✓ | ✓ | |
|
||||
| Softplus | ai.onnx(7+) | softplus | ✓ | ✓ | |
|
||||
|
|
|
|||
|
|
@ -1777,9 +1777,9 @@
|
|||
"test_gather_elements_0",
|
||||
"test_gather_elements_1",
|
||||
"test_gather_elements_negative_indices",
|
||||
// "test_gathernd_example_float32",
|
||||
// "test_gathernd_example_int32_batch_dim1",
|
||||
// "test_gathernd_example_int32",
|
||||
"test_gathernd_example_float32",
|
||||
"test_gathernd_example_int32_batch_dim1",
|
||||
"test_gathernd_example_int32",
|
||||
"test_gemm_all_attributes",
|
||||
"test_gemm_alpha",
|
||||
"test_gemm_beta",
|
||||
|
|
@ -2260,9 +2260,9 @@
|
|||
// // "test_scatter_elements_without_axis",
|
||||
// // "test_scatter_with_axis",
|
||||
// // "test_scatter_without_axis",
|
||||
// // "test_scatternd_add",
|
||||
// // "test_scatternd_multiply",
|
||||
// // "test_scatternd",
|
||||
"test_scatternd_add",
|
||||
"test_scatternd_multiply",
|
||||
"test_scatternd",
|
||||
// // "test_sce_mean_3d_expanded",
|
||||
// // "test_sce_mean_3d_log_prob_expanded",
|
||||
// // "test_sce_mean_3d_log_prob",
|
||||
|
|
|
|||
|
|
@ -215,6 +215,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
|
|||
{"Flatten", "reshape"},
|
||||
{"Floor", "floor"},
|
||||
{"Gather", "gather"},
|
||||
{"GatherND", "gatherND"},
|
||||
{"Gelu", "gelu"},
|
||||
{"Gemm", "gemm"},
|
||||
{"GlobalAveragePool", "averagePool2d"},
|
||||
|
|
@ -260,6 +261,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
|
|||
{"Relu", "relu"},
|
||||
{"Reshape", "reshape"},
|
||||
{"Resize", "resample2d"},
|
||||
{"ScatterND", "scatterND"},
|
||||
{"Shape", "slice"},
|
||||
{"Sigmoid", "sigmoid"},
|
||||
{"Softplus", "softplus"},
|
||||
|
|
|
|||
|
|
@ -0,0 +1,80 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Copyright (c) Intel Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
#include "core/providers/webnn/builders/helper.h"
|
||||
#include "core/providers/webnn/builders/model_builder.h"
|
||||
#include "core/providers/webnn/builders/op_builder_factory.h"
|
||||
|
||||
#include "base_op_builder.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace webnn {
|
||||
|
||||
class GatherNDOpBuilder : public BaseOpBuilder {
|
||||
// Add operator related.
|
||||
private:
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
|
||||
// Operator support related.
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
||||
Status GatherNDOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
emscripten::val data = model_builder.GetOperand(input_defs[0]->Name());
|
||||
emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name());
|
||||
emscripten::val options = emscripten::val::object();
|
||||
options.set("label", node.Name());
|
||||
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("gatherND", data, indices, options);
|
||||
|
||||
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Operator support related.
|
||||
|
||||
bool GatherNDOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const {
|
||||
NodeAttrHelper helper(node);
|
||||
if (helper.Get("batch_dims", 0) != 0) {
|
||||
LOGS(logger, VERBOSE) << "GatherND: WebNN only supports batch_dims 0 (default)";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GatherNDOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& data = *node.InputDefs()[0];
|
||||
const auto& indices = *node.InputDefs()[1];
|
||||
const auto& op_type = node.OpType();
|
||||
|
||||
int32_t data_type;
|
||||
int32_t indices_type;
|
||||
if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) &&
|
||||
IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger);
|
||||
}
|
||||
|
||||
void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<GatherNDOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
}
|
||||
|
||||
} // namespace webnn
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Copyright (c) Intel Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
#include "core/providers/webnn/builders/helper.h"
|
||||
#include "core/providers/webnn/builders/model_builder.h"
|
||||
#include "core/providers/webnn/builders/op_builder_factory.h"
|
||||
|
||||
#include "base_op_builder.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace webnn {
|
||||
|
||||
class ScatterNDOpBuilder : public BaseOpBuilder {
|
||||
// Add operator related.
|
||||
private:
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
|
||||
// Operator support related.
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
||||
Status ScatterNDOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
emscripten::val data = model_builder.GetOperand(input_defs[0]->Name());
|
||||
emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name());
|
||||
emscripten::val updates = model_builder.GetOperand(input_defs[2]->Name());
|
||||
emscripten::val options = emscripten::val::object();
|
||||
options.set("label", node.Name());
|
||||
emscripten::val output =
|
||||
model_builder.GetBuilder().call<emscripten::val>("scatterND", data, indices, updates, options);
|
||||
|
||||
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Operator support related.
|
||||
|
||||
bool ScatterNDOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const {
|
||||
NodeAttrHelper helper(node);
|
||||
if (helper.Get("reduction", "none") != "none") {
|
||||
LOGS(logger, VERBOSE) << "ScatterND: WebNN only supports reduction type none (default)";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ScatterNDOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& data = *node.InputDefs()[0];
|
||||
const auto& indices = *node.InputDefs()[1];
|
||||
const auto& updates = *node.InputDefs()[2];
|
||||
const auto& op_type = node.OpType();
|
||||
|
||||
int32_t data_type;
|
||||
int32_t indices_type;
|
||||
int32_t updates_type;
|
||||
if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger) ||
|
||||
!GetType(updates, updates_type, logger)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (data_type != updates_type) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) &&
|
||||
IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger);
|
||||
}
|
||||
|
||||
void CreateScatterNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<ScatterNDOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
}
|
||||
|
||||
} // namespace webnn
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -98,6 +98,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
|
|||
CreateGatherOpBuilder("Gather", op_registrations);
|
||||
}
|
||||
|
||||
{ // GatherND
|
||||
CreateGatherNDOpBuilder("GatherND", op_registrations);
|
||||
}
|
||||
|
||||
{ // Flatten
|
||||
CreateFlattenOpBuilder("Flatten", op_registrations);
|
||||
}
|
||||
|
|
@ -170,6 +174,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
|
|||
CreateResizeOpBuilder("Resize", op_registrations);
|
||||
}
|
||||
|
||||
{ // ScatterND
|
||||
CreateScatterNDOpBuilder("ScatterND", op_registrations);
|
||||
}
|
||||
|
||||
{ // Shape
|
||||
CreateShapeOpBuilder("Shape", op_registrations);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ void CreateDynamicQuantizeLinearOpBuilder(const std::string& op_type, OpBuilderR
|
|||
void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
|
|
@ -43,6 +44,7 @@ void CreateQDQOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_r
|
|||
void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateScatterNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateShapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
|
|
|
|||
Loading…
Reference in a new issue