mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
[WebNN EP] WebNN only supports 4-D input and weight for Conv/ConvTranspose (#18703)
This commit is contained in:
parent
b129f425fc
commit
1db1c75048
1 changed files with 30 additions and 13 deletions
|
|
@ -293,22 +293,39 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
|
|||
const auto& op_type = node.OpType();
|
||||
const auto& input_defs = node.InputDefs();
|
||||
|
||||
const auto& weight_name = input_defs[1]->Name();
|
||||
std::vector<int64_t> input_shape;
|
||||
if (!GetShape(*input_defs[0], input_shape, logger)) {
|
||||
LOGS(logger, VERBOSE) << "Cannot get input's shape.";
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto input_size = input_shape.size();
|
||||
if (input_size != 4) {
|
||||
LOGS(logger, VERBOSE) << op_type << " [" << name << "]'s input dimension: " << input_size
|
||||
<< ". Only conv 2d is supported.";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<int64_t> weight_shape;
|
||||
if (!GetShape(*input_defs[1], weight_shape, logger)) {
|
||||
LOGS(logger, VERBOSE) << "Cannot get weight's shape.";
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto weight_size = weight_shape.size();
|
||||
if (weight_size != 4) {
|
||||
LOGS(logger, VERBOSE) << op_type << " [" << name << "]'s weight dimension: " << weight_size
|
||||
<< ". Only conv 2d is supported.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// WebNN CPU backend (XNNPACK) requires the filter operand to be a constant.
|
||||
// https://github.com/google/XNNPACK/blob/master/src/subgraph/convolution-2d.c#L739
|
||||
if (device_type == WebnnDeviceType::CPU) {
|
||||
if (Contains(initializers, weight_name)) {
|
||||
const auto& tensor = *initializers.at(weight_name);
|
||||
if (tensor.dims().size() != 4) {
|
||||
LOGS(logger, VERBOSE) << op_type << " [" << name << "] dimension: " << tensor.dims().size()
|
||||
<< " Only conv 2d is supported.";
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
LOGS(logger, VERBOSE) << "The weight of " << op_type << " [" << name << "] must be known";
|
||||
return false;
|
||||
}
|
||||
if (device_type == WebnnDeviceType::CPU && !Contains(initializers, input_defs[1]->Name())) {
|
||||
LOGS(logger, VERBOSE) << "The weight of " << op_type << " [" << name << "] must be known";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue