diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 2ff127e9a0..d23fe11355 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -25,7 +25,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Conv | ai.onnx(7-10, 11+) | conv2d | ✓ | ✓ | Only supports 3-D or 4-D input and 'W' (weight) | | ConvTranspose | ai.onnx(7-10, 11+) | convTranspose2d | ✓ | ✓ | Only supports 3-D or 4-D input and 'W' (weight). WebNN CPU backend only supports default dilations and group | | Cos | ai.onnx(7+) | cos | ✓ | ✓ | | -| CumSum | ai.onnx(11-13, 14+) | cumulativeSum | ✓ | ✓ | | +| CumSum | ai.onnx(11-13, 14+) | cumulativeSum | ✓ | ✓ | 'axis' input should be a constant | | Div | ai.onnx(7-12, 13, 14+) | div | ✓ | ✓ | | | DequantizeLinear | ai.onnx(10-12, 13-18, 19-20, 21-22, 23+) | dequantizeLinear | ✗ | ✓ | | | Dropout | ai.onnx(7-9, 10-11, 12, 13-21, 22+) | identity | ✓ | ✓ | Only supports test mode | diff --git a/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc index 018060f18c..be30c5520d 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc @@ -19,6 +19,9 @@ namespace webnn { class CumSumOpBuilder : public BaseOpBuilder { // Add operator related. + public: + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; @@ -30,8 +33,13 @@ class CumSumOpBuilder : public BaseOpBuilder { }; // Add operator related. -Status CumSumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, - const Node& node, + +void CumSumOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + // Skip axis. + model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); +} + +Status CumSumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); @@ -39,10 +47,14 @@ Status CumSumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); const auto input_rank = input_shape.size(); - NodeAttrHelper helper(node); - int64_t axis = helper.Get("axis", 0); - axis = HandleNegativeAxis(axis, input_rank); + const auto& initializers = model_builder.GetInitializerTensors(); + const std::string axis_name = GetTensorName(input_defs, 1); + const auto axis_tensor = *initializers.at(axis_name); + emscripten::val axis = emscripten::val::undefined(); + ORT_RETURN_IF_NOT(ReadScalarTensorData(axis_tensor, axis, logger), "Cannot get axis value"); + int64_t webnn_axis = HandleNegativeAxis(axis.as(), input_rank); + NodeAttrHelper helper(node); const auto exclusive = helper.Get("exclusive", 0); const auto reverse = helper.Get("reverse", 0); @@ -52,13 +64,14 @@ Status CumSumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, options.set("label", node.Name()); emscripten::val output = emscripten::val::object(); - output = model_builder.GetBuilder().call("cumulativeSum", input, gsl::narrow(axis), options); + output = model_builder.GetBuilder().call("cumulativeSum", input, gsl::narrow(webnn_axis), + options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); } // Operator support related. -bool CumSumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, +bool CumSumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, WebnnDeviceType /* device_type */, const logging::Logger& logger) const { @@ -68,6 +81,13 @@ bool CumSumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ if (!GetShape(*input_defs[0], input_shape, logger)) return false; + const std::string axis_name = GetTensorName(input_defs, 1); + // Inputs contain optional 'axis' input. + if (!Contains(initializers, axis_name)) { + LOGS(logger, VERBOSE) << "The axis must be a constant initializer."; + return false; + } + return true; } diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 4b26ca65f1..6d1c572128 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -82,7 +82,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { } { // CumSum - CreateConcatOpBuilder("CumSum", op_registrations); + CreateCumSumOpBuilder("CumSum", op_registrations); } { // Dropout