mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
[WebNN EP] Fixed bug in ConvTranspose (#21569)
The constraint of ConvTranspose was placed in wrong place.
This commit is contained in:
parent
c5f8389648
commit
a3883af7bf
1 changed files with 17 additions and 17 deletions
|
|
@ -28,7 +28,7 @@ class ConvOpBuilder : public BaseOpBuilder {
|
|||
// Operator support related.
|
||||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
|
@ -378,6 +378,22 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
|
|||
return false;
|
||||
}
|
||||
|
||||
// WebNN CPU backend (TFLite) only supports default dilations and group.
|
||||
// https://source.chromium.org/chromium/chromium/src/+/main:services/webnn/tflite/graph_builder_tflite.cc;l=1040
|
||||
if (device_type == WebnnDeviceType::CPU && op_type == "ConvTranspose") {
|
||||
NodeAttrHelper helper(node);
|
||||
const auto dilations = helper.Get("dilations", std::vector<int64_t>{1, 1});
|
||||
const auto group = helper.Get("group", 1);
|
||||
if (dilations[0] != 1 || (dilations.size() > 1 && dilations[1] != 1)) {
|
||||
LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default dilation 1.";
|
||||
return false;
|
||||
}
|
||||
if (group != 1) {
|
||||
LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default group 1.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -427,22 +443,6 @@ bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTy
|
|||
return false;
|
||||
}
|
||||
|
||||
// WebNN CPU backend (TFLite) only supports default dilations and group.
|
||||
// https://source.chromium.org/chromium/chromium/src/+/main:services/webnn/tflite/graph_builder_tflite.cc;l=1040
|
||||
if (device_type == WebnnDeviceType::CPU && op_type == "ConvTranspose") {
|
||||
NodeAttrHelper helper(node);
|
||||
const auto dilations = helper.Get("dilations", std::vector<int64_t>{1, 1});
|
||||
const auto group = helper.Get("group", 1);
|
||||
if (dilations[0] != 1 || (dilations.size() > 1 && dilations[1] != 1)) {
|
||||
LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default dilation 1.";
|
||||
return false;
|
||||
}
|
||||
if (group != 1) {
|
||||
LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default group 1.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue