[WebNN EP] Fix an issue of CumSum operator (#22936)

This PR limits the axis of the CumSum operator to be a constant when
using WebNN EP.
@Honry  @fdwr PTAL.
This commit is contained in:
Bin Miao 2024-11-26 13:05:53 +08:00 committed by GitHub
parent f80afeb9a1
commit 558ae8621c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 29 additions and 9 deletions

View file

@ -25,7 +25,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| Conv | ai.onnx(7-10, 11+) | conv2d | ✓ | ✓ | Only supports 3-D or 4-D input and 'W' (weight) |
| ConvTranspose | ai.onnx(7-10, 11+) | convTranspose2d | ✓ | ✓ | Only supports 3-D or 4-D input and 'W' (weight). WebNN CPU backend only supports default dilations and group |
| Cos | ai.onnx(7+) | cos | ✓ | ✓ | |
| CumSum | ai.onnx(11-13, 14+) | cumulativeSum | ✓ | ✓ | |
| CumSum | ai.onnx(11-13, 14+) | cumulativeSum | ✓ | ✓ | 'axis' input should be a constant |
| Div | ai.onnx(7-12, 13, 14+) | div | ✓ | ✓ | |
| DequantizeLinear | ai.onnx(10-12, 13-18, 19-20, 21-22, 23+) | dequantizeLinear | ✗ | ✓ | |
| Dropout | ai.onnx(7-9, 10-11, 12, 13-21, 22+) | identity | ✓ | ✓ | Only supports test mode |

View file

@ -19,6 +19,9 @@ namespace webnn {
class CumSumOpBuilder : public BaseOpBuilder {
// Add operator related.
public:
void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
@ -30,8 +33,13 @@ class CumSumOpBuilder : public BaseOpBuilder {
};
// Add operator related.
Status CumSumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const Node& node,
void CumSumOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
// Skip axis.
model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name());
}
Status CumSumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
@ -39,10 +47,14 @@ Status CumSumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape");
const auto input_rank = input_shape.size();
NodeAttrHelper helper(node);
int64_t axis = helper.Get("axis", 0);
axis = HandleNegativeAxis(axis, input_rank);
const auto& initializers = model_builder.GetInitializerTensors();
const std::string axis_name = GetTensorName(input_defs, 1);
const auto axis_tensor = *initializers.at(axis_name);
emscripten::val axis = emscripten::val::undefined();
ORT_RETURN_IF_NOT(ReadScalarTensorData(axis_tensor, axis, logger), "Cannot get axis value");
int64_t webnn_axis = HandleNegativeAxis(axis.as<int64_t>(), input_rank);
NodeAttrHelper helper(node);
const auto exclusive = helper.Get("exclusive", 0);
const auto reverse = helper.Get("reverse", 0);
@ -52,13 +64,14 @@ Status CumSumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
options.set("label", node.Name());
emscripten::val output = emscripten::val::object();
output = model_builder.GetBuilder().call<emscripten::val>("cumulativeSum", input, gsl::narrow<uint32_t>(axis), options);
output = model_builder.GetBuilder().call<emscripten::val>("cumulativeSum", input, gsl::narrow<uint32_t>(webnn_axis),
options);
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
return Status::OK();
}
// Operator support related.
bool CumSumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */,
bool CumSumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
const Node& node,
WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
@ -68,6 +81,13 @@ bool CumSumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ
if (!GetShape(*input_defs[0], input_shape, logger))
return false;
const std::string axis_name = GetTensorName(input_defs, 1);
// Inputs contain optional 'axis' input.
if (!Contains(initializers, axis_name)) {
LOGS(logger, VERBOSE) << "The axis must be a constant initializer.";
return false;
}
return true;
}

View file

@ -82,7 +82,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
}
{ // CumSum
CreateConcatOpBuilder("CumSum", op_registrations);
CreateCumSumOpBuilder("CumSum", op_registrations);
}
{ // Dropout