Check the count of DequantizeLinear for matmul (#7174)

This commit is contained in:
Yufeng Li 2021-03-30 09:00:48 -07:00 committed by GitHub
parent a01334ba56
commit 4f30341253
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 1 deletions

View file

@ -13,7 +13,7 @@ class QDQMatMulTransformer : public QDQOperatorTransformer {
QDQMatMulTransformer(Node& node, Graph& graph) : QDQOperatorTransformer(node, graph) {}
bool Transform(const std::vector<const Node*>& parents, const std::vector<const Node*>& children) override {
if (children.size() != 1) {
if (parents.size() != 2 || children.size() != 1) {
return false;
}

View file

@ -231,6 +231,39 @@ TEST(QDQTransformerTests, MatMul) {
test_case({22, 11, 13, 15}, {15, 13});
}
TEST(QDQTransformerTests, MatMul_No_Fusion) {
auto test_case = [&](const std::vector<int64_t>& input1_shape, const std::vector<int64_t>& input2_shape) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input1_arg = builder.MakeInput<float>(input1_shape, -1.f, 1.f);
auto* input2_arg = builder.MakeInput<float>(input2_shape, -1.f, 1.f);
auto* output_arg = builder.MakeOutput();
// add QDQ + MatMul
auto* matmul_output = builder.MakeIntermediate();
auto* dq_matmul_output1 = AddQDQNodePair<uint8_t>(builder, input1_arg, .004f, 129);
builder.AddNode("MatMul", {dq_matmul_output1, input2_arg}, {matmul_output});
// add Q
builder.AddQuantizeLinearNode<uint8_t>(matmul_output, .0039f, 135, output_arg);
};
auto check_matmul_graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
EXPECT_EQ(op_to_count["MatMul"], 1);
EXPECT_EQ(op_to_count["QLinearMatMul"], 0);
EXPECT_EQ(op_to_count["QuantizeLinear"], 2);
EXPECT_EQ(op_to_count["DequantizeLinear"], 1);
};
TransformerTester(build_test_case, check_matmul_graph, TransformerLevel::Level1, TransformerLevel::Level2);
};
// Test the basic case of a single 1D/2D/3D convolution.
test_case({12, 37}, {37, 12});
test_case({23, 13, 13}, {13, 13});
test_case({22, 11, 13, 15}, {15, 13});
}
#endif // DISABLE_CONTRIB_OPS
} // namespace test