Do not fuse DQ+Node+Q if DQ produces graph output (#14509)

Fix issue #14501
This commit is contained in:
Yufeng Li 2023-02-01 13:36:47 -08:00 committed by GitHub
parent 3d388a1aea
commit d9e675a2af
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 64 additions and 0 deletions

View file

@ -51,6 +51,14 @@ bool NodeGroupSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Nod
return false;
}
auto does_node_produce_graph_output = [&graph_viewer](const Node* node_ptr) {
return graph_viewer.NodeProducesGraphOutput(*node_ptr);
};
if (std::any_of(dq_nodes.begin(), dq_nodes.end(), does_node_produce_graph_output)) {
return false;
}
if (q_nodes.empty()) {
return is_empty_q_nodes_allowed;
}

View file

@ -2760,5 +2760,61 @@ TEST(QDQTransformerTests, QDQFinalCleanupTransformer_GraphInputToOutput) {
test_case(false);
}
// Not fuse if DQ produces graph output
TEST(QDQTransformerTests, DQ_Produce_Graph_Output) {
auto test_case = [&](const std::vector<int64_t>& input_shape, int64_t axis) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<float>(input_shape, -5.f, 5.f);
auto* dq_output_arg = builder.MakeOutput();
auto* output_arg = builder.MakeOutput();
// add input QDQ
auto* input_q_output = builder.MakeIntermediate();
builder.AddQuantizeLinearNode<uint8_t>(input_arg,
.105f,
127,
input_q_output);
builder.AddDequantizeLinearNode<uint8_t>(input_q_output,
.105f,
127,
dq_output_arg);
// add Softmax
auto* softmax_output = builder.MakeIntermediate();
auto& softmax_node = builder.AddNode("Softmax", {dq_output_arg}, {softmax_output});
softmax_node.AddAttribute("axis", axis);
// add output QDQ
auto* q_output = builder.MakeIntermediate();
builder.AddQuantizeLinearNode<uint8_t>(softmax_output,
1.0f / (std::numeric_limits<uint8_t>::max() + 1),
0,
q_output);
builder.AddDequantizeLinearNode<uint8_t>(q_output,
1.0f / (std::numeric_limits<uint8_t>::max() + 1),
0,
output_arg);
};
auto check_graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
EXPECT_EQ(op_to_count["com.microsoft.QLinearSoftmax"], 0);
EXPECT_EQ(op_to_count["Softmax"], 1);
EXPECT_EQ(op_to_count["QuantizeLinear"], 2);
EXPECT_EQ(op_to_count["DequantizeLinear"], 2);
};
TransformerTester(build_test_case,
check_graph,
TransformerLevel::Level1,
TransformerLevel::Level2,
12 /*opset_version*/,
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
};
test_case({1, 12, 37}, -1);
}
} // namespace test
} // namespace onnxruntime