From dcf91266bd2840dfd1df4a994916a4524f4e868d Mon Sep 17 00:00:00 2001 From: shiyi Date: Tue, 29 Oct 2024 06:04:45 +0800 Subject: [PATCH] [WebNN EP] Support GatherND and ScatterND op (#22181) --- js/web/docs/webnn-operators.md | 2 + js/web/test/suite-test-list.jsonc | 12 +-- .../core/providers/webnn/builders/helper.h | 2 + .../builders/impl/gatherND_op_builder.cc | 80 +++++++++++++++++ .../builders/impl/scatterND_op_builder.cc | 89 +++++++++++++++++++ .../webnn/builders/op_builder_factory.cc | 8 ++ .../webnn/builders/op_builder_factory.h | 2 + 7 files changed, 189 insertions(+), 6 deletions(-) create mode 100644 onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc create mode 100644 onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index bf0f1dffb8..ecbfe2e9c4 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -35,6 +35,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Flatten | ai.onnx(7-8, 9-10, 11-12, 13-20, 21+) | reshape | ✓ | ✓ | | | Floor | ai.onnx(7-12, 13+) | floor | ✓ | ✓ | | | Gather | ai.onnx(7-10, 11-12, 13+) | gather | ✓ | ✓ | | +| GatherND | ai.onnx(11, 12, 13+) | gatherND | ✓ | ✓ | Only supports 'batch_dims' == 0 | | Gelu | ai.onnx(20+) | gelu | ✓ | ✓ | | | Gemm | ai.onnx(7-8, 9-10, 11-12, 13+) | gemm | ✓ | ✓ | Only supports 1-D 'C' input | | GlobalAveragePool | ai.onnx(7+) | averagePool2d | ✓ | ✓ | Only supports 4-D input | @@ -79,6 +80,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Relu | ai.onnx(7-12, 13, 14+) | relu | ✓ | ✓ | | | Reshape | ai.onnx(7-12, 13, 14-18, 19-20, 21+) | reshape | ✓ | ✓ | Input 'shape' should be a constant, 0 dimension value in 'shape' is not supported | | Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d | ✓ | ✓ | Only supports 4-D input, antialias == 0, coordinate_transformation_mode == 'half_pixel', exclude_outside == 0, keep_aspect_ratio_policy == 'stretch', 'linear' and 'nearest' modes, input 'scales' and 'sizes' if present must be a constant | +| ScatterND | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterND | ✗ | ✓ | Only supports 'reduction' == 'none' | | Shape | ai.onnx(7-12, 13-14, 15-18, 19-20, 21+) | slice | ✓ | ✓ | | | Sigmoid | ai.onnx(7-12, 13+) | sigmoid | ✓ | ✓ | | | Softplus | ai.onnx(7+) | softplus | ✓ | ✓ | | diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index dcfc8ccc39..d81d741aa0 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1777,9 +1777,9 @@ "test_gather_elements_0", "test_gather_elements_1", "test_gather_elements_negative_indices", - // "test_gathernd_example_float32", - // "test_gathernd_example_int32_batch_dim1", - // "test_gathernd_example_int32", + "test_gathernd_example_float32", + "test_gathernd_example_int32_batch_dim1", + "test_gathernd_example_int32", "test_gemm_all_attributes", "test_gemm_alpha", "test_gemm_beta", @@ -2260,9 +2260,9 @@ // // "test_scatter_elements_without_axis", // // "test_scatter_with_axis", // // "test_scatter_without_axis", - // // "test_scatternd_add", - // // "test_scatternd_multiply", - // // "test_scatternd", + "test_scatternd_add", + "test_scatternd_multiply", + "test_scatternd", // // "test_sce_mean_3d_expanded", // // "test_sce_mean_3d_log_prob_expanded", // // "test_sce_mean_3d_log_prob", diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index aa3613551d..5ea16253b6 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -215,6 +215,7 @@ static const InlinedHashMap op_map = { {"Flatten", "reshape"}, {"Floor", "floor"}, {"Gather", "gather"}, + {"GatherND", "gatherND"}, {"Gelu", "gelu"}, {"Gemm", "gemm"}, {"GlobalAveragePool", "averagePool2d"}, @@ -260,6 +261,7 @@ static const InlinedHashMap op_map = { {"Relu", "relu"}, {"Reshape", "reshape"}, {"Resize", "resample2d"}, + {"ScatterND", "scatterND"}, {"Shape", "slice"}, {"Sigmoid", "sigmoid"}, {"Softplus", "softplus"}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc new file mode 100644 index 0000000000..cb4f85a40e --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class GatherNDOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; +}; + +// Add operator related. + +Status GatherNDOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + emscripten::val data = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name()); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + emscripten::val output = model_builder.GetBuilder().call("gatherND", data, indices, options); + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool GatherNDOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + NodeAttrHelper helper(node); + if (helper.Get("batch_dims", 0) != 0) { + LOGS(logger, VERBOSE) << "GatherND: WebNN only supports batch_dims 0 (default)"; + return false; + } + + return true; +} + +bool GatherNDOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& data = *node.InputDefs()[0]; + const auto& indices = *node.InputDefs()[1]; + const auto& op_type = node.OpType(); + + int32_t data_type; + int32_t indices_type; + if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) { + return false; + } + + return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); +} + +void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc new file mode 100644 index 0000000000..feb93cc14b --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class ScatterNDOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; +}; + +// Add operator related. + +Status ScatterNDOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + emscripten::val data = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name()); + emscripten::val updates = model_builder.GetOperand(input_defs[2]->Name()); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + emscripten::val output = + model_builder.GetBuilder().call("scatterND", data, indices, updates, options); + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool ScatterNDOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + NodeAttrHelper helper(node); + if (helper.Get("reduction", "none") != "none") { + LOGS(logger, VERBOSE) << "ScatterND: WebNN only supports reduction type none (default)"; + return false; + } + + return true; +} + +bool ScatterNDOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& data = *node.InputDefs()[0]; + const auto& indices = *node.InputDefs()[1]; + const auto& updates = *node.InputDefs()[2]; + const auto& op_type = node.OpType(); + + int32_t data_type; + int32_t indices_type; + int32_t updates_type; + if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger) || + !GetType(updates, updates_type, logger)) { + return false; + } + + if (data_type != updates_type) { + return false; + } + + return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); +} + +void CreateScatterNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 8baa479024..93fed1704e 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -98,6 +98,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateGatherOpBuilder("Gather", op_registrations); } + { // GatherND + CreateGatherNDOpBuilder("GatherND", op_registrations); + } + { // Flatten CreateFlattenOpBuilder("Flatten", op_registrations); } @@ -170,6 +174,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateResizeOpBuilder("Resize", op_registrations); } + { // ScatterND + CreateScatterNDOpBuilder("ScatterND", op_registrations); + } + { // Shape CreateShapeOpBuilder("Shape", op_registrations); } diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h index 990be04d42..2278571b5a 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h @@ -31,6 +31,7 @@ void CreateDynamicQuantizeLinearOpBuilder(const std::string& op_type, OpBuilderR void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); @@ -43,6 +44,7 @@ void CreateQDQOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_r void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateScatterNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateShapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);