Reland "[WebNN] Fallback the node when its output doesn't have shape info" (#22685)

The previous PR was reverted because it causes the whole model to
fallback when there is output shape info missing. This PR fixes the
issue by removing redundant fallbacks.
This commit is contained in:
shiyi 2024-11-12 08:30:10 +08:00 committed by GitHub
parent b1e0930eab
commit f7d1f0fc5e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 25 additions and 21 deletions

View file

@ -69,17 +69,16 @@ bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const We
}
}
bool IsInputSupported(const NodeArg& input, const std::string& parent_name, const logging::Logger& logger) {
const auto& input_name = input.Name();
const auto* shape_proto = input.Shape();
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger) {
const auto& node_arg_name = node_arg.Name();
const auto* shape_proto = node_arg.Shape();
// Optional tensors can be indicated by an empty name, just ignore it.
if (input_name.empty()) {
if (node_arg_name.empty()) {
return true;
}
// We do not support input with no shape.
// We do not support input/output with no shape.
if (!shape_proto) {
LOGS(logger, VERBOSE) << "Input [" << input_name << "] of [" << parent_name
<< "] has not shape";
LOGS(logger, VERBOSE) << "Node arg [" << node_arg_name << "] of [" << parent_name << "] has not shape";
return false;
}
@ -87,12 +86,11 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons
// WebNN doesn't support dynamic shape - use sessionOptions.freeDimensionOverrides to fix the shape.
if (!dim.has_dim_value()) {
LOGS(logger, VERBOSE) << "Dynamic shape is not supported, "
<< "use sessionOptions.FreeDimensionOverrides to set a fixed shape for input: "
<< input_name;
<< "use sessionOptions.FreeDimensionOverrides to set a fixed shape: " << node_arg_name;
return false;
}
if (dim.dim_value() == 0) {
LOGS(logger, VERBOSE) << "The shape of [" << input_name << "] has 0 dimension which is not supported by WebNN";
LOGS(logger, VERBOSE) << "The shape of [" << node_arg_name << "] has 0 dimension which is not supported by WebNN";
return false;
}
}
@ -106,13 +104,6 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
const emscripten::val& wnn_limits,
const logging::Logger& logger) {
std::vector<std::vector<size_t>> supported_node_groups;
for (const auto* input : graph_viewer.GetInputs()) {
if (!IsInputSupported(*input, "graph", logger)) {
return supported_node_groups;
}
}
std::vector<size_t> supported_node_group;
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();

View file

@ -180,7 +180,7 @@ inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::s
return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; });
}
bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger);
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger);
// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP.
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,

View file

@ -32,7 +32,7 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons
if (!HasSupportedInputs(node, wnn_limits, logger))
return false;
if (!HasSupportedOutputsImpl(node, wnn_limits, logger))
if (!HasSupportedOutputs(node, wnn_limits, logger))
return false;
if (!HasSupportedOpSet(node, logger))
@ -45,7 +45,7 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val&
const logging::Logger& logger) const {
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
for (const auto* input : node.InputDefs()) {
if (!IsInputSupported(*input, node_name, logger)) {
if (!IsTensorShapeSupported(*input, node_name, logger)) {
return false;
}
}
@ -66,6 +66,18 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node,
return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "Input", logger);
}
bool BaseOpBuilder::HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
for (const auto* output : node.OutputDefs()) {
if (!IsTensorShapeSupported(*output, node_name, logger)) {
return false;
}
}
return HasSupportedOutputsImpl(node, wnn_limits, logger);
}
bool BaseOpBuilder::HasSupportedOutputsImpl(const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {

View file

@ -54,6 +54,7 @@ class BaseOpBuilder : public IOpBuilder {
private:
bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const;
bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
};
} // namespace webnn

View file

@ -227,7 +227,7 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i
if (!shape.empty()) {
dims.reserve(shape.size());
for (const auto& dim : shape) {
// dim_param free dimensions should have already been excluded by IsInputSupported().
// dim_param free dimensions should have already been excluded by IsTensorShapeSupported().
assert(dim.has_dim_value());
dims.push_back(SafeInt<int32_t>(dim.dim_value()));
}