mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
Do not fuse DQ+Node+Q if DQ produces graph output (#14509)
Fix issue #14501
This commit is contained in:
parent
3d388a1aea
commit
d9e675a2af
2 changed files with 64 additions and 0 deletions
|
|
@ -51,6 +51,14 @@ bool NodeGroupSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Nod
|
|||
return false;
|
||||
}
|
||||
|
||||
auto does_node_produce_graph_output = [&graph_viewer](const Node* node_ptr) {
|
||||
return graph_viewer.NodeProducesGraphOutput(*node_ptr);
|
||||
};
|
||||
|
||||
if (std::any_of(dq_nodes.begin(), dq_nodes.end(), does_node_produce_graph_output)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (q_nodes.empty()) {
|
||||
return is_empty_q_nodes_allowed;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2760,5 +2760,61 @@ TEST(QDQTransformerTests, QDQFinalCleanupTransformer_GraphInputToOutput) {
|
|||
test_case(false);
|
||||
}
|
||||
|
||||
// Not fuse if DQ produces graph output
|
||||
TEST(QDQTransformerTests, DQ_Produce_Graph_Output) {
|
||||
auto test_case = [&](const std::vector<int64_t>& input_shape, int64_t axis) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<float>(input_shape, -5.f, 5.f);
|
||||
auto* dq_output_arg = builder.MakeOutput();
|
||||
auto* output_arg = builder.MakeOutput();
|
||||
// add input QDQ
|
||||
auto* input_q_output = builder.MakeIntermediate();
|
||||
builder.AddQuantizeLinearNode<uint8_t>(input_arg,
|
||||
.105f,
|
||||
127,
|
||||
input_q_output);
|
||||
builder.AddDequantizeLinearNode<uint8_t>(input_q_output,
|
||||
.105f,
|
||||
127,
|
||||
dq_output_arg);
|
||||
|
||||
// add Softmax
|
||||
auto* softmax_output = builder.MakeIntermediate();
|
||||
auto& softmax_node = builder.AddNode("Softmax", {dq_output_arg}, {softmax_output});
|
||||
softmax_node.AddAttribute("axis", axis);
|
||||
|
||||
// add output QDQ
|
||||
auto* q_output = builder.MakeIntermediate();
|
||||
builder.AddQuantizeLinearNode<uint8_t>(softmax_output,
|
||||
1.0f / (std::numeric_limits<uint8_t>::max() + 1),
|
||||
0,
|
||||
q_output);
|
||||
builder.AddDequantizeLinearNode<uint8_t>(q_output,
|
||||
1.0f / (std::numeric_limits<uint8_t>::max() + 1),
|
||||
0,
|
||||
output_arg);
|
||||
};
|
||||
|
||||
auto check_graph = [&](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
EXPECT_EQ(op_to_count["com.microsoft.QLinearSoftmax"], 0);
|
||||
EXPECT_EQ(op_to_count["Softmax"], 1);
|
||||
EXPECT_EQ(op_to_count["QuantizeLinear"], 2);
|
||||
EXPECT_EQ(op_to_count["DequantizeLinear"], 2);
|
||||
};
|
||||
|
||||
TransformerTester(build_test_case,
|
||||
check_graph,
|
||||
TransformerLevel::Level1,
|
||||
TransformerLevel::Level2,
|
||||
12 /*opset_version*/,
|
||||
0.01 /*per_sample_tolerance*/,
|
||||
0.01 /*relative_per_sample_tolerance*/,
|
||||
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
|
||||
};
|
||||
|
||||
test_case({1, 12, 37}, -1);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue