From 1f3b675453e8412e5c084bfb95997967d0c2eec2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 11 Nov 2024 19:48:25 +0100 Subject: [PATCH] Fix MatMulBnFusion to exclude cases when tensors are not 2D tensors (#22762) ### Description Fixes #22512, MatMul, Add can be fused into a single Gemm even if tensors dimensions are > 2. The PR excludes that cases. ### Motivation and Context ORT crashes on valid models due to that unexpected fusion. --- .../core/optimizer/matmul_bn_fusion.cc | 17 ++++++++++ .../test/optimizer/graph_transform_test.cc | 29 ++++++++++++++++++ .../fuse-matmul-bn-directly-dont-fuse.onnx | Bin 0 -> 517 bytes 3 files changed, 46 insertions(+) create mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly-dont-fuse.onnx diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index e944522c9c..6b76dc626f 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -107,6 +107,22 @@ bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, cons return false; } + // Checks the first input of MatMul has 2 dimensions. + // The test for the second input is done in method Apply as it accesses the constant. + if (node.InputDefs()[0] == nullptr) { + // This should never happen but just in case. + return false; + } + auto shape_a = node.InputDefs()[0]->Shape(); + if (shape_a == nullptr) { + // We cannot shape the rank. It is better to avoid fusing. + return false; + } + if (shape_a->dim_size() != 2) { + // Gemm only supports 2D tensors. + return false; + } + // First output from BN is required. Others are optional. If any optional outputs exist we can't fuse. const auto& output_defs = batch_norm_node->OutputDefs(); if (output_defs.size() > 1) { @@ -165,6 +181,7 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& bias_tensor->dims_size() != 1 || mean_tensor->dims_size() != 1 || var_tensor->dims_size() != 1 || + matmul_b_tensor->dims_size() != 2 || scale_tensor->dims(0) != matmul_b_tensor->dims(1) || bias_tensor->dims(0) != matmul_b_tensor->dims(1) || mean_tensor->dims(0) != matmul_b_tensor->dims(1) || diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index ee3a1baade..67d60ea3a4 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1764,6 +1764,35 @@ TEST_F(GraphTransformationTests, FuseMatmulBNDirectly) { } } +TEST_F(GraphTransformationTests, DoNotApplyFuseMatmulBNDirectly) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-directly-dont-fuse.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "BatchNormalization") { + expected_output_name = node.OutputDefs()[0]->Name(); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 1); + ASSERT_EQ(op_to_count["MatMul"], 1); + ASSERT_EQ(op_to_count["Gemm"], 0); +} + TEST_F(GraphTransformationTests, FuseMatmulBNWithOnlyReshape) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-reshape.onnx"; diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly-dont-fuse.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly-dont-fuse.onnx new file mode 100644 index 0000000000000000000000000000000000000000..8ca8282572db8d474327e5f88d7e8a1f60c47e15 GIT binary patch literal 517 zcmd zSDarY#0wS3FD(JeE3x?|miU(DaQSngN^o%`<;52#C+4Jbu>)C2nTf? zq%5&Whz)9pkW*qwa)w`iQEp;RW>sPd&@n<{FpKkG&I7wutCoY6gH?c0DW&9@y}fs& zzrATnsojapNc+btLhLz8((JXXI_#F_bK0x1?6p(Ye{Q$n%?jIp7hc)TKdWx{F3r&H z+?8o_>{