mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[WebNN EP] Fix an issue of CumSum operator (#22936)
This PR limits the axis of the CumSum operator to be a constant when using WebNN EP. @Honry @fdwr PTAL.
This commit is contained in:
parent
f80afeb9a1
commit
558ae8621c
3 changed files with 29 additions and 9 deletions
|
|
@ -25,7 +25,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
|
|||
| Conv | ai.onnx(7-10, 11+) | conv2d | ✓ | ✓ | Only supports 3-D or 4-D input and 'W' (weight) |
|
||||
| ConvTranspose | ai.onnx(7-10, 11+) | convTranspose2d | ✓ | ✓ | Only supports 3-D or 4-D input and 'W' (weight). WebNN CPU backend only supports default dilations and group |
|
||||
| Cos | ai.onnx(7+) | cos | ✓ | ✓ | |
|
||||
| CumSum | ai.onnx(11-13, 14+) | cumulativeSum | ✓ | ✓ | |
|
||||
| CumSum | ai.onnx(11-13, 14+) | cumulativeSum | ✓ | ✓ | 'axis' input should be a constant |
|
||||
| Div | ai.onnx(7-12, 13, 14+) | div | ✓ | ✓ | |
|
||||
| DequantizeLinear | ai.onnx(10-12, 13-18, 19-20, 21-22, 23+) | dequantizeLinear | ✗ | ✓ | |
|
||||
| Dropout | ai.onnx(7-9, 10-11, 12, 13-21, 22+) | identity | ✓ | ✓ | Only supports test mode |
|
||||
|
|
|
|||
|
|
@ -19,6 +19,9 @@ namespace webnn {
|
|||
class CumSumOpBuilder : public BaseOpBuilder {
|
||||
// Add operator related.
|
||||
|
||||
public:
|
||||
void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;
|
||||
|
||||
private:
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
|
|
@ -30,8 +33,13 @@ class CumSumOpBuilder : public BaseOpBuilder {
|
|||
};
|
||||
|
||||
// Add operator related.
|
||||
Status CumSumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
||||
const Node& node,
|
||||
|
||||
void CumSumOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
|
||||
// Skip axis.
|
||||
model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name());
|
||||
}
|
||||
|
||||
Status CumSumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
|
||||
|
|
@ -39,10 +47,14 @@ Status CumSumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape");
|
||||
const auto input_rank = input_shape.size();
|
||||
|
||||
NodeAttrHelper helper(node);
|
||||
int64_t axis = helper.Get("axis", 0);
|
||||
axis = HandleNegativeAxis(axis, input_rank);
|
||||
const auto& initializers = model_builder.GetInitializerTensors();
|
||||
const std::string axis_name = GetTensorName(input_defs, 1);
|
||||
const auto axis_tensor = *initializers.at(axis_name);
|
||||
emscripten::val axis = emscripten::val::undefined();
|
||||
ORT_RETURN_IF_NOT(ReadScalarTensorData(axis_tensor, axis, logger), "Cannot get axis value");
|
||||
int64_t webnn_axis = HandleNegativeAxis(axis.as<int64_t>(), input_rank);
|
||||
|
||||
NodeAttrHelper helper(node);
|
||||
const auto exclusive = helper.Get("exclusive", 0);
|
||||
const auto reverse = helper.Get("reverse", 0);
|
||||
|
||||
|
|
@ -52,13 +64,14 @@ Status CumSumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
options.set("label", node.Name());
|
||||
|
||||
emscripten::val output = emscripten::val::object();
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("cumulativeSum", input, gsl::narrow<uint32_t>(axis), options);
|
||||
output = model_builder.GetBuilder().call<emscripten::val>("cumulativeSum", input, gsl::narrow<uint32_t>(webnn_axis),
|
||||
options);
|
||||
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Operator support related.
|
||||
bool CumSumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */,
|
||||
bool CumSumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
|
||||
const Node& node,
|
||||
WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const {
|
||||
|
|
@ -68,6 +81,13 @@ bool CumSumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ
|
|||
if (!GetShape(*input_defs[0], input_shape, logger))
|
||||
return false;
|
||||
|
||||
const std::string axis_name = GetTensorName(input_defs, 1);
|
||||
// Inputs contain optional 'axis' input.
|
||||
if (!Contains(initializers, axis_name)) {
|
||||
LOGS(logger, VERBOSE) << "The axis must be a constant initializer.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
|
|||
}
|
||||
|
||||
{ // CumSum
|
||||
CreateConcatOpBuilder("CumSum", op_registrations);
|
||||
CreateCumSumOpBuilder("CumSum", op_registrations);
|
||||
}
|
||||
|
||||
{ // Dropout
|
||||
|
|
|
|||
Loading…
Reference in a new issue