[WebNN EP] WebNN only supports 4-D input and weight for Conv/ConvTranspose (#18703)

This commit is contained in:
Wanming Lin 2023-12-15 06:33:19 +08:00 committed by GitHub
parent b129f425fc
commit 1db1c75048
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -293,22 +293,39 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
const auto& op_type = node.OpType();
const auto& input_defs = node.InputDefs();
const auto& weight_name = input_defs[1]->Name();
std::vector<int64_t> input_shape;
if (!GetShape(*input_defs[0], input_shape, logger)) {
LOGS(logger, VERBOSE) << "Cannot get input's shape.";
return false;
}
const auto input_size = input_shape.size();
if (input_size != 4) {
LOGS(logger, VERBOSE) << op_type << " [" << name << "]'s input dimension: " << input_size
<< ". Only conv 2d is supported.";
return false;
}
std::vector<int64_t> weight_shape;
if (!GetShape(*input_defs[1], weight_shape, logger)) {
LOGS(logger, VERBOSE) << "Cannot get weight's shape.";
return false;
}
const auto weight_size = weight_shape.size();
if (weight_size != 4) {
LOGS(logger, VERBOSE) << op_type << " [" << name << "]'s weight dimension: " << weight_size
<< ". Only conv 2d is supported.";
return false;
}
// WebNN CPU backend (XNNPACK) requires the filter operand to be a constant.
// https://github.com/google/XNNPACK/blob/master/src/subgraph/convolution-2d.c#L739
if (device_type == WebnnDeviceType::CPU) {
if (Contains(initializers, weight_name)) {
const auto& tensor = *initializers.at(weight_name);
if (tensor.dims().size() != 4) {
LOGS(logger, VERBOSE) << op_type << " [" << name << "] dimension: " << tensor.dims().size()
<< " Only conv 2d is supported.";
return false;
}
} else {
LOGS(logger, VERBOSE) << "The weight of " << op_type << " [" << name << "] must be known";
return false;
}
if (device_type == WebnnDeviceType::CPU && !Contains(initializers, input_defs[1]->Name())) {
LOGS(logger, VERBOSE) << "The weight of " << op_type << " [" << name << "] must be known";
return false;
}
return true;
}