diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index b37340624f..e94db2faa8 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -293,22 +293,39 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); - const auto& weight_name = input_defs[1]->Name(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + LOGS(logger, VERBOSE) << "Cannot get input's shape."; + return false; + } + + const auto input_size = input_shape.size(); + if (input_size != 4) { + LOGS(logger, VERBOSE) << op_type << " [" << name << "]'s input dimension: " << input_size + << ". Only conv 2d is supported."; + return false; + } + + std::vector weight_shape; + if (!GetShape(*input_defs[1], weight_shape, logger)) { + LOGS(logger, VERBOSE) << "Cannot get weight's shape."; + return false; + } + + const auto weight_size = weight_shape.size(); + if (weight_size != 4) { + LOGS(logger, VERBOSE) << op_type << " [" << name << "]'s weight dimension: " << weight_size + << ". Only conv 2d is supported."; + return false; + } + // WebNN CPU backend (XNNPACK) requires the filter operand to be a constant. // https://github.com/google/XNNPACK/blob/master/src/subgraph/convolution-2d.c#L739 - if (device_type == WebnnDeviceType::CPU) { - if (Contains(initializers, weight_name)) { - const auto& tensor = *initializers.at(weight_name); - if (tensor.dims().size() != 4) { - LOGS(logger, VERBOSE) << op_type << " [" << name << "] dimension: " << tensor.dims().size() - << " Only conv 2d is supported."; - return false; - } - } else { - LOGS(logger, VERBOSE) << "The weight of " << op_type << " [" << name << "] must be known"; - return false; - } + if (device_type == WebnnDeviceType::CPU && !Contains(initializers, input_defs[1]->Name())) { + LOGS(logger, VERBOSE) << "The weight of " << op_type << " [" << name << "] must be known"; + return false; } + return true; }