Fix NNAPI EP error when handling external node adjacent to partition. (#11233)

Move a check for a graph output (for the partition) prior to iterating the downstream nodes to avoid trying to get a NodeUnit for a node that is outside of the partition.
This commit is contained in:
Edward Chen 2022-04-20 08:53:29 -07:00 committed by GitHub
parent 70d97bdf53
commit e3ff4a6bfa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 58 additions and 26 deletions

View file

@ -146,9 +146,9 @@ void ModelBuilder::PreprocessActivations() {
}
const NodeUnit& ModelBuilder::GetNodeUnit(const Node* node) const {
// In theory, if node_unit_map_ is generated correctly, see PreprocessNodeUnits(), a NodeUnit can be
// found for any single node in the graph_viewer_, unless the given node is not from graph_viewer_
return *node_unit_map_.at(node);
const auto node_unit_it = node_unit_map_.find(node);
ORT_ENFORCE(node_unit_it != node_unit_map_.end(), "Node does not have corresponding NodeUnit.");
return *node_unit_it->second;
}
void ModelBuilder::PreprocessNodeUnits() {
@ -620,13 +620,12 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
}
int32_t ModelBuilder::FindActivation(const NodeUnit& node_unit) {
int32_t fuse_code = ANEURALNETWORKS_FUSED_NONE;
const auto& output_nodes = node_unit.GetOutputNodes();
if (node_unit.GetOutputNodes().size() != 1) {
LOGS_DEFAULT(VERBOSE) << "FindActivation does not support, NodeUnit [" << node_unit.Name()
<< "] type [" << node_unit.OpType()
<< "], with " << output_nodes.size() << " output nodes";
return fuse_code;
return ANEURALNETWORKS_FUSED_NONE;
}
const auto& outputs = node_unit.Outputs();
@ -634,42 +633,53 @@ int32_t ModelBuilder::FindActivation(const NodeUnit& node_unit) {
LOGS_DEFAULT(VERBOSE) << "FindActivation does not support, NodeUnit [" << node_unit.Name()
<< "] type [" << node_unit.OpType()
<< "], with " << outputs.size() << " outputs";
return fuse_code;
return ANEURALNETWORKS_FUSED_NONE;
}
const NodeArg& output = outputs[0].node_arg;
const auto& output_node = *output_nodes[0];
// if output is a graph output, will add activation separately
if (const auto& graph_outputs = graph_viewer_.GetOutputs();
std::find(graph_outputs.cbegin(), graph_outputs.cend(), &output) != graph_outputs.cend()) {
return ANEURALNETWORKS_FUSED_NONE;
}
// TODO, add support of activation fusion for quantized node group (qdq or qlinear)
// We do not support activation fusion for quantized operators for now
// (usually the activations are fused already in the quantization)
auto quant_op_type = GetQuantizedOpType(node_unit);
if (quant_op_type != QuantizedOpType::Unknown)
return fuse_code;
if (auto quant_op_type = GetQuantizedOpType(node_unit);
quant_op_type != QuantizedOpType::Unknown) {
return ANEURALNETWORKS_FUSED_NONE;
}
const auto& output_node = *output_nodes[0];
int32_t fuse_code = ANEURALNETWORKS_FUSED_NONE;
bool fuse_code_assigned_from_activation = false;
for (auto it = output_node.OutputEdgesBegin(), end = output_node.OutputEdgesEnd(); it != end; ++it) {
const auto& dst_node = it->GetNode();
const auto* dst_input = dst_node.InputDefs()[it->GetDstArgIndex()];
const auto& dst_node_unit = GetNodeUnit(&dst_node);
if (Contains(activation_node_units_, &dst_node_unit)) {
if (&output == dst_input) {
fuse_code = activation_node_units_.at(&dst_node_unit);
}
} else {
// if there is any other non-relu node using the output
// will add relu separately
if (&output == dst_input)
return ANEURALNETWORKS_FUSED_NONE;
}
}
// if output is a graph output, will add activation separately
if (fuse_code != ANEURALNETWORKS_FUSED_NONE) {
const auto& graph_outputs = graph_viewer_.GetOutputs();
if (std::find(graph_outputs.cbegin(), graph_outputs.cend(), &output) != graph_outputs.cend()) {
if (&output != dst_input) {
continue;
}
const auto& dst_node_unit = GetNodeUnit(&dst_node);
auto activation_it = activation_node_units_.find(&dst_node_unit);
if (activation_it == activation_node_units_.end()) {
// output node is not a fusable activation
return ANEURALNETWORKS_FUSED_NONE;
}
if (fuse_code_assigned_from_activation) {
// don't overwrite a previously assigned fuse code, just don't fuse
return ANEURALNETWORKS_FUSED_NONE;
}
fuse_code = activation_it->second;
fuse_code_assigned_from_activation = true;
}
if (fuse_code != ANEURALNETWORKS_FUSED_NONE) {
LOGS_DEFAULT(VERBOSE) << "Node [" << node_unit.Name() << "] type [" << node_unit.OpType()
<< "], fused the output [" << output.Name() << "]";

View file

@ -109,6 +109,7 @@ class ModelBuilder {
const GraphViewer& GetGraphViewer() const { return graph_viewer_; }
// Get the NodeUnit which contains the given node
// the given node must be in the underlying graph_viewer
const NodeUnit& GetNodeUnit(const Node* node) const;
private:

View file

@ -492,6 +492,22 @@ TEST(NnapiExecutionProviderTest, TestOrtFormatModel) {
#endif
}
// test that NNAPI EP can process an activation node that is outside of its partition
TEST(NnapiExecutionProviderTest, ActivationOutsideOfPartition) {
// model starts with Conv -> Relu
constexpr auto* model_file_name = ORT_TSTR("testdata/mnist.level1_opt.ort");
// stop NNAPI partitioning at Relu so NNAPI EP only takes first Conv
const auto nnapi_partitioning_stop_ops = "Relu";
SessionOptions so;
InferenceSessionWrapper session_object{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(
std::make_unique<NnapiExecutionProvider>(0, nnapi_partitioning_stop_ops)));
ASSERT_STATUS_OK(session_object.Load(model_file_name));
ASSERT_STATUS_OK(session_object.Initialize());
// expect one NNAPI partition
ASSERT_EQ(CountAssignedNodes(session_object.GetGraph(), kNnapiExecutionProvider), 1);
}
} // namespace test
} // namespace onnxruntime
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

View file

@ -7,3 +7,8 @@ ai.onnx;8;MaxPool{"inputs": {"0": ["float"]}},Sum{"inputs": {"0": ["float"]}}
ai.onnx;9;Cast{"inputs": {"0": ["float"]}, "outputs": {"0": ["bool"]}}
ai.onnx;11;ArgMax{"inputs": {"0": ["float"]}},If,Loop
ai.onnx.ml;1;ArrayFeatureExtractor,LinearClassifier,Normalizer,ZipMap
# Note: The lines below were added manually.
# TODO find a way to avoid manual modification of this file
# also include Transpose added by layout transformation
ai.onnx;1;Transpose