mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
Revert "[WebNN] Fallback the node when its output doesn't have shape info" (#22669)
Reverts microsoft/onnxruntime#22556 since it causes incorrect fallback.
This commit is contained in:
parent
f9bc24e1a7
commit
c7ecc081ca
5 changed files with 15 additions and 31 deletions
|
|
@ -69,16 +69,17 @@ bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const We
|
|||
}
|
||||
}
|
||||
|
||||
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger) {
|
||||
const auto& node_arg_name = node_arg.Name();
|
||||
const auto* shape_proto = node_arg.Shape();
|
||||
bool IsInputSupported(const NodeArg& input, const std::string& parent_name, const logging::Logger& logger) {
|
||||
const auto& input_name = input.Name();
|
||||
const auto* shape_proto = input.Shape();
|
||||
// Optional tensors can be indicated by an empty name, just ignore it.
|
||||
if (node_arg_name.empty()) {
|
||||
if (input_name.empty()) {
|
||||
return true;
|
||||
}
|
||||
// We do not support input/output with no shape.
|
||||
// We do not support input with no shape.
|
||||
if (!shape_proto) {
|
||||
LOGS(logger, VERBOSE) << "Node arg [" << node_arg_name << "] of [" << parent_name << "] has not shape";
|
||||
LOGS(logger, VERBOSE) << "Input [" << input_name << "] of [" << parent_name
|
||||
<< "] has not shape";
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -86,11 +87,12 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n
|
|||
// WebNN doesn't support dynamic shape - use sessionOptions.freeDimensionOverrides to fix the shape.
|
||||
if (!dim.has_dim_value()) {
|
||||
LOGS(logger, VERBOSE) << "Dynamic shape is not supported, "
|
||||
<< "use sessionOptions.FreeDimensionOverrides to set a fixed shape: " << node_arg_name;
|
||||
<< "use sessionOptions.FreeDimensionOverrides to set a fixed shape for input: "
|
||||
<< input_name;
|
||||
return false;
|
||||
}
|
||||
if (dim.dim_value() == 0) {
|
||||
LOGS(logger, VERBOSE) << "The shape of [" << node_arg_name << "] has 0 dimension which is not supported by WebNN";
|
||||
LOGS(logger, VERBOSE) << "The shape of [" << input_name << "] has 0 dimension which is not supported by WebNN";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
@ -106,12 +108,7 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
|
|||
std::vector<std::vector<size_t>> supported_node_groups;
|
||||
|
||||
for (const auto* input : graph_viewer.GetInputs()) {
|
||||
if (!IsTensorShapeSupported(*input, "graph", logger)) {
|
||||
return supported_node_groups;
|
||||
}
|
||||
}
|
||||
for (const auto* output : graph_viewer.GetOutputs()) {
|
||||
if (!IsTensorShapeSupported(*output, "graph", logger)) {
|
||||
if (!IsInputSupported(*input, "graph", logger)) {
|
||||
return supported_node_groups;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -180,7 +180,7 @@ inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::s
|
|||
return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; });
|
||||
}
|
||||
|
||||
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger);
|
||||
bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger);
|
||||
|
||||
// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP.
|
||||
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons
|
|||
if (!HasSupportedInputs(node, wnn_limits, logger))
|
||||
return false;
|
||||
|
||||
if (!HasSupportedOutputs(node, wnn_limits, logger))
|
||||
if (!HasSupportedOutputsImpl(node, wnn_limits, logger))
|
||||
return false;
|
||||
|
||||
if (!HasSupportedOpSet(node, logger))
|
||||
|
|
@ -47,7 +47,7 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val&
|
|||
const logging::Logger& logger) const {
|
||||
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
|
||||
for (const auto* input : node.InputDefs()) {
|
||||
if (!IsTensorShapeSupported(*input, node_name, logger)) {
|
||||
if (!IsInputSupported(*input, node_name, logger)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
@ -68,18 +68,6 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node,
|
|||
return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "Input", logger);
|
||||
}
|
||||
|
||||
bool BaseOpBuilder::HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
|
||||
for (const auto* output : node.OutputDefs()) {
|
||||
if (!IsTensorShapeSupported(*output, node_name, logger)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return HasSupportedOutputsImpl(node, wnn_limits, logger);
|
||||
}
|
||||
|
||||
bool BaseOpBuilder::HasSupportedOutputsImpl(const Node& node,
|
||||
const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
|
|
|
|||
|
|
@ -54,7 +54,6 @@ class BaseOpBuilder : public IOpBuilder {
|
|||
private:
|
||||
bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const;
|
||||
bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
|
||||
bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
|
||||
};
|
||||
|
||||
} // namespace webnn
|
||||
|
|
|
|||
|
|
@ -222,7 +222,7 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i
|
|||
if (!shape.empty()) {
|
||||
dims.reserve(shape.size());
|
||||
for (const auto& dim : shape) {
|
||||
// dim_param free dimensions should have already been excluded by IsTensorShapeSupported().
|
||||
// dim_param free dimensions should have already been excluded by IsInputSupported().
|
||||
assert(dim.has_dim_value());
|
||||
dims.push_back(SafeInt<int32_t>(dim.dim_value()));
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue