mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Check the count of DequantizeLinear for matmul (#7174)
This commit is contained in:
parent
a01334ba56
commit
4f30341253
2 changed files with 34 additions and 1 deletions
|
|
@ -13,7 +13,7 @@ class QDQMatMulTransformer : public QDQOperatorTransformer {
|
|||
QDQMatMulTransformer(Node& node, Graph& graph) : QDQOperatorTransformer(node, graph) {}
|
||||
|
||||
bool Transform(const std::vector<const Node*>& parents, const std::vector<const Node*>& children) override {
|
||||
if (children.size() != 1) {
|
||||
if (parents.size() != 2 || children.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -231,6 +231,39 @@ TEST(QDQTransformerTests, MatMul) {
|
|||
test_case({22, 11, 13, 15}, {15, 13});
|
||||
}
|
||||
|
||||
TEST(QDQTransformerTests, MatMul_No_Fusion) {
|
||||
auto test_case = [&](const std::vector<int64_t>& input1_shape, const std::vector<int64_t>& input2_shape) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input1_arg = builder.MakeInput<float>(input1_shape, -1.f, 1.f);
|
||||
auto* input2_arg = builder.MakeInput<float>(input2_shape, -1.f, 1.f);
|
||||
auto* output_arg = builder.MakeOutput();
|
||||
|
||||
// add QDQ + MatMul
|
||||
auto* matmul_output = builder.MakeIntermediate();
|
||||
auto* dq_matmul_output1 = AddQDQNodePair<uint8_t>(builder, input1_arg, .004f, 129);
|
||||
builder.AddNode("MatMul", {dq_matmul_output1, input2_arg}, {matmul_output});
|
||||
|
||||
// add Q
|
||||
builder.AddQuantizeLinearNode<uint8_t>(matmul_output, .0039f, 135, output_arg);
|
||||
};
|
||||
|
||||
auto check_matmul_graph = [&](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
EXPECT_EQ(op_to_count["MatMul"], 1);
|
||||
EXPECT_EQ(op_to_count["QLinearMatMul"], 0);
|
||||
EXPECT_EQ(op_to_count["QuantizeLinear"], 2);
|
||||
EXPECT_EQ(op_to_count["DequantizeLinear"], 1);
|
||||
};
|
||||
|
||||
TransformerTester(build_test_case, check_matmul_graph, TransformerLevel::Level1, TransformerLevel::Level2);
|
||||
};
|
||||
|
||||
// Test the basic case of a single 1D/2D/3D convolution.
|
||||
test_case({12, 37}, {37, 12});
|
||||
test_case({23, 13, 13}, {13, 13});
|
||||
test_case({22, 11, 13, 15}, {15, 13});
|
||||
}
|
||||
|
||||
#endif // DISABLE_CONTRIB_OPS
|
||||
|
||||
} // namespace test
|
||||
|
|
|
|||
Loading…
Reference in a new issue