mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Reland "[WebNN] Fallback the node when its output doesn't have shape info" (#22685)
The previous PR was reverted because it causes the whole model to fallback when there is output shape info missing. This PR fixes the issue by removing redundant fallbacks.
This commit is contained in:
parent
b1e0930eab
commit
f7d1f0fc5e
5 changed files with 25 additions and 21 deletions
|
|
@ -69,17 +69,16 @@ bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const We
|
|||
}
|
||||
}
|
||||
|
||||
bool IsInputSupported(const NodeArg& input, const std::string& parent_name, const logging::Logger& logger) {
|
||||
const auto& input_name = input.Name();
|
||||
const auto* shape_proto = input.Shape();
|
||||
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger) {
|
||||
const auto& node_arg_name = node_arg.Name();
|
||||
const auto* shape_proto = node_arg.Shape();
|
||||
// Optional tensors can be indicated by an empty name, just ignore it.
|
||||
if (input_name.empty()) {
|
||||
if (node_arg_name.empty()) {
|
||||
return true;
|
||||
}
|
||||
// We do not support input with no shape.
|
||||
// We do not support input/output with no shape.
|
||||
if (!shape_proto) {
|
||||
LOGS(logger, VERBOSE) << "Input [" << input_name << "] of [" << parent_name
|
||||
<< "] has not shape";
|
||||
LOGS(logger, VERBOSE) << "Node arg [" << node_arg_name << "] of [" << parent_name << "] has not shape";
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -87,12 +86,11 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons
|
|||
// WebNN doesn't support dynamic shape - use sessionOptions.freeDimensionOverrides to fix the shape.
|
||||
if (!dim.has_dim_value()) {
|
||||
LOGS(logger, VERBOSE) << "Dynamic shape is not supported, "
|
||||
<< "use sessionOptions.FreeDimensionOverrides to set a fixed shape for input: "
|
||||
<< input_name;
|
||||
<< "use sessionOptions.FreeDimensionOverrides to set a fixed shape: " << node_arg_name;
|
||||
return false;
|
||||
}
|
||||
if (dim.dim_value() == 0) {
|
||||
LOGS(logger, VERBOSE) << "The shape of [" << input_name << "] has 0 dimension which is not supported by WebNN";
|
||||
LOGS(logger, VERBOSE) << "The shape of [" << node_arg_name << "] has 0 dimension which is not supported by WebNN";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
@ -106,13 +104,6 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
|
|||
const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) {
|
||||
std::vector<std::vector<size_t>> supported_node_groups;
|
||||
|
||||
for (const auto* input : graph_viewer.GetInputs()) {
|
||||
if (!IsInputSupported(*input, "graph", logger)) {
|
||||
return supported_node_groups;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<size_t> supported_node_group;
|
||||
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
|
|
|
|||
|
|
@ -180,7 +180,7 @@ inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::s
|
|||
return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; });
|
||||
}
|
||||
|
||||
bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger);
|
||||
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger);
|
||||
|
||||
// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP.
|
||||
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons
|
|||
if (!HasSupportedInputs(node, wnn_limits, logger))
|
||||
return false;
|
||||
|
||||
if (!HasSupportedOutputsImpl(node, wnn_limits, logger))
|
||||
if (!HasSupportedOutputs(node, wnn_limits, logger))
|
||||
return false;
|
||||
|
||||
if (!HasSupportedOpSet(node, logger))
|
||||
|
|
@ -45,7 +45,7 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val&
|
|||
const logging::Logger& logger) const {
|
||||
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
|
||||
for (const auto* input : node.InputDefs()) {
|
||||
if (!IsInputSupported(*input, node_name, logger)) {
|
||||
if (!IsTensorShapeSupported(*input, node_name, logger)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
@ -66,6 +66,18 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node,
|
|||
return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "Input", logger);
|
||||
}
|
||||
|
||||
bool BaseOpBuilder::HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
|
||||
for (const auto* output : node.OutputDefs()) {
|
||||
if (!IsTensorShapeSupported(*output, node_name, logger)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return HasSupportedOutputsImpl(node, wnn_limits, logger);
|
||||
}
|
||||
|
||||
bool BaseOpBuilder::HasSupportedOutputsImpl(const Node& node,
|
||||
const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ class BaseOpBuilder : public IOpBuilder {
|
|||
private:
|
||||
bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const;
|
||||
bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
|
||||
bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
|
||||
};
|
||||
|
||||
} // namespace webnn
|
||||
|
|
|
|||
|
|
@ -227,7 +227,7 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i
|
|||
if (!shape.empty()) {
|
||||
dims.reserve(shape.size());
|
||||
for (const auto& dim : shape) {
|
||||
// dim_param free dimensions should have already been excluded by IsInputSupported().
|
||||
// dim_param free dimensions should have already been excluded by IsTensorShapeSupported().
|
||||
assert(dim.has_dim_value());
|
||||
dims.push_back(SafeInt<int32_t>(dim.dim_value()));
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue