[WebNN EP] Support GatherND and ScatterND op (#22181)

This commit is contained in:
shiyi 2024-10-29 06:04:45 +08:00 committed by GitHub
parent 975d3dffcf
commit dcf91266bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 189 additions and 6 deletions

View file

@ -35,6 +35,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| Flatten | ai.onnx(7-8, 9-10, 11-12, 13-20, 21+) | reshape | ✓ | ✓ | |
| Floor | ai.onnx(7-12, 13+) | floor | ✓ | ✓ | |
| Gather | ai.onnx(7-10, 11-12, 13+) | gather | ✓ | ✓ | |
| GatherND | ai.onnx(11, 12, 13+) | gatherND | ✓ | ✓ | Only supports 'batch_dims' == 0 |
| Gelu | ai.onnx(20+) | gelu | ✓ | ✓ | |
| Gemm | ai.onnx(7-8, 9-10, 11-12, 13+) | gemm | ✓ | ✓ | Only supports 1-D 'C' input |
| GlobalAveragePool | ai.onnx(7+) | averagePool2d | ✓ | ✓ | Only supports 4-D input |
@ -79,6 +80,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| Relu | ai.onnx(7-12, 13, 14+) | relu | ✓ | ✓ | |
| Reshape | ai.onnx(7-12, 13, 14-18, 19-20, 21+) | reshape | ✓ | ✓ | Input 'shape' should be a constant, 0 dimension value in 'shape' is not supported |
| Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d | ✓ | ✓ | Only supports 4-D input, antialias == 0, coordinate_transformation_mode == 'half_pixel', exclude_outside == 0, keep_aspect_ratio_policy == 'stretch', 'linear' and 'nearest' modes, input 'scales' and 'sizes' if present must be a constant |
| ScatterND | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterND | ✗ | ✓ | Only supports 'reduction' == 'none' |
| Shape | ai.onnx(7-12, 13-14, 15-18, 19-20, 21+) | slice | ✓ | ✓ | |
| Sigmoid | ai.onnx(7-12, 13+) | sigmoid | ✓ | ✓ | |
| Softplus | ai.onnx(7+) | softplus | ✓ | ✓ | |

View file

@ -1777,9 +1777,9 @@
"test_gather_elements_0",
"test_gather_elements_1",
"test_gather_elements_negative_indices",
// "test_gathernd_example_float32",
// "test_gathernd_example_int32_batch_dim1",
// "test_gathernd_example_int32",
"test_gathernd_example_float32",
"test_gathernd_example_int32_batch_dim1",
"test_gathernd_example_int32",
"test_gemm_all_attributes",
"test_gemm_alpha",
"test_gemm_beta",
@ -2260,9 +2260,9 @@
// // "test_scatter_elements_without_axis",
// // "test_scatter_with_axis",
// // "test_scatter_without_axis",
// // "test_scatternd_add",
// // "test_scatternd_multiply",
// // "test_scatternd",
"test_scatternd_add",
"test_scatternd_multiply",
"test_scatternd",
// // "test_sce_mean_3d_expanded",
// // "test_sce_mean_3d_log_prob_expanded",
// // "test_sce_mean_3d_log_prob",

View file

@ -215,6 +215,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"Flatten", "reshape"},
{"Floor", "floor"},
{"Gather", "gather"},
{"GatherND", "gatherND"},
{"Gelu", "gelu"},
{"Gemm", "gemm"},
{"GlobalAveragePool", "averagePool2d"},
@ -260,6 +261,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"Relu", "relu"},
{"Reshape", "reshape"},
{"Resize", "resample2d"},
{"ScatterND", "scatterND"},
{"Shape", "slice"},
{"Sigmoid", "sigmoid"},
{"Softplus", "softplus"},

View file

@ -0,0 +1,80 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Intel Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/webnn/builders/helper.h"
#include "core/providers/webnn/builders/model_builder.h"
#include "core/providers/webnn/builders/op_builder_factory.h"
#include "base_op_builder.h"
namespace onnxruntime {
namespace webnn {
class GatherNDOpBuilder : public BaseOpBuilder {
// Add operator related.
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};
// Add operator related.
Status GatherNDOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
emscripten::val data = model_builder.GetOperand(input_defs[0]->Name());
emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name());
emscripten::val options = emscripten::val::object();
options.set("label", node.Name());
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("gatherND", data, indices, options);
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
return Status::OK();
}
// Operator support related.
bool GatherNDOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
NodeAttrHelper helper(node);
if (helper.Get("batch_dims", 0) != 0) {
LOGS(logger, VERBOSE) << "GatherND: WebNN only supports batch_dims 0 (default)";
return false;
}
return true;
}
bool GatherNDOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& data = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
const auto& op_type = node.OpType();
int32_t data_type;
int32_t indices_type;
if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) {
return false;
}
return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) &&
IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger);
}
void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<GatherNDOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
}
} // namespace webnn
} // namespace onnxruntime

View file

@ -0,0 +1,89 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Intel Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/webnn/builders/helper.h"
#include "core/providers/webnn/builders/model_builder.h"
#include "core/providers/webnn/builders/op_builder_factory.h"
#include "base_op_builder.h"
namespace onnxruntime {
namespace webnn {
class ScatterNDOpBuilder : public BaseOpBuilder {
// Add operator related.
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};
// Add operator related.
Status ScatterNDOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
emscripten::val data = model_builder.GetOperand(input_defs[0]->Name());
emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name());
emscripten::val updates = model_builder.GetOperand(input_defs[2]->Name());
emscripten::val options = emscripten::val::object();
options.set("label", node.Name());
emscripten::val output =
model_builder.GetBuilder().call<emscripten::val>("scatterND", data, indices, updates, options);
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
return Status::OK();
}
// Operator support related.
bool ScatterNDOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
NodeAttrHelper helper(node);
if (helper.Get("reduction", "none") != "none") {
LOGS(logger, VERBOSE) << "ScatterND: WebNN only supports reduction type none (default)";
return false;
}
return true;
}
bool ScatterNDOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& data = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
const auto& updates = *node.InputDefs()[2];
const auto& op_type = node.OpType();
int32_t data_type;
int32_t indices_type;
int32_t updates_type;
if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger) ||
!GetType(updates, updates_type, logger)) {
return false;
}
if (data_type != updates_type) {
return false;
}
return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) &&
IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger);
}
void CreateScatterNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<ScatterNDOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
}
} // namespace webnn
} // namespace onnxruntime

View file

@ -98,6 +98,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateGatherOpBuilder("Gather", op_registrations);
}
{ // GatherND
CreateGatherNDOpBuilder("GatherND", op_registrations);
}
{ // Flatten
CreateFlattenOpBuilder("Flatten", op_registrations);
}
@ -170,6 +174,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateResizeOpBuilder("Resize", op_registrations);
}
{ // ScatterND
CreateScatterNDOpBuilder("ScatterND", op_registrations);
}
{ // Shape
CreateShapeOpBuilder("Shape", op_registrations);
}

View file

@ -31,6 +31,7 @@ void CreateDynamicQuantizeLinearOpBuilder(const std::string& op_type, OpBuilderR
void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
@ -43,6 +44,7 @@ void CreateQDQOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_r
void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateScatterNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateShapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);